Fix embeddings from sentence-transformers (type cast & gpu flags) (#121)

This commit is contained in:
Malte Pietsch 2020-05-26 16:43:05 +02:00 committed by GitHub
parent 5c68a5d755
commit df554fcb45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -83,7 +83,11 @@ class EmbeddingRetriever(BaseRetriever):
# pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models
# e.g. 'roberta-base-nli-stsb-mean-tokens'
self.embedding_model = SentenceTransformer(embedding_model)
if gpu:
device = "gpu"
else:
device = "cpu"
self.embedding_model = SentenceTransformer(embedding_model, device=device)
else:
raise NotImplementedError
@ -111,5 +115,5 @@ class EmbeddingRetriever(BaseRetriever):
elif self.model_format == "sentence_transformers":
# text is single string, sentence-transformers needs a list of strings
res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors
emb = [list(r) for r in res] #cast from numpy
emb = [list(r.astype('float64')) for r in res] #cast from numpy
return emb