mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	Implement proper FK in MetaDocumentORM and MetaLabelORM to work on PostgreSQL (#1990)
				
					
				
			* Properly fix MetaDocumentORM and MetaLabelORM with composite foreign key constraints * update_document_meta() was not using index properly * Exclude ES and Memory from the cosine_sanity_check test * move ensure_ids_are_correct_uuids in conftest and move one test back to faiss & milvus suite Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
		
							parent
							
								
									3e4dbbb32c
								
							
						
					
					
						commit
						e28bf618d7
					
				| @ -380,7 +380,7 @@ Write annotation labels into document store. | ||||
| #### update\_document\_meta | ||||
| 
 | ||||
| ```python | ||||
|  | update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None) | ||||
|  | update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None) | ||||
| ``` | ||||
| 
 | ||||
| Update the metadata dictionary of a document by specifying its string id | ||||
| @ -952,7 +952,7 @@ class SQLDocumentStore(BaseDocumentStore) | ||||
| #### \_\_init\_\_ | ||||
| 
 | ||||
| ```python | ||||
|  | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False) | ||||
|  | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False, isolation_level: str = None) | ||||
| ``` | ||||
| 
 | ||||
| An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends. | ||||
| @ -970,6 +970,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac | ||||
|                             fail: an error is raised if the document ID of the document being added already | ||||
|                             exists. | ||||
| - `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior) | ||||
| - `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) | ||||
| 
 | ||||
| <a name="sql.SQLDocumentStore.get_document_by_id"></a> | ||||
| #### get\_document\_by\_id | ||||
| @ -1094,7 +1095,7 @@ Set vector IDs for all documents as None | ||||
| #### update\_document\_meta | ||||
| 
 | ||||
| ```python | ||||
|  | update_document_meta(id: str, meta: Dict[str, str]) | ||||
|  | update_document_meta(id: str, meta: Dict[str, str], index: str = None) | ||||
| ``` | ||||
| 
 | ||||
| Update the metadata dictionary of a document by specifying its string id | ||||
| @ -1202,7 +1203,7 @@ the vector embeddings are indexed in a FAISS Index. | ||||
| #### \_\_init\_\_ | ||||
| 
 | ||||
| ```python | ||||
|  | __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,) | ||||
|  | __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,) | ||||
| ``` | ||||
| 
 | ||||
| **Arguments**: | ||||
| @ -1248,6 +1249,7 @@ the vector embeddings are indexed in a FAISS Index. | ||||
|     If specified no other params besides faiss_config_path must be specified. | ||||
| - `faiss_config_path`: Stored FAISS initial configuration parameters. | ||||
|     Can be created via calling `save()` | ||||
| - `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) | ||||
| 
 | ||||
| <a name="faiss.FAISSDocumentStore.write_documents"></a> | ||||
| #### write\_documents | ||||
| @ -1479,7 +1481,7 @@ Usage: | ||||
| #### \_\_init\_\_ | ||||
| 
 | ||||
| ```python | ||||
|  | __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,) | ||||
|  | __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', isolation_level: str = None, **kwargs, ,) | ||||
| ``` | ||||
| 
 | ||||
| **Arguments**: | ||||
| @ -1525,6 +1527,7 @@ Note that an overly large index_file_size value may cause failure to load a segm | ||||
|                             overwrite: Update any existing documents with the same ID when adding documents. | ||||
|                             fail: an error is raised if the document ID of the document being added already | ||||
|                             exists. | ||||
| - `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) | ||||
| 
 | ||||
| <a name="milvus.MilvusDocumentStore.write_documents"></a> | ||||
| #### write\_documents | ||||
| @ -1863,7 +1866,7 @@ None | ||||
| #### update\_document\_meta | ||||
| 
 | ||||
| ```python | ||||
|  | update_document_meta(id: str, meta: Dict[str, str]) | ||||
|  | update_document_meta(id: str, meta: Dict[str, str], index: str = None) | ||||
| ``` | ||||
| 
 | ||||
| Update the metadata dictionary of a document by specifying its string id. | ||||
|  | ||||
| @ -551,12 +551,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore): | ||||
|         if labels_to_index: | ||||
|             bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers) | ||||
| 
 | ||||
|     def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None): | ||||
|     def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None): | ||||
|         """ | ||||
|         Update the metadata dictionary of a document by specifying its string id | ||||
|         """ | ||||
|         if not index: | ||||
|             index = self.index | ||||
|         body = {"doc": meta} | ||||
|         self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type, headers=headers) | ||||
|         self.client.update(index=index, id=id, body=body, refresh=self.refresh_type, headers=headers) | ||||
| 
 | ||||
