diff --git a/docs/_src/api/api/document_store.md b/docs/_src/api/api/document_store.md
index aec2bdc88..fa78eb7f8 100644
--- a/docs/_src/api/api/document_store.md
+++ b/docs/_src/api/api/document_store.md
@@ -1269,7 +1269,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### BaseElasticsearchDocumentStore.update\_embeddings
```python
-def update_embeddings(retriever,
+def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int,
float, bool]]] = None,
@@ -2097,7 +2097,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### InMemoryDocumentStore.update\_embeddings
```python
-def update_embeddings(retriever: "BaseRetriever",
+def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None,
update_existing_embeddings: bool = True,
@@ -2913,7 +2913,7 @@ None
#### FAISSDocumentStore.update\_embeddings
```python
-def update_embeddings(retriever: "BaseRetriever",
+def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Any]] = None,
@@ -3277,7 +3277,7 @@ None
#### Milvus1DocumentStore.update\_embeddings
```python
-def update_embeddings(retriever: "BaseRetriever",
+def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
@@ -3681,7 +3681,7 @@ exists.
#### Milvus2DocumentStore.update\_embeddings
```python
-def update_embeddings(retriever: "BaseRetriever",
+def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
@@ -4398,7 +4398,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### WeaviateDocumentStore.update\_embeddings
```python
-def update_embeddings(retriever,
+def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int,
float, bool]]] = None,
@@ -5452,7 +5452,7 @@ Parameter options:
#### PineconeDocumentStore.update\_embeddings
```python
-def update_embeddings(retriever: "BaseRetriever",
+def update_embeddings(retriever: DenseRetriever,
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Union[Dict, List, str, int,
diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md
index 7601e7647..65dd156cc 100644
--- a/docs/_src/api/api/retriever.md
+++ b/docs/_src/api/api/retriever.md
@@ -547,12 +547,60 @@ Performing training on this class according to the TF-IDF algorithm.
# Module dense
+
+
+## DenseRetriever
+
+```python
+class DenseRetriever(BaseRetriever)
+```
+
+Base class for all dense retrievers.
+
+
+
+#### DenseRetriever.embed\_queries
+
+```python
+@abstractmethod
+def embed_queries(queries: List[str]) -> np.ndarray
+```
+
+Create embeddings for a list of queries.
+
+**Arguments**:
+
+- `queries`: List of queries to embed.
+
+**Returns**:
+
+Embeddings, one per input query, shape: (queries, embedding_dim)
+
+
+
+#### DenseRetriever.embed\_documents
+
+```python
+@abstractmethod
+def embed_documents(documents: List[Document]) -> np.ndarray
+```
+
+Create embeddings for a list of documents.
+
+**Arguments**:
+
+- `documents`: List of documents to embed.
+
+**Returns**:
+
+Embeddings of documents, one per input document, shape: (documents, embedding_dim)
+
## DensePassageRetriever
```python
-class DensePassageRetriever(BaseRetriever)
+class DensePassageRetriever(DenseRetriever)
```
Retriever that uses a bi-encoder (one transformer for query, one transformer for passage).
@@ -842,36 +890,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### DensePassageRetriever.embed\_queries
```python
-def embed_queries(texts: List[str]) -> List[np.ndarray]
+def embed_queries(queries: List[str]) -> np.ndarray
```
-Create embeddings for a list of queries using the query encoder
+Create embeddings for a list of queries using the query encoder.
**Arguments**:
-- `texts`: Queries to embed
+- `queries`: List of queries to embed.
**Returns**:
-Embeddings, one per input queries
+Embeddings, one per input query, shape: (queries, embedding_dim)
#### DensePassageRetriever.embed\_documents
```python
-def embed_documents(docs: List[Document]) -> List[np.ndarray]
+def embed_documents(documents: List[Document]) -> np.ndarray
```
-Create embeddings for a list of documents using the passage encoder
+Create embeddings for a list of documents using the passage encoder.
**Arguments**:
-- `docs`: List of Document objects used to represent documents / passages in a standardized way within Haystack.
+- `documents`: List of documents to embed.
**Returns**:
-Embeddings of documents / passages shape (batch_size, embedding_dim)
+Embeddings of documents, one per input document, shape: (documents, embedding_dim)
@@ -1005,7 +1053,7 @@ Load DensePassageRetriever from the specified directory.
## TableTextRetriever
```python
-class TableTextRetriever(BaseRetriever)
+class TableTextRetriever(DenseRetriever)
```
Retriever that uses a tri-encoder to jointly retrieve among a database consisting of text passages and tables
@@ -1198,25 +1246,25 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### TableTextRetriever.embed\_queries
```python
-def embed_queries(texts: List[str]) -> List[np.ndarray]
+def embed_queries(queries: List[str]) -> np.ndarray
```
-Create embeddings for a list of queries using the query encoder
+Create embeddings for a list of queries using the query encoder.
**Arguments**:
-- `texts`: Queries to embed
+- `queries`: List of queries to embed.
**Returns**:
-Embeddings, one per input queries
+Embeddings, one per input query, shape: (queries, embedding_dim)
#### TableTextRetriever.embed\_documents
```python
-def embed_documents(docs: List[Document]) -> List[np.ndarray]
+def embed_documents(documents: List[Document]) -> np.ndarray
```
Create embeddings for a list of text documents and / or tables using the text passage encoder and
@@ -1225,12 +1273,11 @@ the table encoder.
**Arguments**:
-- `docs`: List of Document objects used to represent documents / passages in
-a standardized way within Haystack.
+- `documents`: List of documents to embed.
**Returns**:
-Embeddings of documents / passages. Shape: (batch_size, embedding_dim)
+Embeddings of documents, one per input document, shape: (documents, embedding_dim)
@@ -1370,7 +1417,7 @@ Load TableTextRetriever from the specified directory.
## EmbeddingRetriever
```python
-class EmbeddingRetriever(BaseRetriever)
+class EmbeddingRetriever(DenseRetriever)
```
@@ -1638,36 +1685,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
#### EmbeddingRetriever.embed\_queries
```python
-def embed_queries(texts: List[str]) -> List[np.ndarray]
+def embed_queries(queries: List[str]) -> np.ndarray
```
Create embeddings for a list of queries.
**Arguments**:
-- `texts`: Queries to embed
+- `queries`: List of queries to embed.
**Returns**:
-Embeddings, one per input queries
+Embeddings, one per input query, shape: (queries, embedding_dim)
#### EmbeddingRetriever.embed\_documents
```python
-def embed_documents(docs: List[Document]) -> List[np.ndarray]
+def embed_documents(documents: List[Document]) -> np.ndarray
```
Create embeddings for a list of documents.
**Arguments**:
-- `docs`: List of documents to embed
+- `documents`: List of documents to embed.
**Returns**:
-Embeddings, one per input document
+Embeddings, one per input document, shape: (docs, embedding_dim)
diff --git a/haystack/document_stores/base.py b/haystack/document_stores/base.py
index c17553a27..5325a05ea 100644
--- a/haystack/document_stores/base.py
+++ b/haystack/document_stores/base.py
@@ -12,7 +12,7 @@ import numpy as np
from haystack.schema import Document, Label, MultiLabel
from haystack.nodes.base import BaseComponent
-from haystack.errors import DuplicateDocumentError
+from haystack.errors import DuplicateDocumentError, DocumentStoreError
from haystack.nodes.preprocessor import PreProcessor
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
@@ -698,6 +698,27 @@ class BaseDocumentStore(BaseComponent):
return [label for label in labels if label.id in duplicate_ids]
+ @classmethod
+ def _validate_embeddings_shape(cls, embeddings: np.ndarray, num_documents: int, embedding_dim: int):
+ """
+ Validates the shape of model-generated embeddings against expected values for indexing.
+
+ :param embeddings: Embeddings to validate
+ :param num_documents: Number of documents the embeddings were generated for
+ :param embedding_dim: Number of embedding dimensions to expect
+ """
+ num_embeddings, embedding_size = embeddings.shape
+ if num_embeddings != num_documents:
+ raise DocumentStoreError(
+ "The number of embeddings does not match the number of documents: "
+ f"({num_embeddings} != {num_documents})"
+ )
+ if embedding_size != embedding_dim:
+ raise RuntimeError(
+ f"Embedding dimensions of the model ({embedding_size}) don't match the embedding dimensions of the document store ({embedding_dim}). "
+ f"Initiate {cls.__name__} again with arg embedding_dim={embedding_size}."
+ )
+
class KeywordDocumentStore(BaseDocumentStore):
"""
diff --git a/haystack/document_stores/elasticsearch.py b/haystack/document_stores/elasticsearch.py
index ba224aaa0..4df4db2f0 100644
--- a/haystack/document_stores/elasticsearch.py
+++ b/haystack/document_stores/elasticsearch.py
@@ -27,6 +27,7 @@ from haystack.schema import Document, Label
from haystack.document_stores.base import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.errors import DocumentStoreError, HaystackError
+from haystack.nodes.retriever import DenseRetriever
logger = logging.getLogger(__name__)
@@ -1400,7 +1401,7 @@ class BaseElasticsearchDocumentStore(KeywordDocumentStore):
def update_embeddings(
self,
- retriever,
+ retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
update_existing_embeddings: bool = True,
@@ -1482,16 +1483,10 @@ class BaseElasticsearchDocumentStore(KeywordDocumentStore):
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
for result_batch in get_batches_from_generator(result, batch_size):
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
- embeddings = retriever.embed_documents(document_batch) # type: ignore
- if len(document_batch) != len(embeddings):
- raise DocumentStoreError(
- "The number of embeddings does not match the number of documents in the batch "
- f"({len(embeddings)} != {len(document_batch)})"
- )
- if embeddings[0].shape[0] != self.embedding_dim:
- raise RuntimeError(
- f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate ElasticsearchDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
- )
+ embeddings = retriever.embed_documents(document_batch)
+ self._validate_embeddings_shape(
+ embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
+ )
doc_updates = []
for doc, emb in zip(document_batch, embeddings):
diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py
index 729c8e00a..33af7f0d2 100644
--- a/haystack/document_stores/faiss.py
+++ b/haystack/document_stores/faiss.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, Union, List, Optional, Dict, Generator
+from typing import Any, Union, List, Optional, Dict, Generator
import json
import logging
@@ -22,10 +22,7 @@ except (ImportError, ModuleNotFoundError) as ie:
from haystack.schema import Document
from haystack.document_stores.base import get_batches_from_generator
-from haystack.errors import DocumentStoreError
-
-if TYPE_CHECKING:
- from haystack.nodes.retriever import BaseRetriever
+from haystack.nodes.retriever import DenseRetriever
logger = logging.getLogger(__name__)
@@ -308,7 +305,7 @@ class FAISSDocumentStore(SQLDocumentStore):
def update_embeddings(
self,
- retriever: "BaseRetriever",
+ retriever: DenseRetriever,
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
@@ -361,23 +358,15 @@ class FAISSDocumentStore(SQLDocumentStore):
total=document_count, disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding"
) as progress_bar:
for document_batch in batched_documents:
- embeddings = retriever.embed_documents(document_batch) # type: ignore
- if len(document_batch) != len(embeddings):
- raise DocumentStoreError(
- "The number of embeddings does not match the number of documents in the batch "
- f"({len(embeddings)} != {len(document_batch)})"
- )
- if embeddings[0].shape[0] != self.embedding_dim:
- raise RuntimeError(
- f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate FAISSDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
- )
-
- embeddings_to_index = np.array(embeddings, dtype="float32")
+ embeddings = retriever.embed_documents(document_batch)
+ self._validate_embeddings_shape(
+ embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
+ )
if self.similarity == "cosine":
- self.normalize_embedding(embeddings_to_index)
+ self.normalize_embedding(embeddings)
- self.faiss_indexes[index].add(embeddings_to_index)
+ self.faiss_indexes[index].add(embeddings.astype(np.float32))
vector_id_map = {}
for doc in document_batch:
diff --git a/haystack/document_stores/memory.py b/haystack/document_stores/memory.py
index f99efb294..4840dc8fe 100644
--- a/haystack/document_stores/memory.py
+++ b/haystack/document_stores/memory.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Generator
+from typing import Any, Dict, List, Optional, Union, Generator
import time
import logging
@@ -10,14 +10,12 @@ import torch
from tqdm import tqdm
from haystack.schema import Document, Label
-from haystack.errors import DuplicateDocumentError, DocumentStoreError
+from haystack.errors import DuplicateDocumentError
from haystack.document_stores import BaseDocumentStore
from haystack.document_stores.base import get_batches_from_generator
from haystack.modeling.utils import initialize_device_settings
from haystack.document_stores.filter_utils import LogicalFilterClause
-
-if TYPE_CHECKING:
- from haystack.nodes.retriever import BaseRetriever
+from haystack.nodes.retriever import DenseRetriever
logger = logging.getLogger(__name__)
@@ -398,7 +396,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
def update_embeddings(
self,
- retriever: "BaseRetriever",
+ retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in InMemoryDocStore
update_existing_embeddings: bool = True,
@@ -457,16 +455,10 @@ class InMemoryDocumentStore(BaseDocumentStore):
total=len(result), disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding"
) as progress_bar:
for document_batch in batched_documents:
- embeddings = retriever.embed_documents(document_batch) # type: ignore
- if len(document_batch) != len(embeddings):
- raise DocumentStoreError(
- "The number of embeddings does not match the number of documents in the batch "
- f"({len(embeddings)} != {len(document_batch)})"
- )
- if embeddings[0].shape[0] != self.embedding_dim:
- raise RuntimeError(
- f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate InMemoryDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
- )
+ embeddings = retriever.embed_documents(document_batch)
+ self._validate_embeddings_shape(
+ embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
+ )
for doc, emb in zip(document_batch, embeddings):
self.indexes[index][doc.id].embedding = emb
diff --git a/haystack/document_stores/milvus1.py b/haystack/document_stores/milvus1.py
index a04513f16..e9582abd9 100644
--- a/haystack/document_stores/milvus1.py
+++ b/haystack/document_stores/milvus1.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
+from typing import Any, Dict, Generator, List, Optional, Union
import logging
import warnings
@@ -15,10 +15,7 @@ except (ImportError, ModuleNotFoundError) as ie:
from haystack.schema import Document
from haystack.document_stores.base import get_batches_from_generator
-from haystack.errors import DocumentStoreError
-
-if TYPE_CHECKING:
- from haystack.nodes.retriever import BaseRetriever
+from haystack.nodes.retriever import DenseRetriever
logger = logging.getLogger(__name__)
@@ -289,7 +286,7 @@ class Milvus1DocumentStore(SQLDocumentStore):
def update_embeddings(
self,
- retriever: "BaseRetriever",
+ retriever: DenseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
@@ -335,25 +332,15 @@ class Milvus1DocumentStore(SQLDocumentStore):
for document_batch in batched_documents:
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
- embeddings = retriever.embed_documents(document_batch) # type: ignore
- if len(document_batch) != len(embeddings):
- raise DocumentStoreError(
- "The number of embeddings does not match the number of documents in the batch "
- f"({len(embeddings)} != {len(document_batch)})"
- )
- if embeddings[0].shape[0] != self.embedding_dim:
- raise RuntimeError(
- f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate MilvusDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
- )
+ embeddings = retriever.embed_documents(document_batch)
+ self._validate_embeddings_shape(
+ embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
+ )
if self.similarity == "cosine":
- for embedding in embeddings:
- self.normalize_embedding(embedding)
+ self.normalize_embedding(embeddings)
- embeddings_list = [embedding.tolist() for embedding in embeddings]
- assert len(document_batch) == len(embeddings_list)
-
- status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings_list)
+ status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings.tolist())
if status.code != Status.SUCCESS:
raise RuntimeError(f"Vector embedding insertion failed: {status}")
diff --git a/haystack/document_stores/milvus2.py b/haystack/document_stores/milvus2.py
index 980c0b314..a687d344b 100644
--- a/haystack/document_stores/milvus2.py
+++ b/haystack/document_stores/milvus2.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
+from typing import Any, Dict, Generator, List, Optional, Union
import logging
import warnings
@@ -18,10 +18,7 @@ except (ImportError, ModuleNotFoundError) as ie:
from haystack.schema import Document
from haystack.document_stores.sql import SQLDocumentStore
from haystack.document_stores.base import get_batches_from_generator
-from haystack.errors import DocumentStoreError
-
-if TYPE_CHECKING:
- from haystack.nodes.retriever.base import BaseRetriever
+from haystack.nodes.retriever import DenseRetriever
logger = logging.getLogger(__name__)
@@ -325,7 +322,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
def update_embeddings(
self,
- retriever: "BaseRetriever",
+ retriever: DenseRetriever,
index: Optional[str] = None,
batch_size: int = 10_000,
update_existing_embeddings: bool = True,
@@ -369,23 +366,15 @@ class Milvus2DocumentStore(SQLDocumentStore):
for document_batch in batched_documents:
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
- embeddings = retriever.embed_documents(document_batch) # type: ignore
- if len(document_batch) != len(embeddings):
- raise DocumentStoreError(
- "The number of embeddings does not match the number of documents in the batch "
- f"({len(embeddings)} != {len(document_batch)})"
- )
- if embeddings[0].shape[0] != self.embedding_dim:
- raise RuntimeError(
- f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate MilvusDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
- )
+ embeddings = retriever.embed_documents(document_batch)
+ self._validate_embeddings_shape(
+ embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
+ )
if self.cosine:
- embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
- embeddings_list = [embedding.tolist() for embedding in embeddings]
- assert len(document_batch) == len(embeddings_list)
+ self.normalize_embedding(embeddings)
- mutation_result = self.collection.insert([embeddings_list])
+ mutation_result = self.collection.insert([embeddings.tolist()])
vector_id_map = {}
for vector_id, doc in zip(mutation_result.primary_keys, document_batch):
diff --git a/haystack/document_stores/pinecone.py b/haystack/document_stores/pinecone.py
index a4616b638..0883db888 100644
--- a/haystack/document_stores/pinecone.py
+++ b/haystack/document_stores/pinecone.py
@@ -1,4 +1,4 @@
-from typing import TYPE_CHECKING, Set, Union, List, Optional, Dict, Generator, Any
+from typing import Set, Union, List, Optional, Dict, Generator, Any
import logging
from itertools import islice
@@ -11,10 +11,8 @@ from haystack.schema import Document, Label, Answer, Span
from haystack.document_stores import BaseDocumentStore
from haystack.document_stores.filter_utils import LogicalFilterClause
-from haystack.errors import PineconeDocumentStoreError, DuplicateDocumentError, DocumentStoreError
-
-if TYPE_CHECKING:
- from haystack.nodes.retriever import BaseRetriever
+from haystack.errors import PineconeDocumentStoreError, DuplicateDocumentError
+from haystack.nodes.retriever import DenseRetriever
logger = logging.getLogger(__name__)
@@ -416,7 +414,7 @@ class PineconeDocumentStore(BaseDocumentStore):
def update_embeddings(
self,
- retriever: "BaseRetriever",
+ retriever: DenseRetriever,
index: Optional[str] = None,
update_existing_embeddings: bool = True,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
@@ -489,21 +487,13 @@ class PineconeDocumentStore(BaseDocumentStore):
) as progress_bar:
for _ in range(0, document_count, batch_size):
document_batch = list(islice(documents, batch_size))
- embeddings = retriever.embed_documents(document_batch) # type: ignore
- if len(document_batch) != len(embeddings):
- raise DocumentStoreError(
- "The number of embeddings does not match the number of documents in the batch "
- f"({len(embeddings)} != {len(document_batch)})"
- )
- if embeddings[0].shape[0] != self.embedding_dim:
- raise RuntimeError(
- f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate PineconeDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
- )
+ embeddings = retriever.embed_documents(document_batch)
+ self._validate_embeddings_shape(
+ embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
+ )
- embeddings_to_index = np.array(embeddings, dtype="float32")
if self.similarity == "cosine":
- self.normalize_embedding(embeddings_to_index)
- embeddings = embeddings_to_index.tolist()
+ self.normalize_embedding(embeddings)
metadata = []
ids = []
@@ -512,7 +502,7 @@ class PineconeDocumentStore(BaseDocumentStore):
ids.append(doc.id)
# Update existing vectors in pinecone index
self.pinecone_indexes[index].upsert(
- vectors=zip(ids, embeddings, metadata), namespace=self.embedding_namespace
+ vectors=zip(ids, embeddings.tolist(), metadata), namespace=self.embedding_namespace
)
# Delete existing vectors from document namespace if they exist there
self.delete_documents(index=index, ids=ids, namespace=self.document_namespace)
diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py
index 11dcce0a7..162f98997 100644
--- a/haystack/document_stores/weaviate.py
+++ b/haystack/document_stores/weaviate.py
@@ -23,6 +23,7 @@ from haystack.document_stores.base import get_batches_from_generator
from haystack.document_stores.filter_utils import LogicalFilterClause
from haystack.document_stores.utils import convert_date_to_rfc3339
from haystack.errors import DocumentStoreError
+from haystack.nodes.retriever import DenseRetriever
logger = logging.getLogger(__name__)
@@ -1166,7 +1167,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
def update_embeddings(
self,
- retriever,
+ retriever: DenseRetriever,
index: Optional[str] = None,
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
update_existing_embeddings: bool = True,
@@ -1230,21 +1231,16 @@ class WeaviateDocumentStore(BaseDocumentStore):
document_batch = [
self._convert_weaviate_result_to_document(hit, return_embedding=False) for hit in result_batch
]
- embeddings = retriever.embed_documents(document_batch) # type: ignore
- if len(document_batch) != len(embeddings):
- raise DocumentStoreError(
- "The number of embeddings does not match the number of documents in the batch "
- f"({len(embeddings)} != {len(document_batch)})"
- )
- if embeddings[0].shape[0] != self.embedding_dim:
- raise RuntimeError(
- f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate WeaviateDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
- )
+ embeddings = retriever.embed_documents(document_batch)
+ self._validate_embeddings_shape(
+ embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
+ )
+
+ if self.similarity == "cosine":
+ self.normalize_embedding(embeddings)
for doc, emb in zip(document_batch, embeddings):
# Using update method to only update the embeddings, other properties will be in tact
- if self.similarity == "cosine":
- self.normalize_embedding(emb)
self.weaviate_client.data_object.update({}, class_name=index, uuid=doc.id, vector=emb)
def delete_all_documents(
diff --git a/haystack/json-schemas/haystack-pipeline-main.schema.json b/haystack/json-schemas/haystack-pipeline-main.schema.json
index f00c87f82..28a644af3 100644
--- a/haystack/json-schemas/haystack-pipeline-main.schema.json
+++ b/haystack/json-schemas/haystack-pipeline-main.schema.json
@@ -5570,10 +5570,10 @@
"title": "Use Auth Token",
"anyOf": [
{
- "type": "boolean"
+ "type": "string"
},
{
- "type": "string"
+ "type": "boolean"
},
{
"type": "null"
diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py
index d10b7d10a..6fd9b63c4 100644
--- a/haystack/nodes/__init__.py
+++ b/haystack/nodes/__init__.py
@@ -29,6 +29,7 @@ from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker
from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader, RCIReader
from haystack.nodes.retriever import (
BaseRetriever,
+ DenseRetriever,
DensePassageRetriever,
EmbeddingRetriever,
BM25Retriever,
diff --git a/haystack/nodes/answer_generator/transformers.py b/haystack/nodes/answer_generator/transformers.py
index ab4a70bdf..7d8fa4285 100644
--- a/haystack/nodes/answer_generator/transformers.py
+++ b/haystack/nodes/answer_generator/transformers.py
@@ -188,7 +188,7 @@ class RAGenerator(BaseGenerator):
self.devices[0]
)
- def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[numpy.ndarray]) -> torch.Tensor:
+ def _prepare_passage_embeddings(self, docs: List[Document], embeddings: numpy.ndarray) -> torch.Tensor:
# If document missing embedding, then need embedding for all the documents
is_embedding_required = embeddings is None or any(embedding is None for embedding in embeddings)
diff --git a/haystack/nodes/retriever/__init__.py b/haystack/nodes/retriever/__init__.py
index f1805f6f6..c7c261a04 100644
--- a/haystack/nodes/retriever/__init__.py
+++ b/haystack/nodes/retriever/__init__.py
@@ -1,5 +1,6 @@
from haystack.nodes.retriever.base import BaseRetriever
from haystack.nodes.retriever.dense import (
+ DenseRetriever,
DensePassageRetriever,
EmbeddingRetriever,
MultihopEmbeddingRetriever,
diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py
index 46e4de1bc..a7c30f4c0 100644
--- a/haystack/nodes/retriever/_embedding_encoder.py
+++ b/haystack/nodes/retriever/_embedding_encoder.py
@@ -26,22 +26,22 @@ logger = logging.getLogger(__name__)
class _BaseEmbeddingEncoder:
@abstractmethod
- def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
"""
Create embeddings for a list of queries.
- :param texts: Queries to embed
- :return: Embeddings, one per input queries
+ :param queries: List of queries to embed.
+ :return: Embeddings, one per input query, shape: (queries, embedding_dim)
"""
pass
@abstractmethod
- def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
+ def embed_documents(self, docs: List[Document]) -> np.ndarray:
"""
Create embeddings for a list of documents.
- :param docs: List of documents to embed
- :return: Embeddings, one per input document
+ :param docs: List of documents to embed.
+ :return: Embeddings, one per input document, shape: (documents, embedding_dim)
"""
pass
@@ -118,18 +118,30 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
f"This can be set when initializing the DocumentStore"
)
- def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]:
+ def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray:
# TODO: FARM's `sample_to_features_text` need to fix following warning -
# tokenization_utils.py:460: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts])
- emb = [(r["vec"]) for r in emb]
+ emb = np.stack([r["vec"] for r in emb])
return emb
- def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
- return self.embed(texts)
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
+ """
+ Create embeddings for a list of queries.
- def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
- passages = [d.content for d in docs] # type: ignore
+ :param queries: List of queries to embed.
+ :return: Embeddings, one per input query, shape: (queries, embedding_dim)
+ """
+ return self.embed(queries)
+
+ def embed_documents(self, docs: List[Document]) -> np.ndarray:
+ """
+ Create embeddings for a list of documents.
+
+ :param docs: List of documents to embed.
+ :return: Embeddings, one per input document, shape: (documents, embedding_dim)
+ """
+ passages = [d.content for d in docs]
return self.embed(passages)
def train(
@@ -175,18 +187,31 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
f"This can be set when initializing the DocumentStore"
)
- def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]:
+ def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray:
# texts can be a list of strings or a list of [title, text]
# get back list of numpy embedding vectors
- emb = self.embedding_model.encode(texts, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar)
- emb = [r for r in emb]
+ emb = self.embedding_model.encode(
+ texts, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True
+ )
return emb
- def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
- return self.embed(texts)
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
+ """
+ Create embeddings for a list of queries.
- def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
- passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] # type: ignore
+ :param queries: List of queries to embed.
+ :return: Embeddings, one per input query, shape: (queries, embedding_dim)
+ """
+ return self.embed(queries)
+
+ def embed_documents(self, docs: List[Document]) -> np.ndarray:
+ """
+ Create embeddings for a list of documents.
+
+ :param docs: List of documents to embed.
+ :return: Embeddings, one per input document, shape: (documents, embedding_dim)
+ """
+ passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs]
return self.embed(passages)
def train(
@@ -250,10 +275,15 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
retriever.embedding_model, use_auth_token=retriever.use_auth_token
).to(str(retriever.devices[0]))
- def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
+ """
+ Create embeddings for a list of queries.
- queries = [{"text": q} for q in texts]
- dataloader = self._create_dataloader(queries)
+ :param queries: List of queries to embed.
+ :return: Embeddings, one per input query, shape: (queries, embedding_dim)
+ """
+ query_text = [{"text": q} for q in queries]
+ dataloader = self._create_dataloader(query_text)
embeddings: List[np.ndarray] = []
disable_tqdm = True if len(dataloader) == 1 else not self.progress_bar
@@ -272,8 +302,13 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
return np.concatenate(embeddings)
- def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
+ def embed_documents(self, docs: List[Document]) -> np.ndarray:
+ """
+ Create embeddings for a list of documents.
+ :param docs: List of documents to embed.
+ :return: Embeddings, one per input document, shape: (documents, embedding_dim)
+ """
doc_text = [{"text": d.content} for d in docs]
dataloader = self._create_dataloader(doc_text)
diff --git a/haystack/nodes/retriever/base.py b/haystack/nodes/retriever/base.py
index f5644d5e7..149cb5024 100644
--- a/haystack/nodes/retriever/base.py
+++ b/haystack/nodes/retriever/base.py
@@ -4,7 +4,6 @@ import logging
from abc import abstractmethod
from time import perf_counter
from functools import wraps
-from copy import deepcopy
from tqdm import tqdm
@@ -263,7 +262,7 @@ class BaseRetriever(BaseComponent):
query: Optional[str] = None,
filters: Optional[dict] = None,
top_k: Optional[int] = None,
- documents: Optional[List[dict]] = None,
+ documents: Optional[List[Document]] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: bool = None,
@@ -279,7 +278,7 @@ class BaseRetriever(BaseComponent):
query=query, filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score
)
elif root_node == "File":
- self.index_count += len(documents) # type: ignore
+ self.index_count += len(documents) if documents else 0
run_indexing = self.timing(self.run_indexing, "index_time")
output, stream = run_indexing(documents=documents)
else:
@@ -362,13 +361,7 @@ class BaseRetriever(BaseComponent):
return output, "output_1"
- def run_indexing(self, documents: List[Union[dict, Document]]):
- if self.__class__.__name__ in ["DensePassageRetriever", "EmbeddingRetriever"]:
- documents = deepcopy(documents)
- document_objects = [Document.from_dict(doc) if isinstance(doc, dict) else doc for doc in documents]
- embeddings = self.embed_documents(document_objects) # type: ignore
- for doc, emb in zip(document_objects, embeddings):
- doc.embedding = emb
+ def run_indexing(self, documents: List[Document]):
output = {"documents": documents}
return output, "output_1"
diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py
index 4491da525..132337e6c 100644
--- a/haystack/nodes/retriever/dense.py
+++ b/haystack/nodes/retriever/dense.py
@@ -1,3 +1,4 @@
+from abc import abstractmethod
from typing import List, Dict, Union, Optional, Any
import logging
@@ -42,7 +43,40 @@ from haystack.modeling.utils import initialize_device_settings
logger = logging.getLogger(__name__)
-class DensePassageRetriever(BaseRetriever):
+class DenseRetriever(BaseRetriever):
+ """
+ Base class for all dense retrievers.
+ """
+
+ @abstractmethod
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
+ """
+ Create embeddings for a list of queries.
+
+ :param queries: List of queries to embed.
+ :return: Embeddings, one per input query, shape: (queries, embedding_dim)
+ """
+ pass
+
+ @abstractmethod
+ def embed_documents(self, documents: List[Document]) -> np.ndarray:
+ """
+ Create embeddings for a list of documents.
+
+ :param documents: List of documents to embed.
+ :return: Embeddings of documents, one per input document, shape: (documents, embedding_dim)
+ """
+ pass
+
+ def run_indexing(self, documents: List[Document]):
+ embeddings = self.embed_documents(documents)
+ for doc, emb in zip(documents, embeddings):
+ doc.embedding = emb
+ output = {"documents": documents}
+ return output, "output_1"
+
+
+class DensePassageRetriever(DenseRetriever):
"""
Retriever that uses a bi-encoder (one transformer for query, one transformer for passage).
See the original paper for more details:
@@ -302,7 +336,7 @@ class DensePassageRetriever(BaseRetriever):
index = self.document_store.index
if scale_score is None:
scale_score = self.scale_score
- query_emb = self.embed_queries(texts=[query])
+ query_emb = self.embed_queries(queries=[query])
documents = self.document_store.query_by_embedding(
query_emb=query_emb[0], top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
)
@@ -430,9 +464,9 @@ class DensePassageRetriever(BaseRetriever):
return [[] * len(queries)] # type: ignore
documents = []
- query_embs = []
+ query_embs: List[np.ndarray] = []
for batch in self._get_batches(queries=queries, batch_size=batch_size):
- query_embs.extend(self.embed_queries(texts=batch))
+ query_embs.extend(self.embed_queries(queries=batch))
for query_emb, cur_filters in tqdm(
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
):
@@ -448,7 +482,7 @@ class DensePassageRetriever(BaseRetriever):
return documents
- def _get_predictions(self, dicts):
+ def _get_predictions(self, dicts: List[Dict[str, Any]]) -> Dict[str, np.ndarray]:
"""
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
@@ -472,7 +506,8 @@ class DensePassageRetriever(BaseRetriever):
data_loader = NamedDataLoader(
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
)
- all_embeddings = {"query": [], "passages": []}
+ query_embeddings_batched = []
+ passage_embeddings_batched = []
self.model.eval()
# When running evaluations etc., we don't want a progress bar for every single query
@@ -503,34 +538,35 @@ class DensePassageRetriever(BaseRetriever):
passage_attention_mask=batch.get("passage_attention_mask", None),
)[0]
if query_embeddings is not None:
- all_embeddings["query"].append(query_embeddings.cpu().numpy())
+ query_embeddings_batched.append(query_embeddings.cpu().numpy())
if passage_embeddings is not None:
- all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
+ passage_embeddings_batched.append(passage_embeddings.cpu().numpy())
progress_bar.update(self.batch_size)
- if all_embeddings["passages"]:
- all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
- if all_embeddings["query"]:
- all_embeddings["query"] = np.concatenate(all_embeddings["query"])
+ all_embeddings: Dict[str, np.ndarray] = {}
+ if passage_embeddings_batched:
+ all_embeddings["passages"] = np.concatenate(passage_embeddings_batched)
+ if query_embeddings_batched:
+ all_embeddings["query"] = np.concatenate(query_embeddings_batched)
return all_embeddings
- def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
"""
- Create embeddings for a list of queries using the query encoder
+ Create embeddings for a list of queries using the query encoder.
- :param texts: Queries to embed
- :return: Embeddings, one per input queries
+ :param queries: List of queries to embed.
+ :return: Embeddings, one per input query, shape: (queries, embedding_dim)
"""
- queries = [{"query": q} for q in texts]
- result = self._get_predictions(queries)["query"]
+ query_dicts = [{"query": q} for q in queries]
+ result = self._get_predictions(query_dicts)["query"]
return result
- def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
+ def embed_documents(self, documents: List[Document]) -> np.ndarray:
"""
- Create embeddings for a list of documents using the passage encoder
+ Create embeddings for a list of documents using the passage encoder.
- :param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack.
- :return: Embeddings of documents / passages shape (batch_size, embedding_dim)
+ :param documents: List of documents to embed.
+ :return: Embeddings of documents, one per input document, shape: (documents, embedding_dim)
"""
if self.processor.num_hard_negatives != 0:
logger.warning(
@@ -550,7 +586,7 @@ class DensePassageRetriever(BaseRetriever):
}
]
}
- for d in docs
+ for d in documents
]
embeddings = self._get_predictions(passages)["passages"]
return embeddings
@@ -757,7 +793,7 @@ class DensePassageRetriever(BaseRetriever):
return dpr
-class TableTextRetriever(BaseRetriever):
+class TableTextRetriever(DenseRetriever):
"""
Retriever that uses a tri-encoder to jointly retrieve among a database consisting of text passages and tables
(one transformer for query, one transformer for text passages, one transformer for tables).
@@ -950,7 +986,7 @@ class TableTextRetriever(BaseRetriever):
index = self.document_store.index
if scale_score is None:
scale_score = self.scale_score
- query_emb = self.embed_queries(texts=[query])
+ query_emb = self.embed_queries(queries=[query])
documents = self.document_store.query_by_embedding(
query_emb=query_emb[0], top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
)
@@ -1078,9 +1114,9 @@ class TableTextRetriever(BaseRetriever):
return [[] * len(queries)] # type: ignore
documents = []
- query_embs = []
+ query_embs: List[np.ndarray] = []
for batch in self._get_batches(queries=queries, batch_size=batch_size):
- query_embs.extend(self.embed_queries(texts=batch))
+ query_embs.extend(self.embed_queries(queries=batch))
for query_emb, cur_filters in tqdm(
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
):
@@ -1096,7 +1132,7 @@ class TableTextRetriever(BaseRetriever):
return documents
- def _get_predictions(self, dicts: List[Dict]) -> Dict[str, List[np.ndarray]]:
+ def _get_predictions(self, dicts: List[Dict[str, Any]]) -> Dict[str, np.ndarray]:
"""
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
@@ -1121,7 +1157,8 @@ class TableTextRetriever(BaseRetriever):
data_loader = NamedDataLoader(
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
)
- all_embeddings: Dict = {"query": [], "passages": []}
+ query_embeddings_batched = []
+ passage_embeddings_batched = []
self.model.eval()
# When running evaluations etc., we don't want a progress bar for every single query
@@ -1145,36 +1182,36 @@ class TableTextRetriever(BaseRetriever):
with torch.no_grad():
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
if query_embeddings is not None:
- all_embeddings["query"].append(query_embeddings.cpu().numpy())
+ query_embeddings_batched.append(query_embeddings.cpu().numpy())
if passage_embeddings is not None:
- all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
+ passage_embeddings_batched.append(passage_embeddings.cpu().numpy())
progress_bar.update(self.batch_size)
- if all_embeddings["passages"]:
- all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
- if all_embeddings["query"]:
- all_embeddings["query"] = np.concatenate(all_embeddings["query"])
+ all_embeddings: Dict[str, np.ndarray] = {}
+ if passage_embeddings_batched:
+ all_embeddings["passages"] = np.concatenate(passage_embeddings_batched)
+ if query_embeddings_batched:
+ all_embeddings["query"] = np.concatenate(query_embeddings_batched)
return all_embeddings
- def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
"""
- Create embeddings for a list of queries using the query encoder
+ Create embeddings for a list of queries using the query encoder.
- :param texts: Queries to embed
- :return: Embeddings, one per input queries
+ :param queries: List of queries to embed.
+ :return: Embeddings, one per input query, shape: (queries, embedding_dim)
"""
- queries = [{"query": q} for q in texts]
- result = self._get_predictions(queries)["query"]
+ query_dicts = [{"query": q} for q in queries]
+ result = self._get_predictions(query_dicts)["query"]
return result
- def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
+ def embed_documents(self, documents: List[Document]) -> np.ndarray:
"""
Create embeddings for a list of text documents and / or tables using the text passage encoder and
the table encoder.
- :param docs: List of Document objects used to represent documents / passages in
- a standardized way within Haystack.
- :return: Embeddings of documents / passages. Shape: (batch_size, embedding_dim)
+ :param documents: List of documents to embed.
+ :return: Embeddings of documents, one per input document, shape: (documents, embedding_dim)
"""
if self.processor.num_hard_negatives != 0:
@@ -1185,7 +1222,7 @@ class TableTextRetriever(BaseRetriever):
self.processor.num_hard_negatives = 0
model_input = []
- for doc in docs:
+ for doc in documents:
if doc.content_type == "table":
model_input.append(
{
@@ -1440,7 +1477,7 @@ class TableTextRetriever(BaseRetriever):
return mm_retriever
-class EmbeddingRetriever(BaseRetriever):
+class EmbeddingRetriever(DenseRetriever):
def __init__(
self,
document_store: BaseDocumentStore,
@@ -1639,7 +1676,7 @@ class EmbeddingRetriever(BaseRetriever):
index = self.document_store.index
if scale_score is None:
scale_score = self.scale_score
- query_emb = self.embed_queries(texts=[query])
+ query_emb = self.embed_queries(queries=[query])
documents = self.document_store.query_by_embedding(
query_emb=query_emb[0], filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score
)
@@ -1767,9 +1804,9 @@ class EmbeddingRetriever(BaseRetriever):
return [[] * len(queries)] # type: ignore
documents = []
- query_embs = []
+ query_embs: List[np.ndarray] = []
for batch in self._get_batches(queries=queries, batch_size=batch_size):
- query_embs.extend(self.embed_queries(texts=batch))
+ query_embs.extend(self.embed_queries(queries=batch))
for query_emb, cur_filters in tqdm(
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
):
@@ -1785,28 +1822,28 @@ class EmbeddingRetriever(BaseRetriever):
return documents
- def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
"""
Create embeddings for a list of queries.
- :param texts: Queries to embed
- :return: Embeddings, one per input queries
+ :param queries: List of queries to embed.
+ :return: Embeddings, one per input query, shape: (queries, embedding_dim)
"""
# for backward compatibility: cast pure str input
- if isinstance(texts, str):
- texts = [texts]
- assert isinstance(texts, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
- return self.embedding_encoder.embed_queries(texts)
+ if isinstance(queries, str):
+ queries = [queries]
+ assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
+ return self.embedding_encoder.embed_queries(queries)
- def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
+ def embed_documents(self, documents: List[Document]) -> np.ndarray:
"""
Create embeddings for a list of documents.
- :param docs: List of documents to embed
- :return: Embeddings, one per input document
+ :param documents: List of documents to embed.
+ :return: Embeddings, one per input document, shape: (docs, embedding_dim)
"""
- docs = self._preprocess_documents(docs)
- return self.embedding_encoder.embed_documents(docs)
+ documents = self._preprocess_documents(documents)
+ return self.embedding_encoder.embed_documents(documents)
def _preprocess_documents(self, docs: List[Document]) -> List[Document]:
"""
diff --git a/test/conftest.py b/test/conftest.py
index 4180cfccb..e28adc39b 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -57,6 +57,7 @@ from haystack.nodes import (
BaseGenerator,
BaseSummarizer,
BaseTranslator,
+ DenseRetriever,
)
from haystack.nodes.answer_generator.transformers import Seq2SeqGenerator
from haystack.nodes.answer_generator.transformers import RAGenerator
@@ -308,16 +309,16 @@ class MockTranslator(BaseTranslator):
pass
-class MockDenseRetriever(MockRetriever):
+class MockDenseRetriever(MockRetriever, DenseRetriever):
def __init__(self, document_store: BaseDocumentStore, embedding_dim: int = 768):
self.embedding_dim = embedding_dim
self.document_store = document_store
- def embed_queries(self, texts):
- return [np.random.rand(self.embedding_dim)] * len(texts)
+ def embed_queries(self, queries):
+ return np.random.rand(len(queries), self.embedding_dim)
- def embed_documents(self, docs):
- return [np.random.rand(self.embedding_dim)] * len(docs)
+ def embed_documents(self, documents):
+ return np.random.rand(len(documents), self.embedding_dim)
class MockQuestionGenerator(QuestionGenerator):
diff --git a/test/document_stores/test_document_store.py b/test/document_stores/test_document_store.py
index 5db4f6163..71d7134d6 100644
--- a/test/document_stores/test_document_store.py
+++ b/test/document_stores/test_document_store.py
@@ -576,7 +576,7 @@ def test_update_embeddings(document_store, retriever):
"content": "text_7",
"id": "7",
"meta_field": "value_7",
- "embedding": retriever.embed_queries(texts=["a random string"])[0],
+ "embedding": retriever.embed_queries(queries=["a random string"])[0],
}
document_store.write_documents([doc])