From df554fcb458515c0fd5b08dab46cc11b73acd3bb Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Tue, 26 May 2020 16:43:05 +0200 Subject: [PATCH] Fix embeddings from sentence-transformers (type cast & gpu flags) (#121) --- haystack/retriever/elasticsearch.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/haystack/retriever/elasticsearch.py b/haystack/retriever/elasticsearch.py index 37020fbe1..bb658cd19 100644 --- a/haystack/retriever/elasticsearch.py +++ b/haystack/retriever/elasticsearch.py @@ -83,7 +83,11 @@ class EmbeddingRetriever(BaseRetriever): # pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models # e.g. 'roberta-base-nli-stsb-mean-tokens' - self.embedding_model = SentenceTransformer(embedding_model) + if gpu: + device = "gpu" + else: + device = "cpu" + self.embedding_model = SentenceTransformer(embedding_model, device=device) else: raise NotImplementedError @@ -111,5 +115,5 @@ class EmbeddingRetriever(BaseRetriever): elif self.model_format == "sentence_transformers": # text is single string, sentence-transformers needs a list of strings res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors - emb = [list(r) for r in res] #cast from numpy + emb = [list(r.astype('float64')) for r in res] #cast from numpy return emb