|     def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None, | ||||
|                            only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int: | ||||
|  | ||||
| @ -50,6 +50,7 @@ class FAISSDocumentStore(SQLDocumentStore): | ||||
|         duplicate_documents: str = 'overwrite', | ||||
|         faiss_index_path: Union[str, Path] = None, | ||||
|         faiss_config_path: Union[str, Path] = None, | ||||
|         isolation_level: str = None, | ||||
|         **kwargs, | ||||
|     ): | ||||
|         """ | ||||
| @ -94,6 +95,7 @@ class FAISSDocumentStore(SQLDocumentStore): | ||||
|             If specified no other params besides faiss_config_path must be specified. | ||||
|         :param faiss_config_path: Stored FAISS initial configuration parameters. | ||||
|             Can be created via calling `save()` | ||||
|         :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) | ||||
|         """ | ||||
|         # special case if we want to load an existing index from disk | ||||
|         # load init params from disk and run init again | ||||
| @ -115,7 +117,8 @@ class FAISSDocumentStore(SQLDocumentStore): | ||||
|             index=index, | ||||
|             similarity=similarity, | ||||
|             embedding_field=embedding_field, | ||||
|             progress_bar=progress_bar | ||||
|             progress_bar=progress_bar, | ||||
|             isolation_level=isolation_level | ||||
|         ) | ||||
| 
 | ||||
|         if similarity in ("dot_product", "cosine"): | ||||
| @ -155,7 +158,8 @@ class FAISSDocumentStore(SQLDocumentStore): | ||||
|         super().__init__( | ||||
|             url=sql_url, | ||||
|             index=index, | ||||
|             duplicate_documents=duplicate_documents | ||||
|             duplicate_documents=duplicate_documents, | ||||
|             isolation_level=isolation_level | ||||
|         ) | ||||
| 
 | ||||
|         self._validate_index_sync() | ||||
|  | ||||
| @ -53,6 +53,7 @@ class MilvusDocumentStore(SQLDocumentStore): | ||||
|             embedding_field: str = "embedding", | ||||
|             progress_bar: bool = True, | ||||
|             duplicate_documents: str = 'overwrite', | ||||
|             isolation_level: str = None, | ||||
|             **kwargs, | ||||
|     ): | ||||
|         """ | ||||
| @ -97,6 +98,7 @@ class MilvusDocumentStore(SQLDocumentStore): | ||||
|                                     overwrite: Update any existing documents with the same ID when adding documents. | ||||
|                                     fail: an error is raised if the document ID of the document being added already | ||||
|                                     exists. | ||||
|         :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) | ||||
|         """ | ||||
|         # save init parameters to enable export of component config as YAML | ||||
|         self.set_config( | ||||
| @ -104,6 +106,7 @@ class MilvusDocumentStore(SQLDocumentStore): | ||||
|             embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param, | ||||
|             search_param=search_param, duplicate_documents=duplicate_documents, | ||||
|             return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar, | ||||
|             isolation_level=isolation_level | ||||
|         ) | ||||
| 
 | ||||
|         self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool) | ||||
| @ -139,7 +142,8 @@ class MilvusDocumentStore(SQLDocumentStore): | ||||
|         super().__init__( | ||||
|             url=sql_url, | ||||
|             index=index, | ||||
|             duplicate_documents=duplicate_documents | ||||
|             duplicate_documents=duplicate_documents, | ||||
|             isolation_level=isolation_level, | ||||
|         ) | ||||
| 
 | ||||
|     def __del__(self): | ||||
|  | ||||
| @ -73,6 +73,7 @@ class Milvus2DocumentStore(SQLDocumentStore): | ||||
|             custom_fields: Optional[List[Any]] = None, | ||||
|             progress_bar: bool = True, | ||||
|             duplicate_documents: str = 'overwrite', | ||||
|             isolation_level: str = None | ||||
|     ): | ||||
|         """ | ||||
|         :param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale | ||||
| @ -118,6 +119,7 @@ class Milvus2DocumentStore(SQLDocumentStore): | ||||
|                                     overwrite: Update any existing documents with the same ID when adding documents. | ||||
|                                     fail: an error is raised if the document ID of the document being added already | ||||
|                                     exists. | ||||
|         :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) | ||||
|         """ | ||||
| 
 | ||||
