mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-29 07:59:27 +00:00
Allow SQLDocumentStore to filter by many filters (#1776)
* Aliasing the join is not sufficient yet * Update the filter query in some other functions of SQLDocumentStore - this functionality should be centralized * Adding tests for get_all_documents, now failing * Fix tests * Fix typo spotted by mypy
This commit is contained in:
parent
c5540d05ed
commit
e39d015a59
@ -180,7 +180,7 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
filters: Optional[Dict[str, List[str]]] = None,
|
||||
return_embedding: Optional[bool] = None,
|
||||
) -> List[Document]:
|
||||
documents = list(self.get_all_documents_generator(index=index, filters=filters))
|
||||
documents = list(self.get_all_documents_generator(index=index, filters=filters, return_embedding=return_embedding))
|
||||
return documents
|
||||
|
||||
def get_all_documents_generator(
|
||||
@ -202,7 +202,6 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
:param return_embedding: Whether to return the document embeddings.
|
||||
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
|
||||
"""
|
||||
|
||||
if return_embedding is True:
|
||||
raise Exception("return_embeddings is not supported by SQLDocumentStore.")
|
||||
result = self._query(
|
||||
@ -241,13 +240,14 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
).filter_by(index=index)
|
||||
|
||||
if filters:
|
||||
documents_query = documents_query.join(MetaDocumentORM)
|
||||
for key, values in filters.items():
|
||||
documents_query = documents_query.filter(
|
||||
MetaDocumentORM.name == key,
|
||||
MetaDocumentORM.value.in_(values),
|
||||
DocumentORM.id == MetaDocumentORM.document_id
|
||||
)
|
||||
documents_query = documents_query. \
|
||||
join(MetaDocumentORM, aliased=True). \
|
||||
filter(
|
||||
MetaDocumentORM.name == key,
|
||||
MetaDocumentORM.value.in_(values),
|
||||
)
|
||||
|
||||
if only_documents_without_embedding:
|
||||
documents_query = documents_query.filter(DocumentORM.vector_id.is_(None))
|
||||
if vector_ids:
|
||||
@ -450,9 +450,13 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
query = self.session.query(DocumentORM).filter_by(index=index)
|
||||
|
||||
if filters:
|
||||
query = query.join(MetaDocumentORM)
|
||||
for key, values in filters.items():
|
||||
query = query.filter(MetaDocumentORM.name == key, MetaDocumentORM.value.in_(values))
|
||||
query = query. \
|
||||
join(MetaDocumentORM, aliased=True). \
|
||||
filter(
|
||||
MetaDocumentORM.name == key,
|
||||
MetaDocumentORM.value.in_(values),
|
||||
)
|
||||
|
||||
count = query.count()
|
||||
return count
|
||||
@ -544,11 +548,12 @@ class SQLDocumentStore(BaseDocumentStore):
|
||||
document_ids_to_delete = self.session.query(DocumentORM.id).filter(DocumentORM.index==index)
|
||||
if filters:
|
||||
for key, values in filters.items():
|
||||
document_ids_to_delete = document_ids_to_delete.filter(
|
||||
MetaDocumentORM.name == key,
|
||||
MetaDocumentORM.value.in_(values),
|
||||
DocumentORM.id == MetaDocumentORM.document_id
|
||||
)
|
||||
document_ids_to_delete = document_ids_to_delete. \
|
||||
join(MetaDocumentORM, aliased=True). \
|
||||
filter(
|
||||
MetaDocumentORM.name == key,
|
||||
MetaDocumentORM.value.in_(values),
|
||||
)
|
||||
if ids:
|
||||
document_ids_to_delete = document_ids_to_delete.filter(DocumentORM.id.in_(ids))
|
||||
|
||||
|
||||
@ -14,12 +14,12 @@ from haystack.pipelines import Pipeline
|
||||
from haystack.nodes.retriever.dense import EmbeddingRetriever
|
||||
|
||||
DOCUMENTS = [
|
||||
{"name": "name_1", "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_2", "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_3", "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
{"name": "name_4", "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_5", "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"name": "name_6", "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
{"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
{"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)},
|
||||
{"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)},
|
||||
]
|
||||
|
||||
|
||||
@ -252,6 +252,38 @@ def test_delete_docs_with_filters(document_store, retriever):
|
||||
assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"}
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
def test_delete_docs_with_filters(document_store, retriever):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
document_store.update_embeddings(retriever=retriever, batch_size=4)
|
||||
assert document_store.get_embedding_count() == 6
|
||||
|
||||
document_store.delete_documents(filters={"year": ["2020"]})
|
||||
|
||||
documents = document_store.get_all_documents()
|
||||
assert len(documents) == 3
|
||||
assert document_store.get_embedding_count() == 3
|
||||
assert all("2021" == doc.meta["year"] for doc in documents)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
def test_delete_docs_with_many_filters(document_store, retriever):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
document_store.update_embeddings(retriever=retriever, batch_size=4)
|
||||
assert document_store.get_embedding_count() == 6
|
||||
|
||||
document_store.delete_documents(filters={"month": ["01"], "year": ["2020"]})
|
||||
|
||||
documents = document_store.get_all_documents()
|
||||
assert len(documents) == 5
|
||||
assert document_store.get_embedding_count() == 5
|
||||
assert "name_1" not in {doc.meta["name"] for doc in documents}
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
@ -296,6 +328,50 @@ def test_delete_docs_by_id_with_filters(document_store, retriever):
|
||||
assert all(doc_id in all_ids_left for doc_id in ids_not_to_delete)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
def test_get_docs_with_filters_one_value(document_store, retriever):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
document_store.update_embeddings(retriever=retriever, batch_size=4)
|
||||
assert document_store.get_embedding_count() == 6
|
||||
|
||||
documents = document_store.get_all_documents(filters={"year": ["2020"]})
|
||||
|
||||
assert len(documents) == 3
|
||||
assert all("2020" == doc.meta["year"] for doc in documents)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
def test_get_docs_with_filters_many_values(document_store, retriever):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
document_store.update_embeddings(retriever=retriever, batch_size=4)
|
||||
assert document_store.get_embedding_count() == 6
|
||||
|
||||
documents = document_store.get_all_documents(filters={"name": ["name_5", "name_6"]})
|
||||
|
||||
assert len(documents) == 2
|
||||
assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"}
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("retriever", ["dpr"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
def test_get_docs_with_many_filters(document_store, retriever):
|
||||
document_store.write_documents(DOCUMENTS)
|
||||
document_store.update_embeddings(retriever=retriever, batch_size=4)
|
||||
assert document_store.get_embedding_count() == 6
|
||||
|
||||
documents = document_store.get_all_documents(filters={"month": ["01"], "year": ["2020"]})
|
||||
|
||||
assert len(documents) == 1
|
||||
assert "name_1" == documents[0].meta["name"]
|
||||
assert "01" == documents[0].meta["month"]
|
||||
assert "2020" == documents[0].meta["year"]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("retriever", ["embedding"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user