diff --git a/haystack/document_stores/sql.py b/haystack/document_stores/sql.py index 9b95490db..014c31d13 100644 --- a/haystack/document_stores/sql.py +++ b/haystack/document_stores/sql.py @@ -180,7 +180,7 @@ class SQLDocumentStore(BaseDocumentStore): filters: Optional[Dict[str, List[str]]] = None, return_embedding: Optional[bool] = None, ) -> List[Document]: - documents = list(self.get_all_documents_generator(index=index, filters=filters)) + documents = list(self.get_all_documents_generator(index=index, filters=filters, return_embedding=return_embedding)) return documents def get_all_documents_generator( @@ -202,7 +202,6 @@ class SQLDocumentStore(BaseDocumentStore): :param return_embedding: Whether to return the document embeddings. :param batch_size: When working with large number of documents, batching can help reduce memory footprint. """ - if return_embedding is True: raise Exception("return_embeddings is not supported by SQLDocumentStore.") result = self._query( @@ -241,13 +240,14 @@ class SQLDocumentStore(BaseDocumentStore): ).filter_by(index=index) if filters: - documents_query = documents_query.join(MetaDocumentORM) for key, values in filters.items(): - documents_query = documents_query.filter( - MetaDocumentORM.name == key, - MetaDocumentORM.value.in_(values), - DocumentORM.id == MetaDocumentORM.document_id - ) + documents_query = documents_query. \ + join(MetaDocumentORM, aliased=True). \ + filter( + MetaDocumentORM.name == key, + MetaDocumentORM.value.in_(values), + ) + if only_documents_without_embedding: documents_query = documents_query.filter(DocumentORM.vector_id.is_(None)) if vector_ids: @@ -450,9 +450,13 @@ class SQLDocumentStore(BaseDocumentStore): query = self.session.query(DocumentORM).filter_by(index=index) if filters: - query = query.join(MetaDocumentORM) for key, values in filters.items(): - query = query.filter(MetaDocumentORM.name == key, MetaDocumentORM.value.in_(values)) + query = query. \ + join(MetaDocumentORM, aliased=True). \ + filter( + MetaDocumentORM.name == key, + MetaDocumentORM.value.in_(values), + ) count = query.count() return count @@ -544,11 +548,12 @@ class SQLDocumentStore(BaseDocumentStore): document_ids_to_delete = self.session.query(DocumentORM.id).filter(DocumentORM.index==index) if filters: for key, values in filters.items(): - document_ids_to_delete = document_ids_to_delete.filter( - MetaDocumentORM.name == key, - MetaDocumentORM.value.in_(values), - DocumentORM.id == MetaDocumentORM.document_id - ) + document_ids_to_delete = document_ids_to_delete. \ + join(MetaDocumentORM, aliased=True). \ + filter( + MetaDocumentORM.name == key, + MetaDocumentORM.value.in_(values), + ) if ids: document_ids_to_delete = document_ids_to_delete.filter(DocumentORM.id.in_(ids)) diff --git a/test/test_faiss_and_milvus.py b/test/test_faiss_and_milvus.py index fabd94cc7..81bcca809 100644 --- a/test/test_faiss_and_milvus.py +++ b/test/test_faiss_and_milvus.py @@ -14,12 +14,12 @@ from haystack.pipelines import Pipeline from haystack.nodes.retriever.dense import EmbeddingRetriever DOCUMENTS = [ - {"name": "name_1", "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, - {"name": "name_2", "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, - {"name": "name_3", "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)}, - {"name": "name_4", "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)}, - {"name": "name_5", "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)}, - {"name": "name_6", "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)}, + {"meta": {"name": "name_1", "year": "2020", "month": "01"}, "content": "text_1", "embedding": np.random.rand(768).astype(np.float32)}, + {"meta": {"name": "name_2", "year": "2020", "month": "02"}, "content": "text_2", "embedding": np.random.rand(768).astype(np.float32)}, + {"meta": {"name": "name_3", "year": "2020", "month": "03"}, "content": "text_3", "embedding": np.random.rand(768).astype(np.float64)}, + {"meta": {"name": "name_4", "year": "2021", "month": "01"}, "content": "text_4", "embedding": np.random.rand(768).astype(np.float32)}, + {"meta": {"name": "name_5", "year": "2021", "month": "02"}, "content": "text_5", "embedding": np.random.rand(768).astype(np.float32)}, + {"meta": {"name": "name_6", "year": "2021", "month": "03"}, "content": "text_6", "embedding": np.random.rand(768).astype(np.float64)}, ] @@ -252,6 +252,38 @@ def test_delete_docs_with_filters(document_store, retriever): assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"} +@pytest.mark.slow +@pytest.mark.parametrize("retriever", ["dpr"], indirect=True) +@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) +def test_delete_docs_with_filters(document_store, retriever): + document_store.write_documents(DOCUMENTS) + document_store.update_embeddings(retriever=retriever, batch_size=4) + assert document_store.get_embedding_count() == 6 + + document_store.delete_documents(filters={"year": ["2020"]}) + + documents = document_store.get_all_documents() + assert len(documents) == 3 + assert document_store.get_embedding_count() == 3 + assert all("2021" == doc.meta["year"] for doc in documents) + + +@pytest.mark.slow +@pytest.mark.parametrize("retriever", ["dpr"], indirect=True) +@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) +def test_delete_docs_with_many_filters(document_store, retriever): + document_store.write_documents(DOCUMENTS) + document_store.update_embeddings(retriever=retriever, batch_size=4) + assert document_store.get_embedding_count() == 6 + + document_store.delete_documents(filters={"month": ["01"], "year": ["2020"]}) + + documents = document_store.get_all_documents() + assert len(documents) == 5 + assert document_store.get_embedding_count() == 5 + assert "name_1" not in {doc.meta["name"] for doc in documents} + + @pytest.mark.slow @pytest.mark.parametrize("retriever", ["dpr"], indirect=True) @pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) @@ -296,6 +328,50 @@ def test_delete_docs_by_id_with_filters(document_store, retriever): assert all(doc_id in all_ids_left for doc_id in ids_not_to_delete) +@pytest.mark.slow +@pytest.mark.parametrize("retriever", ["dpr"], indirect=True) +@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) +def test_get_docs_with_filters_one_value(document_store, retriever): + document_store.write_documents(DOCUMENTS) + document_store.update_embeddings(retriever=retriever, batch_size=4) + assert document_store.get_embedding_count() == 6 + + documents = document_store.get_all_documents(filters={"year": ["2020"]}) + + assert len(documents) == 3 + assert all("2020" == doc.meta["year"] for doc in documents) + + +@pytest.mark.slow +@pytest.mark.parametrize("retriever", ["dpr"], indirect=True) +@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) +def test_get_docs_with_filters_many_values(document_store, retriever): + document_store.write_documents(DOCUMENTS) + document_store.update_embeddings(retriever=retriever, batch_size=4) + assert document_store.get_embedding_count() == 6 + + documents = document_store.get_all_documents(filters={"name": ["name_5", "name_6"]}) + + assert len(documents) == 2 + assert {doc.meta["name"] for doc in documents} == {"name_5", "name_6"} + + +@pytest.mark.slow +@pytest.mark.parametrize("retriever", ["dpr"], indirect=True) +@pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True) +def test_get_docs_with_many_filters(document_store, retriever): + document_store.write_documents(DOCUMENTS) + document_store.update_embeddings(retriever=retriever, batch_size=4) + assert document_store.get_embedding_count() == 6 + + documents = document_store.get_all_documents(filters={"month": ["01"], "year": ["2020"]}) + + assert len(documents) == 1 + assert "name_1" == documents[0].meta["name"] + assert "01" == documents[0].meta["month"] + assert "2020" == documents[0].meta["year"] + + @pytest.mark.parametrize("retriever", ["embedding"], indirect=True) @pytest.mark.parametrize("document_store", ["faiss", "milvus"], indirect=True)