diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md index 7ad468855..93d9f3564 100644 --- a/docs/_src/api/api/document_store.md +++ b/docs/_src/api/api/document_store.md @@ -380,7 +380,7 @@ Write annotation labels into document store. #### update\_document\_meta ```python - | update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None) + | update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None) ``` Update the metadata dictionary of a document by specifying its string id @@ -952,7 +952,7 @@ class SQLDocumentStore(BaseDocumentStore) #### \_\_init\_\_ ```python - | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False) + | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False, isolation_level: str = None) ``` An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends. @@ -970,6 +970,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac fail: an error is raised if the document ID of the document being added already exists. - `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior) +- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) #### get\_document\_by\_id @@ -1094,7 +1095,7 @@ Set vector IDs for all documents as None #### update\_document\_meta ```python - | update_document_meta(id: str, meta: Dict[str, str]) + | update_document_meta(id: str, meta: Dict[str, str], index: str = None) ``` Update the metadata dictionary of a document by specifying its string id @@ -1202,7 +1203,7 @@ the vector embeddings are indexed in a FAISS Index. #### \_\_init\_\_ ```python - | __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,) + | __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,) ``` **Arguments**: @@ -1248,6 +1249,7 @@ the vector embeddings are indexed in a FAISS Index. If specified no other params besides faiss_config_path must be specified. - `faiss_config_path`: Stored FAISS initial configuration parameters. Can be created via calling `save()` +- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) #### write\_documents @@ -1479,7 +1481,7 @@ Usage: #### \_\_init\_\_ ```python - | __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,) + | __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', isolation_level: str = None, **kwargs, ,) ``` **Arguments**: @@ -1525,6 +1527,7 @@ Note that an overly large index_file_size value may cause failure to load a segm overwrite: Update any existing documents with the same ID when adding documents. fail: an error is raised if the document ID of the document being added already exists. +- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) #### write\_documents @@ -1863,7 +1866,7 @@ None #### update\_document\_meta ```python - | update_document_meta(id: str, meta: Dict[str, str]) + | update_document_meta(id: str, meta: Dict[str, str], index: str = None) ``` Update the metadata dictionary of a document by specifying its string id. diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py index 6e3951bba..4a4072004 100644 --- a/haystack/document_stores/elasticsearch.py +++ b/haystack/document_stores/elasticsearch.py @@ -551,12 +551,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore): if labels_to_index: bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers) - def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None): + def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None): """ Update the metadata dictionary of a document by specifying its string id """ + if not index: + index = self.index body = {"doc": meta} - self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type, headers=headers) + self.client.update(index=index, id=id, body=body, refresh=self.refresh_type, headers=headers) def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None, only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int: diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index b4fa1e80d..1702330c8 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -50,6 +50,7 @@ class FAISSDocumentStore(SQLDocumentStore): duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, + isolation_level: str = None, **kwargs, ): """ @@ -94,6 +95,7 @@ class FAISSDocumentStore(SQLDocumentStore): If specified no other params besides faiss_config_path must be specified. :param faiss_config_path: Stored FAISS initial configuration parameters. Can be created via calling `save()` + :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) """ # special case if we want to load an existing index from disk # load init params from disk and run init again @@ -115,7 +117,8 @@ class FAISSDocumentStore(SQLDocumentStore): index=index, similarity=similarity, embedding_field=embedding_field, - progress_bar=progress_bar + progress_bar=progress_bar, + isolation_level=isolation_level ) if similarity in ("dot_product", "cosine"): @@ -155,7 +158,8 @@ class FAISSDocumentStore(SQLDocumentStore): super().__init__( url=sql_url, index=index, - duplicate_documents=duplicate_documents + duplicate_documents=duplicate_documents, + isolation_level=isolation_level ) self._validate_index_sync() diff --git a/haystack/document_stores/milvus.py b/haystack/document_stores/milvus.py index cce2e1279..4eeec700b 100644 --- a/haystack/document_stores/milvus.py +++ b/haystack/document_stores/milvus.py @@ -53,6 +53,7 @@ class MilvusDocumentStore(SQLDocumentStore): embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', + isolation_level: str = None, **kwargs, ): """ @@ -97,6 +98,7 @@ class MilvusDocumentStore(SQLDocumentStore): overwrite: Update any existing documents with the same ID when adding documents. fail: an error is raised if the document ID of the document being added already exists. + :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) """ # save init parameters to enable export of component config as YAML self.set_config( @@ -104,6 +106,7 @@ class MilvusDocumentStore(SQLDocumentStore): embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param, search_param=search_param, duplicate_documents=duplicate_documents, return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar, + isolation_level=isolation_level ) self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool) @@ -139,7 +142,8 @@ class MilvusDocumentStore(SQLDocumentStore): super().__init__( url=sql_url, index=index, - duplicate_documents=duplicate_documents + duplicate_documents=duplicate_documents, + isolation_level=isolation_level, ) def __del__(self): diff --git a/haystack/document_stores/milvus2x.py b/haystack/document_stores/milvus2x.py index 163198a3f..4ea662661 100644 --- a/haystack/document_stores/milvus2x.py +++ b/haystack/document_stores/milvus2x.py @@ -73,6 +73,7 @@ class Milvus2DocumentStore(SQLDocumentStore): custom_fields: Optional[List[Any]] = None, progress_bar: bool = True, duplicate_documents: str = 'overwrite', + isolation_level: str = None ): """ :param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale @@ -118,6 +119,7 @@ class Milvus2DocumentStore(SQLDocumentStore): overwrite: Update any existing documents with the same ID when adding documents. fail: an error is raised if the document ID of the document being added already exists. + :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) """ # save init parameters to enable export of component config as YAML @@ -127,6 +129,7 @@ class Milvus2DocumentStore(SQLDocumentStore): search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field, return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar, custom_fields=custom_fields, + isolation_level=isolation_level ) logger.warning("Milvus2DocumentStore is in experimental state until Milvus 2.0 is released") @@ -173,7 +176,8 @@ class Milvus2DocumentStore(SQLDocumentStore): super().__init__( url=sql_url, index=index, - duplicate_documents=duplicate_documents + duplicate_documents=duplicate_documents, + isolation_level=isolation_level, ) def _create_collection_and_index_if_not_exist( diff --git a/haystack/document_stores/sql.py b/haystack/document_stores/sql.py index 666870238..7ef5db791 100644 --- a/haystack/document_stores/sql.py +++ b/haystack/document_stores/sql.py @@ -4,7 +4,7 @@ import logging import itertools import numpy as np from uuid import uuid4 -from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON +from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON, ForeignKeyConstraint from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.sql import case, null @@ -33,9 +33,6 @@ class DocumentORM(ORMBase): # primary key in combination with id to allow the same doc in different indices index = Column(String(100), nullable=False, primary_key=True) vector_id = Column(String(100), unique=True, nullable=True) - - # labels = relationship("LabelORM", back_populates="document") - # speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined") @@ -45,36 +42,18 @@ class MetaDocumentORM(ORMBase): name = Column(String(100), index=True) value = Column(String(1000), index=True) - document_id = Column( - String(100), - ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), - nullable=False, - index=True - ) - documents = relationship("DocumentORM", back_populates="meta") - -class MetaLabelORM(ORMBase): - __tablename__ = "meta_label" - - name = Column(String(100), index=True) - value = Column(String(1000), index=True) - label_id = Column( - String(100), - ForeignKey("label.id", ondelete="CASCADE", onupdate="CASCADE"), - nullable=False, - index=True - ) - - labels = relationship("LabelORM", back_populates="meta") + document_id = Column(String(100), nullable=False, index=True) + document_index = Column(String(100), nullable=False, index=True) + __table_args__ = (ForeignKeyConstraint([document_id, document_index], + [DocumentORM.id, DocumentORM.index], + ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore class LabelORM(ORMBase): __tablename__ = "label" - # document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False) - index = Column(String(100), nullable=False, primary_key=True) query = Column(Text, nullable=False) answer = Column(JSON, nullable=True) @@ -86,7 +65,21 @@ class LabelORM(ORMBase): pipeline_id = Column(String(500), nullable=True) meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined") - # document = relationship("DocumentORM", back_populates="labels") + + +class MetaLabelORM(ORMBase): + __tablename__ = "meta_label" + + name = Column(String(100), index=True) + value = Column(String(1000), index=True) + labels = relationship("LabelORM", back_populates="meta") + + label_id = Column(String(100), nullable=False, index=True) + label_index = Column(String(100), nullable=False, index=True) + __table_args__ = (ForeignKeyConstraint([label_id, label_index], + [LabelORM.id, LabelORM.index], + ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore + class SQLDocumentStore(BaseDocumentStore): @@ -96,7 +89,8 @@ class SQLDocumentStore(BaseDocumentStore): index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", - check_same_thread: bool = False + check_same_thread: bool = False, + isolation_level: str = None ): """ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends. @@ -112,18 +106,21 @@ class SQLDocumentStore(BaseDocumentStore): fail: an error is raised if the document ID of the document being added already exists. :param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior) + :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level) """ # save init parameters to enable export of component config as YAML self.set_config( url=url, index=index, label_index=label_index, duplicate_documents=duplicate_documents, check_same_thread=check_same_thread ) - + create_engine_params = {} + if isolation_level: + create_engine_params["isolation_level"] = isolation_level if "sqlite" in url: - engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}) + engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}, **create_engine_params) else: - engine = create_engine(url) - ORMBase.metadata.create_all(engine) + engine = create_engine(url, **create_engine_params) + Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) self.session = Session() self.index: str = index @@ -461,12 +458,14 @@ class SQLDocumentStore(BaseDocumentStore): self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()}) self.session.commit() - def update_document_meta(self, id: str, meta: Dict[str, str]): + def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None): """ Update the metadata dictionary of a document by specifying its string id """ - self.session.query(MetaDocumentORM).filter_by(document_id=id).delete() - meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id) for key, value in meta.items()] + if not index: + index = self.index + self.session.query(MetaDocumentORM).filter_by(document_id=id, document_index=index).delete() + meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id, document_index=index) for key, value in meta.items()] for m in meta_orms: self.session.add(m) self.session.commit() diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py index b163d5c25..8b2ec54ea 100644 --- a/haystack/document_stores/weaviate.py +++ b/haystack/document_stores/weaviate.py @@ -486,11 +486,13 @@ class WeaviateDocumentStore(BaseDocumentStore): progress_bar.update(batch_size) progress_bar.close() - def update_document_meta(self, id: str, meta: Dict[str, str]): + def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None): """ Update the metadata dictionary of a document by specifying its string id. """ - self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id) + if not index: + index = self.index + self.weaviate_client.data_object.update(meta, class_name=index, uuid=id) def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int: """ diff --git a/test/conftest.py b/test/conftest.py index 67e9c2b35..a9a530e92 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,6 +3,9 @@ import time from subprocess import run from sys import platform import gc +import uuid +import logging +from sqlalchemy import create_engine, text import numpy as np import psutil @@ -40,6 +43,11 @@ from haystack.nodes.translator import TransformersTranslator from haystack.nodes.question_generator import QuestionGenerator +# To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below +SQL_TYPE = "sqlite" +# SQL_TYPE = "postgres" + + def pytest_addoption(parser): parser.addoption("--document_store_type", action="store", default="elasticsearch, faiss, memory, milvus, weaviate") @@ -477,58 +485,112 @@ def get_retriever(retriever_type, document_store): return retriever +def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None: + # Weaviate currently only supports UUIDs + if type(document_store)==WeaviateDocumentStore: + for d in docs: + d["id"] = str(uuid.uuid4()) + + @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"]) -def document_store_with_docs(request, test_docs_xs): +def document_store_with_docs(request, test_docs_xs, tmp_path): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) - document_store = get_document_store(request.param, embedding_dim.args[0]) + document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path) document_store.write_documents(test_docs_xs) yield document_store document_store.delete_documents() - @pytest.fixture -def document_store(request): +def document_store(request, tmp_path): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) - document_store = get_document_store(request.param, embedding_dim.args[0]) + document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path) yield document_store document_store.delete_documents() @pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"]) -def document_store_dot_product(request): +def document_store_dot_product(request, tmp_path): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) - document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") + document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path) yield document_store document_store.delete_documents() @pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"]) -def document_store_dot_product_with_docs(request, test_docs_xs): +def document_store_dot_product_with_docs(request, test_docs_xs, tmp_path): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) - document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") + document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path) document_store.write_documents(test_docs_xs) yield document_store document_store.delete_documents() @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"]) -def document_store_dot_product_small(request): +def document_store_dot_product_small(request, tmp_path): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3)) - document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") + document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path) yield document_store document_store.delete_documents() @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"]) -def document_store_small(request): +def document_store_small(request, tmp_path): embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3)) - document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine") + document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="cosine", tmp_path=tmp_path) yield document_store document_store.delete_documents() -def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate + +@pytest.fixture(scope="function", autouse=True) +def postgres_fixture(): + if SQL_TYPE == "postgres": + setup_postgres() + yield + teardown_postgres() + else: + yield + + +@pytest.fixture +def sql_url(tmp_path): + return get_sql_url(tmp_path) + + +def get_sql_url(tmp_path): + if SQL_TYPE == "postgres": + return "postgresql://postgres:postgres@127.0.0.1/postgres" + else: + return f"sqlite:///{tmp_path}/haystack_test.db" + + +def setup_postgres(): + # status = subprocess.run(["docker run --name postgres_test -d -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres"], shell=True) + # if status.returncode: + # logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.") + # else: + # sleep(5) + engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT') + + with engine.connect() as connection: + try: + connection.execute(text('DROP SCHEMA public CASCADE')) + except Exception as e: + logging.error(e) + connection.execute(text('CREATE SCHEMA public;')) + connection.execute(text('SET SESSION idle_in_transaction_session_timeout = "1s";')) + + +def teardown_postgres(): + engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT') + with engine.connect() as connection: + connection.execute(text('DROP SCHEMA public CASCADE')) + connection.close() + + +def get_document_store(document_store_type, tmp_path, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate if document_store_type == "sql": - document_store = SQLDocumentStore(url="sqlite://", index=index) + document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT") + elif document_store_type == "memory": document_store = InMemoryDocumentStore( - return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity - ) + return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity) + elif document_store_type == "elasticsearch": # make sure we start from a fresh index client = Elasticsearch() @@ -536,28 +598,33 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field=" document_store = ElasticsearchDocumentStore( index=index, return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, similarity=similarity ) + elif document_store_type == "faiss": document_store = FAISSDocumentStore( embedding_dim=embedding_dim, - sql_url="sqlite://", + sql_url=get_sql_url(tmp_path), return_embedding=True, embedding_field=embedding_field, index=index, - similarity=similarity + similarity=similarity, + isolation_level="AUTOCOMMIT" ) + elif document_store_type == "milvus": document_store = MilvusDocumentStore( embedding_dim=embedding_dim, - sql_url="sqlite://", + sql_url=get_sql_url(tmp_path), return_embedding=True, embedding_field=embedding_field, index=index, - similarity=similarity + similarity=similarity, + isolation_level="AUTOCOMMIT" ) _, collections = document_store.milvus_server.list_collections() for collection in collections: if collection.startswith(index): document_store.milvus_server.drop_collection(collection) + elif document_store_type == "weaviate": document_store = WeaviateDocumentStore( weaviate_url="http://localhost:8080", diff --git a/test/test_document_store.py b/test/test_document_store.py index a39859e10..30d55f4ba 100644 --- a/test/test_document_store.py +++ b/test/test_document_store.py @@ -1,4 +1,6 @@ from unittest import mock +import uuid +import math import numpy as np import pandas as pd import pytest @@ -7,7 +9,7 @@ from elasticsearch import Elasticsearch from elasticsearch.exceptions import RequestError -from conftest import get_document_store +from conftest import get_document_store, ensure_ids_are_correct_uuids from haystack.document_stores import WeaviateDocumentStore from haystack.document_stores.base import BaseDocumentStore from haystack.errors import DuplicateDocumentError @@ -17,6 +19,18 @@ from haystack.document_stores.faiss import FAISSDocumentStore from haystack.nodes import EmbeddingRetriever from haystack.pipelines import DocumentSearchPipeline + +DOCUMENTS = [ + {"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, + {"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, + {"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)}, + {"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)}, + {"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)}, + {"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)}, +] + + + @pytest.mark.elasticsearch def test_init_elastic_client(): # defaults @@ -148,8 +162,8 @@ def test_get_all_documents_with_correct_filters(document_store_with_docs): assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"} -def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs): - document_store_with_docs = get_document_store("sql") +def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs, tmp_path): + document_store_with_docs = get_document_store("sql", tmp_path) document_store_with_docs.write_documents(test_docs_xs) document_store_with_docs.use_windowed_query = False @@ -791,7 +805,7 @@ def test_multilabel_no_answer(document_store): assert len(multi_labels[0].answers) == 1 -@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss"], indirect=True) +@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True) # Currently update_document_meta() is not implemented for Memory doc store def test_update_meta(document_store): documents = [ @@ -818,10 +832,8 @@ def test_update_meta(document_store): @pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"]) -def test_custom_embedding_field(document_store_type): - document_store = get_document_store( - document_store_type=document_store_type, embedding_field="custom_embedding_field" - ) +def test_custom_embedding_field(document_store_type, tmp_path): + document_store = get_document_store(document_store_type=document_store_type, tmp_path=tmp_path, embedding_field="custom_embedding_field") doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)} document_store.write_documents([doc_to_write]) documents = document_store.get_all_documents(return_embedding=True) @@ -993,4 +1005,5 @@ def test_custom_headers(document_store_with_docs: BaseDocumentStore): args, kwargs = mock_client.search.call_args assert "headers" in kwargs assert kwargs["headers"] == custom_headers - assert len(documents) > 0 \ No newline at end of file + assert len(documents) > 0 + diff --git a/test/test_faiss_and_milvus.py b/test/test_faiss_and_milvus.py index 8e19dcfde..221ce6e4d 100644 --- a/test/test_faiss_and_milvus.py +++ b/test/test_faiss_and_milvus.py @@ -13,6 +13,9 @@ from haystack.document_stores.weaviate import WeaviateDocumentStore from haystack.pipelines import Pipeline from haystack.nodes.retriever.dense import EmbeddingRetriever +from conftest import ensure_ids_are_correct_uuids + + DOCUMENTS = [ {"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, {"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, @@ -23,12 +26,14 @@ DOCUMENTS = [ ] + @pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner") -def test_faiss_index_save_and_load(tmp_path): +def test_faiss_index_save_and_load(tmp_path, sql_url): document_store = FAISSDocumentStore( - sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", + sql_url=sql_url, index="haystack_test", - progress_bar=False # Just to check if the init parameters are kept + progress_bar=False, # Just to check if the init parameters are kept + isolation_level="AUTOCOMMIT" ) document_store.write_documents(DOCUMENTS) @@ -74,11 +79,12 @@ def test_faiss_index_save_and_load(tmp_path): @pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner") -def test_faiss_index_save_and_load_custom_path(tmp_path): +def test_faiss_index_save_and_load_custom_path(tmp_path, sql_url): document_store = FAISSDocumentStore( - sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", + sql_url=sql_url, index="haystack_test", - progress_bar=False # Just to check if the init parameters are kept + progress_bar=False, # Just to check if the init parameters are kept + isolation_level="AUTOCOMMIT" ) document_store.write_documents(DOCUMENTS) @@ -128,13 +134,15 @@ def test_faiss_index_mutual_exclusive_args(tmp_path): with pytest.raises(ValueError): FAISSDocumentStore( sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", - faiss_index_path=f"{tmp_path/'haystack_test'}" + faiss_index_path=f"{tmp_path/'haystack_test'}", + isolation_level="AUTOCOMMIT" ) with pytest.raises(ValueError): FAISSDocumentStore( f"sqlite:////{tmp_path/'haystack_test.db'}", - faiss_index_path=f"{tmp_path/'haystack_test'}" + faiss_index_path=f"{tmp_path/'haystack_test'}", + isolation_level="AUTOCOMMIT" ) @@ -227,7 +235,9 @@ def test_update_with_empty_store(document_store, retriever): @pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"]) def test_faiss_retrieving(index_factory, tmp_path): document_store = FAISSDocumentStore( - sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", faiss_index_factory_str=index_factory + sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", + faiss_index_factory_str=index_factory, + isolation_level="AUTOCOMMIT" ) document_store.delete_all_documents(index="document") @@ -396,7 +406,6 @@ def test_get_docs_with_many_filters(document_store, retriever): assert "2020" == documents[0].meta["year"] - @pytest.mark.parametrize("retriever", ["embedding"], indirect=True) @pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) def test_pipeline(document_store, retriever): @@ -423,7 +432,9 @@ def test_faiss_passing_index_from_outside(tmp_path): faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable) faiss_index.nprobe = 2 document_store = FAISSDocumentStore( - sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", faiss_index=faiss_index, index=index + sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", + faiss_index=faiss_index, index=index, + isolation_level="AUTOCOMMIT" ) document_store.delete_documents() @@ -437,12 +448,8 @@ def test_faiss_passing_index_from_outside(tmp_path): for doc in documents_indexed: assert 0 <= int(doc.meta["vector_id"]) <= 7 -def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None: - # Weaviate currently only supports UUIDs - if type(document_store)==WeaviateDocumentStore: - for d in docs: - d["id"] = str(uuid.uuid4()) +@pytest.mark.parametrize("document_store", ["faiss", "milvus", "weaviate"], indirect=True) def test_cosine_similarity(document_store): # below we will write documents to the store and then query it to see if vectors were normalized @@ -487,6 +494,7 @@ def test_cosine_similarity(document_store): assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01) +@pytest.mark.parametrize("document_store_dot_product_small", ["faiss", "milvus"], indirect=True) def test_normalize_embeddings_diff_shapes(document_store_dot_product_small): VEC_1 = np.array([.1, .2, .3], dtype="float32") document_store_dot_product_small.normalize_embedding(VEC_1) @@ -497,6 +505,7 @@ def test_normalize_embeddings_diff_shapes(document_store_dot_product_small): assert np.linalg.norm(VEC_1) - 1 < 0.01 +@pytest.mark.parametrize("document_store_small", ["faiss", "milvus", "weaviate"], indirect=True) def test_cosine_sanity_check(document_store_small): VEC_1 = np.array([.1, .2, .3], dtype="float32") VEC_2 = np.array([.4, .5, .6], dtype="float32") @@ -512,4 +521,4 @@ def test_cosine_sanity_check(document_store_small): query_results = document_store_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True) # check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318 - assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002) + assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002) \ No newline at end of file diff --git a/test/test_weaviate.py b/test/test_weaviate.py index b6a3de977..49ece2f52 100644 --- a/test/test_weaviate.py +++ b/test/test_weaviate.py @@ -30,16 +30,16 @@ DOCUMENTS_XS = [ @pytest.fixture(params=["weaviate"]) -def document_store_with_docs(request): - document_store = get_document_store(request.param) +def document_store_with_docs(request, tmp_path): + document_store = get_document_store(request.param, tmp_path=tmp_path) document_store.write_documents(DOCUMENTS_XS) yield document_store document_store.delete_documents() @pytest.fixture(params=["weaviate"]) -def document_store(request): - document_store = get_document_store(request.param) +def document_store(request, tmp_path): + document_store = get_document_store(request.param, tmp_path=tmp_path) yield document_store document_store.delete_documents()