fix: fixed InMemoryDocumentStore.get_embedding_count to return correct number (#3980)

* Fix the embedding count function of InMemoryDocumentStore

* Adding some doc strings explaining how many docs with embeddings to expect.
This commit is contained in:
Sebastian 2023-01-30 12:38:30 +01:00 committed by GitHub
parent fa17f0973e
commit 71de0524de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 55 additions and 2 deletions

View File

@ -573,7 +573,7 @@ class InMemoryDocumentStore(KeywordDocumentStore):
"""
Return the count of embeddings in the document store.
"""
documents = self.get_all_documents(filters=filters, index=index)
documents = self.get_all_documents_generator(filters=filters, index=index, return_embedding=True)
embedding_count = sum(doc.embedding is not None for doc in documents)
return embedding_count

View File

@ -94,6 +94,15 @@ class DocumentStoreBaseTestAbstract:
with pytest.raises(Exception):
ds.write_documents(duplicate_documents, duplicate_documents="fail")
@pytest.mark.integration
def test_get_embedding_count(self, ds, documents):
"""
We expect 6 docs with embeddings because only 6 documents in the documents fixture for this class contain
embeddings.
"""
ds.write_documents(documents)
assert ds.get_embedding_count() == 6
@pytest.mark.skip
@pytest.mark.integration
def test_get_all_documents_without_filters(self, ds, documents):

View File

@ -133,6 +133,8 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
documents_indexed = document_store.get_all_documents()
assert len(documents_indexed) == len(documents_with_embeddings)
assert all(doc.embedding is not None for doc in documents_indexed)
# Check that get_embedding_count works as expected
assert document_store.get_embedding_count() == len(documents_with_embeddings)
@pytest.mark.integration
def test_write_docs_different_indexes(self, ds, documents_with_embeddings):
@ -264,3 +266,9 @@ class TestFAISSDocumentStore(DocumentStoreBaseTestAbstract):
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass
@pytest.mark.skip
@pytest.mark.integration
def test_get_embedding_count(self):
"""Skipped b/c most easily tested in test_write_index_docs"""
pass

View File

@ -56,6 +56,15 @@ class TestMilvusDocumentStore(DocumentStoreBaseTestAbstract):
ds.delete_index(index="custom_index")
assert ds.get_document_count(index="custom_index") == 0
@pytest.mark.integration
def test_get_embedding_count(self, ds, documents):
"""
We expect 9 docs with embeddings because all documents in the documents fixture for this class contain
embeddings.
"""
ds.write_documents(documents)
assert ds.get_embedding_count() == 9
# NOTE: MilvusDocumentStore derives from the SQL one and behaves differently to the others when filters are applied.
# While this should be considered a bug, the relative tests are skipped in the meantime

View File

@ -1,6 +1,7 @@
from typing import List, Union, Dict, Any
import os
import numpy as np
from inspect import getmembers, isclass, isfunction
from unittest.mock import MagicMock
@ -428,7 +429,9 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
"parent1": {"parent2": {"parent3": {"child1": 1, "child2": 2}}},
"meta_field": "multilayer-test",
}
doc = Document(content=f"Multilayered dict", meta=multilayer_meta, embedding=[0.0] * 768)
doc = Document(
content=f"Multilayered dict", meta=multilayer_meta, embedding=np.random.rand(768).astype(np.float32)
)
doc_store_with_docs.write_documents([doc])
retrieved_docs = doc_store_with_docs.get_all_documents(filters={"meta_field": {"$eq": "multilayer-test"}})
@ -447,3 +450,13 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
ds._validate_embeddings_shape.assert_called_once()
ds.update_embeddings(retriever, update_existing_embeddings=False)
ds._validate_embeddings_shape.assert_called_once()
@pytest.mark.integration
def test_get_embedding_count(self, doc_store_with_docs: PineconeDocumentStore):
"""
We expect 1 doc with an embeddings because all documents in already written in doc_store_with_docs contain no
embeddings.
"""
doc = Document(content=f"Doc with embedding", embedding=np.random.rand(768).astype(np.float32))
doc_store_with_docs.write_documents([doc])
assert doc_store_with_docs.get_embedding_count() == 1

View File

@ -127,3 +127,8 @@ class TestSQLDocumentStore(DocumentStoreBaseTestAbstract):
@pytest.mark.integration
def test_multilabel_meta_aggregations(self):
pass
@pytest.mark.skip(reason="embeddings are not supported")
@pytest.mark.integration
def test_get_embedding_count(self):
pass

View File

@ -251,3 +251,12 @@ class TestWeaviateDocumentStore(DocumentStoreBaseTestAbstract):
def test_cant_write_top_level_fields_in_meta(self, ds):
with pytest.raises(ValueError, match='"meta" info contains duplicate key "content"'):
ds.write_documents([Document(content="test", meta={"content": "test-id"})])
@pytest.mark.integration
def test_get_embedding_count(self, ds, documents):
"""
We expect 9 docs with embeddings because all documents in the documents fixture for this class contain
embeddings.
"""
ds.write_documents(documents)
assert ds.get_embedding_count() == 9