Implement proper FK in MetaDocumentORM and MetaLabelORM to work on PostgreSQL (#1990)

* Properly fix MetaDocumentORM and MetaLabelORM with composite foreign key constraints

* update_document_meta() was not using index properly

* Exclude ES and Memory from the cosine_sanity_check test

* move ensure_ids_are_correct_uuids in conftest and move one test back to faiss & milvus suite

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Sara Zan 2022-01-14 13:48:58 +01:00 committed by GitHub
parent 3e4dbbb32c
commit e28bf618d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 208 additions and 101 deletions

View File

@ -380,7 +380,7 @@ Write annotation labels into document store.
#### update\_document\_meta
```python
| update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None)
| update_document_meta(id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None)
```
Update the metadata dictionary of a document by specifying its string id
@ -952,7 +952,7 @@ class SQLDocumentStore(BaseDocumentStore)
#### \_\_init\_\_
```python
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False)
| __init__(url: str = "sqlite://", index: str = "document", label_index: str = "label", duplicate_documents: str = "overwrite", check_same_thread: bool = False, isolation_level: str = None)
```
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
@ -970,6 +970,7 @@ An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL bac
fail: an error is raised if the document ID of the document being added already
exists.
- `check_same_thread`: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
<a name="sql.SQLDocumentStore.get_document_by_id"></a>
#### get\_document\_by\_id
@ -1094,7 +1095,7 @@ Set vector IDs for all documents as None
#### update\_document\_meta
```python
| update_document_meta(id: str, meta: Dict[str, str])
| update_document_meta(id: str, meta: Dict[str, str], index: str = None)
```
Update the metadata dictionary of a document by specifying its string id
@ -1202,7 +1203,7 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_
```python
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, isolation_level: str = None, **kwargs, ,)
```
**Arguments**:
@ -1248,6 +1249,7 @@ the vector embeddings are indexed in a FAISS Index.
If specified no other params besides faiss_config_path must be specified.
- `faiss_config_path`: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
<a name="faiss.FAISSDocumentStore.write_documents"></a>
#### write\_documents
@ -1479,7 +1481,7 @@ Usage:
#### \_\_init\_\_
```python
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', isolation_level: str = None, **kwargs, ,)
```
**Arguments**:
@ -1525,6 +1527,7 @@ Note that an overly large index_file_size value may cause failure to load a segm
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
- `isolation_level`: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
<a name="milvus.MilvusDocumentStore.write_documents"></a>
#### write\_documents
@ -1863,7 +1866,7 @@ None
#### update\_document\_meta
```python
| update_document_meta(id: str, meta: Dict[str, str])
| update_document_meta(id: str, meta: Dict[str, str], index: str = None)
```
Update the metadata dictionary of a document by specifying its string id.

View File

@ -551,12 +551,14 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
if labels_to_index:
bulk(self.client, labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers)
def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None):
def update_document_meta(self, id: str, meta: Dict[str, str], headers: Optional[Dict[str, str]] = None, index: str = None):
"""
Update the metadata dictionary of a document by specifying its string id
"""
if not index:
index = self.index
body = {"doc": meta}
self.client.update(index=self.index, id=id, body=body, refresh=self.refresh_type, headers=headers)
self.client.update(index=index, id=id, body=body, refresh=self.refresh_type, headers=headers)
def get_document_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None,
only_documents_without_embedding: bool = False, headers: Optional[Dict[str, str]] = None) -> int:

View File

