bug: skip validating empty embeddings (#3774)

* skip validating empty embeddings

* skip batches without embeddings to update

* add unit test with mocked retriever
This commit is contained in:
Julian Risch 2023-01-05 15:13:57 +01:00 committed by GitHub
parent e84fae2894
commit 0c2d13f1b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 4 deletions

View File

@ -505,6 +505,12 @@ class PineconeDocumentStore(BaseDocumentStore):
for _ in range(0, document_count, batch_size): for _ in range(0, document_count, batch_size):
document_batch = list(islice(documents, batch_size)) document_batch = list(islice(documents, batch_size))
embeddings = retriever.embed_documents(document_batch) embeddings = retriever.embed_documents(document_batch)
if embeddings.size == 0:
# Skip batch if there are no embeddings. Otherwise, incorrect embedding shape will be inferred and
# Pinecone APi will return a "No vectors provided" Bad Request Error
progress_bar.set_description_str("Documents Processed")
progress_bar.update(batch_size)
continue
self._validate_embeddings_shape( self._validate_embeddings_shape(
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
) )

View File

@ -2,6 +2,7 @@ from typing import List, Union, Dict, Any
import os import os
from inspect import getmembers, isclass, isfunction from inspect import getmembers, isclass, isfunction
from unittest.mock import MagicMock
import pytest import pytest
@ -12,8 +13,7 @@ from haystack.errors import FilterError
from .test_base import DocumentStoreBaseTestAbstract from .test_base import DocumentStoreBaseTestAbstract
from ..mocks import pinecone as pinecone_mock from ..mocks import pinecone as pinecone_mock
from ..conftest import SAMPLES_PATH from ..nodes.test_retriever import MockBaseRetriever
# Set metadata fields used during testing for PineconeDocumentStore meta_config # Set metadata fields used during testing for PineconeDocumentStore meta_config
META_FIELDS = ["meta_field", "name", "date", "numeric_field", "odd_document"] META_FIELDS = ["meta_field", "name", "date", "numeric_field", "odd_document"]
@ -417,3 +417,15 @@ class TestPineconeDocumentStore(DocumentStoreBaseTestAbstract):
assert len(retrieved_docs) == 1 assert len(retrieved_docs) == 1
assert retrieved_docs[0].meta == multilayer_meta assert retrieved_docs[0].meta == multilayer_meta
@pytest.mark.unit
def test_skip_validating_empty_embeddings(self, ds: PineconeDocumentStore):
document = Document(id="0", content="test")
retriever = MockBaseRetriever(document_store=ds, mock_document=document)
ds.write_documents(documents=[document])
ds._validate_embeddings_shape = MagicMock()
ds.update_embeddings(retriever)
ds._validate_embeddings_shape.assert_called_once()
ds.update_embeddings(retriever, update_existing_embeddings=False)
ds._validate_embeddings_shape.assert_called_once()

View File

@ -20,8 +20,6 @@ from haystack.nodes.retriever.base import BaseRetriever
from haystack.pipelines import DocumentSearchPipeline from haystack.pipelines import DocumentSearchPipeline
from haystack.schema import Document from haystack.schema import Document
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.document_stores import MilvusDocumentStore
from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever from haystack.nodes.retriever.dense import DensePassageRetriever, EmbeddingRetriever, TableTextRetriever
from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever from haystack.nodes.retriever.sparse import BM25Retriever, FilterRetriever, TfidfRetriever
from haystack.nodes.retriever.multimodal import MultiModalRetriever from haystack.nodes.retriever.multimodal import MultiModalRetriever
@ -160,6 +158,9 @@ class MockBaseRetriever(MockRetriever):
): ):
return [[self.mock_document] for _ in range(len(queries))] return [[self.mock_document] for _ in range(len(queries))]
def embed_documents(self, documents: List[Document]):
return np.full((len(documents), 768), 0.5)
def test_retrieval_empty_query(document_store: BaseDocumentStore): def test_retrieval_empty_query(document_store: BaseDocumentStore):
# test with empty query using the run() method # test with empty query using the run() method