|         # save init parameters to enable export of component config as YAML | ||||
| @ -127,6 +129,7 @@ class Milvus2DocumentStore(SQLDocumentStore): | ||||
|             search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field, | ||||
|             return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar, | ||||
|             custom_fields=custom_fields, | ||||
|             isolation_level=isolation_level | ||||
|         ) | ||||
| 
 | ||||
|         logger.warning("Milvus2DocumentStore is in experimental state until Milvus 2.0 is released") | ||||
| @ -173,7 +176,8 @@ class Milvus2DocumentStore(SQLDocumentStore): | ||||
|         super().__init__( | ||||
|             url=sql_url, | ||||
|             index=index, | ||||
|             duplicate_documents=duplicate_documents | ||||
|             duplicate_documents=duplicate_documents, | ||||
|             isolation_level=isolation_level, | ||||
|         ) | ||||
| 
 | ||||
|     def _create_collection_and_index_if_not_exist( | ||||
|  | ||||
| @ -4,7 +4,7 @@ import logging | ||||
| import itertools | ||||
| import numpy as np | ||||
| from uuid import uuid4 | ||||
| from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON | ||||
| from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON, ForeignKeyConstraint | ||||
| from sqlalchemy.ext.declarative import declarative_base | ||||
| from sqlalchemy.orm import relationship, sessionmaker | ||||
| from sqlalchemy.sql import case, null | ||||
| @ -33,9 +33,6 @@ class DocumentORM(ORMBase): | ||||
|     # primary key in combination with id to allow the same doc in different indices | ||||
|     index = Column(String(100), nullable=False, primary_key=True) | ||||
|     vector_id = Column(String(100), unique=True, nullable=True) | ||||
| 
 | ||||
|     # labels = relationship("LabelORM", back_populates="document") | ||||
| 
 | ||||
|     # speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata | ||||
|     meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined") | ||||
| 
 | ||||
| @ -45,36 +42,18 @@ class MetaDocumentORM(ORMBase): | ||||
| 
 | ||||
|     name = Column(String(100), index=True) | ||||
|     value = Column(String(1000), index=True) | ||||
|     document_id = Column( | ||||
|         String(100), | ||||
|         ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), | ||||
|         nullable=False, | ||||
|         index=True | ||||
|     ) | ||||
| 
 | ||||
|     documents = relationship("DocumentORM", back_populates="meta") | ||||
| 
 | ||||
| 
 | ||||
| class MetaLabelORM(ORMBase): | ||||
|     __tablename__ = "meta_label" | ||||
| 
 | ||||
|     name = Column(String(100), index=True) | ||||
|     value = Column(String(1000), index=True) | ||||
|     label_id = Column( | ||||
|         String(100), | ||||
|         ForeignKey("label.id", ondelete="CASCADE", onupdate="CASCADE"), | ||||
|         nullable=False, | ||||
|         index=True | ||||
|     ) | ||||
| 
 | ||||
|     labels = relationship("LabelORM", back_populates="meta") | ||||
|     document_id = Column(String(100), nullable=False, index=True) | ||||
|     document_index = Column(String(100), nullable=False, index=True) | ||||
|     __table_args__ = (ForeignKeyConstraint([document_id, document_index], | ||||
|                                            [DocumentORM.id, DocumentORM.index], | ||||
|                                            ondelete="CASCADE", onupdate="CASCADE"), {})  #type: ignore | ||||
| 
 | ||||
| 
 | ||||
| class LabelORM(ORMBase): | ||||
|     __tablename__ = "label" | ||||
| 
 | ||||
|     # document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False) | ||||
| 
 | ||||
|     index = Column(String(100), nullable=False, primary_key=True) | ||||
|     query = Column(Text, nullable=False) | ||||
|     answer = Column(JSON, nullable=True) | ||||
| @ -86,7 +65,21 @@ class LabelORM(ORMBase): | ||||
|     pipeline_id = Column(String(500), nullable=True) | ||||
| 
 | ||||
|     meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined") | ||||
|     # document = relationship("DocumentORM", back_populates="labels") | ||||
| 
 | ||||
| 
 | ||||
| class MetaLabelORM(ORMBase): | ||||
|     __tablename__ = "meta_label" | ||||
| 
 | ||||
|     name = Column(String(100), index=True) | ||||
|     value = Column(String(1000), index=True) | ||||
|     labels = relationship("LabelORM", back_populates="meta") | ||||
| 
 | ||||
