diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md
index 7ad468855..93d9f3564 100644
--- a/docs/_src/api/api/document_store.md
+++ b/docs/_src/api/api/document_store.md
@@ -380,7 +380,7 @@ Write annotation labels into document store.
#### update\_document\_meta
```python
- | update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None)
+ | update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None)
```
Update the metadata dictionary of a document by specifying its string id
@@ -952,7 +952,7 @@ class SQLDocumentStore(BaseDocumentStore)
#### \_\_init\_\_
```python
- | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False)
+ | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False, isolation_level: str = None)
```
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
@@ -970,6 +970,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac
fail: an error is raised if the document ID of the document being added already
exists.
- `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
+- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
#### get\_document\_by\_id
@@ -1094,7 +1095,7 @@ Set vector IDs for all documents as None
#### update\_document\_meta
```python
- | update_document_meta(id: str, meta: Dict[str, str])
+ | update_document_meta(id: str, meta: Dict[str, str], index: str = None)
```
Update the metadata dictionary of a document by specifying its string id
@@ -1202,7 +1203,7 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_
```python
- | __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
+ | __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,)
```
**Arguments**:
@@ -1248,6 +1249,7 @@ the vector embeddings are indexed in a FAISS Index.
If specified no other params besides faiss_config_path must be specified.
- `faiss_config_path`: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
+- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
#### write\_documents
@@ -1479,7 +1481,7 @@ Usage:
#### \_\_init\_\_
```python
- | __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
+ | __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', isolation_level: str = None, **kwargs, ,)
```
**Arguments**:
@@ -1525,6 +1527,7 @@ Note that an overly large index_file_size value may cause failure to load a segm
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
+- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
#### write\_documents
@@ -1863,7 +1866,7 @@ None
#### update\_document\_meta
```python
- | update_document_meta(id: str, meta: Dict[str, str])
+ | update_document_meta(id: str, meta: Dict[str, str], index: str = None)
```
Update the metadata dictionary of a document by specifying its string id.
diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py
index 6e3951bba..4a4072004 100644
--- a/haystack/document_stores/elasticsearch.py
+++ b/haystack/document_stores/elasticsearch.py
@@ -551,12 +551,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if labels_to_index:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers)
- def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None):
+ def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None):
"""
Update the metadata dictionary of a document by specifying its string id
"""
+ if not index:
+ index = self.index
body = {"doc": meta}
- self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type, headers=headers)
+ self.client.update(index=index, id=id, body=body, refresh=self.refresh_type, headers=headers)
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None,
only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int:
diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py
index b4fa1e80d..1702330c8 100644
--- a/haystack/document_stores/faiss.py
+++ b/haystack/document_stores/faiss.py
@@ -50,6 +50,7 @@ class FAISSDocumentStore(SQLDocumentStore):
duplicate_documents: str = 'overwrite',
faiss_index_path: Union[str, Path] = None,
faiss_config_path: Union[str, Path] = None,
+ isolation_level: str = None,
**kwargs,
):
"""
@@ -94,6 +95,7 @@ class FAISSDocumentStore(SQLDocumentStore):
If specified no other params besides faiss_config_path must be specified.
:param faiss_config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
+ :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# special case if we want to load an existing index from disk
# load init params from disk and run init again
@@ -115,7 +117,8 @@ class FAISSDocumentStore(SQLDocumentStore):
index=index,
similarity=similarity,
embedding_field=embedding_field,
- progress_bar=progress_bar
+ progress_bar=progress_bar,
+ isolation_level=isolation_level
)
if similarity in ("dot_product", "cosine"):
@@ -155,7 +158,8 @@ class FAISSDocumentStore(SQLDocumentStore):
super().__init__(
url=sql_url,
index=index,
- duplicate_documents=duplicate_documents
+ duplicate_documents=duplicate_documents,
+ isolation_level=isolation_level
)
self._validate_index_sync()
diff --git a/haystack/document_stores/milvus.py b/haystack/document_stores/milvus.py
index cce2e1279..4eeec700b 100644
--- a/haystack/document_stores/milvus.py
+++ b/haystack/document_stores/milvus.py
@@ -53,6 +53,7 @@ class MilvusDocumentStore(SQLDocumentStore):
embedding_field: str = "embedding",
progress_bar: bool = True,
duplicate_documents: str = 'overwrite',
+ isolation_level: str = None,
**kwargs,
):
"""
@@ -97,6 +98,7 @@ class MilvusDocumentStore(SQLDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
+ :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# save init parameters to enable export of component config as YAML
self.set_config(
@@ -104,6 +106,7 @@ class MilvusDocumentStore(SQLDocumentStore):
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
search_param=search_param, duplicate_documents=duplicate_documents,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
+ isolation_level=isolation_level
)
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
@@ -139,7 +142,8 @@ class MilvusDocumentStore(SQLDocumentStore):
super().__init__(
url=sql_url,
index=index,
- duplicate_documents=duplicate_documents
+ duplicate_documents=duplicate_documents,
+ isolation_level=isolation_level,
)
def __del__(self):
diff --git a/haystack/document_stores/milvus2x.py b/haystack/document_stores/milvus2x.py
index 163198a3f..4ea662661 100644
--- a/haystack/document_stores/milvus2x.py
+++ b/haystack/document_stores/milvus2x.py
@@ -73,6 +73,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
custom_fields: Optional[List[Any]] = None,
progress_bar: bool = True,
duplicate_documents: str = 'overwrite',
+ isolation_level: str = None
):
"""
:param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
@@ -118,6 +119,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
+ :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# save init parameters to enable export of component config as YAML
@@ -127,6 +129,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
custom_fields=custom_fields,
+ isolation_level=isolation_level
)
logger.warning("Milvus2DocumentStore is in experimental state until Milvus 2.0 is released")
@@ -173,7 +176,8 @@ class Milvus2DocumentStore(SQLDocumentStore):
super().__init__(
url=sql_url,
index=index,
- duplicate_documents=duplicate_documents
+ duplicate_documents=duplicate_documents,
+ isolation_level=isolation_level,
)
def _create_collection_and_index_if_not_exist(
diff --git a/haystack/document_stores/sql.py b/haystack/document_stores/sql.py
index 666870238..7ef5db791 100644
--- a/haystack/document_stores/sql.py
+++ b/haystack/document_stores/sql.py
@@ -4,7 +4,7 @@ import logging
import itertools
import numpy as np
from uuid import uuid4
-from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON
+from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON, ForeignKeyConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.sql import case, null
@@ -33,9 +33,6 @@ class DocumentORM(ORMBase):
# primary key in combination with id to allow the same doc in different indices
index = Column(String(100), nullable=False, primary_key=True)
vector_id = Column(String(100), unique=True, nullable=True)
-
- # labels = relationship("LabelORM", back_populates="document")
-
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined")
@@ -45,36 +42,18 @@ class MetaDocumentORM(ORMBase):
name = Column(String(100), index=True)
value = Column(String(1000), index=True)
- document_id = Column(
- String(100),
- ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"),
- nullable=False,
- index=True
- )
-
documents = relationship("DocumentORM", back_populates="meta")
-
-class MetaLabelORM(ORMBase):
- __tablename__ = "meta_label"
-
- name = Column(String(100), index=True)
- value = Column(String(1000), index=True)
- label_id = Column(
- String(100),
- ForeignKey("label.id", ondelete="CASCADE", onupdate="CASCADE"),
- nullable=False,
- index=True
- )
-
- labels = relationship("LabelORM", back_populates="meta")
+ document_id = Column(String(100), nullable=False, index=True)
+ document_index = Column(String(100), nullable=False, index=True)
+ __table_args__ = (ForeignKeyConstraint([document_id, document_index],
+ [DocumentORM.id, DocumentORM.index],
+ ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore
class LabelORM(ORMBase):
__tablename__ = "label"
- # document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
-
index = Column(String(100), nullable=False, primary_key=True)
query = Column(Text, nullable=False)
answer = Column(JSON, nullable=True)
@@ -86,7 +65,21 @@ class LabelORM(ORMBase):
pipeline_id = Column(String(500), nullable=True)
meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined")
- # document = relationship("DocumentORM", back_populates="labels")
+
+
+class MetaLabelORM(ORMBase):
+ __tablename__ = "meta_label"
+
+ name = Column(String(100), index=True)
+ value = Column(String(1000), index=True)
+ labels = relationship("LabelORM", back_populates="meta")
+
+ label_id = Column(String(100), nullable=False, index=True)
+ label_index = Column(String(100), nullable=False, index=True)
+ __table_args__ = (ForeignKeyConstraint([label_id, label_index],
+ [LabelORM.id, LabelORM.index],
+ ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore
+
class SQLDocumentStore(BaseDocumentStore):
@@ -96,7 +89,8 @@ class SQLDocumentStore(BaseDocumentStore):
index: str = "document",
label_index: str = "label",
duplicate_documents: str = "overwrite",
- check_same_thread: bool = False
+ check_same_thread: bool = False,
+ isolation_level: str = None
):
"""
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
@@ -112,18 +106,21 @@ class SQLDocumentStore(BaseDocumentStore):
fail: an error is raised if the document ID of the document being added already
exists.
:param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
+ :param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# save init parameters to enable export of component config as YAML
self.set_config(
url=url, index=index, label_index=label_index, duplicate_documents=duplicate_documents, check_same_thread=check_same_thread
)
-
+ create_engine_params = {}
+ if isolation_level:
+ create_engine_params["isolation_level"] = isolation_level
if "sqlite" in url:
- engine = create_engine(url, connect_args={'check_same_thread': check_same_thread})
+ engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}, **create_engine_params)
else:
- engine = create_engine(url)
- ORMBase.metadata.create_all(engine)
+ engine = create_engine(url, **create_engine_params)
+ Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
self.session = Session()
self.index: str = index
@@ -461,12 +458,14 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
self.session.commit()
- def update_document_meta(self, id: str, meta: Dict[str, str]):
+ def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
"""
Update the metadata dictionary of a document by specifying its string id
"""
- self.session.query(MetaDocumentORM).filter_by(document_id=id).delete()
- meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id) for key, value in meta.items()]
+ if not index:
+ index = self.index
+ self.session.query(MetaDocumentORM).filter_by(document_id=id, document_index=index).delete()
+ meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id, document_index=index) for key, value in meta.items()]
for m in meta_orms:
self.session.add(m)
self.session.commit()
diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py
index b163d5c25..8b2ec54ea 100644
--- a/haystack/document_stores/weaviate.py
+++ b/haystack/document_stores/weaviate.py
@@ -486,11 +486,13 @@ class WeaviateDocumentStore(BaseDocumentStore):
progress_bar.update(batch_size)
progress_bar.close()
- def update_document_meta(self, id: str, meta: Dict[str, str]):
+ def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
"""
Update the metadata dictionary of a document by specifying its string id.
"""
- self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id)
+ if not index:
+ index = self.index
+ self.weaviate_client.data_object.update(meta, class_name=index, uuid=id)
def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
diff --git a/test/conftest.py b/test/conftest.py
index 67e9c2b35..a9a530e92 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -3,6 +3,9 @@ import time
from subprocess import run
from sys import platform
import gc
+import uuid
+import logging
+from sqlalchemy import create_engine, text
import numpy as np
import psutil
@@ -40,6 +43,11 @@ from haystack.nodes.translator import TransformersTranslator
from haystack.nodes.question_generator import QuestionGenerator
+# To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below
+SQL_TYPE = "sqlite"
+# SQL_TYPE = "postgres"
+
+
def pytest_addoption(parser):
parser.addoption("--document_store_type", action="store", default="elasticsearch, faiss, memory, milvus, weaviate")
@@ -477,58 +485,112 @@ def get_retriever(retriever_type, document_store):
return retriever
+def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
+ # Weaviate currently only supports UUIDs
+ if type(document_store)==WeaviateDocumentStore:
+ for d in docs:
+ d["id"] = str(uuid.uuid4())
+
+
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
-def document_store_with_docs(request, test_docs_xs):
+def document_store_with_docs(request, test_docs_xs, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
- document_store = get_document_store(request.param, embedding_dim.args[0])
+ document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path)
document_store.write_documents(test_docs_xs)
yield document_store
document_store.delete_documents()
-
@pytest.fixture
-def document_store(request):
+def document_store(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
- document_store = get_document_store(request.param, embedding_dim.args[0])
+ document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
-def document_store_dot_product(request):
+def document_store_dot_product(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
- document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
+ document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
-def document_store_dot_product_with_docs(request, test_docs_xs):
+def document_store_dot_product_with_docs(request, test_docs_xs, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
- document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
+ document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
document_store.write_documents(test_docs_xs)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"])
-def document_store_dot_product_small(request):
+def document_store_dot_product_small(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
- document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
+ document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
-def document_store_small(request):
+def document_store_small(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
- document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine")
+ document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="cosine", tmp_path=tmp_path)
yield document_store
document_store.delete_documents()
-def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
+
+@pytest.fixture(scope="function", autouse=True)
+def postgres_fixture():
+ if SQL_TYPE == "postgres":
+ setup_postgres()
+ yield
+ teardown_postgres()
+ else:
+ yield
+
+
+@pytest.fixture
+def sql_url(tmp_path):
+ return get_sql_url(tmp_path)
+
+
+def get_sql_url(tmp_path):
+ if SQL_TYPE == "postgres":
+ return "postgresql://postgres:postgres@127.0.0.1/postgres"
+ else:
+ return f"sqlite:///{tmp_path}/haystack_test.db"
+
+
+def setup_postgres():
+ # status = subprocess.run(["docker run --name postgres_test -d -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres"], shell=True)
+ # if status.returncode:
+ # logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.")
+ # else:
+ # sleep(5)
+ engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT')
+
+ with engine.connect() as connection:
+ try:
+ connection.execute(text('DROP SCHEMA public CASCADE'))
+ except Exception as e:
+ logging.error(e)
+ connection.execute(text('CREATE SCHEMA public;'))
+ connection.execute(text('SET SESSION idle_in_transaction_session_timeout = "1s";'))
+
+
+def teardown_postgres():
+ engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT')
+ with engine.connect() as connection:
+ connection.execute(text('DROP SCHEMA public CASCADE'))
+ connection.close()
+
+
+def get_document_store(document_store_type, tmp_path, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
if document_store_type == "sql":
- document_store = SQLDocumentStore(url="sqlite://", index=index)
+ document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT")
+
elif document_store_type == "memory":
document_store = InMemoryDocumentStore(
- return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity
- )
+ return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity)
+
elif document_store_type == "elasticsearch":
# make sure we start from a fresh index
client = Elasticsearch()
@@ -536,28 +598,33 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
document_store = ElasticsearchDocumentStore(
index=index, return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, similarity=similarity
)
+
elif document_store_type == "faiss":
document_store = FAISSDocumentStore(
embedding_dim=embedding_dim,
- sql_url="sqlite://",
+ sql_url=get_sql_url(tmp_path),
return_embedding=True,
embedding_field=embedding_field,
index=index,
- similarity=similarity
+ similarity=similarity,
+ isolation_level="AUTOCOMMIT"
)
+
elif document_store_type == "milvus":
document_store = MilvusDocumentStore(
embedding_dim=embedding_dim,
- sql_url="sqlite://",
+ sql_url=get_sql_url(tmp_path),
return_embedding=True,
embedding_field=embedding_field,
index=index,
- similarity=similarity
+ similarity=similarity,
+ isolation_level="AUTOCOMMIT"
)
_, collections = document_store.milvus_server.list_collections()
for collection in collections:
if collection.startswith(index):
document_store.milvus_server.drop_collection(collection)
+
elif document_store_type == "weaviate":
document_store = WeaviateDocumentStore(
weaviate_url="http://localhost:8080",
diff --git a/test/test_document_store.py b/test/test_document_store.py
index a39859e10..30d55f4ba 100644
--- a/test/test_document_store.py
+++ b/test/test_document_store.py
@@ -1,4 +1,6 @@
from unittest import mock
+import uuid
+import math
import numpy as np
import pandas as pd
import pytest
@@ -7,7 +9,7 @@ from elasticsearch import Elasticsearch
from elasticsearch.exceptions import RequestError
-from conftest import get_document_store
+from conftest import get_document_store, ensure_ids_are_correct_uuids
from haystack.document_stores import WeaviateDocumentStore
from haystack.document_stores.base import BaseDocumentStore
from haystack.errors import DuplicateDocumentError
@@ -17,6 +19,18 @@ from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever
from haystack.pipelines import DocumentSearchPipeline
+
+DOCUMENTS = [
+ {"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
+ {"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
+ {"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
+ {"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
+ {"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
+ {"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
+]
+
+
+
@pytest.mark.elasticsearch
def test_init_elastic_client():
# defaults
@@ -148,8 +162,8 @@ def test_get_all_documents_with_correct_filters(document_store_with_docs):
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
-def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs):
- document_store_with_docs = get_document_store("sql")
+def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs, tmp_path):
+ document_store_with_docs = get_document_store("sql", tmp_path)
document_store_with_docs.write_documents(test_docs_xs)
document_store_with_docs.use_windowed_query = False
@@ -791,7 +805,7 @@ def test_multilabel_no_answer(document_store):
assert len(multi_labels[0].answers) == 1
-@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss"], indirect=True)
+@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True)
# Currently update_document_meta() is not implemented for Memory doc store
def test_update_meta(document_store):
documents = [
@@ -818,10 +832,8 @@ def test_update_meta(document_store):
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
-def test_custom_embedding_field(document_store_type):
- document_store = get_document_store(
- document_store_type=document_store_type, embedding_field="custom_embedding_field"
- )
+def test_custom_embedding_field(document_store_type, tmp_path):
+ document_store = get_document_store(document_store_type=document_store_type, tmp_path=tmp_path, embedding_field="custom_embedding_field")
doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
document_store.write_documents([doc_to_write])
documents = document_store.get_all_documents(return_embedding=True)
@@ -993,4 +1005,5 @@ def test_custom_headers(document_store_with_docs: BaseDocumentStore):
args, kwargs = mock_client.search.call_args
assert "headers" in kwargs
assert kwargs["headers"] == custom_headers
- assert len(documents) > 0
\ No newline at end of file
+ assert len(documents) > 0
+
diff --git a/test/test_faiss_and_milvus.py b/test/test_faiss_and_milvus.py
index 8e19dcfde..221ce6e4d 100644
--- a/test/test_faiss_and_milvus.py
+++ b/test/test_faiss_and_milvus.py
@@ -13,6 +13,9 @@ from haystack.document_stores.weaviate import WeaviateDocumentStore
from haystack.pipelines import Pipeline
from haystack.nodes.retriever.dense import EmbeddingRetriever
+from conftest import ensure_ids_are_correct_uuids
+
+
DOCUMENTS = [
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
@@ -23,12 +26,14 @@ DOCUMENTS = [
]
+
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
-def test_faiss_index_save_and_load(tmp_path):
+def test_faiss_index_save_and_load(tmp_path, sql_url):
document_store = FAISSDocumentStore(
- sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
+ sql_url=sql_url,
index="haystack_test",
- progress_bar=False # Just to check if the init parameters are kept
+ progress_bar=False, # Just to check if the init parameters are kept
+ isolation_level="AUTOCOMMIT"
)
document_store.write_documents(DOCUMENTS)
@@ -74,11 +79,12 @@ def test_faiss_index_save_and_load(tmp_path):
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
-def test_faiss_index_save_and_load_custom_path(tmp_path):
+def test_faiss_index_save_and_load_custom_path(tmp_path, sql_url):
document_store = FAISSDocumentStore(
- sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
+ sql_url=sql_url,
index="haystack_test",
- progress_bar=False # Just to check if the init parameters are kept
+ progress_bar=False, # Just to check if the init parameters are kept
+ isolation_level="AUTOCOMMIT"
)
document_store.write_documents(DOCUMENTS)
@@ -128,13 +134,15 @@ def test_faiss_index_mutual_exclusive_args(tmp_path):
with pytest.raises(ValueError):
FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
- faiss_index_path=f"{tmp_path/'haystack_test'}"
+ faiss_index_path=f"{tmp_path/'haystack_test'}",
+ isolation_level="AUTOCOMMIT"
)
with pytest.raises(ValueError):
FAISSDocumentStore(
f"sqlite:////{tmp_path/'haystack_test.db'}",
- faiss_index_path=f"{tmp_path/'haystack_test'}"
+ faiss_index_path=f"{tmp_path/'haystack_test'}",
+ isolation_level="AUTOCOMMIT"
)
@@ -227,7 +235,9 @@ def test_update_with_empty_store(document_store, retriever):
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
def test_faiss_retrieving(index_factory, tmp_path):
document_store = FAISSDocumentStore(
- sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", faiss_index_factory_str=index_factory
+ sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}",
+ faiss_index_factory_str=index_factory,
+ isolation_level="AUTOCOMMIT"
)
document_store.delete_all_documents(index="document")
@@ -396,7 +406,6 @@ def test_get_docs_with_many_filters(document_store, retriever):
assert "2020" == documents[0].meta["year"]
-
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_pipeline(document_store, retriever):
@@ -423,7 +432,9 @@ def test_faiss_passing_index_from_outside(tmp_path):
faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
faiss_index.nprobe = 2
document_store = FAISSDocumentStore(
- sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", faiss_index=faiss_index, index=index
+ sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}",
+ faiss_index=faiss_index, index=index,
+ isolation_level="AUTOCOMMIT"
)
document_store.delete_documents()
@@ -437,12 +448,8 @@ def test_faiss_passing_index_from_outside(tmp_path):
for doc in documents_indexed:
assert 0 <= int(doc.meta["vector_id"]) <= 7
-def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
- # Weaviate currently only supports UUIDs
- if type(document_store)==WeaviateDocumentStore:
- for d in docs:
- d["id"] = str(uuid.uuid4())
+@pytest.mark.parametrize("document_store", ["faiss", "milvus", "weaviate"], indirect=True)
def test_cosine_similarity(document_store):
# below we will write documents to the store and then query it to see if vectors were normalized
@@ -487,6 +494,7 @@ def test_cosine_similarity(document_store):
assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01)
+@pytest.mark.parametrize("document_store_dot_product_small", ["faiss", "milvus"], indirect=True)
def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
VEC_1 = np.array([.1, .2, .3], dtype="float32")
document_store_dot_product_small.normalize_embedding(VEC_1)
@@ -497,6 +505,7 @@ def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
assert np.linalg.norm(VEC_1) - 1 < 0.01
+@pytest.mark.parametrize("document_store_small", ["faiss", "milvus", "weaviate"], indirect=True)
def test_cosine_sanity_check(document_store_small):
VEC_1 = np.array([.1, .2, .3], dtype="float32")
VEC_2 = np.array([.4, .5, .6], dtype="float32")
@@ -512,4 +521,4 @@ def test_cosine_sanity_check(document_store_small):
query_results = document_store_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True)
# check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318
- assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002)
+ assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002)
\ No newline at end of file
diff --git a/test/test_weaviate.py b/test/test_weaviate.py
index b6a3de977..49ece2f52 100644
--- a/test/test_weaviate.py
+++ b/test/test_weaviate.py
@@ -30,16 +30,16 @@ DOCUMENTS_XS = [
@pytest.fixture(params=["weaviate"])
-def document_store_with_docs(request):
- document_store = get_document_store(request.param)
+def document_store_with_docs(request, tmp_path):
+ document_store = get_document_store(request.param, tmp_path=tmp_path)
document_store.write_documents(DOCUMENTS_XS)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["weaviate"])
-def document_store(request):
- document_store = get_document_store(request.param)
+def document_store(request, tmp_path):
+ document_store = get_document_store(request.param, tmp_path=tmp_path)
yield document_store
document_store.delete_documents()