mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-15 01:47:45 +00:00
Fix update_embeddings() for FAISSDocumentStore (#978)
This commit is contained in:
parent
0051a34ff9
commit
8c1e411380
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@ -77,7 +77,7 @@ jobs:
|
|||||||
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.0.0-cpu-d030521-1ea92e
|
run: docker run -d -p 19530:19530 -p 19121:19121 milvusdb/milvus:1.0.0-cpu-d030521-1ea92e
|
||||||
|
|
||||||
- name: Run GraphDB
|
- name: Run GraphDB
|
||||||
run: docker run -d -p 7200:7200 --name haystack_test_graphdb docker-registry.ontotext.com/graphdb-free:9.4.1-adoptopenjdk11
|
run: docker run -d -p 7200:7200 --name haystack_test_graphdb deepset/graphdb-free:9.4.1-adoptopenjdk11
|
||||||
|
|
||||||
- name: Run Apache Tika
|
- name: Run Apache Tika
|
||||||
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1
|
run: docker run -d -p 9998:9998 -e "TIKA_CHILD_JAVA_OPTS=-JXms128m" -e "TIKA_CHILD_JAVA_OPTS=-JXmx128m" apache/tika:1.24.1
|
||||||
|
@ -490,6 +490,32 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
"""
|
"""
|
||||||
return self.get_document_count(index=index)
|
return self.get_document_count(index=index)
|
||||||
|
|
||||||
|
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Return the count of embeddings in the document store.
|
||||||
|
"""
|
||||||
|
|
||||||
|
index = index or self.index
|
||||||
|
|
||||||
|
body: dict = {"query": {"bool": {"must": [{"exists": {"field": self.embedding_field}}]}}}
|
||||||
|
if filters:
|
||||||
|
filter_clause = []
|
||||||
|
for key, values in filters.items():
|
||||||
|
if type(values) != list:
|
||||||
|
raise ValueError(
|
||||||
|
f'Wrong filter format for key "{key}": Please provide a list of allowed values for each key. '
|
||||||
|
'Example: {"name": ["some", "more"], "category": ["only_one"]} ')
|
||||||
|
filter_clause.append(
|
||||||
|
{
|
||||||
|
"terms": {key: values}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
body["query"]["bool"]["filter"] = filter_clause
|
||||||
|
|
||||||
|
result = self.client.count(index=index, body=body)
|
||||||
|
count = result["count"]
|
||||||
|
return count
|
||||||
|
|
||||||
def get_all_documents(
|
def get_all_documents(
|
||||||
self,
|
self,
|
||||||
index: Optional[str] = None,
|
index: Optional[str] = None,
|
||||||
|
@ -199,6 +199,14 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
index = index or self.index
|
index = index or self.index
|
||||||
|
|
||||||
|
if update_existing_embeddings is True:
|
||||||
|
if filters is None:
|
||||||
|
self.faiss_indexes[index].reset()
|
||||||
|
self.reset_vector_ids(index)
|
||||||
|
else:
|
||||||
|
raise Exception("update_existing_embeddings=True is not supported with filters.")
|
||||||
|
|
||||||
if not self.faiss_indexes.get(index):
|
if not self.faiss_indexes.get(index):
|
||||||
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
raise ValueError("Couldn't find a FAISS index. Try to init the FAISSDocumentStore() again ...")
|
||||||
|
|
||||||
@ -290,6 +298,15 @@ class FAISSDocumentStore(SQLDocumentStore):
|
|||||||
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"]))
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Return the count of embeddings in the document store.
|
||||||
|
"""
|
||||||
|
if filters:
|
||||||
|
raise Exception("filters are not supported for get_embedding_count in FAISSDocumentStore")
|
||||||
|
index = index or self.index
|
||||||
|
return self.faiss_indexes[index].ntotal
|
||||||
|
|
||||||
def train_index(
|
def train_index(
|
||||||
self,
|
self,
|
||||||
documents: Optional[Union[List[dict], List[Document]]],
|
documents: Optional[Union[List[dict], List[Document]]],
|
||||||
|
@ -220,6 +220,14 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||||||
documents = self.get_all_documents(index=index, filters=filters)
|
documents = self.get_all_documents(index=index, filters=filters)
|
||||||
return len(documents)
|
return len(documents)
|
||||||
|
|
||||||
|
def get_embedding_count(self, filters: Optional[Dict[str, List[str]]] = None, index: Optional[str] = None) -> int:
|
||||||
|
"""
|
||||||
|
Return the count of embeddings in the document store.
|
||||||
|
"""
|
||||||
|
documents = self.get_all_documents(filters=filters, index=index)
|
||||||
|
embedding_count = sum(doc.embedding is not None for doc in documents)
|
||||||
|
return embedding_count
|
||||||
|
|
||||||
def get_label_count(self, index: Optional[str] = None) -> int:
|
def get_label_count(self, index: Optional[str] = None) -> int:
|
||||||
"""
|
"""
|
||||||
Return the number of labels in the document store
|
Return the number of labels in the document store
|
||||||
|
@ -520,3 +520,15 @@ class MilvusDocumentStore(SQLDocumentStore):
|
|||||||
return list()
|
return list()
|
||||||
|
|
||||||
return vectors
|
return vectors
|
||||||
|
|
||||||
|
def get_embedding_count(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None) -> int:
|
||||||
|
"""
|
||||||
|
Return the count of embeddings in the document store.
|
||||||
|
"""
|
||||||
|
if filters:
|
||||||
|
raise Exception("filters are not supported for get_embedding_count in MilvusDocumentStore.")
|
||||||
|
index = index or self.index
|
||||||
|
_, embedding_count = self.milvus_server.count_entities(index)
|
||||||
|
if embedding_count is None:
|
||||||
|
embedding_count = 0
|
||||||
|
return embedding_count
|
||||||
|
@ -347,6 +347,10 @@ def get_document_store(document_store_type, embedding_field="embedding"):
|
|||||||
embedding_field=embedding_field,
|
embedding_field=embedding_field,
|
||||||
index="haystack_test",
|
index="haystack_test",
|
||||||
)
|
)
|
||||||
|
_, collections = document_store.milvus_server.list_collections()
|
||||||
|
for collection in collections:
|
||||||
|
if collection.startswith("haystack_test"):
|
||||||
|
document_store.milvus_server.drop_collection(collection)
|
||||||
return document_store
|
return document_store
|
||||||
else:
|
else:
|
||||||
raise Exception(f"No document store fixture for '{document_store_type}'")
|
raise Exception(f"No document store fixture for '{document_store_type}'")
|
||||||
|
@ -5,6 +5,7 @@ from elasticsearch import Elasticsearch
|
|||||||
from conftest import get_document_store
|
from conftest import get_document_store
|
||||||
from haystack import Document, Label
|
from haystack import Document, Label
|
||||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||||
|
from haystack.document_store.faiss import FAISSDocumentStore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
@ -274,19 +275,34 @@ def test_update_embeddings(document_store, retriever):
|
|||||||
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
||||||
|
|
||||||
# test updating with filters
|
# test updating with filters
|
||||||
document_store.update_embeddings(
|
if isinstance(document_store, FAISSDocumentStore):
|
||||||
retriever, index="haystack_test_1", batch_size=3, filters={"meta_field": ["value_0", "value_1"]}
|
with pytest.raises(Exception):
|
||||||
)
|
document_store.update_embeddings(
|
||||||
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
|
retriever, index="haystack_test_1", update_existing_embeddings=True, filters={"meta_field": ["value"]}
|
||||||
embedding_after_update = doc_after_update.embedding
|
)
|
||||||
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
else:
|
||||||
|
document_store.update_embeddings(
|
||||||
|
retriever, index="haystack_test_1", batch_size=3, filters={"meta_field": ["value_0", "value_1"]}
|
||||||
|
)
|
||||||
|
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
|
||||||
|
embedding_after_update = doc_after_update.embedding
|
||||||
|
np.testing.assert_array_equal(embedding_before_update, embedding_after_update)
|
||||||
|
|
||||||
# test update all embeddings
|
# test update all embeddings
|
||||||
document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=True)
|
document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=True)
|
||||||
|
assert document_store.get_embedding_count(index="haystack_test_1") == 11
|
||||||
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
|
doc_after_update = document_store.get_all_documents(index="haystack_test_1", filters={"meta_field": ["value_7"]})[0]
|
||||||
embedding_after_update = doc_after_update.embedding
|
embedding_after_update = doc_after_update.embedding
|
||||||
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update)
|
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, embedding_before_update, embedding_after_update)
|
||||||
|
|
||||||
|
# test update embeddings for newly added docs
|
||||||
|
documents = []
|
||||||
|
for i in range(12, 15):
|
||||||
|
documents.append({"text": f"text_{i}", "id": str(i), "meta_field": f"value_{i}"})
|
||||||
|
document_store.write_documents(documents, index="haystack_test_1")
|
||||||
|
document_store.update_embeddings(retriever, index="haystack_test_1", batch_size=3, update_existing_embeddings=False)
|
||||||
|
assert document_store.get_embedding_count(index="haystack_test_1") == 14
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.elasticsearch
|
@pytest.mark.elasticsearch
|
||||||
def test_delete_all_documents(document_store_with_docs):
|
def test_delete_all_documents(document_store_with_docs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user