Implement proper FK in MetaDocumentORM and MetaLabelORM to work on PostgreSQL (#1990)

* Properly fix MetaDocumentORM and MetaLabelORM with composite foreign key constraints

* update_document_meta() was not using index properly

* Exclude ES and Memory from the cosine_sanity_check test

* move ensure_ids_are_correct_uuids in conftest and move one test back to faiss & milvus suite

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sara Zan 2022-01-14 13:48:58 +01:00 committed by GitHub
parent 3e4dbbb32c
commit e28bf618d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 208 additions and 101 deletions

View File

@ -380,7 +380,7 @@ Write annotation labels into document store.
#### update\_document\_meta #### update\_document\_meta
```python ```python
| update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None) | update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None)
``` ```
Update the metadata dictionary of a document by specifying its string id Update the metadata dictionary of a document by specifying its string id
@ -952,7 +952,7 @@ class SQLDocumentStore(BaseDocumentStore)
#### \_\_init\_\_ #### \_\_init\_\_
```python ```python
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False) | __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False, isolation_level: str = None)
``` ```
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends. An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
@ -970,6 +970,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac
fail: an error is raised if the document ID of the document being added already fail: an error is raised if the document ID of the document being added already
exists. exists.
- `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior) - `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
<a name="sql.SQLDocumentStore.get_document_by_id"></a> <a name="sql.SQLDocumentStore.get_document_by_id"></a>
#### get\_document\_by\_id #### get\_document\_by\_id
@ -1094,7 +1095,7 @@ Set vector IDs for all documents as None
#### update\_document\_meta #### update\_document\_meta
```python ```python
| update_document_meta(id: str, meta: Dict[str, str]) | update_document_meta(id: str, meta: Dict[str, str], index: str = None)
``` ```
Update the metadata dictionary of a document by specifying its string id Update the metadata dictionary of a document by specifying its string id
@ -1202,7 +1203,7 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_ #### \_\_init\_\_
```python ```python
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,) | __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,)
``` ```
**Arguments**: **Arguments**:
@ -1248,6 +1249,7 @@ the vector embeddings are indexed in a FAISS Index.
If specified no other params besides faiss_config_path must be specified. If specified no other params besides faiss_config_path must be specified.
- `faiss_config_path`: Stored FAISS initial configuration parameters. - `faiss_config_path`: Stored FAISS initial configuration parameters.
Can be created via calling `save()` Can be created via calling `save()`
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
<a name="faiss.FAISSDocumentStore.write_documents"></a> <a name="faiss.FAISSDocumentStore.write_documents"></a>
#### write\_documents #### write\_documents
@ -1479,7 +1481,7 @@ Usage:
#### \_\_init\_\_ #### \_\_init\_\_
```python ```python
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,) | __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', isolation_level: str = None, **kwargs, ,)
``` ```
**Arguments**: **Arguments**:
@ -1525,6 +1527,7 @@ Note that an overly large index_file_size value may cause failure to load a segm
overwrite: Update any existing documents with the same ID when adding documents. overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already fail: an error is raised if the document ID of the document being added already
exists. exists.
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
<a name="milvus.MilvusDocumentStore.write_documents"></a> <a name="milvus.MilvusDocumentStore.write_documents"></a>
#### write\_documents #### write\_documents
@ -1863,7 +1866,7 @@ None
#### update\_document\_meta #### update\_document\_meta
```python ```python
| update_document_meta(id: str, meta: Dict[str, str]) | update_document_meta(id: str, meta: Dict[str, str], index: str = None)
``` ```
Update the metadata dictionary of a document by specifying its string id. Update the metadata dictionary of a document by specifying its string id.

View File

@ -551,12 +551,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if labels_to_index: if labels_to_index:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers) bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers)
def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None): def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None):
""" """
Update the metadata dictionary of a document by specifying its string id Update the metadata dictionary of a document by specifying its string id
""" """
if not index:
index = self.index
body = {"doc": meta} body = {"doc": meta}
self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type, headers=headers) self.client.update(index=index, id=id, body=body, refresh=self.refresh_type, headers=headers)
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None, def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None,
only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int: only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int:

View File

