make SentenceTransformersTextEmbedder non batch (#5811)

This commit is contained in:
Stefano Fiorucci 2023-09-14 12:38:24 +02:00 committed by GitHub
parent 4bad202197
commit ad5b615503
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 24 deletions

View File

@ -80,22 +80,22 @@ class SentenceTransformersTextEmbedder:
model_name_or_path=self.model_name_or_path, device=self.device, use_auth_token=self.use_auth_token
)
@component.output_types(embeddings=List[List[float]])
def run(self, texts: List[str]):
"""Embed a list of strings."""
if not isinstance(texts, list) or not isinstance(texts[0], str):
@component.output_types(embedding=List[float])
def run(self, text: str):
"""Embed a string."""
if not isinstance(text, str):
raise TypeError(
"SentenceTransformersTextEmbedder expects a list of strings as input."
"SentenceTransformersTextEmbedder expects a string as input."
"In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder."
)
if not hasattr(self, "embedding_backend"):
raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
texts_to_embed = [self.prefix + text + self.suffix for text in texts]
embeddings = self.embedding_backend.embed(
texts_to_embed,
text_to_embed = self.prefix + text + self.suffix
embedding = self.embedding_backend.embed(
[text_to_embed],
batch_size=self.batch_size,
show_progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
)
return {"embeddings": embeddings}
)[0]
return {"embedding": embedding}

View File

@ -139,27 +139,20 @@ class TestSentenceTransformersTextEmbedder:
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist()
texts = ["sentence1", "sentence2"]
text = "a nice text to embed"
result = embedder.run(texts=texts)
embeddings = result["embeddings"]
result = embedder.run(text=text)
embedding = result["embedding"]
assert isinstance(embeddings, list)
assert len(embeddings) == len(texts)
for embedding in embeddings:
assert isinstance(embedding, list)
assert isinstance(embedding[0], float)
assert isinstance(embedding, list)
assert all(isinstance(el, float) for el in embedding)
@pytest.mark.unit
def test_run_wrong_input_format(self):
embedder = SentenceTransformersTextEmbedder(model_name_or_path="model")
embedder.embedding_backend = MagicMock()
string_input = "text"
list_integers_input = [1, 2, 3]
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a list of strings as input"):
embedder.run(texts=string_input)
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a list of strings as input"):
embedder.run(texts=list_integers_input)
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"):
embedder.run(text=list_integers_input)