diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fad0b03a0..a2039eba5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,7 +77,7 @@ jobs: run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.0.0-cpu-d030521-1ea92e - name: Run GraphDB - run: docker run -d -p 7200:7200 --name haystack_test_graphdb docker-registry.ontotext.com/graphdb-free:9.4.1-adoptopenjdk11 + run: docker run -d -p 7200:7200 --name haystack_test_graphdb deepset/graphdb-free:9.4.1-adoptopenjdk11 - name: Run Apache Tika run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1 diff --git a/haystack/document_store/elasticsearch.py b/haystack/document_store/elasticsearch.py index 83bdacad9..9ec567e20 100644 --- a/haystack/document_store/elasticsearch.py +++ b/haystack/document_store/elasticsearch.py @@ -490,6 +490,32 @@ class ElasticsearchDocumentStore(BaseDocumentStore): """ return self.get_document_count(index=index) + def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int: + """ + Return the count of embeddings in the document store. + """ + + index = index or self.index + + body: dict = {"query": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}} + if filters: + filter_clause = [] + for key, values in filters.items(): + if type(values) != list: + raise ValueError( + f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. ' + 'Example: {"name": ["some", "more"], "category": ["only_one"]} ') + filter_clause.append( + { + "terms": {key: values} + } + ) + body["query"]["bool"]["filter"] = filter_clause + + result = self.client.count(index=index, body=body) + count = result["count"] + return count + def get_all_documents( self, index: Optional[str] = None, diff --git a/haystack/document_store/faiss.py b/haystack/document_store/faiss.py index 011a0210f..26e2ef5fe 100644 --- a/haystack/document_store/faiss.py +++ b/haystack/document_store/faiss.py @@ -199,6 +199,14 @@ class FAISSDocumentStore(SQLDocumentStore): """ index = index or self.index + + if update_existing_embeddings is True: + if filters is None: + self.faiss_indexes[index].reset() + self.reset_vector_ids(index) + else: + raise Exception("update_existing_embeddings=True is not supported with filters.") + if not self.faiss_indexes.get(index): raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...") @@ -290,6 +298,15 @@ class FAISSDocumentStore(SQLDocumentStore): doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"])) return documents + def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int: + """ + Return the count of embeddings in the document store. + """ + if filters: + raise Exception("filters are not supported for get_embedding_count in FAISSDocumentStore") + index = index or self.index + return self.faiss_indexes[index].ntotal + def train_index( self, documents: Optional[Union[List[dict], List[Document]]], diff --git a/haystack/document_store/memory.py b/haystack/document_store/memory.py index 721b637d6..3a6c417d2 100644 --- a/haystack/document_store/memory.py +++ b/haystack/document_store/memory.py @@ -220,6 +220,14 @@ class InMemoryDocumentStore(BaseDocumentStore): documents = self.get_all_documents(index=index, filters=filters) return len(documents) + def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int: + """ + Return the count of embeddings in the document store. + """ + documents = self.get_all_documents(filters=filters, index=index) + embedding_count = sum(doc.embedding is not None for doc in documents) + return embedding_count + def get_label_count(self, index: Optional[str] = None) -> int: """ Return the number of labels in the document store diff --git a/haystack/document_store/milvus.py b/haystack/document_store/milvus.py index d89b986b0..549f73693 100644 --- a/haystack/document_store/milvus.py +++ b/haystack/document_store/milvus.py @@ -520,3 +520,15 @@ class MilvusDocumentStore(SQLDocumentStore): return list() return vectors + + def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int: + """ + Return the count of embeddings in the document store. + """ + if filters: + raise Exception("filters are not supported for get_embedding_count in MilvusDocumentStore.") + index = index or self.index + _, embedding_count = self.milvus_server.count_entities(index) + if embedding_count is None: + embedding_count = 0 + return embedding_count diff --git a/test/conftest.py b/test/conftest.py index 9128b6386..42aac7e06 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -347,6 +347,10 @@ def get_document_store(document_store_type, embedding_field="embedding"): embedding_field=embedding_field, index="haystack_test", ) + _, collections = document_store.milvus_server.list_collections() + for collection in collections: + if collection.startswith("haystack_test"): + document_store.milvus_server.drop_collection(collection) return document_store else: raise Exception(f"No document store fixture for '{document_store_type}'") diff --git a/test/test_document_store.py b/test/test_document_store.py index 9080ee82f..d42618927 100644 --- a/test/test_document_store.py +++ b/test/test_document_store.py @@ -5,6 +5,7 @@ from elasticsearch import Elasticsearch from conftest import get_document_store from haystack import Document, Label from haystack.document_store.elasticsearch import ElasticsearchDocumentStore +from haystack.document_store.faiss import FAISSDocumentStore @pytest.mark.elasticsearch @@ -274,19 +275,34 @@ def test_update_embeddings(document_store, retriever): np.testing.assert_array_equal(embedding_before_update, embedding_after_update) # test updating with filters - document_store.update_embeddings( - retriever, index="haystack_test_1", batch_size=3, filters={"meta_field": ["value_0", "value_1"]} - ) - doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0] - embedding_after_update = doc_after_update.embedding - np.testing.assert_array_equal(embedding_before_update, embedding_after_update) + if isinstance(document_store, FAISSDocumentStore): + with pytest.raises(Exception): + document_store.update_embeddings( + retriever, index="haystack_test_1", update_existing_embeddings=True, filters={"meta_field": ["value"]} + ) + else: + document_store.update_embeddings( + retriever, index="haystack_test_1", batch_size=3, filters={"meta_field": ["value_0", "value_1"]} + ) + doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0] + embedding_after_update = doc_after_update.embedding + np.testing.assert_array_equal(embedding_before_update, embedding_after_update) # test update all embeddings document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=True) + assert document_store.get_embedding_count(index="haystack_test_1") == 11 doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0] embedding_after_update = doc_after_update.embedding np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update) + # test update embeddings for newly added docs + documents = [] + for i in range(12, 15): + documents.append({"text": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"}) + document_store.write_documents(documents, index="haystack_test_1") + document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=False) + assert document_store.get_embedding_count(index="haystack_test_1") == 14 + @pytest.mark.elasticsearch def test_delete_all_documents(document_store_with_docs):