bug: skip validating empty embeddings (#3774)

* skip validating empty embeddings

* skip batches without embeddings to update

* add unit test with mocked retriever
This commit is contained in:
Julian Risch 2023-01-05 15:13:57 +01:00 committed by GitHub
parent e84fae2894
commit 0c2d13f1b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 4 deletions

View File

@ -505,6 +505,12 @@ class PineconeDocumentStore(BaseDocumentStore):
for _ in range(0, document_count, batch_size):
document_batch = list(islice(documents, batch_size))
embeddings = retriever.embed_documents(document_batch)
if embeddings.size == 0:
# Skip batch if there are no embeddings. Otherwise, incorrect embedding shape will be inferred and
# Pinecone APi will return a "No vectors provided" Bad Request Error
progress_bar.set_description_str("Documents Processed")
progress_bar.update(batch_size)
continue
self._validate_embeddings_shape(
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
)

View File

@ -2,6 +2,7 @@ from typing import List, Union, Dict, Any
import os
from inspect import getmembers, isclass, isfunction
from unittest.mock import MagicMock
import pytest
@ -12,8 +13,7 @@ from haystack.errors import FilterError
from .test_base import DocumentStoreBaseTestAbstract
from ..mocks import pinecone as pinecone_mock
from ..conftest import SAMPLES_PATH
from ..nodes.test_retriever import MockBaseRetriever
# Set metadata fields used during testing for PineconeDocumentStore meta_config
META_FIELDS = ["meta_field", "name", "date", "numeric_field", "odd_document"]
@ -417,3 +417,15 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
assert len(retrieved_docs) == 1
assert retrieved_docs[0].meta == multilayer_meta
@pytest.mark.unit
def test_skip_validating_empty_embeddings(self, ds: PineconeDocumentStore):
document = Document(id="0", content="test")
retriever = MockBaseRetriever(document_store=ds, mock_document=document)
ds.write_documents(documents=[document])
ds._validate_embeddings_shape = MagicMock()
ds.update_embeddings(retriever)
ds._validate_embeddings_shape.assert_called_once()
ds.update_embeddings(retriever, update_existing_embeddings=False)
ds._validate_embeddings_shape.assert_called_once()

View File

@ -20,8 +20,6 @@ from haystack.nodes.retriever.base import BaseRetriever
from haystack.pipelines import DocumentSearchPipeline
from haystack.schema import Document
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.document_stores import MilvusDocumentStore
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from haystack.nodes.retriever.multimodal import MultiModalRetriever
@ -160,6 +158,9 @@ class MockBaseRetriever(MockRetriever):
):
return [[self.mock_document] for _ in range(len(queries))]
def embed_documents(self, documents: List[Document]):
return np.full((len(documents), 768), 0.5)
def test_retrieval_empty_query(document_store: BaseDocumentStore):
# test with empty query using the run() method