From b10e2c392e05e0baede3d2ea9673e045b14ab422 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Wed, 21 Sep 2022 19:08:54 +0200 Subject: [PATCH] chore: add `DenseRetriever` abstraction (#3252) * support cosine similiarity with faiss * update docs * update api docs * fix tests * Revert "update api docs" This reverts commit 6138fdfefb3beaee2d55c5729cd4a2745ea6b143. * fix api docs * collapse test * rename similairity to space_type mappings * only normalize for faiss * fix merge * fix docs normalization * get rid of List[np.array] * update docs * fix tests and tutorials * fix mypy * fix mypy * fix mypy again * again mypy * blacken * update tutorial 4 docs * fix embeddingretriever * fix faiss * move dense specific logic to DenseRetriever * fix mypy * cosine tests for all documents stores * fix pinecone * add docstring * docstring corrections * update docs * add integration test marker * docstrings update * update docs * fix typo * update docs * fix MockDenseRetriever * run integration tests for all documentstores * fix test_update_embeddings_cosine_similarity * fix faiss tests not running * blacken * make test_cosine_sanity_check integration test * update docs * fix imports * import DenseRetriever normally * update docs * fix deepcopy of documents * update schema * Revert "update schema" This reverts commit 83cf8f323648468e1c322d54852bec084d637e3f. * fix schema for ci manually --- docs/_src/api/api/document_store.md | 14 +- docs/_src/api/api/retriever.md | 97 ++++++++--- haystack/document_stores/base.py | 23 ++- haystack/document_stores/elasticsearch.py | 17 +- haystack/document_stores/faiss.py | 29 +--- haystack/document_stores/memory.py | 24 +-- haystack/document_stores/milvus1.py | 31 +--- haystack/document_stores/milvus2.py | 29 +--- haystack/document_stores/pinecone.py | 30 ++-- haystack/document_stores/weaviate.py | 22 +-- .../haystack-pipeline-main.schema.json | 4 +- haystack/nodes/__init__.py | 1 + .../nodes/answer_generator/transformers.py | 2 +- haystack/nodes/retriever/__init__.py | 1 + .../nodes/retriever/_embedding_encoder.py | 81 ++++++--- haystack/nodes/retriever/base.py | 13 +- haystack/nodes/retriever/dense.py | 161 +++++++++++------- test/conftest.py | 11 +- test/document_stores/test_document_store.py | 2 +- 19 files changed, 333 insertions(+), 259 deletions(-) diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md index aec2bdc88..fa78eb7f8 100644 --- a/docs/_src/api/api/document_store.md +++ b/docs/_src/api/api/document_store.md @@ -1269,7 +1269,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. #### BaseElasticsearchDocumentStore.update\_embeddings ```python -def update_embeddings(retriever, +def update_embeddings(retriever: DenseRetriever, index: Optional[str] = None, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, @@ -2097,7 +2097,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. #### InMemoryDocumentStore.update\_embeddings ```python -def update_embeddings(retriever: "BaseRetriever", +def update_embeddings(retriever: DenseRetriever, index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, update_existing_embeddings: bool = True, @@ -2913,7 +2913,7 @@ None #### FAISSDocumentStore.update\_embeddings ```python -def update_embeddings(retriever: "BaseRetriever", +def update_embeddings(retriever: DenseRetriever, index: Optional[str] = None, update_existing_embeddings: bool = True, filters: Optional[Dict[str, Any]] = None, @@ -3277,7 +3277,7 @@ None #### Milvus1DocumentStore.update\_embeddings ```python -def update_embeddings(retriever: "BaseRetriever", +def update_embeddings(retriever: DenseRetriever, index: Optional[str] = None, batch_size: int = 10_000, update_existing_embeddings: bool = True, @@ -3681,7 +3681,7 @@ exists. #### Milvus2DocumentStore.update\_embeddings ```python -def update_embeddings(retriever: "BaseRetriever", +def update_embeddings(retriever: DenseRetriever, index: Optional[str] = None, batch_size: int = 10_000, update_existing_embeddings: bool = True, @@ -4398,7 +4398,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. #### WeaviateDocumentStore.update\_embeddings ```python -def update_embeddings(retriever, +def update_embeddings(retriever: DenseRetriever, index: Optional[str] = None, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, @@ -5452,7 +5452,7 @@ Parameter options: #### PineconeDocumentStore.update\_embeddings ```python -def update_embeddings(retriever: "BaseRetriever", +def update_embeddings(retriever: DenseRetriever, index: Optional[str] = None, update_existing_embeddings: bool = True, filters: Optional[Dict[str, Union[Dict, List, str, int, diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index 7601e7647..65dd156cc 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -547,12 +547,60 @@ Performing training on this class according to the TF-IDF algorithm. # Module dense + + +## DenseRetriever + +```python +class DenseRetriever(BaseRetriever) +``` + +Base class for all dense retrievers. + + + +#### DenseRetriever.embed\_queries + +```python +@abstractmethod +def embed_queries(queries: List[str]) -> np.ndarray +``` + +Create embeddings for a list of queries. + +**Arguments**: + +- `queries`: List of queries to embed. + +**Returns**: + +Embeddings, one per input query, shape: (queries, embedding_dim) + + + +#### DenseRetriever.embed\_documents + +```python +@abstractmethod +def embed_documents(documents: List[Document]) -> np.ndarray +``` + +Create embeddings for a list of documents. + +**Arguments**: + +- `documents`: List of documents to embed. + +**Returns**: + +Embeddings of documents, one per input document, shape: (documents, embedding_dim) + ## DensePassageRetriever ```python -class DensePassageRetriever(BaseRetriever) +class DensePassageRetriever(DenseRetriever) ``` Retriever that uses a bi-encoder (one transformer for query, one transformer for passage). @@ -842,36 +890,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. #### DensePassageRetriever.embed\_queries ```python -def embed_queries(texts: List[str]) -> List[np.ndarray] +def embed_queries(queries: List[str]) -> np.ndarray ``` -Create embeddings for a list of queries using the query encoder +Create embeddings for a list of queries using the query encoder. **Arguments**: -- `texts`: Queries to embed +- `queries`: List of queries to embed. **Returns**: -Embeddings, one per input queries +Embeddings, one per input query, shape: (queries, embedding_dim) #### DensePassageRetriever.embed\_documents ```python -def embed_documents(docs: List[Document]) -> List[np.ndarray] +def embed_documents(documents: List[Document]) -> np.ndarray ``` -Create embeddings for a list of documents using the passage encoder +Create embeddings for a list of documents using the passage encoder. **Arguments**: -- `docs`: List of Document objects used to represent documents / passages in a standardized way within Haystack. +- `documents`: List of documents to embed. **Returns**: -Embeddings of documents / passages shape (batch_size, embedding_dim) +Embeddings of documents, one per input document, shape: (documents, embedding_dim) @@ -1005,7 +1053,7 @@ Load DensePassageRetriever from the specified directory. ## TableTextRetriever ```python -class TableTextRetriever(BaseRetriever) +class TableTextRetriever(DenseRetriever) ``` Retriever that uses a tri-encoder to jointly retrieve among a database consisting of text passages and tables @@ -1198,25 +1246,25 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. #### TableTextRetriever.embed\_queries ```python -def embed_queries(texts: List[str]) -> List[np.ndarray] +def embed_queries(queries: List[str]) -> np.ndarray ``` -Create embeddings for a list of queries using the query encoder +Create embeddings for a list of queries using the query encoder. **Arguments**: -- `texts`: Queries to embed +- `queries`: List of queries to embed. **Returns**: -Embeddings, one per input queries +Embeddings, one per input query, shape: (queries, embedding_dim) #### TableTextRetriever.embed\_documents ```python -def embed_documents(docs: List[Document]) -> List[np.ndarray] +def embed_documents(documents: List[Document]) -> np.ndarray ``` Create embeddings for a list of text documents and / or tables using the text passage encoder and @@ -1225,12 +1273,11 @@ the table encoder. **Arguments**: -- `docs`: List of Document objects used to represent documents / passages in -a standardized way within Haystack. +- `documents`: List of documents to embed. **Returns**: -Embeddings of documents / passages. Shape: (batch_size, embedding_dim) +Embeddings of documents, one per input document, shape: (documents, embedding_dim) @@ -1370,7 +1417,7 @@ Load TableTextRetriever from the specified directory. ## EmbeddingRetriever ```python -class EmbeddingRetriever(BaseRetriever) +class EmbeddingRetriever(DenseRetriever) ``` @@ -1638,36 +1685,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used. #### EmbeddingRetriever.embed\_queries ```python -def embed_queries(texts: List[str]) -> List[np.ndarray] +def embed_queries(queries: List[str]) -> np.ndarray ``` Create embeddings for a list of queries. **Arguments**: -- `texts`: Queries to embed +- `queries`: List of queries to embed. **Returns**: -Embeddings, one per input queries +Embeddings, one per input query, shape: (queries, embedding_dim) #### EmbeddingRetriever.embed\_documents ```python -def embed_documents(docs: List[Document]) -> List[np.ndarray] +def embed_documents(documents: List[Document]) -> np.ndarray ``` Create embeddings for a list of documents. **Arguments**: -- `docs`: List of documents to embed +- `documents`: List of documents to embed. **Returns**: -Embeddings, one per input document +Embeddings, one per input document, shape: (docs, embedding_dim) diff --git a/haystack/document_stores/base.py b/haystack/document_stores/base.py index c17553a27..5325a05ea 100644 --- a/haystack/document_stores/base.py +++ b/haystack/document_stores/base.py @@ -12,7 +12,7 @@ import numpy as np from haystack.schema import Document, Label, MultiLabel from haystack.nodes.base import BaseComponent -from haystack.errors import DuplicateDocumentError +from haystack.errors import DuplicateDocumentError, DocumentStoreError from haystack.nodes.preprocessor import PreProcessor from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl @@ -698,6 +698,27 @@ class BaseDocumentStore(BaseComponent): return [label for label in labels if label.id in duplicate_ids] + @classmethod + def _validate_embeddings_shape(cls, embeddings: np.ndarray, num_documents: int, embedding_dim: int): + """ + Validates the shape of model-generated embeddings against expected values for indexing. + + :param embeddings: Embeddings to validate + :param num_documents: Number of documents the embeddings were generated for + :param embedding_dim: Number of embedding dimensions to expect + """ + num_embeddings, embedding_size = embeddings.shape + if num_embeddings != num_documents: + raise DocumentStoreError( + "The number of embeddings does not match the number of documents: " + f"({num_embeddings} != {num_documents})" + ) + if embedding_size != embedding_dim: + raise RuntimeError( + f"Embedding dimensions of the model ({embedding_size}) don't match the embedding dimensions of the document store ({embedding_dim}). " + f"Initiate {cls.__name__} again with arg embedding_dim={embedding_size}." + ) + class KeywordDocumentStore(BaseDocumentStore): """ diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py index ba224aaa0..4df4db2f0 100644 --- a/haystack/document_stores/elasticsearch.py +++ b/haystack/document_stores/elasticsearch.py @@ -27,6 +27,7 @@ from haystack.schema import Document, Label from haystack.document_stores.base import get_batches_from_generator from haystack.document_stores.filter_utils import LogicalFilterClause from haystack.errors import DocumentStoreError, HaystackError +from haystack.nodes.retriever import DenseRetriever logger = logging.getLogger(__name__) @@ -1400,7 +1401,7 @@ class BaseElasticsearchDocumentStore(KeywordDocumentStore): def update_embeddings( self, - retriever, + retriever: DenseRetriever, index: Optional[str] = None, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, update_existing_embeddings: bool = True, @@ -1482,16 +1483,10 @@ class BaseElasticsearchDocumentStore(KeywordDocumentStore): with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar: for result_batch in get_batches_from_generator(result, batch_size): document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch] - embeddings = retriever.embed_documents(document_batch) # type: ignore - if len(document_batch) != len(embeddings): - raise DocumentStoreError( - "The number of embeddings does not match the number of documents in the batch " - f"({len(embeddings)} != {len(document_batch)})" - ) - if embeddings[0].shape[0] != self.embedding_dim: - raise RuntimeError( - f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate ElasticsearchDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}." - ) + embeddings = retriever.embed_documents(document_batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim + ) doc_updates = [] for doc, emb in zip(document_batch, embeddings): diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index 729c8e00a..33af7f0d2 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Union, List, Optional, Dict, Generator +from typing import Any, Union, List, Optional, Dict, Generator import json import logging @@ -22,10 +22,7 @@ except (ImportError, ModuleNotFoundError) as ie: from haystack.schema import Document from haystack.document_stores.base import get_batches_from_generator -from haystack.errors import DocumentStoreError - -if TYPE_CHECKING: - from haystack.nodes.retriever import BaseRetriever +from haystack.nodes.retriever import DenseRetriever logger = logging.getLogger(__name__) @@ -308,7 +305,7 @@ class FAISSDocumentStore(SQLDocumentStore): def update_embeddings( self, - retriever: "BaseRetriever", + retriever: DenseRetriever, index: Optional[str] = None, update_existing_embeddings: bool = True, filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore @@ -361,23 +358,15 @@ class FAISSDocumentStore(SQLDocumentStore): total=document_count, disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding" ) as progress_bar: for document_batch in batched_documents: - embeddings = retriever.embed_documents(document_batch) # type: ignore - if len(document_batch) != len(embeddings): - raise DocumentStoreError( - "The number of embeddings does not match the number of documents in the batch " - f"({len(embeddings)} != {len(document_batch)})" - ) - if embeddings[0].shape[0] != self.embedding_dim: - raise RuntimeError( - f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate FAISSDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}." - ) - - embeddings_to_index = np.array(embeddings, dtype="float32") + embeddings = retriever.embed_documents(document_batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim + ) if self.similarity == "cosine": - self.normalize_embedding(embeddings_to_index) + self.normalize_embedding(embeddings) - self.faiss_indexes[index].add(embeddings_to_index) + self.faiss_indexes[index].add(embeddings.astype(np.float32)) vector_id_map = {} for doc in document_batch: diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py index f99efb294..4840dc8fe 100644 --- a/haystack/document_stores/memory.py +++ b/haystack/document_stores/memory.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Generator +from typing import Any, Dict, List, Optional, Union, Generator import time import logging @@ -10,14 +10,12 @@ import torch from tqdm import tqdm from haystack.schema import Document, Label -from haystack.errors import DuplicateDocumentError, DocumentStoreError +from haystack.errors import DuplicateDocumentError from haystack.document_stores import BaseDocumentStore from haystack.document_stores.base import get_batches_from_generator from haystack.modeling.utils import initialize_device_settings from haystack.document_stores.filter_utils import LogicalFilterClause - -if TYPE_CHECKING: - from haystack.nodes.retriever import BaseRetriever +from haystack.nodes.retriever import DenseRetriever logger = logging.getLogger(__name__) @@ -398,7 +396,7 @@ class InMemoryDocumentStore(BaseDocumentStore): def update_embeddings( self, - retriever: "BaseRetriever", + retriever: DenseRetriever, index: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in InMemoryDocStore update_existing_embeddings: bool = True, @@ -457,16 +455,10 @@ class InMemoryDocumentStore(BaseDocumentStore): total=len(result), disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding" ) as progress_bar: for document_batch in batched_documents: - embeddings = retriever.embed_documents(document_batch) # type: ignore - if len(document_batch) != len(embeddings): - raise DocumentStoreError( - "The number of embeddings does not match the number of documents in the batch " - f"({len(embeddings)} != {len(document_batch)})" - ) - if embeddings[0].shape[0] != self.embedding_dim: - raise RuntimeError( - f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate InMemoryDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}." - ) + embeddings = retriever.embed_documents(document_batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim + ) for doc, emb in zip(document_batch, embeddings): self.indexes[index][doc.id].embedding = emb diff --git a/haystack/document_stores/milvus1.py b/haystack/document_stores/milvus1.py index a04513f16..e9582abd9 100644 --- a/haystack/document_stores/milvus1.py +++ b/haystack/document_stores/milvus1.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Union import logging import warnings @@ -15,10 +15,7 @@ except (ImportError, ModuleNotFoundError) as ie: from haystack.schema import Document from haystack.document_stores.base import get_batches_from_generator -from haystack.errors import DocumentStoreError - -if TYPE_CHECKING: - from haystack.nodes.retriever import BaseRetriever +from haystack.nodes.retriever import DenseRetriever logger = logging.getLogger(__name__) @@ -289,7 +286,7 @@ class Milvus1DocumentStore(SQLDocumentStore): def update_embeddings( self, - retriever: "BaseRetriever", + retriever: DenseRetriever, index: Optional[str] = None, batch_size: int = 10_000, update_existing_embeddings: bool = True, @@ -335,25 +332,15 @@ class Milvus1DocumentStore(SQLDocumentStore): for document_batch in batched_documents: self._delete_vector_ids_from_milvus(documents=document_batch, index=index) - embeddings = retriever.embed_documents(document_batch) # type: ignore - if len(document_batch) != len(embeddings): - raise DocumentStoreError( - "The number of embeddings does not match the number of documents in the batch " - f"({len(embeddings)} != {len(document_batch)})" - ) - if embeddings[0].shape[0] != self.embedding_dim: - raise RuntimeError( - f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate MilvusDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}." - ) + embeddings = retriever.embed_documents(document_batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim + ) if self.similarity == "cosine": - for embedding in embeddings: - self.normalize_embedding(embedding) + self.normalize_embedding(embeddings) - embeddings_list = [embedding.tolist() for embedding in embeddings] - assert len(document_batch) == len(embeddings_list) - - status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings_list) + status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings.tolist()) if status.code != Status.SUCCESS: raise RuntimeError(f"Vector embedding insertion failed: {status}") diff --git a/haystack/document_stores/milvus2.py b/haystack/document_stores/milvus2.py index 980c0b314..a687d344b 100644 --- a/haystack/document_stores/milvus2.py +++ b/haystack/document_stores/milvus2.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Union import logging import warnings @@ -18,10 +18,7 @@ except (ImportError, ModuleNotFoundError) as ie: from haystack.schema import Document from haystack.document_stores.sql import SQLDocumentStore from haystack.document_stores.base import get_batches_from_generator -from haystack.errors import DocumentStoreError - -if TYPE_CHECKING: - from haystack.nodes.retriever.base import BaseRetriever +from haystack.nodes.retriever import DenseRetriever logger = logging.getLogger(__name__) @@ -325,7 +322,7 @@ class Milvus2DocumentStore(SQLDocumentStore): def update_embeddings( self, - retriever: "BaseRetriever", + retriever: DenseRetriever, index: Optional[str] = None, batch_size: int = 10_000, update_existing_embeddings: bool = True, @@ -369,23 +366,15 @@ class Milvus2DocumentStore(SQLDocumentStore): for document_batch in batched_documents: self._delete_vector_ids_from_milvus(documents=document_batch, index=index) - embeddings = retriever.embed_documents(document_batch) # type: ignore - if len(document_batch) != len(embeddings): - raise DocumentStoreError( - "The number of embeddings does not match the number of documents in the batch " - f"({len(embeddings)} != {len(document_batch)})" - ) - if embeddings[0].shape[0] != self.embedding_dim: - raise RuntimeError( - f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate MilvusDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}." - ) + embeddings = retriever.embed_documents(document_batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim + ) if self.cosine: - embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings] - embeddings_list = [embedding.tolist() for embedding in embeddings] - assert len(document_batch) == len(embeddings_list) + self.normalize_embedding(embeddings) - mutation_result = self.collection.insert([embeddings_list]) + mutation_result = self.collection.insert([embeddings.tolist()]) vector_id_map = {} for vector_id, doc in zip(mutation_result.primary_keys, document_batch): diff --git a/haystack/document_stores/pinecone.py b/haystack/document_stores/pinecone.py index a4616b638..0883db888 100644 --- a/haystack/document_stores/pinecone.py +++ b/haystack/document_stores/pinecone.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Set, Union, List, Optional, Dict, Generator, Any +from typing import Set, Union, List, Optional, Dict, Generator, Any import logging from itertools import islice @@ -11,10 +11,8 @@ from haystack.schema import Document, Label, Answer, Span from haystack.document_stores import BaseDocumentStore from haystack.document_stores.filter_utils import LogicalFilterClause -from haystack.errors import PineconeDocumentStoreError, DuplicateDocumentError, DocumentStoreError - -if TYPE_CHECKING: - from haystack.nodes.retriever import BaseRetriever +from haystack.errors import PineconeDocumentStoreError, DuplicateDocumentError +from haystack.nodes.retriever import DenseRetriever logger = logging.getLogger(__name__) @@ -416,7 +414,7 @@ class PineconeDocumentStore(BaseDocumentStore): def update_embeddings( self, - retriever: "BaseRetriever", + retriever: DenseRetriever, index: Optional[str] = None, update_existing_embeddings: bool = True, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, @@ -489,21 +487,13 @@ class PineconeDocumentStore(BaseDocumentStore): ) as progress_bar: for _ in range(0, document_count, batch_size): document_batch = list(islice(documents, batch_size)) - embeddings = retriever.embed_documents(document_batch) # type: ignore - if len(document_batch) != len(embeddings): - raise DocumentStoreError( - "The number of embeddings does not match the number of documents in the batch " - f"({len(embeddings)} != {len(document_batch)})" - ) - if embeddings[0].shape[0] != self.embedding_dim: - raise RuntimeError( - f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate PineconeDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}." - ) + embeddings = retriever.embed_documents(document_batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim + ) - embeddings_to_index = np.array(embeddings, dtype="float32") if self.similarity == "cosine": - self.normalize_embedding(embeddings_to_index) - embeddings = embeddings_to_index.tolist() + self.normalize_embedding(embeddings) metadata = [] ids = [] @@ -512,7 +502,7 @@ class PineconeDocumentStore(BaseDocumentStore): ids.append(doc.id) # Update existing vectors in pinecone index self.pinecone_indexes[index].upsert( - vectors=zip(ids, embeddings, metadata), namespace=self.embedding_namespace + vectors=zip(ids, embeddings.tolist(), metadata), namespace=self.embedding_namespace ) # Delete existing vectors from document namespace if they exist there self.delete_documents(index=index, ids=ids, namespace=self.document_namespace) diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py index 11dcce0a7..162f98997 100644 --- a/haystack/document_stores/weaviate.py +++ b/haystack/document_stores/weaviate.py @@ -23,6 +23,7 @@ from haystack.document_stores.base import get_batches_from_generator from haystack.document_stores.filter_utils import LogicalFilterClause from haystack.document_stores.utils import convert_date_to_rfc3339 from haystack.errors import DocumentStoreError +from haystack.nodes.retriever import DenseRetriever logger = logging.getLogger(__name__) @@ -1166,7 +1167,7 @@ class WeaviateDocumentStore(BaseDocumentStore): def update_embeddings( self, - retriever, + retriever: DenseRetriever, index: Optional[str] = None, filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None, update_existing_embeddings: bool = True, @@ -1230,21 +1231,16 @@ class WeaviateDocumentStore(BaseDocumentStore): document_batch = [ self._convert_weaviate_result_to_document(hit, return_embedding=False) for hit in result_batch ] - embeddings = retriever.embed_documents(document_batch) # type: ignore - if len(document_batch) != len(embeddings): - raise DocumentStoreError( - "The number of embeddings does not match the number of documents in the batch " - f"({len(embeddings)} != {len(document_batch)})" - ) - if embeddings[0].shape[0] != self.embedding_dim: - raise RuntimeError( - f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate WeaviateDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}." - ) + embeddings = retriever.embed_documents(document_batch) + self._validate_embeddings_shape( + embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim + ) + + if self.similarity == "cosine": + self.normalize_embedding(embeddings) for doc, emb in zip(document_batch, embeddings): # Using update method to only update the embeddings, other properties will be in tact - if self.similarity == "cosine": - self.normalize_embedding(emb) self.weaviate_client.data_object.update({}, class_name=index, uuid=doc.id, vector=emb) def delete_all_documents( diff --git a/haystack/json-schemas/haystack-pipeline-main.schema.json b/haystack/json-schemas/haystack-pipeline-main.schema.json index f00c87f82..28a644af3 100644 --- a/haystack/json-schemas/haystack-pipeline-main.schema.json +++ b/haystack/json-schemas/haystack-pipeline-main.schema.json @@ -5570,10 +5570,10 @@ "title": "Use Auth Token", "anyOf": [ { - "type": "boolean" + "type": "string" }, { - "type": "string" + "type": "boolean" }, { "type": "null" diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py index d10b7d10a..6fd9b63c4 100644 --- a/haystack/nodes/__init__.py +++ b/haystack/nodes/__init__.py @@ -29,6 +29,7 @@ from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader, RCIReader from haystack.nodes.retriever import ( BaseRetriever, + DenseRetriever, DensePassageRetriever, EmbeddingRetriever, BM25Retriever, diff --git a/haystack/nodes/answer_generator/transformers.py b/haystack/nodes/answer_generator/transformers.py index ab4a70bdf..7d8fa4285 100644 --- a/haystack/nodes/answer_generator/transformers.py +++ b/haystack/nodes/answer_generator/transformers.py @@ -188,7 +188,7 @@ class RAGenerator(BaseGenerator): self.devices[0] ) - def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[numpy.ndarray]) -> torch.Tensor: + def _prepare_passage_embeddings(self, docs: List[Document], embeddings: numpy.ndarray) -> torch.Tensor: # If document missing embedding, then need embedding for all the documents is_embedding_required = embeddings is None or any(embedding is None for embedding in embeddings) diff --git a/haystack/nodes/retriever/__init__.py b/haystack/nodes/retriever/__init__.py index f1805f6f6..c7c261a04 100644 --- a/haystack/nodes/retriever/__init__.py +++ b/haystack/nodes/retriever/__init__.py @@ -1,5 +1,6 @@ from haystack.nodes.retriever.base import BaseRetriever from haystack.nodes.retriever.dense import ( + DenseRetriever, DensePassageRetriever, EmbeddingRetriever, MultihopEmbeddingRetriever, diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index 46e4de1bc..a7c30f4c0 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -26,22 +26,22 @@ logger = logging.getLogger(__name__) class _BaseEmbeddingEncoder: @abstractmethod - def embed_queries(self, texts: List[str]) -> List[np.ndarray]: + def embed_queries(self, queries: List[str]) -> np.ndarray: """ Create embeddings for a list of queries. - :param texts: Queries to embed - :return: Embeddings, one per input queries + :param queries: List of queries to embed. + :return: Embeddings, one per input query, shape: (queries, embedding_dim) """ pass @abstractmethod - def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: + def embed_documents(self, docs: List[Document]) -> np.ndarray: """ Create embeddings for a list of documents. - :param docs: List of documents to embed - :return: Embeddings, one per input document + :param docs: List of documents to embed. + :return: Embeddings, one per input document, shape: (documents, embedding_dim) """ pass @@ -118,18 +118,30 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder): f"This can be set when initializing the DocumentStore" ) - def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]: + def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray: # TODO: FARM's `sample_to_features_text` need to fix following warning - # tokenization_utils.py:460: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead. emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) - emb = [(r["vec"]) for r in emb] + emb = np.stack([r["vec"] for r in emb]) return emb - def embed_queries(self, texts: List[str]) -> List[np.ndarray]: - return self.embed(texts) + def embed_queries(self, queries: List[str]) -> np.ndarray: + """ + Create embeddings for a list of queries. - def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: - passages = [d.content for d in docs] # type: ignore + :param queries: List of queries to embed. + :return: Embeddings, one per input query, shape: (queries, embedding_dim) + """ + return self.embed(queries) + + def embed_documents(self, docs: List[Document]) -> np.ndarray: + """ + Create embeddings for a list of documents. + + :param docs: List of documents to embed. + :return: Embeddings, one per input document, shape: (documents, embedding_dim) + """ + passages = [d.content for d in docs] return self.embed(passages) def train( @@ -175,18 +187,31 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder): f"This can be set when initializing the DocumentStore" ) - def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]: + def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray: # texts can be a list of strings or a list of [title, text] # get back list of numpy embedding vectors - emb = self.embedding_model.encode(texts, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar) - emb = [r for r in emb] + emb = self.embedding_model.encode( + texts, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True + ) return emb - def embed_queries(self, texts: List[str]) -> List[np.ndarray]: - return self.embed(texts) + def embed_queries(self, queries: List[str]) -> np.ndarray: + """ + Create embeddings for a list of queries. - def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: - passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] # type: ignore + :param queries: List of queries to embed. + :return: Embeddings, one per input query, shape: (queries, embedding_dim) + """ + return self.embed(queries) + + def embed_documents(self, docs: List[Document]) -> np.ndarray: + """ + Create embeddings for a list of documents. + + :param docs: List of documents to embed. + :return: Embeddings, one per input document, shape: (documents, embedding_dim) + """ + passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] return self.embed(passages) def train( @@ -250,10 +275,15 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder): retriever.embedding_model, use_auth_token=retriever.use_auth_token ).to(str(retriever.devices[0])) - def embed_queries(self, texts: List[str]) -> List[np.ndarray]: + def embed_queries(self, queries: List[str]) -> np.ndarray: + """ + Create embeddings for a list of queries. - queries = [{"text": q} for q in texts] - dataloader = self._create_dataloader(queries) + :param queries: List of queries to embed. + :return: Embeddings, one per input query, shape: (queries, embedding_dim) + """ + query_text = [{"text": q} for q in queries] + dataloader = self._create_dataloader(query_text) embeddings: List[np.ndarray] = [] disable_tqdm = True if len(dataloader) == 1 else not self.progress_bar @@ -272,8 +302,13 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder): return np.concatenate(embeddings) - def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: + def embed_documents(self, docs: List[Document]) -> np.ndarray: + """ + Create embeddings for a list of documents. + :param docs: List of documents to embed. + :return: Embeddings, one per input document, shape: (documents, embedding_dim) + """ doc_text = [{"text": d.content} for d in docs] dataloader = self._create_dataloader(doc_text) diff --git a/haystack/nodes/retriever/base.py b/haystack/nodes/retriever/base.py index f5644d5e7..149cb5024 100644 --- a/haystack/nodes/retriever/base.py +++ b/haystack/nodes/retriever/base.py @@ -4,7 +4,6 @@ import logging from abc import abstractmethod from time import perf_counter from functools import wraps -from copy import deepcopy from tqdm import tqdm @@ -263,7 +262,7 @@ class BaseRetriever(BaseComponent): query: Optional[str] = None, filters: Optional[dict] = None, top_k: Optional[int] = None, - documents: Optional[List[dict]] = None, + documents: Optional[List[Document]] = None, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None, scale_score: bool = None, @@ -279,7 +278,7 @@ class BaseRetriever(BaseComponent): query=query, filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score ) elif root_node == "File": - self.index_count += len(documents) # type: ignore + self.index_count += len(documents) if documents else 0 run_indexing = self.timing(self.run_indexing, "index_time") output, stream = run_indexing(documents=documents) else: @@ -362,13 +361,7 @@ class BaseRetriever(BaseComponent): return output, "output_1" - def run_indexing(self, documents: List[Union[dict, Document]]): - if self.__class__.__name__ in ["DensePassageRetriever", "EmbeddingRetriever"]: - documents = deepcopy(documents) - document_objects = [Document.from_dict(doc) if isinstance(doc, dict) else doc for doc in documents] - embeddings = self.embed_documents(document_objects) # type: ignore - for doc, emb in zip(document_objects, embeddings): - doc.embedding = emb + def run_indexing(self, documents: List[Document]): output = {"documents": documents} return output, "output_1" diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index 4491da525..132337e6c 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import List, Dict, Union, Optional, Any import logging @@ -42,7 +43,40 @@ from haystack.modeling.utils import initialize_device_settings logger = logging.getLogger(__name__) -class DensePassageRetriever(BaseRetriever): +class DenseRetriever(BaseRetriever): + """ + Base class for all dense retrievers. + """ + + @abstractmethod + def embed_queries(self, queries: List[str]) -> np.ndarray: + """ + Create embeddings for a list of queries. + + :param queries: List of queries to embed. + :return: Embeddings, one per input query, shape: (queries, embedding_dim) + """ + pass + + @abstractmethod + def embed_documents(self, documents: List[Document]) -> np.ndarray: + """ + Create embeddings for a list of documents. + + :param documents: List of documents to embed. + :return: Embeddings of documents, one per input document, shape: (documents, embedding_dim) + """ + pass + + def run_indexing(self, documents: List[Document]): + embeddings = self.embed_documents(documents) + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + output = {"documents": documents} + return output, "output_1" + + +class DensePassageRetriever(DenseRetriever): """ Retriever that uses a bi-encoder (one transformer for query, one transformer for passage). See the original paper for more details: @@ -302,7 +336,7 @@ class DensePassageRetriever(BaseRetriever): index = self.document_store.index if scale_score is None: scale_score = self.scale_score - query_emb = self.embed_queries(texts=[query]) + query_emb = self.embed_queries(queries=[query]) documents = self.document_store.query_by_embedding( query_emb=query_emb[0], top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score ) @@ -430,9 +464,9 @@ class DensePassageRetriever(BaseRetriever): return [[] * len(queries)] # type: ignore documents = [] - query_embs = [] + query_embs: List[np.ndarray] = [] for batch in self._get_batches(queries=queries, batch_size=batch_size): - query_embs.extend(self.embed_queries(texts=batch)) + query_embs.extend(self.embed_queries(queries=batch)) for query_emb, cur_filters in tqdm( zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying" ): @@ -448,7 +482,7 @@ class DensePassageRetriever(BaseRetriever): return documents - def _get_predictions(self, dicts): + def _get_predictions(self, dicts: List[Dict[str, Any]]) -> Dict[str, np.ndarray]: """ Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting). @@ -472,7 +506,8 @@ class DensePassageRetriever(BaseRetriever): data_loader = NamedDataLoader( dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names ) - all_embeddings = {"query": [], "passages": []} + query_embeddings_batched = [] + passage_embeddings_batched = [] self.model.eval() # When running evaluations etc., we don't want a progress bar for every single query @@ -503,34 +538,35 @@ class DensePassageRetriever(BaseRetriever): passage_attention_mask=batch.get("passage_attention_mask", None), )[0] if query_embeddings is not None: - all_embeddings["query"].append(query_embeddings.cpu().numpy()) + query_embeddings_batched.append(query_embeddings.cpu().numpy()) if passage_embeddings is not None: - all_embeddings["passages"].append(passage_embeddings.cpu().numpy()) + passage_embeddings_batched.append(passage_embeddings.cpu().numpy()) progress_bar.update(self.batch_size) - if all_embeddings["passages"]: - all_embeddings["passages"] = np.concatenate(all_embeddings["passages"]) - if all_embeddings["query"]: - all_embeddings["query"] = np.concatenate(all_embeddings["query"]) + all_embeddings: Dict[str, np.ndarray] = {} + if passage_embeddings_batched: + all_embeddings["passages"] = np.concatenate(passage_embeddings_batched) + if query_embeddings_batched: + all_embeddings["query"] = np.concatenate(query_embeddings_batched) return all_embeddings - def embed_queries(self, texts: List[str]) -> List[np.ndarray]: + def embed_queries(self, queries: List[str]) -> np.ndarray: """ - Create embeddings for a list of queries using the query encoder + Create embeddings for a list of queries using the query encoder. - :param texts: Queries to embed - :return: Embeddings, one per input queries + :param queries: List of queries to embed. + :return: Embeddings, one per input query, shape: (queries, embedding_dim) """ - queries = [{"query": q} for q in texts] - result = self._get_predictions(queries)["query"] + query_dicts = [{"query": q} for q in queries] + result = self._get_predictions(query_dicts)["query"] return result - def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: + def embed_documents(self, documents: List[Document]) -> np.ndarray: """ - Create embeddings for a list of documents using the passage encoder + Create embeddings for a list of documents using the passage encoder. - :param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack. - :return: Embeddings of documents / passages shape (batch_size, embedding_dim) + :param documents: List of documents to embed. + :return: Embeddings of documents, one per input document, shape: (documents, embedding_dim) """ if self.processor.num_hard_negatives != 0: logger.warning( @@ -550,7 +586,7 @@ class DensePassageRetriever(BaseRetriever): } ] } - for d in docs + for d in documents ] embeddings = self._get_predictions(passages)["passages"] return embeddings @@ -757,7 +793,7 @@ class DensePassageRetriever(BaseRetriever): return dpr -class TableTextRetriever(BaseRetriever): +class TableTextRetriever(DenseRetriever): """ Retriever that uses a tri-encoder to jointly retrieve among a database consisting of text passages and tables (one transformer for query, one transformer for text passages, one transformer for tables). @@ -950,7 +986,7 @@ class TableTextRetriever(BaseRetriever): index = self.document_store.index if scale_score is None: scale_score = self.scale_score - query_emb = self.embed_queries(texts=[query]) + query_emb = self.embed_queries(queries=[query]) documents = self.document_store.query_by_embedding( query_emb=query_emb[0], top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score ) @@ -1078,9 +1114,9 @@ class TableTextRetriever(BaseRetriever): return [[] * len(queries)] # type: ignore documents = [] - query_embs = [] + query_embs: List[np.ndarray] = [] for batch in self._get_batches(queries=queries, batch_size=batch_size): - query_embs.extend(self.embed_queries(texts=batch)) + query_embs.extend(self.embed_queries(queries=batch)) for query_emb, cur_filters in tqdm( zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying" ): @@ -1096,7 +1132,7 @@ class TableTextRetriever(BaseRetriever): return documents - def _get_predictions(self, dicts: List[Dict]) -> Dict[str, List[np.ndarray]]: + def _get_predictions(self, dicts: List[Dict[str, Any]]) -> Dict[str, np.ndarray]: """ Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting). @@ -1121,7 +1157,8 @@ class TableTextRetriever(BaseRetriever): data_loader = NamedDataLoader( dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names ) - all_embeddings: Dict = {"query": [], "passages": []} + query_embeddings_batched = [] + passage_embeddings_batched = [] self.model.eval() # When running evaluations etc., we don't want a progress bar for every single query @@ -1145,36 +1182,36 @@ class TableTextRetriever(BaseRetriever): with torch.no_grad(): query_embeddings, passage_embeddings = self.model.forward(**batch)[0] if query_embeddings is not None: - all_embeddings["query"].append(query_embeddings.cpu().numpy()) + query_embeddings_batched.append(query_embeddings.cpu().numpy()) if passage_embeddings is not None: - all_embeddings["passages"].append(passage_embeddings.cpu().numpy()) + passage_embeddings_batched.append(passage_embeddings.cpu().numpy()) progress_bar.update(self.batch_size) - if all_embeddings["passages"]: - all_embeddings["passages"] = np.concatenate(all_embeddings["passages"]) - if all_embeddings["query"]: - all_embeddings["query"] = np.concatenate(all_embeddings["query"]) + all_embeddings: Dict[str, np.ndarray] = {} + if passage_embeddings_batched: + all_embeddings["passages"] = np.concatenate(passage_embeddings_batched) + if query_embeddings_batched: + all_embeddings["query"] = np.concatenate(query_embeddings_batched) return all_embeddings - def embed_queries(self, texts: List[str]) -> List[np.ndarray]: + def embed_queries(self, queries: List[str]) -> np.ndarray: """ - Create embeddings for a list of queries using the query encoder + Create embeddings for a list of queries using the query encoder. - :param texts: Queries to embed - :return: Embeddings, one per input queries + :param queries: List of queries to embed. + :return: Embeddings, one per input query, shape: (queries, embedding_dim) """ - queries = [{"query": q} for q in texts] - result = self._get_predictions(queries)["query"] + query_dicts = [{"query": q} for q in queries] + result = self._get_predictions(query_dicts)["query"] return result - def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: + def embed_documents(self, documents: List[Document]) -> np.ndarray: """ Create embeddings for a list of text documents and / or tables using the text passage encoder and the table encoder. - :param docs: List of Document objects used to represent documents / passages in - a standardized way within Haystack. - :return: Embeddings of documents / passages. Shape: (batch_size, embedding_dim) + :param documents: List of documents to embed. + :return: Embeddings of documents, one per input document, shape: (documents, embedding_dim) """ if self.processor.num_hard_negatives != 0: @@ -1185,7 +1222,7 @@ class TableTextRetriever(BaseRetriever): self.processor.num_hard_negatives = 0 model_input = [] - for doc in docs: + for doc in documents: if doc.content_type == "table": model_input.append( { @@ -1440,7 +1477,7 @@ class TableTextRetriever(BaseRetriever): return mm_retriever -class EmbeddingRetriever(BaseRetriever): +class EmbeddingRetriever(DenseRetriever): def __init__( self, document_store: BaseDocumentStore, @@ -1639,7 +1676,7 @@ class EmbeddingRetriever(BaseRetriever): index = self.document_store.index if scale_score is None: scale_score = self.scale_score - query_emb = self.embed_queries(texts=[query]) + query_emb = self.embed_queries(queries=[query]) documents = self.document_store.query_by_embedding( query_emb=query_emb[0], filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score ) @@ -1767,9 +1804,9 @@ class EmbeddingRetriever(BaseRetriever): return [[] * len(queries)] # type: ignore documents = [] - query_embs = [] + query_embs: List[np.ndarray] = [] for batch in self._get_batches(queries=queries, batch_size=batch_size): - query_embs.extend(self.embed_queries(texts=batch)) + query_embs.extend(self.embed_queries(queries=batch)) for query_emb, cur_filters in tqdm( zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying" ): @@ -1785,28 +1822,28 @@ class EmbeddingRetriever(BaseRetriever): return documents - def embed_queries(self, texts: List[str]) -> List[np.ndarray]: + def embed_queries(self, queries: List[str]) -> np.ndarray: """ Create embeddings for a list of queries. - :param texts: Queries to embed - :return: Embeddings, one per input queries + :param queries: List of queries to embed. + :return: Embeddings, one per input query, shape: (queries, embedding_dim) """ # for backward compatibility: cast pure str input - if isinstance(texts, str): - texts = [texts] - assert isinstance(texts, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])" - return self.embedding_encoder.embed_queries(texts) + if isinstance(queries, str): + queries = [queries] + assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])" + return self.embedding_encoder.embed_queries(queries) - def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: + def embed_documents(self, documents: List[Document]) -> np.ndarray: """ Create embeddings for a list of documents. - :param docs: List of documents to embed - :return: Embeddings, one per input document + :param documents: List of documents to embed. + :return: Embeddings, one per input document, shape: (docs, embedding_dim) """ - docs = self._preprocess_documents(docs) - return self.embedding_encoder.embed_documents(docs) + documents = self._preprocess_documents(documents) + return self.embedding_encoder.embed_documents(documents) def _preprocess_documents(self, docs: List[Document]) -> List[Document]: """ diff --git a/test/conftest.py b/test/conftest.py index 4180cfccb..e28adc39b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -57,6 +57,7 @@ from haystack.nodes import ( BaseGenerator, BaseSummarizer, BaseTranslator, + DenseRetriever, ) from haystack.nodes.answer_generator.transformers import Seq2SeqGenerator from haystack.nodes.answer_generator.transformers import RAGenerator @@ -308,16 +309,16 @@ class MockTranslator(BaseTranslator): pass -class MockDenseRetriever(MockRetriever): +class MockDenseRetriever(MockRetriever, DenseRetriever): def __init__(self, document_store: BaseDocumentStore, embedding_dim: int = 768): self.embedding_dim = embedding_dim self.document_store = document_store - def embed_queries(self, texts): - return [np.random.rand(self.embedding_dim)] * len(texts) + def embed_queries(self, queries): + return np.random.rand(len(queries), self.embedding_dim) - def embed_documents(self, docs): - return [np.random.rand(self.embedding_dim)] * len(docs) + def embed_documents(self, documents): + return np.random.rand(len(documents), self.embedding_dim) class MockQuestionGenerator(QuestionGenerator): diff --git a/test/document_stores/test_document_store.py b/test/document_stores/test_document_store.py index 5db4f6163..71d7134d6 100644 --- a/test/document_stores/test_document_store.py +++ b/test/document_stores/test_document_store.py @@ -576,7 +576,7 @@ def test_update_embeddings(document_store, retriever): "content": "text_7", "id": "7", "meta_field": "value_7", - "embedding": retriever.embed_queries(texts=["a random string"])[0], + "embedding": retriever.embed_queries(queries=["a random string"])[0], } document_store.write_documents([doc])