mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-04 05:43:29 +00:00
Fix embeddings from sentence-transformers (type cast & gpu flags) (#121)
This commit is contained in:
parent
5c68a5d755
commit
df554fcb45
@ -83,7 +83,11 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
|
|
||||||
# pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models
|
# pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models
|
||||||
# e.g. 'roberta-base-nli-stsb-mean-tokens'
|
# e.g. 'roberta-base-nli-stsb-mean-tokens'
|
||||||
self.embedding_model = SentenceTransformer(embedding_model)
|
if gpu:
|
||||||
|
device = "gpu"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
self.embedding_model = SentenceTransformer(embedding_model, device=device)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -111,5 +115,5 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
elif self.model_format == "sentence_transformers":
|
elif self.model_format == "sentence_transformers":
|
||||||
# text is single string, sentence-transformers needs a list of strings
|
# text is single string, sentence-transformers needs a list of strings
|
||||||
res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors
|
res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors
|
||||||
emb = [list(r) for r in res] #cast from numpy
|
emb = [list(r.astype('float64')) for r in res] #cast from numpy
|
||||||
return emb
|
return emb
|
||||||
|
Loading…
x
Reference in New Issue
Block a user