Fix return type of EmbeddingRetriever to numpy array (#245)

This commit is contained in:
Malte Pietsch 2020-07-17 19:03:31 +02:00 committed by GitHub
parent 4da480aa15
commit 355be293b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 7 deletions

View File

@ -4,6 +4,7 @@ from string import Template
from typing import List, Optional, Union, Dict, Any
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan
import numpy as np
from haystack.database.base import BaseDocumentStore, Document
@ -236,7 +237,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
return documents
def query_by_embedding(self,
query_emb: List[float],
query_emb: np.array,
filters: Optional[dict] = None,
top_k: int = 10,
index: Optional[str] = None) -> List[Document]:
@ -255,7 +256,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
"script": {
"source": f"cosineSimilarity(params.query_vector,doc['{self.embedding_field}']) + 1.0",
"params": {
"query_vector": query_emb
"query_vector": query_emb.tolist()
}
}
}

View File

@ -246,7 +246,7 @@ class EmbeddingRetriever(BaseRetriever):
top_k=top_k, index=index)
return documents
def embed(self, texts: Union[List[str], str]) -> List[List[float]]:
def embed(self, texts: Union[List[str], str]) -> List[np.array]:
"""
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
:param texts: texts to embed
@ -259,13 +259,14 @@ class EmbeddingRetriever(BaseRetriever):
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
if self.model_format == "farm" or self.model_format == "transformers":
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
emb = [list(r["vec"]) for r in res] #cast from numpy
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
emb = [(r["vec"]) for r in emb]
elif self.model_format == "sentence_transformers":
# text is single string, sentence-transformers needs a list of strings
# get back list of numpy embedding vectors
res = self.embedding_model.encode(texts) # type: ignore
emb = [list(r.astype('float64')) for r in res] #cast from numpy
emb = self.embedding_model.encode(texts) # type: ignore
# cast to float64 as float32 can cause trouble when serializing for ES
emb = [(r.astype('float64')) for r in emb]
return emb
def embed_queries(self, texts: List[str]) -> List[np.array]: