mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 16:15:35 +00:00

* move embedding backends * use token in Sentence Transformers embeddings * more compact token handling * token parameter in reader * add token to ranker * release note * add test for reader
142 lines
6.0 KiB
Python
142 lines
6.0 KiB
Python
from typing import List, Optional, Union, Dict, Any
|
|
|
|
from haystack.preview import component, Document, default_to_dict, default_from_dict
|
|
from haystack.preview.components.embedders.backends.sentence_transformers_backend import (
|
|
_SentenceTransformersEmbeddingBackendFactory,
|
|
)
|
|
|
|
|
|
@component
|
|
class SentenceTransformersDocumentEmbedder:
|
|
"""
|
|
A component for computing Document embeddings using Sentence Transformers models.
|
|
The embedding of each Document is stored in the `embedding` field of the Document.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2",
|
|
device: Optional[str] = None,
|
|
token: Union[bool, str, None] = None,
|
|
prefix: str = "",
|
|
suffix: str = "",
|
|
batch_size: int = 32,
|
|
progress_bar: bool = True,
|
|
normalize_embeddings: bool = False,
|
|
metadata_fields_to_embed: Optional[List[str]] = None,
|
|
embedding_separator: str = "\n",
|
|
):
|
|
"""
|
|
Create a SentenceTransformersDocumentEmbedder component.
|
|
|
|
:param model_name_or_path: Local path or name of the model in Hugging Face's model hub,
|
|
such as ``'sentence-transformers/all-mpnet-base-v2'``.
|
|
:param device: Device (like 'cuda' / 'cpu') that should be used for computation.
|
|
Defaults to CPU.
|
|
:param token: The API token used to download private models from Hugging Face.
|
|
If this parameter is set to `True`, then the token generated when running
|
|
`transformers-cli login` (stored in ~/.huggingface) will be used.
|
|
:param prefix: A string to add to the beginning of each Document text before embedding.
|
|
:param suffix: A string to add to the end of each Document text before embedding.
|
|
:param batch_size: Number of strings to encode at once.
|
|
:param progress_bar: If true, displays progress bar during embedding.
|
|
:param normalize_embeddings: If set to true, returned vectors will have length 1.
|
|
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document content.
|
|
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
|
|
"""
|
|
|
|
self.model_name_or_path = model_name_or_path
|
|
# TODO: remove device parameter and use Haystack's device management once migrated
|
|
self.device = device or "cpu"
|
|
self.token = token
|
|
self.prefix = prefix
|
|
self.suffix = suffix
|
|
self.batch_size = batch_size
|
|
self.progress_bar = progress_bar
|
|
self.normalize_embeddings = normalize_embeddings
|
|
self.metadata_fields_to_embed = metadata_fields_to_embed or []
|
|
self.embedding_separator = embedding_separator
|
|
|
|
def _get_telemetry_data(self) -> Dict[str, Any]:
|
|
"""
|
|
Data that is sent to Posthog for usage analytics.
|
|
"""
|
|
return {"model": self.model_name_or_path}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""
|
|
Serialize this component to a dictionary.
|
|
"""
|
|
return default_to_dict(
|
|
self,
|
|
model_name_or_path=self.model_name_or_path,
|
|
device=self.device,
|
|
token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens
|
|
prefix=self.prefix,
|
|
suffix=self.suffix,
|
|
batch_size=self.batch_size,
|
|
progress_bar=self.progress_bar,
|
|
normalize_embeddings=self.normalize_embeddings,
|
|
metadata_fields_to_embed=self.metadata_fields_to_embed,
|
|
embedding_separator=self.embedding_separator,
|
|
)
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder":
|
|
"""
|
|
Deserialize this component from a dictionary.
|
|
"""
|
|
return default_from_dict(cls, data)
|
|
|
|
def warm_up(self):
|
|
"""
|
|
Load the embedding backend.
|
|
"""
|
|
if not hasattr(self, "embedding_backend"):
|
|
self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
|
|
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.token
|
|
)
|
|
|
|
@component.output_types(documents=List[Document])
|
|
def run(self, documents: List[Document]):
|
|
"""
|
|
Embed a list of Documents.
|
|
The embedding of each Document is stored in the `embedding` field of the Document.
|
|
"""
|
|
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
|
|
raise TypeError(
|
|
"SentenceTransformersDocumentEmbedder expects a list of Documents as input."
|
|
"In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder."
|
|
)
|
|
if not hasattr(self, "embedding_backend"):
|
|
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
|
|
|
|
# TODO: once non textual Documents are properly supported, we should also prepare them for embedding here
|
|
|
|
texts_to_embed = []
|
|
for doc in documents:
|
|
meta_values_to_embed = [
|
|
str(doc.metadata[key])
|
|
for key in self.metadata_fields_to_embed
|
|
if key in doc.metadata and doc.metadata[key]
|
|
]
|
|
text_to_embed = (
|
|
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.text or ""]) + self.suffix
|
|
)
|
|
texts_to_embed.append(text_to_embed)
|
|
|
|
embeddings = self.embedding_backend.embed(
|
|
texts_to_embed,
|
|
batch_size=self.batch_size,
|
|
show_progress_bar=self.progress_bar,
|
|
normalize_embeddings=self.normalize_embeddings,
|
|
)
|
|
|
|
documents_with_embeddings = []
|
|
for doc, emb in zip(documents, embeddings):
|
|
doc_as_dict = doc.to_dict()
|
|
doc_as_dict["embedding"] = emb
|
|
documents_with_embeddings.append(Document.from_dict(doc_as_dict))
|
|
|
|
return {"documents": documents_with_embeddings}
|