|     label_id = Column(String(100), nullable=False, index=True) | ||||
|     label_index = Column(String(100), nullable=False, index=True) | ||||
|     __table_args__ = (ForeignKeyConstraint([label_id, label_index], | ||||
|                                            [LabelORM.id, LabelORM.index], | ||||
|                                            ondelete="CASCADE", onupdate="CASCADE"), {})  #type: ignore | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class SQLDocumentStore(BaseDocumentStore): | ||||
| @ -96,7 +89,8 @@ class SQLDocumentStore(BaseDocumentStore): | ||||
|         index: str = "document", | ||||
|         label_index: str = "label", | ||||
|         duplicate_documents: str = "overwrite", | ||||
|         check_same_thread: bool = False | ||||
|         check_same_thread: bool = False, | ||||
|         isolation_level: str = None | ||||
|     ): | ||||
|         """ | ||||
|         An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends. | ||||
| @ -112,18 +106,21 @@ class SQLDocumentStore(BaseDocumentStore): | ||||
|                                     fail: an error is raised if the document ID of the document being added already | ||||
|                                     exists. | ||||
|         :param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)  | ||||
|         :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) | ||||
|         """ | ||||
| 
 | ||||
|         # save init parameters to enable export of component config as YAML | ||||
|         self.set_config( | ||||
|                 url=url, index=index, label_index=label_index, duplicate_documents=duplicate_documents, check_same_thread=check_same_thread | ||||
|         ) | ||||
| 
 | ||||
|         create_engine_params = {} | ||||
|         if isolation_level: | ||||
|             create_engine_params["isolation_level"] = isolation_level | ||||
|         if "sqlite" in url: | ||||
|             engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}) | ||||
|             engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}, **create_engine_params) | ||||
|         else: | ||||
|             engine = create_engine(url) | ||||
|         ORMBase.metadata.create_all(engine) | ||||
|             engine = create_engine(url, **create_engine_params) | ||||
|         Base.metadata.create_all(engine) | ||||
|         Session = sessionmaker(bind=engine) | ||||
|         self.session = Session() | ||||
|         self.index: str = index | ||||
| @ -461,12 +458,14 @@ class SQLDocumentStore(BaseDocumentStore): | ||||
|         self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()}) | ||||
|         self.session.commit() | ||||
| 
 | ||||
|     def update_document_meta(self, id: str, meta: Dict[str, str]): | ||||
|     def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None): | ||||
|         """ | ||||
|         Update the metadata dictionary of a document by specifying its string id | ||||
|         """ | ||||
|         self.session.query(MetaDocumentORM).filter_by(document_id=id).delete() | ||||
|         meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id) for key, value in meta.items()] | ||||
|         if not index: | ||||
|             index = self.index | ||||
|         self.session.query(MetaDocumentORM).filter_by(document_id=id, document_index=index).delete() | ||||
|         meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id, document_index=index) for key, value in meta.items()] | ||||
|         for m in meta_orms: | ||||
|             self.session.add(m) | ||||
|         self.session.commit() | ||||
|  | ||||
| @ -486,11 +486,13 @@ class WeaviateDocumentStore(BaseDocumentStore): | ||||
|                 progress_bar.update(batch_size) | ||||
|         progress_bar.close() | ||||
| 
 | ||||
|     def update_document_meta(self, id: str, meta: Dict[str, str]): | ||||
|     def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None): | ||||
|         """ | ||||
|         Update the metadata dictionary of a document by specifying its string id. | ||||
|         """ | ||||
|         self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id) | ||||
|         if not index: | ||||
|             index = self.index | ||||
|         self.weaviate_client.data_object.update(meta, class_name=index, uuid=id) | ||||
| 
 | ||||
