mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-24 17:30:38 +00:00
Implement proper FK in MetaDocumentORM
and MetaLabelORM
to work on PostgreSQL (#1990)
* Properly fix MetaDocumentORM and MetaLabelORM with composite foreign key constraints * update_document_meta() was not using index properly * Exclude ES and Memory from the cosine_sanity_check test * move ensure_ids_are_correct_uuids in conftest and move one test back to faiss & milvus suite Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
3e4dbbb32c
commit
e28bf618d7
@ -380,7 +380,7 @@ Write annotation labels into document store.
|
||||
#### update\_document\_meta
|
||||
|
||||
```python
|
||||
| update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None)
|
||||
| update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None)
|
||||
```
|
||||
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
@ -952,7 +952,7 @@ class SQLDocumentStore(BaseDocumentStore)
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False)
|
||||
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False, isolation_level: str = None)
|
||||
```
|
||||
|
||||
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
|
||||
@ -970,6 +970,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
- `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
|
||||
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
|
||||
<a name="sql.SQLDocumentStore.get_document_by_id"></a>
|
||||
#### get\_document\_by\_id
|
||||
@ -1094,7 +1095,7 @@ Set vector IDs for all documents as None
|
||||
#### update\_document\_meta
|
||||
|
||||
```python
|
||||
| update_document_meta(id: str, meta: Dict[str, str])
|
||||
| update_document_meta(id: str, meta: Dict[str, str], index: str = None)
|
||||
```
|
||||
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
@ -1202,7 +1203,7 @@ the vector embeddings are indexed in a FAISS Index.
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
|
||||
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -1248,6 +1249,7 @@ the vector embeddings are indexed in a FAISS Index.
|
||||
If specified no other params besides faiss_config_path must be specified.
|
||||
- `faiss_config_path`: Stored FAISS initial configuration parameters.
|
||||
Can be created via calling `save()`
|
||||
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
|
||||
<a name="faiss.FAISSDocumentStore.write_documents"></a>
|
||||
#### write\_documents
|
||||
@ -1479,7 +1481,7 @@ Usage:
|
||||
#### \_\_init\_\_
|
||||
|
||||
```python
|
||||
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
|
||||
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', isolation_level: str = None, **kwargs, ,)
|
||||
```
|
||||
|
||||
**Arguments**:
|
||||
@ -1525,6 +1527,7 @@ Note that an overly large index_file_size value may cause failure to load a segm
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
|
||||
<a name="milvus.MilvusDocumentStore.write_documents"></a>
|
||||
#### write\_documents
|
||||
@ -1863,7 +1866,7 @@ None
|
||||
#### update\_document\_meta
|
||||
|
||||
```python
|
||||
| update_document_meta(id: str, meta: Dict[str, str])
|
||||
| update_document_meta(id: str, meta: Dict[str, str], index: str = None)
|
||||
```
|
||||
|
||||
Update the metadata dictionary of a document by specifying its string id.
|
||||
|
@ -551,12 +551,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
if labels_to_index:
|
||||
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers)
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None):
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
"""
|
||||
if not index:
|
||||
index = self.index
|
||||
body = {"doc": meta}
|
||||
self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type, headers=headers)
|
||||
self.client.update(index=index, id=id, body=body, refresh=self.refresh_type, headers=headers)
|
||||
|
||||
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None,
|
||||
only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int:
|
||||
|
@ -50,6 +50,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
duplicate_documents: str = 'overwrite',
|
||||
faiss_index_path: Union[str, Path] = None,
|
||||
faiss_config_path: Union[str, Path] = None,
|
||||
isolation_level: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -94,6 +95,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
If specified no other params besides faiss_config_path must be specified.
|
||||
:param faiss_config_path: Stored FAISS initial configuration parameters.
|
||||
Can be created via calling `save()`
|
||||
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
"""
|
||||
# special case if we want to load an existing index from disk
|
||||
# load init params from disk and run init again
|
||||
@ -115,7 +117,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
index=index,
|
||||
similarity=similarity,
|
||||
embedding_field=embedding_field,
|
||||
progress_bar=progress_bar
|
||||
progress_bar=progress_bar,
|
||||
isolation_level=isolation_level
|
||||
)
|
||||
|
||||
if similarity in ("dot_product", "cosine"):
|
||||
@ -155,7 +158,8 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
super().__init__(
|
||||
url=sql_url,
|
||||
index=index,
|
||||
duplicate_documents=duplicate_documents
|
||||
duplicate_documents=duplicate_documents,
|
||||
isolation_level=isolation_level
|
||||
)
|
||||
|
||||
self._validate_index_sync()
|
||||
|
@ -53,6 +53,7 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
embedding_field: str = "embedding",
|
||||
progress_bar: bool = True,
|
||||
duplicate_documents: str = 'overwrite',
|
||||
isolation_level: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -97,6 +98,7 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
"""
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
@ -104,6 +106,7 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
|
||||
search_param=search_param, duplicate_documents=duplicate_documents,
|
||||
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
|
||||
isolation_level=isolation_level
|
||||
)
|
||||
|
||||
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
|
||||
@ -139,7 +142,8 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
super().__init__(
|
||||
url=sql_url,
|
||||
index=index,
|
||||
duplicate_documents=duplicate_documents
|
||||
duplicate_documents=duplicate_documents,
|
||||
isolation_level=isolation_level,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
|
@ -73,6 +73,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
custom_fields: Optional[List[Any]] = None,
|
||||
progress_bar: bool = True,
|
||||
duplicate_documents: str = 'overwrite',
|
||||
isolation_level: str = None
|
||||
):
|
||||
"""
|
||||
:param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
|
||||
@ -118,6 +119,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
overwrite: Update any existing documents with the same ID when adding documents.
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
@ -127,6 +129,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field,
|
||||
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
|
||||
custom_fields=custom_fields,
|
||||
isolation_level=isolation_level
|
||||
)
|
||||
|
||||
logger.warning("Milvus2DocumentStore is in experimental state until Milvus 2.0 is released")
|
||||
@ -173,7 +176,8 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
super().__init__(
|
||||
url=sql_url,
|
||||
index=index,
|
||||
duplicate_documents=duplicate_documents
|
||||
duplicate_documents=duplicate_documents,
|
||||
isolation_level=isolation_level,
|
||||
)
|
||||
|
||||
def _create_collection_and_index_if_not_exist(
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import itertools
|
||||
import numpy as np
|
||||
from uuid import uuid4
|
||||
from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON
|
||||
from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON, ForeignKeyConstraint
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship, sessionmaker
|
||||
from sqlalchemy.sql import case, null
|
||||
@ -33,9 +33,6 @@ class DocumentORM(ORMBase):
|
||||
# primary key in combination with id to allow the same doc in different indices
|
||||
index = Column(String(100), nullable=False, primary_key=True)
|
||||
vector_id = Column(String(100), unique=True, nullable=True)
|
||||
|
||||
# labels = relationship("LabelORM", back_populates="document")
|
||||
|
||||
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
|
||||
meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined")
|
||||
|
||||
@ -45,36 +42,18 @@ class MetaDocumentORM(ORMBase):
|
||||
|
||||
name = Column(String(100), index=True)
|
||||
value = Column(String(1000), index=True)
|
||||
document_id = Column(
|
||||
String(100),
|
||||
ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
|
||||
documents = relationship("DocumentORM", back_populates="meta")
|
||||
|
||||
|
||||
class MetaLabelORM(ORMBase):
|
||||
__tablename__ = "meta_label"
|
||||
|
||||
name = Column(String(100), index=True)
|
||||
value = Column(String(1000), index=True)
|
||||
label_id = Column(
|
||||
String(100),
|
||||
ForeignKey("label.id", ondelete="CASCADE", onupdate="CASCADE"),
|
||||
nullable=False,
|
||||
index=True
|
||||
)
|
||||
|
||||
labels = relationship("LabelORM", back_populates="meta")
|
||||
document_id = Column(String(100), nullable=False, index=True)
|
||||
document_index = Column(String(100), nullable=False, index=True)
|
||||
__table_args__ = (ForeignKeyConstraint([document_id, document_index],
|
||||
[DocumentORM.id, DocumentORM.index],
|
||||
ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore
|
||||
|
||||
|
||||
class LabelORM(ORMBase):
|
||||
__tablename__ = "label"
|
||||
|
||||
# document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
|
||||
|
||||
index = Column(String(100), nullable=False, primary_key=True)
|
||||
query = Column(Text, nullable=False)
|
||||
answer = Column(JSON, nullable=True)
|
||||
@ -86,7 +65,21 @@ class LabelORM(ORMBase):
|
||||
pipeline_id = Column(String(500), nullable=True)
|
||||
|
||||
meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined")
|
||||
# document = relationship("DocumentORM", back_populates="labels")
|
||||
|
||||
|
||||
class MetaLabelORM(ORMBase):
|
||||
__tablename__ = "meta_label"
|
||||
|
||||
name = Column(String(100), index=True)
|
||||
value = Column(String(1000), index=True)
|
||||
labels = relationship("LabelORM", back_populates="meta")
|
||||
|
||||
label_id = Column(String(100), nullable=False, index=True)
|
||||
label_index = Column(String(100), nullable=False, index=True)
|
||||
__table_args__ = (ForeignKeyConstraint([label_id, label_index],
|
||||
[LabelORM.id, LabelORM.index],
|
||||
ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore
|
||||
|
||||
|
||||
|
||||
class SQLDocumentStore(BaseDocumentStore):
|
||||
@ -96,7 +89,8 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
index: str = "document",
|
||||
label_index: str = "label",
|
||||
duplicate_documents: str = "overwrite",
|
||||
check_same_thread: bool = False
|
||||
check_same_thread: bool = False,
|
||||
isolation_level: str = None
|
||||
):
|
||||
"""
|
||||
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
|
||||
@ -112,18 +106,21 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
fail: an error is raised if the document ID of the document being added already
|
||||
exists.
|
||||
:param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
|
||||
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
url=url, index=index, label_index=label_index, duplicate_documents=duplicate_documents, check_same_thread=check_same_thread
|
||||
)
|
||||
|
||||
create_engine_params = {}
|
||||
if isolation_level:
|
||||
create_engine_params["isolation_level"] = isolation_level
|
||||
if "sqlite" in url:
|
||||
engine = create_engine(url, connect_args={'check_same_thread': check_same_thread})
|
||||
engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}, **create_engine_params)
|
||||
else:
|
||||
engine = create_engine(url)
|
||||
ORMBase.metadata.create_all(engine)
|
||||
engine = create_engine(url, **create_engine_params)
|
||||
Base.metadata.create_all(engine)
|
||||
Session = sessionmaker(bind=engine)
|
||||
self.session = Session()
|
||||
self.index: str = index
|
||||
@ -461,12 +458,14 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
|
||||
self.session.commit()
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id
|
||||
"""
|
||||
self.session.query(MetaDocumentORM).filter_by(document_id=id).delete()
|
||||
meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id) for key, value in meta.items()]
|
||||
if not index:
|
||||
index = self.index
|
||||
self.session.query(MetaDocumentORM).filter_by(document_id=id, document_index=index).delete()
|
||||
meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id, document_index=index) for key, value in meta.items()]
|
||||
for m in meta_orms:
|
||||
self.session.add(m)
|
||||
self.session.commit()
|
||||
|
@ -486,11 +486,13 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
progress_bar.update(batch_size)
|
||||
progress_bar.close()
|
||||
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str]):
|
||||
def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
|
||||
"""
|
||||
Update the metadata dictionary of a document by specifying its string id.
|
||||
"""
|
||||
self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id)
|
||||
if not index:
|
||||
index = self.index
|
||||
self.weaviate_client.data_object.update(meta, class_name=index, uuid=id)
|
||||
|
||||
def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
|
109
test/conftest.py
109
test/conftest.py
@ -3,6 +3,9 @@ import time
|
||||
from subprocess import run
|
||||
from sys import platform
|
||||
import gc
|
||||
import uuid
|
||||
import logging
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@ -40,6 +43,11 @@ from haystack.nodes.translator import TransformersTranslator
|
||||
from haystack.nodes.question_generator import QuestionGenerator
|
||||
|
||||
|
||||
# To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below
|
||||
SQL_TYPE = "sqlite"
|
||||
# SQL_TYPE = "postgres"
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--document_store_type", action="store", default="elasticsearch, faiss, memory, milvus, weaviate")
|
||||
|
||||
@ -477,58 +485,112 @@ def get_retriever(retriever_type, document_store):
|
||||
return retriever
|
||||
|
||||
|
||||
def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
|
||||
# Weaviate currently only supports UUIDs
|
||||
if type(document_store)==WeaviateDocumentStore:
|
||||
for d in docs:
|
||||
d["id"] = str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
|
||||
def document_store_with_docs(request, test_docs_xs):
|
||||
def document_store_with_docs(request, test_docs_xs, tmp_path):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0])
|
||||
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path)
|
||||
document_store.write_documents(test_docs_xs)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_store(request):
|
||||
def document_store(request, tmp_path):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0])
|
||||
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
|
||||
def document_store_dot_product(request):
|
||||
def document_store_dot_product(request, tmp_path):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
|
||||
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
|
||||
def document_store_dot_product_with_docs(request, test_docs_xs):
|
||||
def document_store_dot_product_with_docs(request, test_docs_xs, tmp_path):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
|
||||
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
|
||||
document_store.write_documents(test_docs_xs)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"])
|
||||
def document_store_dot_product_small(request):
|
||||
def document_store_dot_product_small(request, tmp_path):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
|
||||
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
|
||||
def document_store_small(request):
|
||||
def document_store_small(request, tmp_path):
|
||||
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
|
||||
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine")
|
||||
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="cosine", tmp_path=tmp_path)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def postgres_fixture():
|
||||
if SQL_TYPE == "postgres":
|
||||
setup_postgres()
|
||||
yield
|
||||
teardown_postgres()
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sql_url(tmp_path):
|
||||
return get_sql_url(tmp_path)
|
||||
|
||||
|
||||
def get_sql_url(tmp_path):
|
||||
if SQL_TYPE == "postgres":
|
||||
return "postgresql://postgres:postgres@127.0.0.1/postgres"
|
||||
else:
|
||||
return f"sqlite:///{tmp_path}/haystack_test.db"
|
||||
|
||||
|
||||
def setup_postgres():
|
||||
# status = subprocess.run(["docker run --name postgres_test -d -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres"], shell=True)
|
||||
# if status.returncode:
|
||||
# logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.")
|
||||
# else:
|
||||
# sleep(5)
|
||||
engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT')
|
||||
|
||||
with engine.connect() as connection:
|
||||
try:
|
||||
connection.execute(text('DROP SCHEMA public CASCADE'))
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
connection.execute(text('CREATE SCHEMA public;'))
|
||||
connection.execute(text('SET SESSION idle_in_transaction_session_timeout = "1s";'))
|
||||
|
||||
|
||||
def teardown_postgres():
|
||||
engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT')
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text('DROP SCHEMA public CASCADE'))
|
||||
connection.close()
|
||||
|
||||
|
||||
def get_document_store(document_store_type, tmp_path, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
|
||||
if document_store_type == "sql":
|
||||
document_store = SQLDocumentStore(url="sqlite://", index=index)
|
||||
document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT")
|
||||
|
||||
elif document_store_type == "memory":
|
||||
document_store = InMemoryDocumentStore(
|
||||
return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity
|
||||
)
|
||||
return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity)
|
||||
|
||||
elif document_store_type == "elasticsearch":
|
||||
# make sure we start from a fresh index
|
||||
client = Elasticsearch()
|
||||
@ -536,28 +598,33 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
|
||||
document_store = ElasticsearchDocumentStore(
|
||||
index=index, return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, similarity=similarity
|
||||
)
|
||||
|
||||
elif document_store_type == "faiss":
|
||||
document_store = FAISSDocumentStore(
|
||||
embedding_dim=embedding_dim,
|
||||
sql_url="sqlite://",
|
||||
sql_url=get_sql_url(tmp_path),
|
||||
return_embedding=True,
|
||||
embedding_field=embedding_field,
|
||||
index=index,
|
||||
similarity=similarity
|
||||
similarity=similarity,
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
|
||||
elif document_store_type == "milvus":
|
||||
document_store = MilvusDocumentStore(
|
||||
embedding_dim=embedding_dim,
|
||||
sql_url="sqlite://",
|
||||
sql_url=get_sql_url(tmp_path),
|
||||
return_embedding=True,
|
||||
embedding_field=embedding_field,
|
||||
index=index,
|
||||
similarity=similarity
|
||||
similarity=similarity,
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
_, collections = document_store.milvus_server.list_collections()
|
||||
for collection in collections:
|
||||
if collection.startswith(index):
|
||||
document_store.milvus_server.drop_collection(collection)
|
||||
|
||||
elif document_store_type == "weaviate":
|
||||
document_store = WeaviateDocumentStore(
|
||||
weaviate_url="http://localhost:8080",
|
||||
|
@ -1,4 +1,6 @@
|
||||
from unittest import mock
|
||||
import uuid
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
@ -7,7 +9,7 @@ from elasticsearch import Elasticsearch
|
||||
from elasticsearch.exceptions import RequestError
|
||||
|
||||
|
||||
from conftest import get_document_store
|
||||
from conftest import get_document_store, ensure_ids_are_correct_uuids
|
||||
from haystack.document_stores import WeaviateDocumentStore
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.errors import DuplicateDocumentError
|
||||
@ -17,6 +19,18 @@ from haystack.document_stores.faiss import FAISSDocumentStore
|
||||
from haystack.nodes import EmbeddingRetriever
|
||||
from haystack.pipelines import DocumentSearchPipeline
|
||||
|
||||
|
||||
DOCUMENTS = [
|
||||
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
{"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_init_elastic_client():
|
||||
# defaults
|
||||
@ -148,8 +162,8 @@ def test_get_all_documents_with_correct_filters(document_store_with_docs):
|
||||
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
|
||||
|
||||
|
||||
def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs):
|
||||
document_store_with_docs = get_document_store("sql")
|
||||
def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs, tmp_path):
|
||||
document_store_with_docs = get_document_store("sql", tmp_path)
|
||||
document_store_with_docs.write_documents(test_docs_xs)
|
||||
|
||||
document_store_with_docs.use_windowed_query = False
|
||||
@ -791,7 +805,7 @@ def test_multilabel_no_answer(document_store):
|
||||
assert len(multi_labels[0].answers) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True)
|
||||
# Currently update_document_meta() is not implemented for Memory doc store
|
||||
def test_update_meta(document_store):
|
||||
documents = [
|
||||
@ -818,10 +832,8 @@ def test_update_meta(document_store):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
|
||||
def test_custom_embedding_field(document_store_type):
|
||||
document_store = get_document_store(
|
||||
document_store_type=document_store_type, embedding_field="custom_embedding_field"
|
||||
)
|
||||
def test_custom_embedding_field(document_store_type, tmp_path):
|
||||
document_store = get_document_store(document_store_type=document_store_type, tmp_path=tmp_path, embedding_field="custom_embedding_field")
|
||||
doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
|
||||
document_store.write_documents([doc_to_write])
|
||||
documents = document_store.get_all_documents(return_embedding=True)
|
||||
@ -993,4 +1005,5 @@ def test_custom_headers(document_store_with_docs: BaseDocumentStore):
|
||||
args, kwargs = mock_client.search.call_args
|
||||
assert "headers" in kwargs
|
||||
assert kwargs["headers"] == custom_headers
|
||||
assert len(documents) > 0
|
||||
assert len(documents) > 0
|
||||
|
||||
|
@ -13,6 +13,9 @@ from haystack.document_stores.weaviate import WeaviateDocumentStore
|
||||
from haystack.pipelines import Pipeline
|
||||
from haystack.nodes.retriever.dense import EmbeddingRetriever
|
||||
|
||||
from conftest import ensure_ids_are_correct_uuids
|
||||
|
||||
|
||||
DOCUMENTS = [
|
||||
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
@ -23,12 +26,14 @@ DOCUMENTS = [
|
||||
]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
|
||||
def test_faiss_index_save_and_load(tmp_path):
|
||||
def test_faiss_index_save_and_load(tmp_path, sql_url):
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
|
||||
sql_url=sql_url,
|
||||
index="haystack_test",
|
||||
progress_bar=False # Just to check if the init parameters are kept
|
||||
progress_bar=False, # Just to check if the init parameters are kept
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
|
||||
@ -74,11 +79,12 @@ def test_faiss_index_save_and_load(tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
|
||||
def test_faiss_index_save_and_load_custom_path(tmp_path):
|
||||
def test_faiss_index_save_and_load_custom_path(tmp_path, sql_url):
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
|
||||
sql_url=sql_url,
|
||||
index="haystack_test",
|
||||
progress_bar=False # Just to check if the init parameters are kept
|
||||
progress_bar=False, # Just to check if the init parameters are kept
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
|
||||
@ -128,13 +134,15 @@ def test_faiss_index_mutual_exclusive_args(tmp_path):
|
||||
with pytest.raises(ValueError):
|
||||
FAISSDocumentStore(
|
||||
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
|
||||
faiss_index_path=f"{tmp_path/'haystack_test'}"
|
||||
faiss_index_path=f"{tmp_path/'haystack_test'}",
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
FAISSDocumentStore(
|
||||
f"sqlite:////{tmp_path/'haystack_test.db'}",
|
||||
faiss_index_path=f"{tmp_path/'haystack_test'}"
|
||||
faiss_index_path=f"{tmp_path/'haystack_test'}",
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
|
||||
|
||||
@ -227,7 +235,9 @@ def test_update_with_empty_store(document_store, retriever):
|
||||
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
|
||||
def test_faiss_retrieving(index_factory, tmp_path):
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", faiss_index_factory_str=index_factory
|
||||
sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}",
|
||||
faiss_index_factory_str=index_factory,
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
|
||||
document_store.delete_all_documents(index="document")
|
||||
@ -396,7 +406,6 @@ def test_get_docs_with_many_filters(document_store, retriever):
|
||||
assert "2020" == documents[0].meta["year"]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
def test_pipeline(document_store, retriever):
|
||||
@ -423,7 +432,9 @@ def test_faiss_passing_index_from_outside(tmp_path):
|
||||
faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
|
||||
faiss_index.nprobe = 2
|
||||
document_store = FAISSDocumentStore(
|
||||
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", faiss_index=faiss_index, index=index
|
||||
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}",
|
||||
faiss_index=faiss_index, index=index,
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
|
||||
document_store.delete_documents()
|
||||
@ -437,12 +448,8 @@ def test_faiss_passing_index_from_outside(tmp_path):
|
||||
for doc in documents_indexed:
|
||||
assert 0 <= int(doc.meta["vector_id"]) <= 7
|
||||
|
||||
def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
|
||||
# Weaviate currently only supports UUIDs
|
||||
if type(document_store)==WeaviateDocumentStore:
|
||||
for d in docs:
|
||||
d["id"] = str(uuid.uuid4())
|
||||
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus", "weaviate"], indirect=True)
|
||||
def test_cosine_similarity(document_store):
|
||||
# below we will write documents to the store and then query it to see if vectors were normalized
|
||||
|
||||
@ -487,6 +494,7 @@ def test_cosine_similarity(document_store):
|
||||
assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_dot_product_small", ["faiss", "milvus"], indirect=True)
|
||||
def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
|
||||
VEC_1 = np.array([.1, .2, .3], dtype="float32")
|
||||
document_store_dot_product_small.normalize_embedding(VEC_1)
|
||||
@ -497,6 +505,7 @@ def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
|
||||
assert np.linalg.norm(VEC_1) - 1 < 0.01
|
||||
|
||||
|
||||
@pytest.mark.parametrize("document_store_small", ["faiss", "milvus", "weaviate"], indirect=True)
|
||||
def test_cosine_sanity_check(document_store_small):
|
||||
VEC_1 = np.array([.1, .2, .3], dtype="float32")
|
||||
VEC_2 = np.array([.4, .5, .6], dtype="float32")
|
||||
@ -512,4 +521,4 @@ def test_cosine_sanity_check(document_store_small):
|
||||
query_results = document_store_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True)
|
||||
|
||||
# check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318
|
||||
assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002)
|
||||
assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002)
|
@ -30,16 +30,16 @@ DOCUMENTS_XS = [
|
||||
|
||||
|
||||
@pytest.fixture(params=["weaviate"])
|
||||
def document_store_with_docs(request):
|
||||
document_store = get_document_store(request.param)
|
||||
def document_store_with_docs(request, tmp_path):
|
||||
document_store = get_document_store(request.param, tmp_path=tmp_path)
|
||||
document_store.write_documents(DOCUMENTS_XS)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
|
||||
@pytest.fixture(params=["weaviate"])
|
||||
def document_store(request):
|
||||
document_store = get_document_store(request.param)
|
||||
def document_store(request, tmp_path):
|
||||
document_store = get_document_store(request.param, tmp_path=tmp_path)
|
||||
yield document_store
|
||||
document_store.delete_documents()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user