mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-07 12:37:27 +00:00
make SentenceTransformersTextEmbedder non batch (#5811)
This commit is contained in:
parent
4bad202197
commit
ad5b615503
@ -80,22 +80,22 @@ class SentenceTransformersTextEmbedder:
|
||||
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token
|
||||
)
|
||||
|
||||
@component.output_types(embeddings=List[List[float]])
|
||||
def run(self, texts: List[str]):
|
||||
"""Embed a list of strings."""
|
||||
if not isinstance(texts, list) or not isinstance(texts[0], str):
|
||||
@component.output_types(embedding=List[float])
|
||||
def run(self, text: str):
|
||||
"""Embed a string."""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(
|
||||
"SentenceTransformersTextEmbedder expects a list of strings as input."
|
||||
"SentenceTransformersTextEmbedder expects a string as input."
|
||||
"In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder."
|
||||
)
|
||||
if not hasattr(self, "embedding_backend"):
|
||||
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
|
||||
|
||||
texts_to_embed = [self.prefix + text + self.suffix for text in texts]
|
||||
embeddings = self.embedding_backend.embed(
|
||||
texts_to_embed,
|
||||
text_to_embed = self.prefix + text + self.suffix
|
||||
embedding = self.embedding_backend.embed(
|
||||
[text_to_embed],
|
||||
batch_size=self.batch_size,
|
||||
show_progress_bar=self.progress_bar,
|
||||
normalize_embeddings=self.normalize_embeddings,
|
||||
)
|
||||
return {"embeddings": embeddings}
|
||||
)[0]
|
||||
return {"embedding": embedding}
|
||||
|
||||
@ -139,27 +139,20 @@ class TestSentenceTransformersTextEmbedder:
|
||||
embedder.embedding_backend = MagicMock()
|
||||
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()
|
||||
|
||||
texts = ["sentence1", "sentence2"]
|
||||
text = "a nice text to embed"
|
||||
|
||||
result = embedder.run(texts=texts)
|
||||
embeddings = result["embeddings"]
|
||||
result = embedder.run(text=text)
|
||||
embedding = result["embedding"]
|
||||
|
||||
assert isinstance(embeddings, list)
|
||||
assert len(embeddings) == len(texts)
|
||||
for embedding in embeddings:
|
||||
assert isinstance(embedding, list)
|
||||
assert isinstance(embedding[0], float)
|
||||
assert isinstance(embedding, list)
|
||||
assert all(isinstance(el, float) for el in embedding)
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run_wrong_input_format(self):
|
||||
embedder = SentenceTransformersTextEmbedder(model_name_or_path="model")
|
||||
embedder.embedding_backend = MagicMock()
|
||||
|
||||
string_input = "text"
|
||||
list_integers_input = [1, 2, 3]
|
||||
|
||||
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a list of strings as input"):
|
||||
embedder.run(texts=string_input)
|
||||
|
||||
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a list of strings as input"):
|
||||
embedder.run(texts=list_integers_input)
|
||||
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"):
|
||||
embedder.run(text=list_integers_input)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user