mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 05:58:57 +00:00
chore: add DenseRetriever abstraction (#3252)
* support cosine similiarity with faiss * update docs * update api docs * fix tests * Revert "update api docs" This reverts commit 6138fdfefb3beaee2d55c5729cd4a2745ea6b143. * fix api docs * collapse test * rename similairity to space_type mappings * only normalize for faiss * fix merge * fix docs normalization * get rid of List[np.array] * update docs * fix tests and tutorials * fix mypy * fix mypy * fix mypy again * again mypy * blacken * update tutorial 4 docs * fix embeddingretriever * fix faiss * move dense specific logic to DenseRetriever * fix mypy * cosine tests for all documents stores * fix pinecone * add docstring * docstring corrections * update docs * add integration test marker * docstrings update * update docs * fix typo * update docs * fix MockDenseRetriever * run integration tests for all documentstores * fix test_update_embeddings_cosine_similarity * fix faiss tests not running * blacken * make test_cosine_sanity_check integration test * update docs * fix imports * import DenseRetriever normally * update docs * fix deepcopy of documents * update schema * Revert "update schema" This reverts commit 83cf8f323648468e1c322d54852bec084d637e3f. * fix schema for ci manually
This commit is contained in:
parent
492a8046d8
commit
b10e2c392e
@ -1269,7 +1269,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||
#### BaseElasticsearchDocumentStore.update\_embeddings
|
||||
|
||||
```python
|
||||
def update_embeddings(retriever,
|
||||
def update_embeddings(retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int,
|
||||
float, bool]]] = None,
|
||||
@ -2097,7 +2097,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||
#### InMemoryDocumentStore.update\_embeddings
|
||||
|
||||
```python
|
||||
def update_embeddings(retriever: "BaseRetriever",
|
||||
def update_embeddings(retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
update_existing_embeddings: bool = True,
|
||||
@ -2913,7 +2913,7 @@ None
|
||||
#### FAISSDocumentStore.update\_embeddings
|
||||
|
||||
```python
|
||||
def update_embeddings(retriever: "BaseRetriever",
|
||||
def update_embeddings(retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
update_existing_embeddings: bool = True,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
@ -3277,7 +3277,7 @@ None
|
||||
#### Milvus1DocumentStore.update\_embeddings
|
||||
|
||||
```python
|
||||
def update_embeddings(retriever: "BaseRetriever",
|
||||
def update_embeddings(retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
batch_size: int = 10_000,
|
||||
update_existing_embeddings: bool = True,
|
||||
@ -3681,7 +3681,7 @@ exists.
|
||||
#### Milvus2DocumentStore.update\_embeddings
|
||||
|
||||
```python
|
||||
def update_embeddings(retriever: "BaseRetriever",
|
||||
def update_embeddings(retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
batch_size: int = 10_000,
|
||||
update_existing_embeddings: bool = True,
|
||||
@ -4398,7 +4398,7 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||
#### WeaviateDocumentStore.update\_embeddings
|
||||
|
||||
```python
|
||||
def update_embeddings(retriever,
|
||||
def update_embeddings(retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int,
|
||||
float, bool]]] = None,
|
||||
@ -5452,7 +5452,7 @@ Parameter options:
|
||||
#### PineconeDocumentStore.update\_embeddings
|
||||
|
||||
```python
|
||||
def update_embeddings(retriever: "BaseRetriever",
|
||||
def update_embeddings(retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
update_existing_embeddings: bool = True,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int,
|
||||
|
||||
@ -547,12 +547,60 @@ Performing training on this class according to the TF-IDF algorithm.
|
||||
|
||||
# Module dense
|
||||
|
||||
<a id="dense.DenseRetriever"></a>
|
||||
|
||||
## DenseRetriever
|
||||
|
||||
```python
|
||||
class DenseRetriever(BaseRetriever)
|
||||
```
|
||||
|
||||
Base class for all dense retrievers.
|
||||
|
||||
<a id="dense.DenseRetriever.embed_queries"></a>
|
||||
|
||||
#### DenseRetriever.embed\_queries
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def embed_queries(queries: List[str]) -> np.ndarray
|
||||
```
|
||||
|
||||
Create embeddings for a list of queries.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `queries`: List of queries to embed.
|
||||
|
||||
**Returns**:
|
||||
|
||||
Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
|
||||
<a id="dense.DenseRetriever.embed_documents"></a>
|
||||
|
||||
#### DenseRetriever.embed\_documents
|
||||
|
||||
```python
|
||||
@abstractmethod
|
||||
def embed_documents(documents: List[Document]) -> np.ndarray
|
||||
```
|
||||
|
||||
Create embeddings for a list of documents.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `documents`: List of documents to embed.
|
||||
|
||||
**Returns**:
|
||||
|
||||
Embeddings of documents, one per input document, shape: (documents, embedding_dim)
|
||||
|
||||
<a id="dense.DensePassageRetriever"></a>
|
||||
|
||||
## DensePassageRetriever
|
||||
|
||||
```python
|
||||
class DensePassageRetriever(BaseRetriever)
|
||||
class DensePassageRetriever(DenseRetriever)
|
||||
```
|
||||
|
||||
Retriever that uses a bi-encoder (one transformer for query, one transformer for passage).
|
||||
@ -842,36 +890,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||
#### DensePassageRetriever.embed\_queries
|
||||
|
||||
```python
|
||||
def embed_queries(texts: List[str]) -> List[np.ndarray]
|
||||
def embed_queries(queries: List[str]) -> np.ndarray
|
||||
```
|
||||
|
||||
Create embeddings for a list of queries using the query encoder
|
||||
Create embeddings for a list of queries using the query encoder.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `texts`: Queries to embed
|
||||
- `queries`: List of queries to embed.
|
||||
|
||||
**Returns**:
|
||||
|
||||
Embeddings, one per input queries
|
||||
Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
|
||||
<a id="dense.DensePassageRetriever.embed_documents"></a>
|
||||
|
||||
#### DensePassageRetriever.embed\_documents
|
||||
|
||||
```python
|
||||
def embed_documents(docs: List[Document]) -> List[np.ndarray]
|
||||
def embed_documents(documents: List[Document]) -> np.ndarray
|
||||
```
|
||||
|
||||
Create embeddings for a list of documents using the passage encoder
|
||||
Create embeddings for a list of documents using the passage encoder.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `docs`: List of Document objects used to represent documents / passages in a standardized way within Haystack.
|
||||
- `documents`: List of documents to embed.
|
||||
|
||||
**Returns**:
|
||||
|
||||
Embeddings of documents / passages shape (batch_size, embedding_dim)
|
||||
Embeddings of documents, one per input document, shape: (documents, embedding_dim)
|
||||
|
||||
<a id="dense.DensePassageRetriever.train"></a>
|
||||
|
||||
@ -1005,7 +1053,7 @@ Load DensePassageRetriever from the specified directory.
|
||||
## TableTextRetriever
|
||||
|
||||
```python
|
||||
class TableTextRetriever(BaseRetriever)
|
||||
class TableTextRetriever(DenseRetriever)
|
||||
```
|
||||
|
||||
Retriever that uses a tri-encoder to jointly retrieve among a database consisting of text passages and tables
|
||||
@ -1198,25 +1246,25 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||
#### TableTextRetriever.embed\_queries
|
||||
|
||||
```python
|
||||
def embed_queries(texts: List[str]) -> List[np.ndarray]
|
||||
def embed_queries(queries: List[str]) -> np.ndarray
|
||||
```
|
||||
|
||||
Create embeddings for a list of queries using the query encoder
|
||||
Create embeddings for a list of queries using the query encoder.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `texts`: Queries to embed
|
||||
- `queries`: List of queries to embed.
|
||||
|
||||
**Returns**:
|
||||
|
||||
Embeddings, one per input queries
|
||||
Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
|
||||
<a id="dense.TableTextRetriever.embed_documents"></a>
|
||||
|
||||
#### TableTextRetriever.embed\_documents
|
||||
|
||||
```python
|
||||
def embed_documents(docs: List[Document]) -> List[np.ndarray]
|
||||
def embed_documents(documents: List[Document]) -> np.ndarray
|
||||
```
|
||||
|
||||
Create embeddings for a list of text documents and / or tables using the text passage encoder and
|
||||
@ -1225,12 +1273,11 @@ the table encoder.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `docs`: List of Document objects used to represent documents / passages in
|
||||
a standardized way within Haystack.
|
||||
- `documents`: List of documents to embed.
|
||||
|
||||
**Returns**:
|
||||
|
||||
Embeddings of documents / passages. Shape: (batch_size, embedding_dim)
|
||||
Embeddings of documents, one per input document, shape: (documents, embedding_dim)
|
||||
|
||||
<a id="dense.TableTextRetriever.train"></a>
|
||||
|
||||
@ -1370,7 +1417,7 @@ Load TableTextRetriever from the specified directory.
|
||||
## EmbeddingRetriever
|
||||
|
||||
```python
|
||||
class EmbeddingRetriever(BaseRetriever)
|
||||
class EmbeddingRetriever(DenseRetriever)
|
||||
```
|
||||
|
||||
<a id="dense.EmbeddingRetriever.__init__"></a>
|
||||
@ -1638,36 +1685,36 @@ Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
|
||||
#### EmbeddingRetriever.embed\_queries
|
||||
|
||||
```python
|
||||
def embed_queries(texts: List[str]) -> List[np.ndarray]
|
||||
def embed_queries(queries: List[str]) -> np.ndarray
|
||||
```
|
||||
|
||||
Create embeddings for a list of queries.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `texts`: Queries to embed
|
||||
- `queries`: List of queries to embed.
|
||||
|
||||
**Returns**:
|
||||
|
||||
Embeddings, one per input queries
|
||||
Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
|
||||
<a id="dense.EmbeddingRetriever.embed_documents"></a>
|
||||
|
||||
#### EmbeddingRetriever.embed\_documents
|
||||
|
||||
```python
|
||||
def embed_documents(docs: List[Document]) -> List[np.ndarray]
|
||||
def embed_documents(documents: List[Document]) -> np.ndarray
|
||||
```
|
||||
|
||||
Create embeddings for a list of documents.
|
||||
|
||||
**Arguments**:
|
||||
|
||||
- `docs`: List of documents to embed
|
||||
- `documents`: List of documents to embed.
|
||||
|
||||
**Returns**:
|
||||
|
||||
Embeddings, one per input document
|
||||
Embeddings, one per input document, shape: (docs, embedding_dim)
|
||||
|
||||
<a id="dense.EmbeddingRetriever.train"></a>
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ import numpy as np
|
||||
|
||||
from haystack.schema import Document, Label, MultiLabel
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.errors import DuplicateDocumentError
|
||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
||||
from haystack.nodes.preprocessor import PreProcessor
|
||||
from haystack.document_stores.utils import eval_data_from_json, eval_data_from_jsonl, squad_json_to_jsonl
|
||||
|
||||
@ -698,6 +698,27 @@ class BaseDocumentStore(BaseComponent):
|
||||
|
||||
return [label for label in labels if label.id in duplicate_ids]
|
||||
|
||||
@classmethod
|
||||
def _validate_embeddings_shape(cls, embeddings: np.ndarray, num_documents: int, embedding_dim: int):
|
||||
"""
|
||||
Validates the shape of model-generated embeddings against expected values for indexing.
|
||||
|
||||
:param embeddings: Embeddings to validate
|
||||
:param num_documents: Number of documents the embeddings were generated for
|
||||
:param embedding_dim: Number of embedding dimensions to expect
|
||||
"""
|
||||
num_embeddings, embedding_size = embeddings.shape
|
||||
if num_embeddings != num_documents:
|
||||
raise DocumentStoreError(
|
||||
"The number of embeddings does not match the number of documents: "
|
||||
f"({num_embeddings} != {num_documents})"
|
||||
)
|
||||
if embedding_size != embedding_dim:
|
||||
raise RuntimeError(
|
||||
f"Embedding dimensions of the model ({embedding_size}) don't match the embedding dimensions of the document store ({embedding_dim}). "
|
||||
f"Initiate {cls.__name__} again with arg embedding_dim={embedding_size}."
|
||||
)
|
||||
|
||||
|
||||
class KeywordDocumentStore(BaseDocumentStore):
|
||||
"""
|
||||
|
||||
@ -27,6 +27,7 @@ from haystack.schema import Document, Label
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||
from haystack.errors import DocumentStoreError, HaystackError
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1400,7 +1401,7 @@ class BaseElasticsearchDocumentStore(KeywordDocumentStore):
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever,
|
||||
retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
update_existing_embeddings: bool = True,
|
||||
@ -1482,16 +1483,10 @@ class BaseElasticsearchDocumentStore(KeywordDocumentStore):
|
||||
with tqdm(total=document_count, position=0, unit=" Docs", desc="Updating embeddings") as progress_bar:
|
||||
for result_batch in get_batches_from_generator(result, batch_size):
|
||||
document_batch = [self._convert_es_hit_to_document(hit, return_embedding=False) for hit in result_batch]
|
||||
embeddings = retriever.embed_documents(document_batch) # type: ignore
|
||||
if len(document_batch) != len(embeddings):
|
||||
raise DocumentStoreError(
|
||||
"The number of embeddings does not match the number of documents in the batch "
|
||||
f"({len(embeddings)} != {len(document_batch)})"
|
||||
)
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(
|
||||
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate ElasticsearchDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
|
||||
)
|
||||
embeddings = retriever.embed_documents(document_batch)
|
||||
self._validate_embeddings_shape(
|
||||
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
|
||||
)
|
||||
|
||||
doc_updates = []
|
||||
for doc, emb in zip(document_batch, embeddings):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Union, List, Optional, Dict, Generator
|
||||
from typing import Any, Union, List, Optional, Dict, Generator
|
||||
|
||||
import json
|
||||
import logging
|
||||
@ -22,10 +22,7 @@ except (ImportError, ModuleNotFoundError) as ie:
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.errors import DocumentStoreError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever import BaseRetriever
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -308,7 +305,7 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever: "BaseRetriever",
|
||||
retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
update_existing_embeddings: bool = True,
|
||||
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in FAISSDocStore
|
||||
@ -361,23 +358,15 @@ class FAISSDocumentStore(SQLDocumentStore):
|
||||
total=document_count, disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding"
|
||||
) as progress_bar:
|
||||
for document_batch in batched_documents:
|
||||
embeddings = retriever.embed_documents(document_batch) # type: ignore
|
||||
if len(document_batch) != len(embeddings):
|
||||
raise DocumentStoreError(
|
||||
"The number of embeddings does not match the number of documents in the batch "
|
||||
f"({len(embeddings)} != {len(document_batch)})"
|
||||
)
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(
|
||||
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate FAISSDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
|
||||
)
|
||||
|
||||
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||
embeddings = retriever.embed_documents(document_batch)
|
||||
self._validate_embeddings_shape(
|
||||
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
|
||||
)
|
||||
|
||||
if self.similarity == "cosine":
|
||||
self.normalize_embedding(embeddings_to_index)
|
||||
self.normalize_embedding(embeddings)
|
||||
|
||||
self.faiss_indexes[index].add(embeddings_to_index)
|
||||
self.faiss_indexes[index].add(embeddings.astype(np.float32))
|
||||
|
||||
vector_id_map = {}
|
||||
for doc in document_batch:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Generator
|
||||
from typing import Any, Dict, List, Optional, Union, Generator
|
||||
|
||||
import time
|
||||
import logging
|
||||
@ -10,14 +10,12 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack.schema import Document, Label
|
||||
from haystack.errors import DuplicateDocumentError, DocumentStoreError
|
||||
from haystack.errors import DuplicateDocumentError
|
||||
from haystack.document_stores import BaseDocumentStore
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.modeling.utils import initialize_device_settings
|
||||
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever import BaseRetriever
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -398,7 +396,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever: "BaseRetriever",
|
||||
retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in InMemoryDocStore
|
||||
update_existing_embeddings: bool = True,
|
||||
@ -457,16 +455,10 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
total=len(result), disable=not self.progress_bar, position=0, unit=" docs", desc="Updating Embedding"
|
||||
) as progress_bar:
|
||||
for document_batch in batched_documents:
|
||||
embeddings = retriever.embed_documents(document_batch) # type: ignore
|
||||
if len(document_batch) != len(embeddings):
|
||||
raise DocumentStoreError(
|
||||
"The number of embeddings does not match the number of documents in the batch "
|
||||
f"({len(embeddings)} != {len(document_batch)})"
|
||||
)
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(
|
||||
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate InMemoryDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
|
||||
)
|
||||
embeddings = retriever.embed_documents(document_batch)
|
||||
self._validate_embeddings_shape(
|
||||
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
|
||||
)
|
||||
|
||||
for doc, emb in zip(document_batch, embeddings):
|
||||
self.indexes[index][doc.id].embedding = emb
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
@ -15,10 +15,7 @@ except (ImportError, ModuleNotFoundError) as ie:
|
||||
|
||||
from haystack.schema import Document
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.errors import DocumentStoreError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever import BaseRetriever
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -289,7 +286,7 @@ class Milvus1DocumentStore(SQLDocumentStore):
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever: "BaseRetriever",
|
||||
retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
batch_size: int = 10_000,
|
||||
update_existing_embeddings: bool = True,
|
||||
@ -335,25 +332,15 @@ class Milvus1DocumentStore(SQLDocumentStore):
|
||||
for document_batch in batched_documents:
|
||||
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
|
||||
|
||||
embeddings = retriever.embed_documents(document_batch) # type: ignore
|
||||
if len(document_batch) != len(embeddings):
|
||||
raise DocumentStoreError(
|
||||
"The number of embeddings does not match the number of documents in the batch "
|
||||
f"({len(embeddings)} != {len(document_batch)})"
|
||||
)
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(
|
||||
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate MilvusDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
|
||||
)
|
||||
embeddings = retriever.embed_documents(document_batch)
|
||||
self._validate_embeddings_shape(
|
||||
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
|
||||
)
|
||||
|
||||
if self.similarity == "cosine":
|
||||
for embedding in embeddings:
|
||||
self.normalize_embedding(embedding)
|
||||
self.normalize_embedding(embeddings)
|
||||
|
||||
embeddings_list = [embedding.tolist() for embedding in embeddings]
|
||||
assert len(document_batch) == len(embeddings_list)
|
||||
|
||||
status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings_list)
|
||||
status, vector_ids = self.milvus_server.insert(collection_name=index, records=embeddings.tolist())
|
||||
if status.code != Status.SUCCESS:
|
||||
raise RuntimeError(f"Vector embedding insertion failed: {status}")
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
@ -18,10 +18,7 @@ except (ImportError, ModuleNotFoundError) as ie:
|
||||
from haystack.schema import Document
|
||||
from haystack.document_stores.sql import SQLDocumentStore
|
||||
from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.errors import DocumentStoreError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -325,7 +322,7 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever: "BaseRetriever",
|
||||
retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
batch_size: int = 10_000,
|
||||
update_existing_embeddings: bool = True,
|
||||
@ -369,23 +366,15 @@ class Milvus2DocumentStore(SQLDocumentStore):
|
||||
for document_batch in batched_documents:
|
||||
self._delete_vector_ids_from_milvus(documents=document_batch, index=index)
|
||||
|
||||
embeddings = retriever.embed_documents(document_batch) # type: ignore
|
||||
if len(document_batch) != len(embeddings):
|
||||
raise DocumentStoreError(
|
||||
"The number of embeddings does not match the number of documents in the batch "
|
||||
f"({len(embeddings)} != {len(document_batch)})"
|
||||
)
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(
|
||||
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate MilvusDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
|
||||
)
|
||||
embeddings = retriever.embed_documents(document_batch)
|
||||
self._validate_embeddings_shape(
|
||||
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
|
||||
)
|
||||
|
||||
if self.cosine:
|
||||
embeddings = [embedding / np.linalg.norm(embedding) for embedding in embeddings]
|
||||
embeddings_list = [embedding.tolist() for embedding in embeddings]
|
||||
assert len(document_batch) == len(embeddings_list)
|
||||
self.normalize_embedding(embeddings)
|
||||
|
||||
mutation_result = self.collection.insert([embeddings_list])
|
||||
mutation_result = self.collection.insert([embeddings.tolist()])
|
||||
|
||||
vector_id_map = {}
|
||||
for vector_id, doc in zip(mutation_result.primary_keys, document_batch):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Set, Union, List, Optional, Dict, Generator, Any
|
||||
from typing import Set, Union, List, Optional, Dict, Generator, Any
|
||||
|
||||
import logging
|
||||
from itertools import islice
|
||||
@ -11,10 +11,8 @@ from haystack.schema import Document, Label, Answer, Span
|
||||
from haystack.document_stores import BaseDocumentStore
|
||||
|
||||
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||
from haystack.errors import PineconeDocumentStoreError, DuplicateDocumentError, DocumentStoreError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.nodes.retriever import BaseRetriever
|
||||
from haystack.errors import PineconeDocumentStoreError, DuplicateDocumentError
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -416,7 +414,7 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever: "BaseRetriever",
|
||||
retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
update_existing_embeddings: bool = True,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
@ -489,21 +487,13 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
) as progress_bar:
|
||||
for _ in range(0, document_count, batch_size):
|
||||
document_batch = list(islice(documents, batch_size))
|
||||
embeddings = retriever.embed_documents(document_batch) # type: ignore
|
||||
if len(document_batch) != len(embeddings):
|
||||
raise DocumentStoreError(
|
||||
"The number of embeddings does not match the number of documents in the batch "
|
||||
f"({len(embeddings)} != {len(document_batch)})"
|
||||
)
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(
|
||||
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate PineconeDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
|
||||
)
|
||||
embeddings = retriever.embed_documents(document_batch)
|
||||
self._validate_embeddings_shape(
|
||||
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
|
||||
)
|
||||
|
||||
embeddings_to_index = np.array(embeddings, dtype="float32")
|
||||
if self.similarity == "cosine":
|
||||
self.normalize_embedding(embeddings_to_index)
|
||||
embeddings = embeddings_to_index.tolist()
|
||||
self.normalize_embedding(embeddings)
|
||||
|
||||
metadata = []
|
||||
ids = []
|
||||
@ -512,7 +502,7 @@ class PineconeDocumentStore(BaseDocumentStore):
|
||||
ids.append(doc.id)
|
||||
# Update existing vectors in pinecone index
|
||||
self.pinecone_indexes[index].upsert(
|
||||
vectors=zip(ids, embeddings, metadata), namespace=self.embedding_namespace
|
||||
vectors=zip(ids, embeddings.tolist(), metadata), namespace=self.embedding_namespace
|
||||
)
|
||||
# Delete existing vectors from document namespace if they exist there
|
||||
self.delete_documents(index=index, ids=ids, namespace=self.document_namespace)
|
||||
|
||||
@ -23,6 +23,7 @@ from haystack.document_stores.base import get_batches_from_generator
|
||||
from haystack.document_stores.filter_utils import LogicalFilterClause
|
||||
from haystack.document_stores.utils import convert_date_to_rfc3339
|
||||
from haystack.errors import DocumentStoreError
|
||||
from haystack.nodes.retriever import DenseRetriever
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1166,7 +1167,7 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
|
||||
def update_embeddings(
|
||||
self,
|
||||
retriever,
|
||||
retriever: DenseRetriever,
|
||||
index: Optional[str] = None,
|
||||
filters: Optional[Dict[str, Union[Dict, List, str, int, float, bool]]] = None,
|
||||
update_existing_embeddings: bool = True,
|
||||
@ -1230,21 +1231,16 @@ class WeaviateDocumentStore(BaseDocumentStore):
|
||||
document_batch = [
|
||||
self._convert_weaviate_result_to_document(hit, return_embedding=False) for hit in result_batch
|
||||
]
|
||||
embeddings = retriever.embed_documents(document_batch) # type: ignore
|
||||
if len(document_batch) != len(embeddings):
|
||||
raise DocumentStoreError(
|
||||
"The number of embeddings does not match the number of documents in the batch "
|
||||
f"({len(embeddings)} != {len(document_batch)})"
|
||||
)
|
||||
if embeddings[0].shape[0] != self.embedding_dim:
|
||||
raise RuntimeError(
|
||||
f"Embedding dimensions of the model ({embeddings[0].shape[0]}) doesn't match the embedding dimensions of the document store ({self.embedding_dim}). Please reinitiate WeaviateDocumentStore() with arg embedding_dim={embeddings[0].shape[0]}."
|
||||
)
|
||||
embeddings = retriever.embed_documents(document_batch)
|
||||
self._validate_embeddings_shape(
|
||||
embeddings=embeddings, num_documents=len(document_batch), embedding_dim=self.embedding_dim
|
||||
)
|
||||
|
||||
if self.similarity == "cosine":
|
||||
self.normalize_embedding(embeddings)
|
||||
|
||||
for doc, emb in zip(document_batch, embeddings):
|
||||
# Using update method to only update the embeddings, other properties will be in tact
|
||||
if self.similarity == "cosine":
|
||||
self.normalize_embedding(emb)
|
||||
self.weaviate_client.data_object.update({}, class_name=index, uuid=doc.id, vector=emb)
|
||||
|
||||
def delete_all_documents(
|
||||
|
||||
@ -5570,10 +5570,10 @@
|
||||
"title": "Use Auth Token",
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "boolean"
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "string"
|
||||
"type": "boolean"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
|
||||
@ -29,6 +29,7 @@ from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker
|
||||
from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader, RCIReader
|
||||
from haystack.nodes.retriever import (
|
||||
BaseRetriever,
|
||||
DenseRetriever,
|
||||
DensePassageRetriever,
|
||||
EmbeddingRetriever,
|
||||
BM25Retriever,
|
||||
|
||||
@ -188,7 +188,7 @@ class RAGenerator(BaseGenerator):
|
||||
self.devices[0]
|
||||
)
|
||||
|
||||
def _prepare_passage_embeddings(self, docs: List[Document], embeddings: List[numpy.ndarray]) -> torch.Tensor:
|
||||
def _prepare_passage_embeddings(self, docs: List[Document], embeddings: numpy.ndarray) -> torch.Tensor:
|
||||
|
||||
# If document missing embedding, then need embedding for all the documents
|
||||
is_embedding_required = embeddings is None or any(embedding is None for embedding in embeddings)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.nodes.retriever.dense import (
|
||||
DenseRetriever,
|
||||
DensePassageRetriever,
|
||||
EmbeddingRetriever,
|
||||
MultihopEmbeddingRetriever,
|
||||
|
||||
@ -26,22 +26,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class _BaseEmbeddingEncoder:
|
||||
@abstractmethod
|
||||
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of queries.
|
||||
|
||||
:param texts: Queries to embed
|
||||
:return: Embeddings, one per input queries
|
||||
:param queries: List of queries to embed.
|
||||
:return: Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
|
||||
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of documents.
|
||||
|
||||
:param docs: List of documents to embed
|
||||
:return: Embeddings, one per input document
|
||||
:param docs: List of documents to embed.
|
||||
:return: Embeddings, one per input document, shape: (documents, embedding_dim)
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -118,18 +118,30 @@ class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
f"This can be set when initializing the DocumentStore"
|
||||
)
|
||||
|
||||
def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]:
|
||||
def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray:
|
||||
# TODO: FARM's `sample_to_features_text` need to fix following warning -
|
||||
# tokenization_utils.py:460: FutureWarning: `is_pretokenized` is deprecated and will be removed in a future version, use `is_split_into_words` instead.
|
||||
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts])
|
||||
emb = [(r["vec"]) for r in emb]
|
||||
emb = np.stack([r["vec"] for r in emb])
|
||||
return emb
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
|
||||
return self.embed(texts)
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of queries.
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
|
||||
passages = [d.content for d in docs] # type: ignore
|
||||
:param queries: List of queries to embed.
|
||||
:return: Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
"""
|
||||
return self.embed(queries)
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of documents.
|
||||
|
||||
:param docs: List of documents to embed.
|
||||
:return: Embeddings, one per input document, shape: (documents, embedding_dim)
|
||||
"""
|
||||
passages = [d.content for d in docs]
|
||||
return self.embed(passages)
|
||||
|
||||
def train(
|
||||
@ -175,18 +187,31 @@ class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
f"This can be set when initializing the DocumentStore"
|
||||
)
|
||||
|
||||
def embed(self, texts: Union[List[List[str]], List[str], str]) -> List[np.ndarray]:
|
||||
def embed(self, texts: Union[List[List[str]], List[str], str]) -> np.ndarray:
|
||||
# texts can be a list of strings or a list of [title, text]
|
||||
# get back list of numpy embedding vectors
|
||||
emb = self.embedding_model.encode(texts, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar)
|
||||
emb = [r for r in emb]
|
||||
emb = self.embedding_model.encode(
|
||||
texts, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True
|
||||
)
|
||||
return emb
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
|
||||
return self.embed(texts)
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of queries.
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
|
||||
passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] # type: ignore
|
||||
:param queries: List of queries to embed.
|
||||
:return: Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
"""
|
||||
return self.embed(queries)
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of documents.
|
||||
|
||||
:param docs: List of documents to embed.
|
||||
:return: Embeddings, one per input document, shape: (documents, embedding_dim)
|
||||
"""
|
||||
passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs]
|
||||
return self.embed(passages)
|
||||
|
||||
def train(
|
||||
@ -250,10 +275,15 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
retriever.embedding_model, use_auth_token=retriever.use_auth_token
|
||||
).to(str(retriever.devices[0]))
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of queries.
|
||||
|
||||
queries = [{"text": q} for q in texts]
|
||||
dataloader = self._create_dataloader(queries)
|
||||
:param queries: List of queries to embed.
|
||||
:return: Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
"""
|
||||
query_text = [{"text": q} for q in queries]
|
||||
dataloader = self._create_dataloader(query_text)
|
||||
|
||||
embeddings: List[np.ndarray] = []
|
||||
disable_tqdm = True if len(dataloader) == 1 else not self.progress_bar
|
||||
@ -272,8 +302,13 @@ class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
|
||||
|
||||
return np.concatenate(embeddings)
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
|
||||
def embed_documents(self, docs: List[Document]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of documents.
|
||||
|
||||
:param docs: List of documents to embed.
|
||||
:return: Embeddings, one per input document, shape: (documents, embedding_dim)
|
||||
"""
|
||||
doc_text = [{"text": d.content} for d in docs]
|
||||
dataloader = self._create_dataloader(doc_text)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import logging
|
||||
from abc import abstractmethod
|
||||
from time import perf_counter
|
||||
from functools import wraps
|
||||
from copy import deepcopy
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -263,7 +262,7 @@ class BaseRetriever(BaseComponent):
|
||||
query: Optional[str] = None,
|
||||
filters: Optional[dict] = None,
|
||||
top_k: Optional[int] = None,
|
||||
documents: Optional[List[dict]] = None,
|
||||
documents: Optional[List[Document]] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: bool = None,
|
||||
@ -279,7 +278,7 @@ class BaseRetriever(BaseComponent):
|
||||
query=query, filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score
|
||||
)
|
||||
elif root_node == "File":
|
||||
self.index_count += len(documents) # type: ignore
|
||||
self.index_count += len(documents) if documents else 0
|
||||
run_indexing = self.timing(self.run_indexing, "index_time")
|
||||
output, stream = run_indexing(documents=documents)
|
||||
else:
|
||||
@ -362,13 +361,7 @@ class BaseRetriever(BaseComponent):
|
||||
|
||||
return output, "output_1"
|
||||
|
||||
def run_indexing(self, documents: List[Union[dict, Document]]):
|
||||
if self.__class__.__name__ in ["DensePassageRetriever", "EmbeddingRetriever"]:
|
||||
documents = deepcopy(documents)
|
||||
document_objects = [Document.from_dict(doc) if isinstance(doc, dict) else doc for doc in documents]
|
||||
embeddings = self.embed_documents(document_objects) # type: ignore
|
||||
for doc, emb in zip(document_objects, embeddings):
|
||||
doc.embedding = emb
|
||||
def run_indexing(self, documents: List[Document]):
|
||||
output = {"documents": documents}
|
||||
return output, "output_1"
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Dict, Union, Optional, Any
|
||||
|
||||
import logging
|
||||
@ -42,7 +43,40 @@ from haystack.modeling.utils import initialize_device_settings
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DensePassageRetriever(BaseRetriever):
|
||||
class DenseRetriever(BaseRetriever):
|
||||
"""
|
||||
Base class for all dense retrievers.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of queries.
|
||||
|
||||
:param queries: List of queries to embed.
|
||||
:return: Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, documents: List[Document]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of documents.
|
||||
|
||||
:param documents: List of documents to embed.
|
||||
:return: Embeddings of documents, one per input document, shape: (documents, embedding_dim)
|
||||
"""
|
||||
pass
|
||||
|
||||
def run_indexing(self, documents: List[Document]):
|
||||
embeddings = self.embed_documents(documents)
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc.embedding = emb
|
||||
output = {"documents": documents}
|
||||
return output, "output_1"
|
||||
|
||||
|
||||
class DensePassageRetriever(DenseRetriever):
|
||||
"""
|
||||
Retriever that uses a bi-encoder (one transformer for query, one transformer for passage).
|
||||
See the original paper for more details:
|
||||
@ -302,7 +336,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
index = self.document_store.index
|
||||
if scale_score is None:
|
||||
scale_score = self.scale_score
|
||||
query_emb = self.embed_queries(texts=[query])
|
||||
query_emb = self.embed_queries(queries=[query])
|
||||
documents = self.document_store.query_by_embedding(
|
||||
query_emb=query_emb[0], top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
|
||||
)
|
||||
@ -430,9 +464,9 @@ class DensePassageRetriever(BaseRetriever):
|
||||
return [[] * len(queries)] # type: ignore
|
||||
|
||||
documents = []
|
||||
query_embs = []
|
||||
query_embs: List[np.ndarray] = []
|
||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||
query_embs.extend(self.embed_queries(texts=batch))
|
||||
query_embs.extend(self.embed_queries(queries=batch))
|
||||
for query_emb, cur_filters in tqdm(
|
||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
||||
):
|
||||
@ -448,7 +482,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
|
||||
return documents
|
||||
|
||||
def _get_predictions(self, dicts):
|
||||
def _get_predictions(self, dicts: List[Dict[str, Any]]) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
|
||||
|
||||
@ -472,7 +506,8 @@ class DensePassageRetriever(BaseRetriever):
|
||||
data_loader = NamedDataLoader(
|
||||
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
||||
)
|
||||
all_embeddings = {"query": [], "passages": []}
|
||||
query_embeddings_batched = []
|
||||
passage_embeddings_batched = []
|
||||
self.model.eval()
|
||||
|
||||
# When running evaluations etc., we don't want a progress bar for every single query
|
||||
@ -503,34 +538,35 @@ class DensePassageRetriever(BaseRetriever):
|
||||
passage_attention_mask=batch.get("passage_attention_mask", None),
|
||||
)[0]
|
||||
if query_embeddings is not None:
|
||||
all_embeddings["query"].append(query_embeddings.cpu().numpy())
|
||||
query_embeddings_batched.append(query_embeddings.cpu().numpy())
|
||||
if passage_embeddings is not None:
|
||||
all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
|
||||
passage_embeddings_batched.append(passage_embeddings.cpu().numpy())
|
||||
progress_bar.update(self.batch_size)
|
||||
|
||||
if all_embeddings["passages"]:
|
||||
all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
|
||||
if all_embeddings["query"]:
|
||||
all_embeddings["query"] = np.concatenate(all_embeddings["query"])
|
||||
all_embeddings: Dict[str, np.ndarray] = {}
|
||||
if passage_embeddings_batched:
|
||||
all_embeddings["passages"] = np.concatenate(passage_embeddings_batched)
|
||||
if query_embeddings_batched:
|
||||
all_embeddings["query"] = np.concatenate(query_embeddings_batched)
|
||||
return all_embeddings
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of queries using the query encoder
|
||||
Create embeddings for a list of queries using the query encoder.
|
||||
|
||||
:param texts: Queries to embed
|
||||
:return: Embeddings, one per input queries
|
||||
:param queries: List of queries to embed.
|
||||
:return: Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
"""
|
||||
queries = [{"query": q} for q in texts]
|
||||
result = self._get_predictions(queries)["query"]
|
||||
query_dicts = [{"query": q} for q in queries]
|
||||
result = self._get_predictions(query_dicts)["query"]
|
||||
return result
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
|
||||
def embed_documents(self, documents: List[Document]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of documents using the passage encoder
|
||||
Create embeddings for a list of documents using the passage encoder.
|
||||
|
||||
:param docs: List of Document objects used to represent documents / passages in a standardized way within Haystack.
|
||||
:return: Embeddings of documents / passages shape (batch_size, embedding_dim)
|
||||
:param documents: List of documents to embed.
|
||||
:return: Embeddings of documents, one per input document, shape: (documents, embedding_dim)
|
||||
"""
|
||||
if self.processor.num_hard_negatives != 0:
|
||||
logger.warning(
|
||||
@ -550,7 +586,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
}
|
||||
]
|
||||
}
|
||||
for d in docs
|
||||
for d in documents
|
||||
]
|
||||
embeddings = self._get_predictions(passages)["passages"]
|
||||
return embeddings
|
||||
@ -757,7 +793,7 @@ class DensePassageRetriever(BaseRetriever):
|
||||
return dpr
|
||||
|
||||
|
||||
class TableTextRetriever(BaseRetriever):
|
||||
class TableTextRetriever(DenseRetriever):
|
||||
"""
|
||||
Retriever that uses a tri-encoder to jointly retrieve among a database consisting of text passages and tables
|
||||
(one transformer for query, one transformer for text passages, one transformer for tables).
|
||||
@ -950,7 +986,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
index = self.document_store.index
|
||||
if scale_score is None:
|
||||
scale_score = self.scale_score
|
||||
query_emb = self.embed_queries(texts=[query])
|
||||
query_emb = self.embed_queries(queries=[query])
|
||||
documents = self.document_store.query_by_embedding(
|
||||
query_emb=query_emb[0], top_k=top_k, filters=filters, index=index, headers=headers, scale_score=scale_score
|
||||
)
|
||||
@ -1078,9 +1114,9 @@ class TableTextRetriever(BaseRetriever):
|
||||
return [[] * len(queries)] # type: ignore
|
||||
|
||||
documents = []
|
||||
query_embs = []
|
||||
query_embs: List[np.ndarray] = []
|
||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||
query_embs.extend(self.embed_queries(texts=batch))
|
||||
query_embs.extend(self.embed_queries(queries=batch))
|
||||
for query_emb, cur_filters in tqdm(
|
||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
||||
):
|
||||
@ -1096,7 +1132,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
|
||||
return documents
|
||||
|
||||
def _get_predictions(self, dicts: List[Dict]) -> Dict[str, List[np.ndarray]]:
|
||||
def _get_predictions(self, dicts: List[Dict[str, Any]]) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
|
||||
|
||||
@ -1121,7 +1157,8 @@ class TableTextRetriever(BaseRetriever):
|
||||
data_loader = NamedDataLoader(
|
||||
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
|
||||
)
|
||||
all_embeddings: Dict = {"query": [], "passages": []}
|
||||
query_embeddings_batched = []
|
||||
passage_embeddings_batched = []
|
||||
self.model.eval()
|
||||
|
||||
# When running evaluations etc., we don't want a progress bar for every single query
|
||||
@ -1145,36 +1182,36 @@ class TableTextRetriever(BaseRetriever):
|
||||
with torch.no_grad():
|
||||
query_embeddings, passage_embeddings = self.model.forward(**batch)[0]
|
||||
if query_embeddings is not None:
|
||||
all_embeddings["query"].append(query_embeddings.cpu().numpy())
|
||||
query_embeddings_batched.append(query_embeddings.cpu().numpy())
|
||||
if passage_embeddings is not None:
|
||||
all_embeddings["passages"].append(passage_embeddings.cpu().numpy())
|
||||
passage_embeddings_batched.append(passage_embeddings.cpu().numpy())
|
||||
progress_bar.update(self.batch_size)
|
||||
|
||||
if all_embeddings["passages"]:
|
||||
all_embeddings["passages"] = np.concatenate(all_embeddings["passages"])
|
||||
if all_embeddings["query"]:
|
||||
all_embeddings["query"] = np.concatenate(all_embeddings["query"])
|
||||
all_embeddings: Dict[str, np.ndarray] = {}
|
||||
if passage_embeddings_batched:
|
||||
all_embeddings["passages"] = np.concatenate(passage_embeddings_batched)
|
||||
if query_embeddings_batched:
|
||||
all_embeddings["query"] = np.concatenate(query_embeddings_batched)
|
||||
return all_embeddings
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of queries using the query encoder
|
||||
Create embeddings for a list of queries using the query encoder.
|
||||
|
||||
:param texts: Queries to embed
|
||||
:return: Embeddings, one per input queries
|
||||
:param queries: List of queries to embed.
|
||||
:return: Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
"""
|
||||
queries = [{"query": q} for q in texts]
|
||||
result = self._get_predictions(queries)["query"]
|
||||
query_dicts = [{"query": q} for q in queries]
|
||||
result = self._get_predictions(query_dicts)["query"]
|
||||
return result
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
|
||||
def embed_documents(self, documents: List[Document]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of text documents and / or tables using the text passage encoder and
|
||||
the table encoder.
|
||||
|
||||
:param docs: List of Document objects used to represent documents / passages in
|
||||
a standardized way within Haystack.
|
||||
:return: Embeddings of documents / passages. Shape: (batch_size, embedding_dim)
|
||||
:param documents: List of documents to embed.
|
||||
:return: Embeddings of documents, one per input document, shape: (documents, embedding_dim)
|
||||
"""
|
||||
|
||||
if self.processor.num_hard_negatives != 0:
|
||||
@ -1185,7 +1222,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
self.processor.num_hard_negatives = 0
|
||||
|
||||
model_input = []
|
||||
for doc in docs:
|
||||
for doc in documents:
|
||||
if doc.content_type == "table":
|
||||
model_input.append(
|
||||
{
|
||||
@ -1440,7 +1477,7 @@ class TableTextRetriever(BaseRetriever):
|
||||
return mm_retriever
|
||||
|
||||
|
||||
class EmbeddingRetriever(BaseRetriever):
|
||||
class EmbeddingRetriever(DenseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
document_store: BaseDocumentStore,
|
||||
@ -1639,7 +1676,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
index = self.document_store.index
|
||||
if scale_score is None:
|
||||
scale_score = self.scale_score
|
||||
query_emb = self.embed_queries(texts=[query])
|
||||
query_emb = self.embed_queries(queries=[query])
|
||||
documents = self.document_store.query_by_embedding(
|
||||
query_emb=query_emb[0], filters=filters, top_k=top_k, index=index, headers=headers, scale_score=scale_score
|
||||
)
|
||||
@ -1767,9 +1804,9 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
return [[] * len(queries)] # type: ignore
|
||||
|
||||
documents = []
|
||||
query_embs = []
|
||||
query_embs: List[np.ndarray] = []
|
||||
for batch in self._get_batches(queries=queries, batch_size=batch_size):
|
||||
query_embs.extend(self.embed_queries(texts=batch))
|
||||
query_embs.extend(self.embed_queries(queries=batch))
|
||||
for query_emb, cur_filters in tqdm(
|
||||
zip(query_embs, filters), total=len(query_embs), disable=not self.progress_bar, desc="Querying"
|
||||
):
|
||||
@ -1785,28 +1822,28 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
|
||||
return documents
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.ndarray]:
|
||||
def embed_queries(self, queries: List[str]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of queries.
|
||||
|
||||
:param texts: Queries to embed
|
||||
:return: Embeddings, one per input queries
|
||||
:param queries: List of queries to embed.
|
||||
:return: Embeddings, one per input query, shape: (queries, embedding_dim)
|
||||
"""
|
||||
# for backward compatibility: cast pure str input
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
assert isinstance(texts, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||||
return self.embedding_encoder.embed_queries(texts)
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
assert isinstance(queries, list), "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||||
return self.embedding_encoder.embed_queries(queries)
|
||||
|
||||
def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
|
||||
def embed_documents(self, documents: List[Document]) -> np.ndarray:
|
||||
"""
|
||||
Create embeddings for a list of documents.
|
||||
|
||||
:param docs: List of documents to embed
|
||||
:return: Embeddings, one per input document
|
||||
:param documents: List of documents to embed.
|
||||
:return: Embeddings, one per input document, shape: (docs, embedding_dim)
|
||||
"""
|
||||
docs = self._preprocess_documents(docs)
|
||||
return self.embedding_encoder.embed_documents(docs)
|
||||
documents = self._preprocess_documents(documents)
|
||||
return self.embedding_encoder.embed_documents(documents)
|
||||
|
||||
def _preprocess_documents(self, docs: List[Document]) -> List[Document]:
|
||||
"""
|
||||
|
||||
@ -57,6 +57,7 @@ from haystack.nodes import (
|
||||
BaseGenerator,
|
||||
BaseSummarizer,
|
||||
BaseTranslator,
|
||||
DenseRetriever,
|
||||
)
|
||||
from haystack.nodes.answer_generator.transformers import Seq2SeqGenerator
|
||||
from haystack.nodes.answer_generator.transformers import RAGenerator
|
||||
@ -308,16 +309,16 @@ class MockTranslator(BaseTranslator):
|
||||
pass
|
||||
|
||||
|
||||
class MockDenseRetriever(MockRetriever):
|
||||
class MockDenseRetriever(MockRetriever, DenseRetriever):
|
||||
def __init__(self, document_store: BaseDocumentStore, embedding_dim: int = 768):
|
||||
self.embedding_dim = embedding_dim
|
||||
self.document_store = document_store
|
||||
|
||||
def embed_queries(self, texts):
|
||||
return [np.random.rand(self.embedding_dim)] * len(texts)
|
||||
def embed_queries(self, queries):
|
||||
return np.random.rand(len(queries), self.embedding_dim)
|
||||
|
||||
def embed_documents(self, docs):
|
||||
return [np.random.rand(self.embedding_dim)] * len(docs)
|
||||
def embed_documents(self, documents):
|
||||
return np.random.rand(len(documents), self.embedding_dim)
|
||||
|
||||
|
||||
class MockQuestionGenerator(QuestionGenerator):
|
||||
|
||||
@ -576,7 +576,7 @@ def test_update_embeddings(document_store, retriever):
|
||||
"content": "text_7",
|
||||
"id": "7",
|
||||
"meta_field": "value_7",
|
||||
"embedding": retriever.embed_queries(texts=["a random string"])[0],
|
||||
"embedding": retriever.embed_queries(queries=["a random string"])[0],
|
||||
}
|
||||
document_store.write_documents([doc])
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user