mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 11:56:35 +00:00
Fix embeddings from sentence-transformers (type cast & gpu flags) (#121)
This commit is contained in:
parent
5c68a5d755
commit
df554fcb45
@ -83,7 +83,11 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
|
||||
# pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models
|
||||
# e.g. 'roberta-base-nli-stsb-mean-tokens'
|
||||
self.embedding_model = SentenceTransformer(embedding_model)
|
||||
if gpu:
|
||||
device = "gpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
self.embedding_model = SentenceTransformer(embedding_model, device=device)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -111,5 +115,5 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
elif self.model_format == "sentence_transformers":
|
||||
# text is single string, sentence-transformers needs a list of strings
|
||||
res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors
|
||||
emb = [list(r) for r in res] #cast from numpy
|
||||
emb = [list(r.astype('float64')) for r in res] #cast from numpy
|
||||
return emb
|
||||
|
Loading…
x
Reference in New Issue
Block a user