fix: fixed InMemoryDocumentStore.get_embedding_count to return correct number (#3980)

* Fix the embedding count function of InMemoryDocumentStore

* Adding some doc strings explaining how many docs with embeddings to expect.
This commit is contained in:
Sebastian 2023-01-30 12:38:30 +01:00 committed by GitHub
parent fa17f0973e
commit 71de0524de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 2 deletions

View File

@ -573,7 +573,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
""" """
Return the count of embeddings in the document store. Return the count of embeddings in the document store.
""" """
documents = self.get_all_documents(filters=filters, index=index) documents = self.get_all_documents_generator(filters=filters, index=index, return_embedding=True)
embedding_count = sum(doc.embedding is not None for doc in documents) embedding_count = sum(doc.embedding is not None for doc in documents)
return embedding_count return embedding_count

View File

@ -94,6 +94,15 @@ class DocumentStoreBaseTestAbstract:
with pytest.raises(Exception): with pytest.raises(Exception):
ds.write_documents(duplicate_documents, duplicate_documents="fail") ds.write_documents(duplicate_documents, duplicate_documents="fail")
@pytest.mark.integration
def test_get_embedding_count(self, ds, documents):
"""
We expect 6 docs with embeddings because only 6 documents in the documents fixture for this class contain
embeddings.
"""
ds.write_documents(documents)
assert ds.get_embedding_count() == 6
@pytest.mark.skip @pytest.mark.skip
@pytest.mark.integration @pytest.mark.integration
def test_get_all_documents_without_filters(self, ds, documents): def test_get_all_documents_without_filters(self, ds, documents):

View File

@ -133,6 +133,8 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
documents_indexed = document_store.get_all_documents() documents_indexed = document_store.get_all_documents()
assert len(documents_indexed) == len(documents_with_embeddings) assert len(documents_indexed) == len(documents_with_embeddings)
assert all(doc.embedding is not None for doc in documents_indexed) assert all(doc.embedding is not None for doc in documents_indexed)
# Check that get_embedding_count works as expected
assert document_store.get_embedding_count() == len(documents_with_embeddings)
@pytest.mark.integration @pytest.mark.integration
def test_write_docs_different_indexes(self, ds, documents_with_embeddings): def test_write_docs_different_indexes(self, ds, documents_with_embeddings):
@ -264,3 +266,9 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
@pytest.mark.integration @pytest.mark.integration
def test_multilabel_meta_aggregations(self): def test_multilabel_meta_aggregations(self):
pass pass
@pytest.mark.skip
@pytest.mark.integration
def test_get_embedding_count(self):
"""Skipped b/c most easily tested in test_write_index_docs"""
pass

View File

@ -56,6 +56,15 @@ class TestMilvusDocumentStore(DocumentStoreBaseTestAbstract):
ds.delete_index(index="custom_index") ds.delete_index(index="custom_index")
assert ds.get_document_count(index="custom_index") == 0 assert ds.get_document_count(index="custom_index") == 0
@pytest.mark.integration
def test_get_embedding_count(self, ds, documents):
"""
We expect 9 docs with embeddings because all documents in the documents fixture for this class contain
embeddings.
"""
ds.write_documents(documents)
assert ds.get_embedding_count() == 9
# NOTE: MilvusDocumentStore derives from the SQL one and behaves differently to the others when filters are applied. # NOTE: MilvusDocumentStore derives from the SQL one and behaves differently to the others when filters are applied.
# While this should be considered a bug, the relative tests are skipped in the meantime # While this should be considered a bug, the relative tests are skipped in the meantime

View File

@ -1,6 +1,7 @@
from typing import List, Union, Dict, Any from typing import List, Union, Dict, Any
import os import os
import numpy as np
from inspect import getmembers, isclass, isfunction from inspect import getmembers, isclass, isfunction
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -428,7 +429,9 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
"parent1": {"parent2": {"parent3": {"child1": 1, "child2": 2}}}, "parent1": {"parent2": {"parent3": {"child1": 1, "child2": 2}}},
"meta_field": "multilayer-test", "meta_field": "multilayer-test",
} }
doc = Document(content=f"Multilayered dict", meta=multilayer_meta, embedding=[0.0] * 768) doc = Document(
content=f"Multilayered dict", meta=multilayer_meta, embedding=np.random.rand(768).astype(np.float32)
)
doc_store_with_docs.write_documents([doc]) doc_store_with_docs.write_documents([doc])
retrieved_docs = doc_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "multilayer-test"}}) retrieved_docs = doc_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "multilayer-test"}})
@ -447,3 +450,13 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
ds._validate_embeddings_shape.assert_called_once() ds._validate_embeddings_shape.assert_called_once()
ds.update_embeddings(retriever, update_existing_embeddings=False) ds.update_embeddings(retriever, update_existing_embeddings=False)
ds._validate_embeddings_shape.assert_called_once() ds._validate_embeddings_shape.assert_called_once()
@pytest.mark.integration
def test_get_embedding_count(self, doc_store_with_docs: PineconeDocumentStore):
"""
We expect 1 doc with an embeddings because all documents in already written in doc_store_with_docs contain no
embeddings.
"""
doc = Document(content=f"Doc with embedding", embedding=np.random.rand(768).astype(np.float32))
doc_store_with_docs.write_documents([doc])
assert doc_store_with_docs.get_embedding_count() == 1

View File

@ -127,3 +127,8 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
@pytest.mark.integration @pytest.mark.integration
def test_multilabel_meta_aggregations(self): def test_multilabel_meta_aggregations(self):
pass pass
@pytest.mark.skip(reason="embeddings are not supported")
@pytest.mark.integration
def test_get_embedding_count(self):
pass

View File

@ -251,3 +251,12 @@ class TestWeaviateDocumentStore(DocumentStoreBaseTestAbstract):
def test_cant_write_top_level_fields_in_meta(self, ds): def test_cant_write_top_level_fields_in_meta(self, ds):
with pytest.raises(ValueError, match='"meta" info contains duplicate key "content"'): with pytest.raises(ValueError, match='"meta" info contains duplicate key "content"'):
ds.write_documents([Document(content="test", meta={"content": "test-id"})]) ds.write_documents([Document(content="test", meta={"content": "test-id"})])
@pytest.mark.integration
def test_get_embedding_count(self, ds, documents):
"""
We expect 9 docs with embeddings because all documents in the documents fixture for this class contain
embeddings.
"""
ds.write_documents(documents)
assert ds.get_embedding_count() == 9