Fix update_embeddings() for FAISSDocumentStore (#978)

This commit is contained in:
oryx1729 2021-04-21 09:56:35 +02:00 committed by GitHub
parent 0051a34ff9
commit 8c1e411380
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 90 additions and 7 deletions

View File

@ -77,7 +77,7 @@ jobs:
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.0.0-cpu-d030521-1ea92e
- name: Run GraphDB
run: docker run -d -p 7200:7200 --name haystack_test_graphdb docker-registry.ontotext.com/graphdb-free:9.4.1-adoptopenjdk11
run: docker run -d -p 7200:7200 --name haystack_test_graphdb deepset/graphdb-free:9.4.1-adoptopenjdk11
- name: Run Apache Tika
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1

View File

@ -490,6 +490,32 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
"""
return self.get_document_count(index=index)
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
"""
Return the count of embeddings in the document store.
"""
index = index or self.index
body: dict = {"query": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}}
if filters:
filter_clause = []
for key, values in filters.items():
if type(values) != list:
raise ValueError(
f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. '
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
filter_clause.append(
{
"terms": {key: values}
}
)
body["query"]["bool"]["filter"] = filter_clause
result = self.client.count(index=index, body=body)
count = result["count"]
return count
def get_all_documents(
self,
index: Optional[str] = None,

View File

@ -199,6 +199,14 @@ class FAISSDocumentStore(SQLDocumentStore):
"""
index = index or self.index
if update_existing_embeddings is True:
if filters is None:
self.faiss_indexes[index].reset()
self.reset_vector_ids(index)
else:
raise Exception("update_existing_embeddings=True is not supported with filters.")
if not self.faiss_indexes.get(index):
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
@ -290,6 +298,15 @@ class FAISSDocumentStore(SQLDocumentStore):
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
return documents
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
"""
Return the count of embeddings in the document store.
"""
if filters:
raise Exception("filters are not supported for get_embedding_count in FAISSDocumentStore")
index = index or self.index
return self.faiss_indexes[index].ntotal
def train_index(
self,
documents: Optional[Union[List[dict], List[Document]]],

View File

@ -220,6 +220,14 @@ class InMemoryDocumentStore(BaseDocumentStore):
documents = self.get_all_documents(index=index, filters=filters)
return len(documents)
def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
"""
Return the count of embeddings in the document store.
"""
documents = self.get_all_documents(filters=filters, index=index)
embedding_count = sum(doc.embedding is not None for doc in documents)
return embedding_count
def get_label_count(self, index: Optional[str] = None) -> int:
"""
Return the number of labels in the document store

View File

@ -520,3 +520,15 @@ class MilvusDocumentStore(SQLDocumentStore):
return list()
return vectors
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
"""
Return the count of embeddings in the document store.
"""
if filters:
raise Exception("filters are not supported for get_embedding_count in MilvusDocumentStore.")
index = index or self.index
_, embedding_count = self.milvus_server.count_entities(index)
if embedding_count is None:
embedding_count = 0
return embedding_count

View File

@ -347,6 +347,10 @@ def get_document_store(document_store_type, embedding_field="embedding"):
embedding_field=embedding_field,
index="haystack_test",
)
_, collections = document_store.milvus_server.list_collections()
for collection in collections:
if collection.startswith("haystack_test"):
document_store.milvus_server.drop_collection(collection)
return document_store
else:
raise Exception(f"No document store fixture for '{document_store_type}'")

View File

@ -5,6 +5,7 @@ from elasticsearch import Elasticsearch
from conftest import get_document_store
from haystack import Document, Label
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
from haystack.document_store.faiss import FAISSDocumentStore
@pytest.mark.elasticsearch
@ -274,19 +275,34 @@ def test_update_embeddings(document_store, retriever):
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
# test updating with filters
document_store.update_embeddings(
retriever, index="haystack_test_1", batch_size=3, filters={"meta_field": ["value_0", "value_1"]}
)
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
embedding_after_update = doc_after_update.embedding
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
if isinstance(document_store, FAISSDocumentStore):
with pytest.raises(Exception):
document_store.update_embeddings(
retriever, index="haystack_test_1", update_existing_embeddings=True, filters={"meta_field": ["value"]}
)
else:
document_store.update_embeddings(
retriever, index="haystack_test_1", batch_size=3, filters={"meta_field": ["value_0", "value_1"]}
)
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
embedding_after_update = doc_after_update.embedding
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
# test update all embeddings
document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=True)
assert document_store.get_embedding_count(index="haystack_test_1") == 11
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
embedding_after_update = doc_after_update.embedding
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update)
# test update embeddings for newly added docs
documents = []
for i in range(12, 15):
documents.append({"text": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
document_store.write_documents(documents, index="haystack_test_1")
document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=False)
assert document_store.get_embedding_count(index="haystack_test_1") == 14
@pytest.mark.elasticsearch
def test_delete_all_documents(document_store_with_docs):