Allow SQLDocumentStore to filter by many filters (#1776)

* Aliasing the join is not sufficient yet

* Update the filter query in some other functions of SQLDocumentStore - this functionality should be centralized

* Adding tests for get_all_documents, now failing

* Fix tests

* Fix typo spotted by mypy
This commit is contained in:
Sara Zan 2021-12-01 16:16:17 +01:00 committed by GitHub
parent c5540d05ed
commit e39d015a59
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 21 deletions

View File

@ -180,7 +180,7 @@ class SQLDocumentStore(BaseDocumentStore):
filters: Optional[Dict[str, List[str]]] = None,
return_embedding: Optional[bool] = None,
) -> List[Document]:
documents = list(self.get_all_documents_generator(index=index, filters=filters))
documents = list(self.get_all_documents_generator(index=index, filters=filters, return_embedding=return_embedding))
return documents
def get_all_documents_generator(
@ -202,7 +202,6 @@ class SQLDocumentStore(BaseDocumentStore):
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
"""
if return_embedding is True:
raise Exception("return_embeddings is not supported by SQLDocumentStore.")
result = self._query(
@ -241,13 +240,14 @@ class SQLDocumentStore(BaseDocumentStore):
).filter_by(index=index)
if filters:
documents_query = documents_query.join(MetaDocumentORM)
for key, values in filters.items():
documents_query = documents_query.filter(
MetaDocumentORM.name == key,
MetaDocumentORM.value.in_(values),
DocumentORM.id == MetaDocumentORM.document_id
)
documents_query = documents_query. \
join(MetaDocumentORM, aliased=True). \
filter(
MetaDocumentORM.name == key,
MetaDocumentORM.value.in_(values),
)
if only_documents_without_embedding:
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))
if vector_ids:
@ -450,9 +450,13 @@ class SQLDocumentStore(BaseDocumentStore):
query = self.session.query(DocumentORM).filter_by(index=index)
if filters:
query = query.join(MetaDocumentORM)
for key, values in filters.items():
query = query.filter(MetaDocumentORM.name == key, MetaDocumentORM.value.in_(values))
query = query. \
join(MetaDocumentORM, aliased=True). \
filter(
MetaDocumentORM.name == key,
MetaDocumentORM.value.in_(values),
)
count = query.count()
return count
@ -544,11 +548,12 @@ class SQLDocumentStore(BaseDocumentStore):
document_ids_to_delete = self.session.query(DocumentORM.id).filter(DocumentORM.index==index)
if filters:
for key, values in filters.items():
document_ids_to_delete = document_ids_to_delete.filter(
MetaDocumentORM.name == key,
MetaDocumentORM.value.in_(values),
DocumentORM.id == MetaDocumentORM.document_id
)
document_ids_to_delete = document_ids_to_delete. \
join(MetaDocumentORM, aliased=True). \
filter(
MetaDocumentORM.name == key,
MetaDocumentORM.value.in_(values),
)
if ids:
document_ids_to_delete = document_ids_to_delete.filter(DocumentORM.id.in_(ids))

View File

@ -14,12 +14,12 @@ from haystack.pipelines import Pipeline
from haystack.nodes.retriever.dense import EmbeddingRetriever
DOCUMENTS = [
{"name": "name_1", "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_2", "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_3", "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
{"name": "name_4", "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_5", "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
{"name": "name_6", "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
{"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
{"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
]
@ -252,6 +252,38 @@ def test_delete_docs_with_filters(document_store, retriever):
assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"}
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_delete_docs_with_filters(document_store, retriever):
document_store.write_documents(DOCUMENTS)
document_store.update_embeddings(retriever=retriever, batch_size=4)
assert document_store.get_embedding_count() == 6
document_store.delete_documents(filters={"year": ["2020"]})
documents = document_store.get_all_documents()
assert len(documents) == 3
assert document_store.get_embedding_count() == 3
assert all("2021" == doc.meta["year"] for doc in documents)
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_delete_docs_with_many_filters(document_store, retriever):
document_store.write_documents(DOCUMENTS)
document_store.update_embeddings(retriever=retriever, batch_size=4)
assert document_store.get_embedding_count() == 6
document_store.delete_documents(filters={"month": ["01"], "year": ["2020"]})
documents = document_store.get_all_documents()
assert len(documents) == 5
assert document_store.get_embedding_count() == 5
assert "name_1" not in {doc.meta["name"] for doc in documents}
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
@ -296,6 +328,50 @@ def test_delete_docs_by_id_with_filters(document_store, retriever):
assert all(doc_id in all_ids_left for doc_id in ids_not_to_delete)
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_get_docs_with_filters_one_value(document_store, retriever):
document_store.write_documents(DOCUMENTS)
document_store.update_embeddings(retriever=retriever, batch_size=4)
assert document_store.get_embedding_count() == 6
documents = document_store.get_all_documents(filters={"year": ["2020"]})
assert len(documents) == 3
assert all("2020" == doc.meta["year"] for doc in documents)
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_get_docs_with_filters_many_values(document_store, retriever):
document_store.write_documents(DOCUMENTS)
document_store.update_embeddings(retriever=retriever, batch_size=4)
assert document_store.get_embedding_count() == 6
documents = document_store.get_all_documents(filters={"name": ["name_5", "name_6"]})
assert len(documents) == 2
assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"}
@pytest.mark.slow
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
def test_get_docs_with_many_filters(document_store, retriever):
document_store.write_documents(DOCUMENTS)
document_store.update_embeddings(retriever=retriever, batch_size=4)
assert document_store.get_embedding_count() == 6
documents = document_store.get_all_documents(filters={"month": ["01"], "year": ["2020"]})
assert len(documents) == 1
assert "name_1" == documents[0].meta["name"]
assert "01" == documents[0].meta["month"]
assert "2020" == documents[0].meta["year"]
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)