@ -50,6 +50,7 @@ class FAISSDocumentStore(SQLDocumentStore):
duplicate_documents: str = 'overwrite',
faiss_index_path: Union[str, Path] = None,
faiss_config_path: Union[str, Path] = None,
isolation_level: str = None,
**kwargs,
):
"""
@ -94,6 +95,7 @@ class FAISSDocumentStore(SQLDocumentStore):
If specified no other params besides faiss_config_path must be specified.
:param faiss_config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# special case if we want to load an existing index from disk
# load init params from disk and run init again
@ -115,7 +117,8 @@ class FAISSDocumentStore(SQLDocumentStore):
index=index,
similarity=similarity,
embedding_field=embedding_field,
progress_bar=progress_bar
progress_bar=progress_bar,
isolation_level=isolation_level
)
if similarity in ("dot_product", "cosine"):
@ -155,7 +158,8 @@ class FAISSDocumentStore(SQLDocumentStore):
super().__init__(
url=sql_url,
index=index,
duplicate_documents=duplicate_documents
duplicate_documents=duplicate_documents,
isolation_level=isolation_level
)
self._validate_index_sync()

View File

@ -53,6 +53,7 @@ class MilvusDocumentStore(SQLDocumentStore):
embedding_field: str = "embedding",
progress_bar: bool = True,
duplicate_documents: str = 'overwrite',
isolation_level: str = None,
**kwargs,
):
"""
@ -97,6 +98,7 @@ class MilvusDocumentStore(SQLDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# save init parameters to enable export of component config as YAML
self.set_config(
@ -104,6 +106,7 @@ class MilvusDocumentStore(SQLDocumentStore):
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
search_param=search_param, duplicate_documents=duplicate_documents,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
isolation_level=isolation_level
)
self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
@ -139,7 +142,8 @@ class MilvusDocumentStore(SQLDocumentStore):
super().__init__(
url=sql_url,
index=index,
duplicate_documents=duplicate_documents
duplicate_documents=duplicate_documents,
isolation_level=isolation_level,
)
def __del__(self):

View File

@ -73,6 +73,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
custom_fields: Optional[List[Any]] = None,
progress_bar: bool = True,
duplicate_documents: str = 'overwrite',
isolation_level: str = None
):
"""
:param sql_url: SQL connection URL for storing document texts and metadata. It defaults to a local, file based SQLite DB. For large scale
@ -118,6 +119,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
overwrite: Update any existing documents with the same ID when adding documents.
fail: an error is raised if the document ID of the document being added already
exists.
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# save init parameters to enable export of component config as YAML
@ -127,6 +129,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
custom_fields=custom_fields,
isolation_level=isolation_level
)
logger.warning("Milvus2DocumentStore is in experimental state until Milvus 2.0 is released")
@ -173,7 +176,8 @@ class Milvus2DocumentStore(SQLDocumentStore):
super().__init__(
url=sql_url,
index=index,
duplicate_documents=duplicate_documents
duplicate_documents=duplicate_documents,
isolation_level=isolation_level,
)
def _create_collection_and_index_if_not_exist(

View File

@ -4,7 +4,7 @@ import logging
import itertools
import numpy as np
from uuid import uuid4
from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON
from sqlalchemy import and_, func, create_engine, Column, String, DateTime, ForeignKey, Boolean, Text, text, JSON, ForeignKeyConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from sqlalchemy.sql import case, null
@ -33,9 +33,6 @@ class DocumentORM(ORMBase):
# primary key in combination with id to allow the same doc in different indices
index = Column(String(100), nullable=False, primary_key=True)
vector_id = Column(String(100), unique=True, nullable=True)
# labels = relationship("LabelORM", back_populates="document")
# speeds up queries for get_documents_by_vector_ids() by having a single query that returns joined metadata
meta = relationship("MetaDocumentORM", back_populates="documents", lazy="joined")
@ -45,36 +42,18 @@ class MetaDocumentORM(ORMBase):
name = Column(String(100), index=True)
value = Column(String(1000), index=True)
document_id = Column(
String(100),
ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"),
nullable=False,
index=True
)
documents = relationship("DocumentORM", back_populates="meta")
class MetaLabelORM(ORMBase):
__tablename__ = "meta_label"
name = Column(String(100), index=True)
value = Column(String(1000), index=True)
label_id = Column(
String(100),
ForeignKey("label.id", ondelete="CASCADE", onupdate="CASCADE"),
nullable=False,
index=True
)
labels = relationship("LabelORM", back_populates="meta")
document_id = Column(String(100), nullable=False, index=True)
document_index = Column(String(100), nullable=False, index=True)
__table_args__ = (ForeignKeyConstraint([document_id, document_index],
[DocumentORM.id, DocumentORM.index],
ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore
class LabelORM(ORMBase):
__tablename__ = "label"
# document_id = Column(String(100), ForeignKey("document.id", ondelete="CASCADE", onupdate="CASCADE"), nullable=False)
index = Column(String(100), nullable=False, primary_key=True)
query = Column(Text, nullable=False)
answer = Column(JSON, nullable=True)
@ -86,7 +65,21 @@ class LabelORM(ORMBase):
pipeline_id = Column(String(500), nullable=True)
meta = relationship("MetaLabelORM", back_populates="labels", lazy="joined")
# document = relationship("DocumentORM", back_populates="labels")
class MetaLabelORM(ORMBase):
__tablename__ = "meta_label"
name = Column(String(100), index=True)
value = Column(String(1000), index=True)
labels = relationship("LabelORM", back_populates="meta")
label_id = Column(String(100), nullable=False, index=True)
label_index = Column(String(100), nullable=False, index=True)
__table_args__ = (ForeignKeyConstraint([label_id, label_index],
[LabelORM.id, LabelORM.index],
ondelete="CASCADE", onupdate="CASCADE"), {}) #type: ignore
class SQLDocumentStore(BaseDocumentStore):
@ -96,7 +89,8 @@ class SQLDocumentStore(BaseDocumentStore):
index: str = "document",
label_index: str = "label",
duplicate_documents: str = "overwrite",
check_same_thread: bool = False
check_same_thread: bool = False,
isolation_level: str = None
):
"""
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
@ -112,18 +106,21 @@ class SQLDocumentStore(BaseDocumentStore):
fail: an error is raised if the document ID of the document being added already
exists.
:param check_same_thread: Set to False to mitigate multithreading issues in older SQLite versions (see https://docs.sqlalchemy.org/en/14/dialects/sqlite.html?highlight=check_same_thread#threading-pooling-behavior)
:param isolation_level: see SQLAlchemy's `isolation_level` parameter for `create_engine()` (https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.isolation_level)
"""
# save init parameters to enable export of component config as YAML
self.set_config(
url=url, index=index, label_index=label_index, duplicate_documents=duplicate_documents, check_same_thread=check_same_thread
)
create_engine_params = {}
if isolation_level:
create_engine_params["isolation_level"] = isolation_level
if "sqlite" in url:
engine = create_engine(url, connect_args={'check_same_thread': check_same_thread})
engine = create_engine(url, connect_args={'check_same_thread': check_same_thread}, **create_engine_params)
else:
engine = create_engine(url)
ORMBase.metadata.create_all(engine)
engine = create_engine(url, **create_engine_params)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
self.session = Session()
self.index: str = index
@ -461,12 +458,14 @@ class SQLDocumentStore(BaseDocumentStore):
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
self.session.commit()
def update_document_meta(self, id: str, meta: Dict[str, str]):
def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
"""
Update the metadata dictionary of a document by specifying its string id
"""
self.session.query(MetaDocumentORM).filter_by(document_id=id).delete()
meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id) for key, value in meta.items()]
if not index:
index = self.index
self.session.query(MetaDocumentORM).filter_by(document_id=id, document_index=index).delete()
meta_orms = [MetaDocumentORM(name=key, value=value, document_id=id, document_index=index) for key, value in meta.items()]
for m in meta_orms:
self.session.add(m)
self.session.commit()

View File

@ -486,11 +486,13 @@ class WeaviateDocumentStore(BaseDocumentStore):
progress_bar.update(batch_size)
progress_bar.close()
def update_document_meta(self, id: str, meta: Dict[str, str]):
def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
"""
Update the metadata dictionary of a document by specifying its string id.
"""
self.weaviate_client.data_object.update(meta, class_name=self.index, uuid=id)
if not index:
index = self.index
self.weaviate_client.data_object.update(meta, class_name=index, uuid=id)
def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""

View File

@ -3,6 +3,9 @@ import time
from subprocess import run
from sys import platform
import gc
import uuid
import logging
from sqlalchemy import create_engine, text
import numpy as np
import psutil
@ -40,6 +43,11 @@ from haystack.nodes.translator import TransformersTranslator
from haystack.nodes.question_generator import QuestionGenerator
# To manually run the tests with default PostgreSQL instead of SQLite, switch the lines below
SQL_TYPE = "sqlite"
# SQL_TYPE = "postgres"
def pytest_addoption(parser):
parser.addoption("--document_store_type", action="store", default="elasticsearch, faiss, memory, milvus, weaviate")
@ -477,58 +485,112 @@ def get_retriever(retriever_type, document_store):
return retriever
def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
# Weaviate currently only supports UUIDs
if type(document_store)==WeaviateDocumentStore:
for d in docs:
d["id"] = str(uuid.uuid4())
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
def document_store_with_docs(request, test_docs_xs):
def document_store_with_docs(request, test_docs_xs, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
document_store = get_document_store(request.param, embedding_dim.args[0])
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path)
document_store.write_documents(test_docs_xs)
yield document_store
document_store.delete_documents()
@pytest.fixture
def document_store(request):
def document_store(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
document_store = get_document_store(request.param, embedding_dim.args[0])
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], tmp_path=tmp_path)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
def document_store_dot_product(request):
def document_store_dot_product(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["memory", "faiss", "milvus", "elasticsearch"])
def document_store_dot_product_with_docs(request, test_docs_xs):
def document_store_dot_product_with_docs(request, test_docs_xs, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(768))
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
document_store.write_documents(test_docs_xs)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus"])
def document_store_dot_product_small(request):
def document_store_dot_product_small(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="dot_product")
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="dot_product", tmp_path=tmp_path)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["elasticsearch", "faiss", "memory", "milvus", "weaviate"])
def document_store_small(request):
def document_store_small(request, tmp_path):
embedding_dim = request.node.get_closest_marker("embedding_dim", pytest.mark.embedding_dim(3))
document_store = get_document_store(request.param, embedding_dim.args[0], similarity="cosine")
document_store = get_document_store(document_store_type=request.param, embedding_dim=embedding_dim.args[0], similarity="cosine", tmp_path=tmp_path)
yield document_store
document_store.delete_documents()
def get_document_store(document_store_type, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
@pytest.fixture(scope="function", autouse=True)
def postgres_fixture():
if SQL_TYPE == "postgres":
setup_postgres()
yield
teardown_postgres()
else:
yield
@pytest.fixture
def sql_url(tmp_path):
return get_sql_url(tmp_path)
def get_sql_url(tmp_path):
if SQL_TYPE == "postgres":
return "postgresql://postgres:postgres@127.0.0.1/postgres"
else:
return f"sqlite:///{tmp_path}/haystack_test.db"
def setup_postgres():
# status = subprocess.run(["docker run --name postgres_test -d -e POSTGRES_HOST_AUTH_METHOD=trust -p 5432:5432 postgres"], shell=True)
# if status.returncode:
# logging.warning("Tried to start PostgreSQL through Docker but this failed. It is likely that there is already an existing instance running.")
# else:
# sleep(5)
engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT')
with engine.connect() as connection:
try:
connection.execute(text('DROP SCHEMA public CASCADE'))
except Exception as e:
logging.error(e)
connection.execute(text('CREATE SCHEMA public;'))
connection.execute(text('SET SESSION idle_in_transaction_session_timeout = "1s";'))
def teardown_postgres():
engine = create_engine('postgresql://postgres:postgres@127.0.0.1/postgres', isolation_level='AUTOCOMMIT')
with engine.connect() as connection:
connection.execute(text('DROP SCHEMA public CASCADE'))
connection.close()
def get_document_store(document_store_type, tmp_path, embedding_dim=768, embedding_field="embedding", index="haystack_test", similarity:str="cosine"): # cosine is default similarity as dot product is not supported by Weaviate
if document_store_type == "sql":
document_store = SQLDocumentStore(url="sqlite://", index=index)
document_store = SQLDocumentStore(url=get_sql_url(tmp_path), index=index, isolation_level="AUTOCOMMIT")
elif document_store_type == "memory":
document_store = InMemoryDocumentStore(
return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity
)
return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, index=index, similarity=similarity)
elif document_store_type == "elasticsearch":
# make sure we start from a fresh index
client = Elasticsearch()
@ -536,28 +598,33 @@ def get_document_store(document_store_type, embedding_dim=768, embedding_field="
document_store = ElasticsearchDocumentStore(
index=index, return_embedding=True, embedding_dim=embedding_dim, embedding_field=embedding_field, similarity=similarity
)
elif document_store_type == "faiss":
document_store = FAISSDocumentStore(
embedding_dim=embedding_dim,
sql_url="sqlite://",
sql_url=get_sql_url(tmp_path),
return_embedding=True,
embedding_field=embedding_field,
index=index,
similarity=similarity
similarity=similarity,
isolation_level="AUTOCOMMIT"
)
elif document_store_type == "milvus":
document_store = MilvusDocumentStore(
embedding_dim=embedding_dim,
sql_url="sqlite://",
sql_url=get_sql_url(tmp_path),
return_embedding=True,
embedding_field=embedding_field,
index=index,
similarity=similarity
similarity=similarity,
isolation_level="AUTOCOMMIT"
)
_, collections = document_store.milvus_server.list_collections()
for collection in collections:
if collection.startswith(index):
document_store.milvus_server.drop_collection(collection)
elif document_store_type == "weaviate":
document_store = WeaviateDocumentStore(
weaviate_url="http://localhost:8080",

View File

@ -1,4 +1,6 @@
from unittest import mock
import uuid
import math
import numpy as np
import pandas as pd
import pytest
@ -7,7 +9,7 @@ from elasticsearch import Elasticsearch
from elasticsearch.exceptions import RequestError
from conftest import get_document_store
from conftest import get_document_store, ensure_ids_are_correct_uuids
from haystack.document_stores import WeaviateDocumentStore
from haystack.document_stores.base import BaseDocumentStore
from haystack.errors import DuplicateDocumentError
@ -17,6 +19,18 @@ from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever
from haystack.pipelines import DocumentSearchPipeline
DOCUMENTS = [
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
{"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
]
@pytest.mark.elasticsearch
def test_init_elastic_client():
# defaults
@ -148,8 +162,8 @@ def test_get_all_documents_with_correct_filters(document_store_with_docs):
assert {d.meta["meta_field"] for d in documents} == {"test1", "test3"}
def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs):
document_store_with_docs = get_document_store("sql")
def test_get_all_documents_with_correct_filters_legacy_sqlite(test_docs_xs, tmp_path):
document_store_with_docs = get_document_store("sql", tmp_path)
document_store_with_docs.write_documents(test_docs_xs)
document_store_with_docs.use_windowed_query = False
@ -791,7 +805,7 @@ def test_multilabel_no_answer(document_store):
assert len(multi_labels[0].answers) == 1
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss"], indirect=True)
@pytest.mark.parametrize("document_store", ["elasticsearch", "faiss", "milvus", "weaviate"], indirect=True)
# Currently update_document_meta() is not implemented for Memory doc store
def test_update_meta(document_store):
documents = [
@ -818,10 +832,8 @@ def test_update_meta(document_store):
@pytest.mark.parametrize("document_store_type", ["elasticsearch", "memory"])
def test_custom_embedding_field(document_store_type):
document_store = get_document_store(
document_store_type=document_store_type, embedding_field="custom_embedding_field"
)
def test_custom_embedding_field(document_store_type, tmp_path):
document_store = get_document_store(document_store_type=document_store_type, tmp_path=tmp_path, embedding_field="custom_embedding_field")
doc_to_write = {"content": "test", "custom_embedding_field": np.random.rand(768).astype(np.float32)}
document_store.write_documents([doc_to_write])
documents = document_store.get_all_documents(return_embedding=True)
@ -993,4 +1005,5 @@ def test_custom_headers(document_store_with_docs: BaseDocumentStore):
args, kwargs = mock_client.search.call_args
assert "headers" in kwargs
assert kwargs["headers"] == custom_headers
assert len(documents) > 0
assert len(documents) > 0

View File

@ -13,6 +13,9 @@ from haystack.document_stores.weaviate import WeaviateDocumentStore
from haystack.pipelines import Pipeline
from haystack.nodes.retriever.dense import EmbeddingRetriever
from conftest import ensure_ids_are_correct_uuids
DOCUMENTS = [
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
@ -23,12 +26,14 @@ DOCUMENTS = [
]
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
def test_faiss_index_save_and_load(tmp_path):
def test_faiss_index_save_and_load(tmp_path, sql_url):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
sql_url=sql_url,
index="haystack_test",
progress_bar=False # Just to check if the init parameters are kept
progress_bar=False, # Just to check if the init parameters are kept
isolation_level="AUTOCOMMIT"
)
document_store.write_documents(DOCUMENTS)
@ -74,11 +79,12 @@ def test_faiss_index_save_and_load(tmp_path):
@pytest.mark.skipif(sys.platform in ['win32', 'cygwin'], reason="Test with tmp_path not working on windows runner")
def test_faiss_index_save_and_load_custom_path(tmp_path):
def test_faiss_index_save_and_load_custom_path(tmp_path, sql_url):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
sql_url=sql_url,
index="haystack_test",
progress_bar=False # Just to check if the init parameters are kept
progress_bar=False, # Just to check if the init parameters are kept
isolation_level="AUTOCOMMIT"
)
document_store.write_documents(DOCUMENTS)
@ -128,13 +134,15 @@ def test_faiss_index_mutual_exclusive_args(tmp_path):
with pytest.raises(ValueError):
FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test.db'}",
faiss_index_path=f"{tmp_path/'haystack_test'}"
faiss_index_path=f"{tmp_path/'haystack_test'}",
isolation_level="AUTOCOMMIT"
)
with pytest.raises(ValueError):
FAISSDocumentStore(
f"sqlite:////{tmp_path/'haystack_test.db'}",
faiss_index_path=f"{tmp_path/'haystack_test'}"
faiss_index_path=f"{tmp_path/'haystack_test'}",
isolation_level="AUTOCOMMIT"
)
@ -227,7 +235,9 @@ def test_update_with_empty_store(document_store, retriever):
@pytest.mark.parametrize("index_factory", ["Flat", "HNSW", "IVF1,Flat"])
def test_faiss_retrieving(index_factory, tmp_path):
document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}", faiss_index_factory_str=index_factory
sql_url=f"sqlite:////{tmp_path/'test_faiss_retrieving.db'}",
faiss_index_factory_str=index_factory,
isolation_level="AUTOCOMMIT"
)
document_store.delete_all_documents(index="document")
@ -396,7 +406,6 @@ def test_get_docs_with_many_filters(document_store, retriever):
assert "2020" == documents[0].meta["year"]
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_pipeline(document_store, retriever):
@ -423,7 +432,9 @@ def test_faiss_passing_index_from_outside(tmp_path):
faiss_index.set_direct_map_type(faiss.DirectMap.Hashtable)
faiss_index.nprobe = 2
document_store = FAISSDocumentStore(
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}", faiss_index=faiss_index, index=index
sql_url=f"sqlite:////{tmp_path/'haystack_test_faiss.db'}",
faiss_index=faiss_index, index=index,
isolation_level="AUTOCOMMIT"
)
document_store.delete_documents()
@ -437,12 +448,8 @@ def test_faiss_passing_index_from_outside(tmp_path):
for doc in documents_indexed:
assert 0 <= int(doc.meta["vector_id"]) <= 7
def ensure_ids_are_correct_uuids(docs:list,document_store:object)->None:
# Weaviate currently only supports UUIDs
if type(document_store)==WeaviateDocumentStore:
for d in docs:
d["id"] = str(uuid.uuid4())
@pytest.mark.parametrize("document_store", ["faiss", "milvus", "weaviate"], indirect=True)
def test_cosine_similarity(document_store):
# below we will write documents to the store and then query it to see if vectors were normalized
@ -487,6 +494,7 @@ def test_cosine_similarity(document_store):
assert not np.allclose(original_emb[0], doc.embedding, rtol=0.01)
@pytest.mark.parametrize("document_store_dot_product_small", ["faiss", "milvus"], indirect=True)
def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
VEC_1 = np.array([.1, .2, .3], dtype="float32")
document_store_dot_product_small.normalize_embedding(VEC_1)
@ -497,6 +505,7 @@ def test_normalize_embeddings_diff_shapes(document_store_dot_product_small):
assert np.linalg.norm(VEC_1) - 1 < 0.01
@pytest.mark.parametrize("document_store_small", ["faiss", "milvus", "weaviate"], indirect=True)
def test_cosine_sanity_check(document_store_small):
VEC_1 = np.array([.1, .2, .3], dtype="float32")
VEC_2 = np.array([.4, .5, .6], dtype="float32")
@ -512,4 +521,4 @@ def test_cosine_sanity_check(document_store_small):
query_results = document_store_small.query_by_embedding(query_emb=VEC_2, top_k=1, return_embedding=True)
# check if faiss returns the same cosine similarity. Manual testing with faiss yielded 0.9746318
assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002)
assert math.isclose(query_results[0].score, KNOWN_COSINE, abs_tol=0.00002)

View File

@ -30,16 +30,16 @@ DOCUMENTS_XS = [
@pytest.fixture(params=["weaviate"])
def document_store_with_docs(request):
document_store = get_document_store(request.param)
def document_store_with_docs(request, tmp_path):
document_store = get_document_store(request.param, tmp_path=tmp_path)
document_store.write_documents(DOCUMENTS_XS)
yield document_store
document_store.delete_documents()
@pytest.fixture(params=["weaviate"])
def document_store(request):
document_store = get_document_store(request.param)
def document_store(request, tmp_path):
document_store = get_document_store(request.param, tmp_path=tmp_path)
yield document_store
document_store.delete_documents()