mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-14 17:38:24 +00:00
Fix update_embeddings() for FAISSDocumentStore (#978)
This commit is contained in:
parent
0051a34ff9
commit
8c1e411380
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@ -77,7 +77,7 @@ jobs:
|
||||
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.0.0-cpu-d030521-1ea92e
|
||||
|
||||
- name: Run GraphDB
|
||||
run: docker run -d -p 7200:7200 --name haystack_test_graphdb docker-registry.ontotext.com/graphdb-free:9.4.1-adoptopenjdk11
|
||||
run: docker run -d -p 7200:7200 --name haystack_test_graphdb deepset/graphdb-free:9.4.1-adoptopenjdk11
|
||||
|
||||
- name: Run Apache Tika
|
||||
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1
|
||||
|
@ -490,6 +490,32 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
return self.get_document_count(index=index)
|
||||
|
||||
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
|
||||
"""
|
||||
Return the count of embeddings in the document store.
|
||||
"""
|
||||
|
||||
index = index or self.index
|
||||
|
||||
body: dict = {"query": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}}
|
||||
if filters:
|
||||
filter_clause = []
|
||||
for key, values in filters.items():
|
||||
if type(values) != list:
|
||||
raise ValueError(
|
||||
f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. '
|
||||
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
|
||||
filter_clause.append(
|
||||
{
|
||||
"terms": {key: values}
|
||||
}
|
||||
)
|
||||
body["query"]["bool"]["filter"] = filter_clause
|
||||
|
||||
result = self.client.count(index=index, body=body)
|
||||
count = result["count"]
|
||||
return count
|
||||
|
||||
def get_all_documents(
|
||||
self,
|
||||
index: Optional[str] = None,
|
||||
|
@ -199,6 +199,14 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
"""
|
||||
|
||||
index = index or self.index
|
||||
|
||||
if update_existing_embeddings is True:
|
||||
if filters is None:
|
||||
self.faiss_indexes[index].reset()
|
||||
self.reset_vector_ids(index)
|
||||
else:
|
||||
raise Exception("update_existing_embeddings=True is not supported with filters.")
|
||||
|
||||
if not self.faiss_indexes.get(index):
|
||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||
|
||||
@ -290,6 +298,15 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
||||
return documents
|
||||
|
||||
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
|
||||
"""
|
||||
Return the count of embeddings in the document store.
|
||||
"""
|
||||
if filters:
|
||||
raise Exception("filters are not supported for get_embedding_count in FAISSDocumentStore")
|
||||
index = index or self.index
|
||||
return self.faiss_indexes[index].ntotal
|
||||
|
||||
def train_index(
|
||||
self,
|
||||
documents: Optional[Union[List[dict], List[Document]]],
|
||||
|
@ -220,6 +220,14 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
documents = self.get_all_documents(index=index, filters=filters)
|
||||
return len(documents)
|
||||
|
||||
def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the count of embeddings in the document store.
|
||||
"""
|
||||
documents = self.get_all_documents(filters=filters, index=index)
|
||||
embedding_count = sum(doc.embedding is not None for doc in documents)
|
||||
return embedding_count
|
||||
|
||||
def get_label_count(self, index: Optional[str] = None) -> int:
|
||||
"""
|
||||
Return the number of labels in the document store
|
||||
|
@ -520,3 +520,15 @@ class MilvusDocumentStore(SQLDocumentStore):
|
||||
return list()
|
||||
|
||||
return vectors
|
||||
|
||||
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
|
||||
"""
|
||||
Return the count of embeddings in the document store.
|
||||
"""
|
||||
if filters:
|
||||
raise Exception("filters are not supported for get_embedding_count in MilvusDocumentStore.")
|
||||
index = index or self.index
|
||||
_, embedding_count = self.milvus_server.count_entities(index)
|
||||
if embedding_count is None:
|
||||
embedding_count = 0
|
||||
return embedding_count
|
||||
|
@ -347,6 +347,10 @@ def get_document_store(document_store_type, embedding_field="embedding"):
|
||||
embedding_field=embedding_field,
|
||||
index="haystack_test",
|
||||
)
|
||||
_, collections = document_store.milvus_server.list_collections()
|
||||
for collection in collections:
|
||||
if collection.startswith("haystack_test"):
|
||||
document_store.milvus_server.drop_collection(collection)
|
||||
return document_store
|
||||
else:
|
||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||
|
@ -5,6 +5,7 @@ from elasticsearch import Elasticsearch
|
||||
from conftest import get_document_store
|
||||
from haystack import Document, Label
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.document_store.faiss import FAISSDocumentStore
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@ -274,19 +275,34 @@ def test_update_embeddings(document_store, retriever):
|
||||
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
||||
|
||||
# test updating with filters
|
||||
document_store.update_embeddings(
|
||||
retriever, index="haystack_test_1", batch_size=3, filters={"meta_field": ["value_0", "value_1"]}
|
||||
)
|
||||
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
|
||||
embedding_after_update = doc_after_update.embedding
|
||||
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
||||
if isinstance(document_store, FAISSDocumentStore):
|
||||
with pytest.raises(Exception):
|
||||
document_store.update_embeddings(
|
||||
retriever, index="haystack_test_1", update_existing_embeddings=True, filters={"meta_field": ["value"]}
|
||||
)
|
||||
else:
|
||||
document_store.update_embeddings(
|
||||
retriever, index="haystack_test_1", batch_size=3, filters={"meta_field": ["value_0", "value_1"]}
|
||||
)
|
||||
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
|
||||
embedding_after_update = doc_after_update.embedding
|
||||
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
||||
|
||||
# test update all embeddings
|
||||
document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=True)
|
||||
assert document_store.get_embedding_count(index="haystack_test_1") == 11
|
||||
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
|
||||
embedding_after_update = doc_after_update.embedding
|
||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update)
|
||||
|
||||
# test update embeddings for newly added docs
|
||||
documents = []
|
||||
for i in range(12, 15):
|
||||
documents.append({"text": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
||||
document_store.write_documents(documents, index="haystack_test_1")
|
||||
document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=False)
|
||||
assert document_store.get_embedding_count(index="haystack_test_1") == 14
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
def test_delete_all_documents(document_store_with_docs):
|
||||
|
Loading…
x
Reference in New Issue
Block a user