mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 18:59:28 +00:00
fix: fixed InMemoryDocumentStore.get_embedding_count to return correct number (#3980)
* Fix the embedding count function of InMemoryDocumentStore * Adding some doc strings explaining how many docs with embeddings to expect.
This commit is contained in:
parent
fa17f0973e
commit
71de0524de
@ -573,7 +573,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
|
||||
"""
|
||||
Return the count of embeddings in the document store.
|
||||
"""
|
||||
documents = self.get_all_documents(filters=filters, index=index)
|
||||
documents = self.get_all_documents_generator(filters=filters, index=index, return_embedding=True)
|
||||
embedding_count = sum(doc.embedding is not None for doc in documents)
|
||||
return embedding_count
|
||||
|
||||
|
||||
@ -94,6 +94,15 @@ class DocumentStoreBaseTestAbstract:
|
||||
with pytest.raises(Exception):
|
||||
ds.write_documents(duplicate_documents, duplicate_documents="fail")
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_embedding_count(self, ds, documents):
|
||||
"""
|
||||
We expect 6 docs with embeddings because only 6 documents in the documents fixture for this class contain
|
||||
embeddings.
|
||||
"""
|
||||
ds.write_documents(documents)
|
||||
assert ds.get_embedding_count() == 6
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_get_all_documents_without_filters(self, ds, documents):
|
||||
|
||||
@ -133,6 +133,8 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
documents_indexed = document_store.get_all_documents()
|
||||
assert len(documents_indexed) == len(documents_with_embeddings)
|
||||
assert all(doc.embedding is not None for doc in documents_indexed)
|
||||
# Check that get_embedding_count works as expected
|
||||
assert document_store.get_embedding_count() == len(documents_with_embeddings)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_write_docs_different_indexes(self, ds, documents_with_embeddings):
|
||||
@ -264,3 +266,9 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
@pytest.mark.integration
|
||||
def test_multilabel_meta_aggregations(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.integration
|
||||
def test_get_embedding_count(self):
|
||||
"""Skipped b/c most easily tested in test_write_index_docs"""
|
||||
pass
|
||||
|
||||
@ -56,6 +56,15 @@ class TestMilvusDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
ds.delete_index(index="custom_index")
|
||||
assert ds.get_document_count(index="custom_index") == 0
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_embedding_count(self, ds, documents):
|
||||
"""
|
||||
We expect 9 docs with embeddings because all documents in the documents fixture for this class contain
|
||||
embeddings.
|
||||
"""
|
||||
ds.write_documents(documents)
|
||||
assert ds.get_embedding_count() == 9
|
||||
|
||||
# NOTE: MilvusDocumentStore derives from the SQL one and behaves differently to the others when filters are applied.
|
||||
# While this should be considered a bug, the relative tests are skipped in the meantime
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import List, Union, Dict, Any
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from inspect import getmembers, isclass, isfunction
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -428,7 +429,9 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
"parent1": {"parent2": {"parent3": {"child1": 1, "child2": 2}}},
|
||||
"meta_field": "multilayer-test",
|
||||
}
|
||||
doc = Document(content=f"Multilayered dict", meta=multilayer_meta, embedding=[0.0] * 768)
|
||||
doc = Document(
|
||||
content=f"Multilayered dict", meta=multilayer_meta, embedding=np.random.rand(768).astype(np.float32)
|
||||
)
|
||||
|
||||
doc_store_with_docs.write_documents([doc])
|
||||
retrieved_docs = doc_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "multilayer-test"}})
|
||||
@ -447,3 +450,13 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
ds._validate_embeddings_shape.assert_called_once()
|
||||
ds.update_embeddings(retriever, update_existing_embeddings=False)
|
||||
ds._validate_embeddings_shape.assert_called_once()
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_embedding_count(self, doc_store_with_docs: PineconeDocumentStore):
|
||||
"""
|
||||
We expect 1 doc with an embeddings because all documents in already written in doc_store_with_docs contain no
|
||||
embeddings.
|
||||
"""
|
||||
doc = Document(content=f"Doc with embedding", embedding=np.random.rand(768).astype(np.float32))
|
||||
doc_store_with_docs.write_documents([doc])
|
||||
assert doc_store_with_docs.get_embedding_count() == 1
|
||||
|
||||
@ -127,3 +127,8 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
@pytest.mark.integration
|
||||
def test_multilabel_meta_aggregations(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="embeddings are not supported")
|
||||
@pytest.mark.integration
|
||||
def test_get_embedding_count(self):
|
||||
pass
|
||||
|
||||
@ -251,3 +251,12 @@ class TestWeaviateDocumentStore(DocumentStoreBaseTestAbstract):
|
||||
def test_cant_write_top_level_fields_in_meta(self, ds):
|
||||
with pytest.raises(ValueError, match='"meta" info contains duplicate key "content"'):
|
||||
ds.write_documents([Document(content="test", meta={"content": "test-id"})])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_get_embedding_count(self, ds, documents):
|
||||
"""
|
||||
We expect 9 docs with embeddings because all documents in the documents fixture for this class contain
|
||||
embeddings.
|
||||
"""
|
||||
ds.write_documents(documents)
|
||||
assert ds.get_embedding_count() == 9
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user