from typing import List, Optional, Union, Dict, Any from haystack.preview import component, Document, default_to_dict, default_from_dict from haystack.preview.components.embedders.backends.sentence_transformers_backend import ( _SentenceTransformersEmbeddingBackendFactory, ) @component class SentenceTransformersDocumentEmbedder: """ A component for computing Document embeddings using Sentence Transformers models. The embedding of each Document is stored in the `embedding` field of the Document. """ def __init__( self, model_name_or_path: str = "sentence-transformers/all-mpnet-base-v2", device: Optional[str] = None, token: Union[bool, str, None] = None, prefix: str = "", suffix: str = "", batch_size: int = 32, progress_bar: bool = True, normalize_embeddings: bool = False, metadata_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", ): """ Create a SentenceTransformersDocumentEmbedder component. :param model_name_or_path: Local path or name of the model in Hugging Face's model hub, such as ``'sentence-transformers/all-mpnet-base-v2'``. :param device: Device (like 'cuda' / 'cpu') that should be used for computation. Defaults to CPU. :param token: The API token used to download private models from Hugging Face. If this parameter is set to `True`, then the token generated when running `transformers-cli login` (stored in ~/.huggingface) will be used. :param prefix: A string to add to the beginning of each Document text before embedding. :param suffix: A string to add to the end of each Document text before embedding. :param batch_size: Number of strings to encode at once. :param progress_bar: If true, displays progress bar during embedding. :param normalize_embeddings: If set to true, returned vectors will have length 1. :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. """ self.model_name_or_path = model_name_or_path # TODO: remove device parameter and use Haystack's device management once migrated self.device = device or "cpu" self.token = token self.prefix = prefix self.suffix = suffix self.batch_size = batch_size self.progress_bar = progress_bar self.normalize_embeddings = normalize_embeddings self.metadata_fields_to_embed = metadata_fields_to_embed or [] self.embedding_separator = embedding_separator def _get_telemetry_data(self) -> Dict[str, Any]: """ Data that is sent to Posthog for usage analytics. """ return {"model": self.model_name_or_path} def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ return default_to_dict( self, model_name_or_path=self.model_name_or_path, device=self.device, token=self.token if not isinstance(self.token, str) else None, # don't serialize valid tokens prefix=self.prefix, suffix=self.suffix, batch_size=self.batch_size, progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, metadata_fields_to_embed=self.metadata_fields_to_embed, embedding_separator=self.embedding_separator, ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDocumentEmbedder": """ Deserialize this component from a dictionary. """ return default_from_dict(cls, data) def warm_up(self): """ Load the embedding backend. """ if not hasattr(self, "embedding_backend"): self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend( model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.token ) @component.output_types(documents=List[Document]) def run(self, documents: List[Document]): """ Embed a list of Documents. The embedding of each Document is stored in the `embedding` field of the Document. """ if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): raise TypeError( "SentenceTransformersDocumentEmbedder expects a list of Documents as input." "In case you want to embed a list of strings, please use the SentenceTransformersTextEmbedder." ) if not hasattr(self, "embedding_backend"): raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here texts_to_embed = [] for doc in documents: meta_values_to_embed = [ str(doc.metadata[key]) for key in self.metadata_fields_to_embed if key in doc.metadata and doc.metadata[key] ] text_to_embed = ( self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.text or ""]) + self.suffix ) texts_to_embed.append(text_to_embed) embeddings = self.embedding_backend.embed( texts_to_embed, batch_size=self.batch_size, show_progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, ) documents_with_embeddings = [] for doc, emb in zip(documents, embeddings): doc_as_dict = doc.to_dict() doc_as_dict["embedding"] = emb documents_with_embeddings.append(Document.from_dict(doc_as_dict)) return {"documents": documents_with_embeddings}