@ -50,6 +50,7 @@ class FAISSDocumentStore(SQLDocumentStore):
duplicate_documents: str = 'overwrite', duplicate_documents: str = 'overwrite',
faiss_index_path: Union[str, Path] = None, faiss_index_path: Union[str, Path] = None,
faiss_config_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None,
isolation_level: str = None,
**kwargs, **kwargs,
): ):
""" """
@ -94,6 +95,7 @@ class FAISSDocumentStore(SQLDocumentStore):
If specified no other params besides faiss_config_path must be specified. If specified no other params besides faiss_config_path must be specified.
:param faiss_config_path: Stored FAISS initial configuration parameters. :param faiss_config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()` Can be created via calling `save()`
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
""" """
# special case if we want to load an existing index from disk # special case if we want to load an existing index from disk
# load init params from disk and run init again # load init params from disk and run init again
@ -115,7 +117,8 @@ class FAISSDocumentStore(SQLDocumentStore):
index=index, index=index,
similarity=similarity, similarity=similarity,
embedding_field=embedding_field, embedding_field=embedding_field,
progress_bar=progress_bar progress_bar=progress_bar,
isolation_level=isolation_level
) )
if similarity in ("dot_product", "cosine"): if similarity in ("dot_product", "cosine"):
@ -155,7 +158,8 @@ class FAISSDocumentStore(SQLDocumentStore):
super().__init__( super().__init__(
url=sql_url, url=sql_url,
index=index, index=index,
duplicate_documents=duplicate_documents duplicate_documents=duplicate_documents,
isolation_level=isolation_level
) )
self._validate_index_sync() self._validate_index_sync()

View File

@ -53,6 +53,7 @@ class MilvusDocumentStore(SQLDocumentStore):
embedding_field: str = "embedding", embedding_field: str = "embedding",
progress_bar: bool = True, progress_bar: bool = True,
duplicate_documents: str = 'overwrite', duplicate_documents: str = 'overwrite',
isolation_level: str = None,
**kwargs, **kwargs,
): ):
""" """
@ -97,6 +98,7 @@ class MilvusDocumentStore(SQLDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents. overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already fail: an error is raised if the document ID of the document being added already
exists. exists.
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
""" """
# save init parameters to enable export of component config as YAML # save init parameters to enable export of component config as YAML
self.set_config( self.set_config(
@ -104,6 +106,7 @@ class MilvusDocumentStore(SQLDocumentStore):
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param, embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
search_param=search_param, duplicate_documents=duplicate_documents, search_param=search_param, duplicate_documents=duplicate_documents,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar, return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
isolation_level=isolation_level
) )
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool) self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
@ -139,7 +142,8 @@ class MilvusDocumentStore(SQLDocumentStore):
super().__init__( super().__init__(
url=sql_url, url=sql_url,
index=index, index=index,
duplicate_documents=duplicate_documents duplicate_documents=duplicate_documents,
isolation_level=isolation_level,
) )
def __del__(self): def __del__(self):

View File

@ -73,6 +73,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
custom_fields: Optional[List[Any]] = None, custom_fields: Optional[List[Any]] = None,
progress_bar: bool = True, progress_bar: bool = True,
duplicate_documents: str = 'overwrite', duplicate_documents: str = 'overwrite',
isolation_level: str = None
): ):
""" """
:param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale :param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
@ -118,6 +119,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents. overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already fail: an error is raised if the document ID of the document being added already
exists. exists.
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
""" """
# save init parameters to enable export of component config as YAML # save init parameters to enable export of component config as YAML
@ -127,6 +129,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field, search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar, return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
custom_fields=custom_fields, custom_fields=custom_fields,
isolation_level=isolation_level
) )
logger.warning("Milvus2DocumentStore is in experimental state until Milvus 2.0 is released") logger.warning("Milvus2DocumentStore is in experimental state until Milvus 2.0 is released")
@ -173,7 +176,8 @@ class Milvus2DocumentStore(SQLDocumentStore):
super().__init__( super().__init__(
url=sql_url, url=sql_url,
index=index, index=index,
duplicate_documents=duplicate_documents duplicate_documents=duplicate_documents,
isolation_level=isolation_level,
) )
def _create_collection_and_index_if_not_exist( def _create_collection_and_index_if_not_exist(

View File

@ -4,7 +4,7 @@ import logging
import itertools import itertools
import numpy as np import numpy as np
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON, ForeignKeyConstraint
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.sql import case, null from sqlalchemy.sql import case, null
@ -33,9 +33,6 @@ class DocumentORM(ORMBase):
# primary key in combination with id to allow the same doc in different indices # primary key in combination with id to allow the same doc in different indices
index = Column(String(100), nullable=False, primary_key=True) index = Column(String(100), nullable=False, primary_key=True)
vector_id = Column(String(100), unique=True, nullable=True) vector_id = Column(String(100), unique=True, nullable=True)
# labels = relationship("LabelORM", back_populates="document")
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata # speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined") meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined")
@ -45,36 +42,18 @@ class MetaDocumentORM(ORMBase):
name = Column(String(100), index=True) name = Column(String(100), index=True)
value = Column(String(1000), index=True) value = Column(String(1000), index=True)
document_id = Column(
String(100),
ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"),
nullable=False,
index=True
)
documents = relationship("DocumentORM", back_populates="meta") documents = relationship("DocumentORM", back_populates="meta")
document_id = Column(String(100), nullable=False, index=True)
class MetaLabelORM(ORMBase): document_index = Column(String(100), nullable=False, index=True)
__tablename__ = "meta_label" __table_args__ = (ForeignKeyConstraint([document_id, document_index],
[DocumentORM.id, DocumentORM.index],
name = Column(String(100), index=True) ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore
value = Column(String(1000), index=True)
label_id = Column(
String(100),
ForeignKey("label.id", ondelete="CASCADE", onupdate="CASCADE"),
nullable=False,
index=True
)
labels = relationship("LabelORM", back_populates="meta")
class LabelORM(ORMBase): class LabelORM(ORMBase):
__tablename__ = "label" __tablename__ = "label"
# document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
index = Column(String(100), nullable=False, primary_key=True) index = Column(String(100), nullable=False, primary_key=True)
query = Column(Text, nullable=False) query = Column(Text, nullable=False)
answer = Column(JSON, nullable=True) answer = Column(JSON, nullable=True)
@ -86,7 +65,21 @@ class LabelORM(ORMBase):
pipeline_id = Column(String(500), nullable=True) pipeline_id = Column(String(500), nullable=True)
meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined") meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined")
# document = relationship("DocumentORM", back_populates="labels")
class MetaLabelORM(ORMBase):
__tablename__ = "meta_label"
name = Column(String(100), index=True)
value = Column(String(1000), index=True)
labels = relationship("LabelORM", back_populates="meta")
label_id = Column(String(100), nullable=False, index=True)
label_index = Column(String(100), nullable=False, index=True)
__table_args__ = (ForeignKeyConstraint([label_id, label_index],
[LabelORM.id, LabelORM.index],
ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore
class SQLDocumentStore(BaseDocumentStore): class SQLDocumentStore(BaseDocumentStore):
@ -96,7 +89,8 @@ class SQLDocumentStore(BaseDocumentStore):
index: str = "document", index: str = "document",
label_index: str = "label", label_index: str = "label",
duplicate_documents: str = "overwrite", duplicate_documents: str = "overwrite",
check_same_thread: bool = False check_same_thread: bool = False,
isolation_level: str = None
): ):
""" """
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends. An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
@ -112,18 +106,21 @@ class SQLDocumentStore(BaseDocumentStore):
fail: an error is raised if the document ID of the document being added already fail: an error is raised if the document ID of the document being added already
exists. exists.
:param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior) :param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
""" """
# save init parameters to enable export of component config as YAML # save init parameters to enable export of component config as YAML
self.set_config( self.set_config(
url=url, index=index, label_index=label_index, duplicate_documents=duplicate_documents, check_same_thread=check_same_thread url=url, index=index, label_index=label_index, duplicate_documents=duplicate_documents, check_same_thread=check_same_thread
) )
create_engine_params = {}
if isolation_level:
create_engine_params["isolation_level"] = isolation_level
if "sqlite" in url: if "sqlite" in url:
engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}) engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}, **create_engine_params)
else: else:
engine = create_engine(url) engine = create_engine(url, **create_engine_params)
ORMBase.metadata.create_all(engine) Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine) Session = sessionmaker(bind=engine)
self.session = Session() self.session = Session()
self.index: str = index self.index: str = index
@ -461,12 +458,14 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()}) self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
self.session.commit() self.session.commit()
def update_document_meta(self, id: str, meta: Dict[str, str]): def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
""" """
Update the metadata dictionary of a document by specifying its string id Update the metadata dictionary of a document by specifying its string id
""" """
self.session.query(MetaDocumentORM).filter_by(document_id=id).delete() if not index:
meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id) for key, value in meta.items()] index = self.index
self.session.query(MetaDocumentORM).filter_by(document_id=id, document_index=index).delete()
meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id, document_index=index) for key, value in meta.items()]
for m in meta_orms: for m in meta_orms:
self.session.add(m) self.session.add(m)
self.session.commit() self.session.commit()

View File

@ -486,11 +486,13 @@ class WeaviateDocumentStore(BaseDocumentStore):
progress_bar.update(batch_size) progress_bar.update(batch_size)
progress_bar.close() progress_bar.close()
def update_document_meta(self, id: str, meta: Dict[str, str]): def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
""" """
Update the metadata dictionary of a document by specifying its string id. Update the metadata dictionary of a document by specifying its string id.
""" """
self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id) if not index:
index = self.index
self.weaviate_client.data_object.update(meta, class_name=index, uuid=id)
def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int: def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
""" """

View File

@ -3,6 +3,9 @@ import time
from subprocess import run from subprocess import run
from sys import platform from sys import platform
import gc import gc
import uuid
import logging
from sqlalchemy import create_engine, text
import numpy as np import numpy as np
import psutil import psutil
@ -40,6 +43,11 @@ from haystack.nodes.translator import TransformersTranslator
from haystack.nodes.question_generator import QuestionGenerator from haystack.nodes.question_generator import QuestionGenerator
# To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below
SQL_TYPE = "sqlite"
# SQL_TYPE = "postgres"
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption("--document_store_type", action="store", default="elasticsearch, faiss, memory, milvus, weaviate") parser.addoption("--document_store_type", action="store", default="elasticsearch, faiss, memory, milvus, weaviate")
@ -477,58 +485,112 @@ def get_retriever(retriever_type, document_store):
return retriever return retriever
def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
# Weaviate currently only supports UUIDs
if type(document_store)==WeaviateDocumentStore:
for d in docs:
d["id"] = str(uuid.uuid4())
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"]) @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
def document_store_with_docs(request, test_docs_xs): def document_store_with_docs(request, test_docs_xs, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
document_store = get_document_store(request.param, embedding_dim.args[0]) document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path)
document_store.write_documents(test_docs_xs) document_store.write_documents(test_docs_xs)
yield document_store yield document_store
document_store.delete_documents() document_store.delete_documents()
@pytest.fixture @pytest.fixture
def document_store(request): def document_store(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
document_store = get_document_store(request.param, embedding_dim.args[0]) document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path)
yield document_store yield document_store
document_store.delete_documents() document_store.delete_documents()
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"]) @pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
def document_store_dot_product(request): def document_store_dot_product(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
yield document_store yield document_store
document_store.delete_documents() document_store.delete_documents()
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"]) @pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
def document_store_dot_product_with_docs(request, test_docs_xs): def document_store_dot_product_with_docs(request, test_docs_xs, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768)) embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
document_store.write_documents(test_docs_xs) document_store.write_documents(test_docs_xs)
yield document_store yield document_store
document_store.delete_documents() document_store.delete_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"]) @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"])
def document_store_dot_product_small(request): def document_store_dot_product_small(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3)) embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product") document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
yield document_store yield document_store
document_store.delete_documents() document_store.delete_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"]) @pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
def document_store_small(request): def document_store_small(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3)) embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine") document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="cosine", tmp_path=tmp_path)
yield document_store yield document_store
document_store.delete_documents() document_store.delete_documents()
def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
@pytest.fixture(scope="function", autouse=True)
def postgres_fixture():
if SQL_TYPE == "postgres":
setup_postgres()
yield
teardown_postgres()
else:
yield
@pytest.fixture
def sql_url(tmp_path):
return get_sql_url(tmp_path)
def get_sql_url(tmp_path):
if SQL_TYPE == "postgres":
return "postgresql://postgres:postgres@127.0.0.1/postgres"
else:
return f"sqlite:///{tmp_path}/haystack_test.db"
def setup_postgres():
# status = subprocess.run(["docker run --name postgres_test -d -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres"], shell=True)
# if status.returncode:
# logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.")
# else:
# sleep(5)
engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT')
with engine.connect() as connection:
try:
connection.execute(text('DROP SCHEMA public CASCADE'))
except Exception as e:
logging.error(e)
connection.execute(text('CREATE SCHEMA public;'))
connection.execute(text('SET SESSION idle_in_transaction_session_timeout = "1s";'))
def teardown_postgres():
engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT')
with engine.connect() as connection:
connection.execute(text('DROP SCHEMA public CASCADE'))
connection.close()
def get_document_store(document_store_type, tmp_path, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
if document_store_type == "sql": if document_store_type == "sql":
document_store = SQLDocumentStore(url="sqlite://", index=index) document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT")
elif document_store_type == "memory": elif document_store_type == "memory":
document_store = InMemoryDocumentStore( document_store = InMemoryDocumentStore(
return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity)
)
elif document_store_type == "elasticsearch": elif document_store_type == "elasticsearch":
# make sure we start from a fresh index # make sure we start from a fresh index
client = Elasticsearch() client = Elasticsearch()
@ -536,28 +598,33 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
document_store = ElasticsearchDocumentStore( document_store = ElasticsearchDocumentStore(
index=index, return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, similarity=similarity index=index, return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, similarity=similarity
) )
elif document_store_type == "faiss": elif document_store_type == "faiss":
document_store = FAISSDocumentStore( document_store = FAISSDocumentStore(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
sql_url="sqlite://", sql_url=get_sql_url(tmp_path),
return_embedding=True, return_embedding=True,
embedding_field=embedding_field, embedding_field=embedding_field,
index=index, index=index,
similarity=similarity similarity=similarity,
isolation_level="AUTOCOMMIT"
) )
elif document_store_type == "milvus": elif document_store_type == "milvus":
document_store = MilvusDocumentStore( document_store = MilvusDocumentStore(
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
sql_url="sqlite://", sql_url=get_sql_url(tmp_path),
return_embedding=True, return_embedding=True,
embedding_field=embedding_field, embedding_field=embedding_field,
index=index, index=index,
similarity=similarity similarity=similarity,
isolation_level="AUTOCOMMIT"
) )
_, collections = document_store.milvus_server.list_collections() _, collections = document_store.milvus_server.list_collections()
for collection in collections: for collection in collections:
if collection.startswith(index): if collection.startswith(index):
document_store.milvus_server.drop_collection(collection) document_store.milvus_server.drop_collection(collection)
elif document_store_type == "weaviate": elif document_store_type == "weaviate":
document_store = WeaviateDocumentStore( document_store = WeaviateDocumentStore(
weaviate_url="http://localhost:8080", weaviate_url="http://localhost:8080",

View File

@ -1,4 +1,6 @@
from unittest import mock from unittest import mock
import uuid
import math
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pytest import pytest
@ -7,7 +9,7 @@ from elasticsearch import Elasticsearch
from elasticsearch.exceptions import RequestError from elasticsearch.exceptions import RequestError
from conftest import get_document_store from conftest import get_document_store, ensure_ids_are_correct_uuids
from haystack.document_stores import WeaviateDocumentStore from haystack.document_stores import WeaviateDocumentStore
from haystack.document_stores.base import BaseDocumentStore from haystack.document_stores.base import BaseDocumentStore
from haystack.errors import DuplicateDocumentError from haystack.errors import DuplicateDocumentError
@ -17,6 +19,18 @@ from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever from haystack.nodes import EmbeddingRetriever
from haystack.pipelines import DocumentSearchPipeline from haystack.pipelines import DocumentSearchPipeline
DOCUMENTS = [
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
{"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
]
@pytest.mark.elasticsearch @pytest.mark.elasticsearch
def test_init_elastic_client(): def test_init_elastic_client():
# defaults # defaults
@ -148,8 +162,8 @@ def test_get_all_documents_with_correct_filters(document_store_with_docs):
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"} assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs): def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs, tmp_path):
document_store_with_docs = get_document_store("sql") document_store_with_docs = get_document_store("sql", tmp_path)
document_store_with_docs.write_documents(test_docs_xs) document_store_with_docs.write_documents(test_docs_xs)
document_store_with_docs.use_windowed_query = False document_store_with_docs.use_windowed_query = False
@ -791,7 +805,7 @@ def test_multilabel_no_answer(document_store):
assert len(multi_labels[0].answers) == 1 assert len(multi_labels[0].answers) == 1
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss"], indirect=True) @pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True)
# Currently update_document_meta() is not implemented for Memory doc store # Currently update_document_meta() is not implemented for Memory doc store
def test_update_meta(document_store): def test_update_meta(document_store):
documents = [ documents = [
@ -818,10 +832,8 @@ def test_update_meta(document_store):
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"]) @pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
def test_custom_embedding_field(document_store_type): def test_custom_embedding_field(document_store_type, tmp_path):
document_store = get_document_store( document_store = get_document_store(document_store_type=document_store_type, tmp_path=tmp_path, embedding_field="custom_embedding_field")
document_store_type=document_store_type, embedding_field="custom_embedding_field"
)
doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)} doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
document_store.write_documents([doc_to_write]) document_store.write_documents([doc_to_write])
documents = document_store.get_all_documents(return_embedding=True) documents = document_store.get_all_documents(return_embedding=True)
@ -994,3 +1006,4 @@ def test_custom_headers(document_store_with_docs: BaseDocumentStore):
assert "headers" in kwargs assert "headers" in kwargs
assert kwargs["headers"] == custom_headers assert kwargs["headers"] == custom_headers
assert len(documents) > 0 assert len(documents) > 0

View File

@ -13,6 +13,9 @@ from haystack.document_stores.weaviate import WeaviateDocumentStore
from haystack.pipelines import Pipeline from haystack.pipelines import Pipeline
from haystack.nodes.retriever.dense import EmbeddingRetriever from haystack.nodes.retriever.dense import EmbeddingRetriever
from conftest import ensure_ids_are_correct_uuids
DOCUMENTS = [ DOCUMENTS = [
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, {"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, {"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
@ -23,12 +26,14 @@ DOCUMENTS = [
] ]
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner") @pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
def test_faiss_index_save_and_load(tmp_path): def test_faiss_index_save_and_load(tmp_path, sql_url):
document_store = FAISSDocumentStore( document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", sql_url=sql_url,
index="haystack_test", index="haystack_test",
progress_bar=False # Just to check if the init parameters are kept progress_bar=False, # Just to check if the init parameters are kept
isolation_level="AUTOCOMMIT"
) )
document_store.write_documents(DOCUMENTS) document_store.write_documents(DOCUMENTS)
@ -74,11 +79,12 @@ def test_faiss_index_save_and_load(tmp_path):
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner") @pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
def test_faiss_index_save_and_load_custom_path(tmp_path): def test_faiss_index_save_and_load_custom_path(tmp_path, sql_url):
document_store = FAISSDocumentStore( document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", sql_url=sql_url,
index="haystack_test", index="haystack_test",
progress_bar=False # Just to check if the init parameters are kept progress_bar=False, # Just to check if the init parameters are kept
isolation_level="AUTOCOMMIT"
) )
document_store.write_documents(DOCUMENTS) document_store.write_documents(DOCUMENTS)
@ -128,13 +134,15 @@ def test_faiss_index_mutual_exclusive_args(tmp_path):
with pytest.raises(ValueError): with pytest.raises(ValueError):
FAISSDocumentStore( FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}", sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
faiss_index_path=f"{tmp_path/'haystack_test'}" faiss_index_path=f"{tmp_path/'haystack_test'}",
isolation_level="AUTOCOMMIT"
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
FAISSDocumentStore( FAISSDocumentStore(
f"sqlite:////{tmp_path/'haystack_test.db'}", f"sqlite:////{tmp_path/'haystack_test.db'}",
faiss_index_path=f"{tmp_path/'haystack_test'}" faiss_index_path=f"{tmp_path/'haystack_test'}",
isolation_level="AUTOCOMMIT"
) )
@ -227,7 +235,9 @@ def test_update_with_empty_store(document_store, retriever):
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"]) @pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
def test_faiss_retrieving(index_factory, tmp_path): def test_faiss_retrieving(index_factory, tmp_path):
document_store = FAISSDocumentStore( document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", faiss_index_factory_str=index_factory sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}",
faiss_index_factory_str=index_factory,
isolation_level="AUTOCOMMIT"
) )
document_store.delete_all_documents(index="document") document_store.delete_all_documents(index="document")
@ -396,7 +406,6 @@ def test_get_docs_with_many_filters(document_store, retriever):
assert "2020" == documents[0].meta["year"] assert "2020" == documents[0].meta["year"]
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True) @pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) @pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_pipeline(document_store, retriever): def test_pipeline(document_store, retriever):
@ -423,7 +432,9 @@ def test_faiss_passing_index_from_outside(tmp_path):
faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable) faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
faiss_index.nprobe = 2 faiss_index.nprobe = 2
document_store = FAISSDocumentStore( document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", faiss_index=faiss_index, index=index sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}",
faiss_index=faiss_index, index=index,
isolation_level="AUTOCOMMIT"
) )
document_store.delete_documents() document_store.delete_documents()
@ -437,12 +448,8 @@ def test_faiss_passing_index_from_outside(tmp_path):
for doc in documents_indexed: for doc in documents_indexed:
assert 0 <= int(doc.meta["vector_id"]) <= 7 assert 0 <= int(doc.meta["vector_id"]) <= 7
def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
# Weaviate currently only supports UUIDs
if type(document_store)==WeaviateDocumentStore:
for d in docs:
d["id"] = str(uuid.uuid4())
@pytest.mark.parametrize("document_store", ["faiss", "milvus", "weaviate"], indirect=True)
def test_cosine_similarity(document_store): def test_cosine_similarity(document_store):
# below we will write documents to the store and then query it to see if vectors were normalized # below we will write documents to the store and then query it to see if vectors were normalized
@ -487,6 +494,7 @@ def test_cosine_similarity(document_store):
assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01) assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01)
@pytest.mark.parametrize("document_store_dot_product_small", ["faiss", "milvus"], indirect=True)
def test_normalize_embeddings_diff_shapes(document_store_dot_product_small): def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
VEC_1 = np.array([.1, .2, .3], dtype="float32") VEC_1 = np.array([.1, .2, .3], dtype="float32")
document_store_dot_product_small.normalize_embedding(VEC_1) document_store_dot_product_small.normalize_embedding(VEC_1)
@ -497,6 +505,7 @@ def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
assert np.linalg.norm(VEC_1) - 1 < 0.01 assert np.linalg.norm(VEC_1) - 1 < 0.01
@pytest.mark.parametrize("document_store_small", ["faiss", "milvus", "weaviate"], indirect=True)
def test_cosine_sanity_check(document_store_small): def test_cosine_sanity_check(document_store_small):
VEC_1 = np.array([.1, .2, .3], dtype="float32") VEC_1 = np.array([.1, .2, .3], dtype="float32")
VEC_2 = np.array([.4, .5, .6], dtype="float32") VEC_2 = np.array([.4, .5, .6], dtype="float32")

View File

@ -30,16 +30,16 @@ DOCUMENTS_XS = [
@pytest.fixture(params=["weaviate"]) @pytest.fixture(params=["weaviate"])
def document_store_with_docs(request): def document_store_with_docs(request, tmp_path):
document_store = get_document_store(request.param) document_store = get_document_store(request.param, tmp_path=tmp_path)
document_store.write_documents(DOCUMENTS_XS) document_store.write_documents(DOCUMENTS_XS)
yield document_store yield document_store
document_store.delete_documents() document_store.delete_documents()
@pytest.fixture(params=["weaviate"]) @pytest.fixture(params=["weaviate"])
def document_store(request): def document_store(request, tmp_path):
document_store = get_document_store(request.param) document_store = get_document_store(request.param, tmp_path=tmp_path)
yield document_store yield document_store
document_store.delete_documents() document_store.delete_documents()