|     def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int: | ||||
|         """ | ||||
|  | ||||
							
								
								
									
										109
									
								
								test/conftest.py
									
									
									
									
									
								
							
							
						
						
									
										109
									
								
								test/conftest.py
									
									
									
									
									
								
							| @ -3,6 +3,9 @@ import time | ||||
| from subprocess import run | ||||
| from sys import platform | ||||
| import gc | ||||
| import uuid | ||||
| import logging | ||||
| from sqlalchemy import create_engine, text | ||||
| 
 | ||||
| import numpy as np | ||||
| import psutil | ||||
| @ -40,6 +43,11 @@ from haystack.nodes.translator import TransformersTranslator | ||||
| from haystack.nodes.question_generator import QuestionGenerator | ||||
| 
 | ||||
| 
 | ||||
| # To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below | ||||
| SQL_TYPE = "sqlite" | ||||
| # SQL_TYPE = "postgres" | ||||
| 
 | ||||
| 
 | ||||
| def pytest_addoption(parser): | ||||
|     parser.addoption("--document_store_type", action="store", default="elasticsearch, faiss, memory, milvus, weaviate") | ||||
| 
 | ||||
| @ -477,58 +485,112 @@ def get_retriever(retriever_type, document_store): | ||||
|     return retriever | ||||
| 
 | ||||
| 
 | ||||
| def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None: | ||||
|     # Weaviate currently only supports UUIDs | ||||
|     if type(document_store)==WeaviateDocumentStore: | ||||
|         for d in docs: | ||||
|             d["id"] = str(uuid.uuid4()) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"]) | ||||
| def document_store_with_docs(request, test_docs_xs): | ||||
| def document_store_with_docs(request, test_docs_xs, tmp_path): | ||||
|     embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) | ||||
|     document_store = get_document_store(request.param, embedding_dim.args[0]) | ||||
|     document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path) | ||||
|     document_store.write_documents(test_docs_xs) | ||||
|     yield document_store | ||||
|     document_store.delete_documents() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def document_store(request): | ||||
| def document_store(request, tmp_path): | ||||
|     embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) | ||||
|     document_store = get_document_store(request.param, embedding_dim.args[0]) | ||||
|     document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path) | ||||
|     yield document_store | ||||
|     document_store.delete_documents() | ||||
| 
 | ||||
| @pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"]) | ||||
| def document_store_dot_product(request): | ||||
| def document_store_dot_product(request, tmp_path): | ||||
|     embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) | ||||
|     document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") | ||||
|     document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path) | ||||
|     yield document_store | ||||
|     document_store.delete_documents() | ||||
| 
 | ||||
| @pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"]) | ||||
| def document_store_dot_product_with_docs(request, test_docs_xs): | ||||
| def document_store_dot_product_with_docs(request, test_docs_xs, tmp_path): | ||||
|     embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) | ||||
|     document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") | ||||
|     document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path) | ||||
|     document_store.write_documents(test_docs_xs) | ||||
|     yield document_store | ||||
|     document_store.delete_documents() | ||||
| 
 | ||||
| @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"]) | ||||
| def document_store_dot_product_small(request): | ||||
| def document_store_dot_product_small(request, tmp_path): | ||||
|     embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3)) | ||||
|     document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") | ||||
|     document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path) | ||||
|     yield document_store | ||||
|     document_store.delete_documents() | ||||
| 
 | ||||
| @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"]) | ||||
| def document_store_small(request): | ||||
| def document_store_small(request, tmp_path): | ||||
|     embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3)) | ||||
|     document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine") | ||||
|     document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="cosine", tmp_path=tmp_path) | ||||
|     yield document_store | ||||
|     document_store.delete_documents() | ||||
| 
 | ||||
| def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate | ||||
| 
 | ||||
| @pytest.fixture(scope="function", autouse=True) | ||||
| def postgres_fixture(): | ||||
|     if SQL_TYPE == "postgres": | ||||
|         setup_postgres() | ||||
|         yield | ||||
|         teardown_postgres() | ||||
|     else: | ||||
|         yield | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def sql_url(tmp_path): | ||||
|     return  get_sql_url(tmp_path) | ||||
| 
 | ||||
| 
 | ||||
| def get_sql_url(tmp_path): | ||||
|     if SQL_TYPE == "postgres": | ||||
|         return "postgresql://postgres:postgres@127.0.0.1/postgres" | ||||
|     else: | ||||
|         return f"sqlite:///{tmp_path}/haystack_test.db" | ||||
| 
 | ||||
| 
 | ||||
| def setup_postgres(): | ||||
|     # status = subprocess.run(["docker run --name postgres_test -d -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres"], shell=True) | ||||
|     # if status.returncode: | ||||
|     #     logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.") | ||||
|     # else: | ||||
|     #     sleep(5) | ||||
|     engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT') | ||||
| 
 | ||||
|     with engine.connect() as connection: | ||||
|         try: | ||||
|             connection.execute(text('DROP SCHEMA public CASCADE')) | ||||
|         except Exception as e: | ||||
|             logging.error(e) | ||||
|         connection.execute(text('CREATE SCHEMA public;')) | ||||
|         connection.execute(text('SET SESSION idle_in_transaction_session_timeout = "1s";')) | ||||
| 
 | ||||
|          | ||||
| def teardown_postgres(): | ||||
|     engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT') | ||||
|     with engine.connect() as connection: | ||||
|         connection.execute(text('DROP SCHEMA public CASCADE')) | ||||
|         connection.close() | ||||
| 
 | ||||
| 
 | ||||
| def get_document_store(document_store_type, tmp_path, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate | ||||
|     if document_store_type == "sql": | ||||
|         document_store = SQLDocumentStore(url="sqlite://", index=index) | ||||
|         document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT") | ||||
| 
 | ||||
|     elif document_store_type == "memory": | ||||
|         document_store = InMemoryDocumentStore( | ||||
|             return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity | ||||
|         ) | ||||
|             return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity) | ||||
|          | ||||
|     elif document_store_type == "elasticsearch": | ||||
|         # make sure we start from a fresh index | ||||
|         client = Elasticsearch() | ||||
| @ -536,28 +598,33 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field=" | ||||
|         document_store = ElasticsearchDocumentStore( | ||||
|             index=index, return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, similarity=similarity | ||||
|         ) | ||||
| 
 | ||||
|     elif document_store_type == "faiss": | ||||
|         document_store = FAISSDocumentStore( | ||||
|             embedding_dim=embedding_dim, | ||||
|             sql_url="sqlite://", | ||||
|             sql_url=get_sql_url(tmp_path), | ||||
|             return_embedding=True, | ||||
|             embedding_field=embedding_field, | ||||
|             index=index, | ||||
|             similarity=similarity | ||||
|             similarity=similarity, | ||||
|             isolation_level="AUTOCOMMIT" | ||||
|         ) | ||||
| 
 | ||||
|     elif document_store_type == "milvus": | ||||
|         document_store = MilvusDocumentStore( | ||||
|             embedding_dim=embedding_dim, | ||||
|             sql_url="sqlite://", | ||||
|             sql_url=get_sql_url(tmp_path), | ||||
|             return_embedding=True, | ||||
|             embedding_field=embedding_field, | ||||
|             index=index, | ||||
|             similarity=similarity | ||||
|             similarity=similarity, | ||||
|             isolation_level="AUTOCOMMIT" | ||||
|         ) | ||||
|         _, collections = document_store.milvus_server.list_collections() | ||||
|         for collection in collections: | ||||
|             if collection.startswith(index): | ||||
|                 document_store.milvus_server.drop_collection(collection) | ||||
|      | ||||
|     elif document_store_type == "weaviate": | ||||
|         document_store = WeaviateDocumentStore( | ||||
|             weaviate_url="http://localhost:8080", | ||||
|  | ||||
| @ -1,4 +1,6 @@ | ||||
| from unittest import mock | ||||
| import uuid | ||||
| import math | ||||
| import numpy as np | ||||
| import pandas as pd | ||||
| import pytest | ||||
| @ -7,7 +9,7 @@ from elasticsearch import Elasticsearch | ||||
| from elasticsearch.exceptions import RequestError | ||||
| 
 | ||||
| 
 | ||||
| from conftest import get_document_store | ||||
| from conftest import get_document_store, ensure_ids_are_correct_uuids | ||||
| from haystack.document_stores import WeaviateDocumentStore | ||||
| from haystack.document_stores.base import BaseDocumentStore | ||||
| from haystack.errors import DuplicateDocumentError | ||||
| @ -17,6 +19,18 @@ from haystack.document_stores.faiss import FAISSDocumentStore | ||||
| from haystack.nodes import EmbeddingRetriever | ||||
| from haystack.pipelines import DocumentSearchPipeline | ||||
| 
 | ||||
| 
 | ||||
| DOCUMENTS = [ | ||||
|     {"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, | ||||
|     {"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, | ||||
|     {"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)}, | ||||
|     {"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)}, | ||||
|     {"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)}, | ||||
|     {"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)}, | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.elasticsearch | ||||
| def test_init_elastic_client(): | ||||
|     # defaults | ||||
| @ -148,8 +162,8 @@ def test_get_all_documents_with_correct_filters(document_store_with_docs): | ||||
|     assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"} | ||||
| 
 | ||||
| 
 | ||||
| def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs): | ||||
|     document_store_with_docs = get_document_store("sql") | ||||
| def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs, tmp_path): | ||||
|     document_store_with_docs = get_document_store("sql", tmp_path) | ||||
|     document_store_with_docs.write_documents(test_docs_xs) | ||||
| 
 | ||||
|     document_store_with_docs.use_windowed_query = False | ||||
| @ -791,7 +805,7 @@ def test_multilabel_no_answer(document_store): | ||||
|     assert len(multi_labels[0].answers) == 1 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss"], indirect=True) | ||||
| @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True) | ||||
| # Currently update_document_meta() is not implemented for Memory doc store | ||||
| def test_update_meta(document_store): | ||||
|     documents = [ | ||||
| @ -818,10 +832,8 @@ def test_update_meta(document_store): | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"]) | ||||
| def test_custom_embedding_field(document_store_type): | ||||
|     document_store = get_document_store( | ||||
|         document_store_type=document_store_type, embedding_field="custom_embedding_field" | ||||
|     ) | ||||
| def test_custom_embedding_field(document_store_type, tmp_path): | ||||
|     document_store = get_document_store(document_store_type=document_store_type, tmp_path=tmp_path, embedding_field="custom_embedding_field") | ||||
|     doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)} | ||||
|     document_store.write_documents([doc_to_write]) | ||||
|     documents = document_store.get_all_documents(return_embedding=True) | ||||
| @ -993,4 +1005,5 @@ def test_custom_headers(document_store_with_docs: BaseDocumentStore): | ||||
|         args, kwargs = mock_client.search.call_args | ||||
|         assert "headers" in kwargs | ||||
|         assert kwargs["headers"] == custom_headers | ||||
|         assert len(documents) > 0 | ||||
|         assert len(documents) > 0 | ||||
| 
 | ||||
|  | ||||
| @ -13,6 +13,9 @@ from haystack.document_stores.weaviate import WeaviateDocumentStore | ||||
| from haystack.pipelines import Pipeline | ||||
| from haystack.nodes.retriever.dense import EmbeddingRetriever | ||||
| 
 | ||||
| from conftest import ensure_ids_are_correct_uuids | ||||
| 
 | ||||
| 
 | ||||
| DOCUMENTS = [ | ||||
|     {"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, | ||||
|     {"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, | ||||
| @ -23,12 +26,14 @@ DOCUMENTS = [ | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner") | ||||
| def test_faiss_index_save_and_load(tmp_path): | ||||
| def test_faiss_index_save_and_load(tmp_path, sql_url): | ||||
|     document_store = FAISSDocumentStore( | ||||
|         sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", | ||||
|         sql_url=sql_url, | ||||
|         index="haystack_test", | ||||
|         progress_bar=False  # Just to check if the init parameters are kept | ||||
|         progress_bar=False,  # Just to check if the init parameters are kept | ||||
|         isolation_level="AUTOCOMMIT" | ||||
|     ) | ||||
|     document_store.write_documents(DOCUMENTS) | ||||
| 
 | ||||
| @ -74,11 +79,12 @@ def test_faiss_index_save_and_load(tmp_path): | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner") | ||||
| def test_faiss_index_save_and_load_custom_path(tmp_path): | ||||
| def test_faiss_index_save_and_load_custom_path(tmp_path, sql_url): | ||||
|     document_store = FAISSDocumentStore( | ||||
|         sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", | ||||
|         sql_url=sql_url, | ||||
|         index="haystack_test", | ||||
|         progress_bar=False  # Just to check if the init parameters are kept | ||||
|         progress_bar=False,  # Just to check if the init parameters are kept | ||||
|         isolation_level="AUTOCOMMIT" | ||||
|     ) | ||||
|     document_store.write_documents(DOCUMENTS) | ||||
| 
 | ||||
| @ -128,13 +134,15 @@ def test_faiss_index_mutual_exclusive_args(tmp_path): | ||||
|     with pytest.raises(ValueError): | ||||
|         FAISSDocumentStore( | ||||
|             sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", | ||||
|             faiss_index_path=f"{tmp_path/'haystack_test'}" | ||||
|             faiss_index_path=f"{tmp_path/'haystack_test'}", | ||||
|             isolation_level="AUTOCOMMIT" | ||||
|         ) | ||||
| 
 | ||||
|     with pytest.raises(ValueError): | ||||
|         FAISSDocumentStore( | ||||
|             f"sqlite:////{tmp_path/'haystack_test.db'}", | ||||
|             faiss_index_path=f"{tmp_path/'haystack_test'}" | ||||
|             faiss_index_path=f"{tmp_path/'haystack_test'}", | ||||
|             isolation_level="AUTOCOMMIT" | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @ -227,7 +235,9 @@ def test_update_with_empty_store(document_store, retriever): | ||||
| @pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"]) | ||||
| def test_faiss_retrieving(index_factory, tmp_path): | ||||
|     document_store = FAISSDocumentStore( | ||||
|         sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", faiss_index_factory_str=index_factory | ||||
|         sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}",  | ||||
|         faiss_index_factory_str=index_factory, | ||||
|         isolation_level="AUTOCOMMIT" | ||||
|     ) | ||||
| 
 | ||||
|     document_store.delete_all_documents(index="document") | ||||
| @ -396,7 +406,6 @@ def test_get_docs_with_many_filters(document_store, retriever): | ||||
|     assert "2020" == documents[0].meta["year"] | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("retriever", ["embedding"], indirect=True) | ||||
| @pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) | ||||
| def test_pipeline(document_store, retriever): | ||||
| @ -423,7 +432,9 @@ def test_faiss_passing_index_from_outside(tmp_path): | ||||
|     faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable) | ||||
|     faiss_index.nprobe = 2 | ||||
|     document_store = FAISSDocumentStore( | ||||
|         sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", faiss_index=faiss_index, index=index | ||||
|         sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}",  | ||||
|         faiss_index=faiss_index, index=index, | ||||
|         isolation_level="AUTOCOMMIT" | ||||
|     ) | ||||
| 
 | ||||
|     document_store.delete_documents() | ||||
| @ -437,12 +448,8 @@ def test_faiss_passing_index_from_outside(tmp_path): | ||||
|     for doc in documents_indexed: | ||||
|         assert 0 <= int(doc.meta["vector_id"]) <= 7 | ||||
| 
 | ||||
| def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None: | ||||
|     # Weaviate currently only supports UUIDs | ||||
|     if type(document_store)==WeaviateDocumentStore: | ||||
|         for d in docs: | ||||
|             d["id"] = str(uuid.uuid4()) | ||||
| 
 | ||||
| @pytest.mark.parametrize("document_store", ["faiss", "milvus", "weaviate"], indirect=True) | ||||
| def test_cosine_similarity(document_store): | ||||
|     # below we will write documents to the store and then query it to see if vectors were normalized | ||||
| 
 | ||||
| @ -487,6 +494,7 @@ def test_cosine_similarity(document_store): | ||||
|         assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("document_store_dot_product_small", ["faiss", "milvus"], indirect=True) | ||||
| def test_normalize_embeddings_diff_shapes(document_store_dot_product_small): | ||||
|     VEC_1 = np.array([.1, .2, .3], dtype="float32") | ||||
|     document_store_dot_product_small.normalize_embedding(VEC_1) | ||||
| @ -497,6 +505,7 @@ def test_normalize_embeddings_diff_shapes(document_store_dot_product_small): | ||||
|     assert np.linalg.norm(VEC_1) - 1 < 0.01 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("document_store_small", ["faiss", "milvus", "weaviate"], indirect=True) | ||||
| def test_cosine_sanity_check(document_store_small): | ||||
|     VEC_1 = np.array([.1, .2, .3], dtype="float32") | ||||
|     VEC_2 = np.array([.4, .5, .6], dtype="float32") | ||||
| @ -512,4 +521,4 @@ def test_cosine_sanity_check(document_store_small): | ||||
|     query_results = document_store_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True) | ||||
| 
 | ||||
|     # check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318 | ||||
|     assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002) | ||||
|     assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002) | ||||
| @ -30,16 +30,16 @@ DOCUMENTS_XS = [ | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(params=["weaviate"]) | ||||
| def document_store_with_docs(request): | ||||
|     document_store = get_document_store(request.param) | ||||
| def document_store_with_docs(request, tmp_path): | ||||
|     document_store = get_document_store(request.param, tmp_path=tmp_path) | ||||
|     document_store.write_documents(DOCUMENTS_XS) | ||||
|     yield document_store | ||||
|     document_store.delete_documents() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(params=["weaviate"]) | ||||
| def document_store(request): | ||||
|     document_store = get_document_store(request.param) | ||||
| def document_store(request, tmp_path): | ||||
|     document_store = get_document_store(request.param, tmp_path=tmp_path) | ||||
|     yield document_store | ||||
|     document_store.delete_documents() | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sara Zan
						Sara Zan