diff --git a/haystack/preview/components/embedders/sentence_transformers_text_embedder.py b/haystack/preview/components/embedders/sentence_transformers_text_embedder.py index 08ced5a8c..c22e04379 100644 --- a/haystack/preview/components/embedders/sentence_transformers_text_embedder.py +++ b/haystack/preview/components/embedders/sentence_transformers_text_embedder.py @@ -80,22 +80,22 @@ class SentenceTransformersTextEmbedder: model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token ) - @component.output_types(embeddings=List[List[float]]) - def run(self, texts: List[str]): - """Embed a list of strings.""" - if not isinstance(texts, list) or not isinstance(texts[0], str): + @component.output_types(embedding=List[float]) + def run(self, text: str): + """Embed a string.""" + if not isinstance(text, str): raise TypeError( - "SentenceTransformersTextEmbedder expects a list of strings as input." + "SentenceTransformersTextEmbedder expects a string as input." "In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder." ) if not hasattr(self, "embedding_backend"): raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.") - texts_to_embed = [self.prefix + text + self.suffix for text in texts] - embeddings = self.embedding_backend.embed( - texts_to_embed, + text_to_embed = self.prefix + text + self.suffix + embedding = self.embedding_backend.embed( + [text_to_embed], batch_size=self.batch_size, show_progress_bar=self.progress_bar, normalize_embeddings=self.normalize_embeddings, - ) - return {"embeddings": embeddings} + )[0] + return {"embedding": embedding} diff --git a/test/preview/components/embedders/test_sentence_transformers_text_embedder.py b/test/preview/components/embedders/test_sentence_transformers_text_embedder.py index 9aa20696a..70ab959c7 100644 --- a/test/preview/components/embedders/test_sentence_transformers_text_embedder.py +++ b/test/preview/components/embedders/test_sentence_transformers_text_embedder.py @@ -139,27 +139,20 @@ class TestSentenceTransformersTextEmbedder: embedder.embedding_backend = MagicMock() embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() - texts = ["sentence1", "sentence2"] + text = "a nice text to embed" - result = embedder.run(texts=texts) - embeddings = result["embeddings"] + result = embedder.run(text=text) + embedding = result["embedding"] - assert isinstance(embeddings, list) - assert len(embeddings) == len(texts) - for embedding in embeddings: - assert isinstance(embedding, list) - assert isinstance(embedding[0], float) + assert isinstance(embedding, list) + assert all(isinstance(el, float) for el in embedding) @pytest.mark.unit def test_run_wrong_input_format(self): embedder = SentenceTransformersTextEmbedder(model_name_or_path="model") embedder.embedding_backend = MagicMock() - string_input = "text" list_integers_input = [1, 2, 3] - with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a list of strings as input"): - embedder.run(texts=string_input) - - with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a list of strings as input"): - embedder.run(texts=list_integers_input) + with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"): + embedder.run(text=list_integers_input)