Fix return type of EmbeddingRetriever to numpy array (#245)

This commit is contained in:
Malte Pietsch 2020-07-17 19:03:31 +02:00 committed by GitHub
parent 4da480aa15
commit 355be293b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 7 deletions

View File

@ -4,6 +4,7 @@ from string import Template
from typing import List, Optional, Union, Dict, Any from typing import List, Optional, Union, Dict, Any
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk, scan from elasticsearch.helpers import bulk, scan
import numpy as np
from haystack.database.base import BaseDocumentStore, Document from haystack.database.base import BaseDocumentStore, Document
@ -236,7 +237,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
return documents return documents
def query_by_embedding(self, def query_by_embedding(self,
query_emb: List[float], query_emb: np.array,
filters: Optional[dict] = None, filters: Optional[dict] = None,
top_k: int = 10, top_k: int = 10,
index: Optional[str] = None) -> List[Document]: index: Optional[str] = None) -> List[Document]:
@ -255,7 +256,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
"script": { "script": {
"source": f"cosineSimilarity(params.query_vector,doc['{self.embedding_field}']) + 1.0", "source": f"cosineSimilarity(params.query_vector,doc['{self.embedding_field}']) + 1.0",
"params": { "params": {
"query_vector": query_emb "query_vector": query_emb.tolist()
} }
} }
} }

View File

@ -246,7 +246,7 @@ class EmbeddingRetriever(BaseRetriever):
top_k=top_k, index=index) top_k=top_k, index=index)
return documents return documents
def embed(self, texts: Union[List[str], str]) -> List[List[float]]: def embed(self, texts: Union[List[str], str]) -> List[np.array]:
""" """
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`) Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
:param texts: texts to embed :param texts: texts to embed
@ -259,13 +259,14 @@ class EmbeddingRetriever(BaseRetriever):
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])" assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
if self.model_format == "farm" or self.model_format == "transformers": if self.model_format == "farm" or self.model_format == "transformers":
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
emb = [list(r["vec"]) for r in res] #cast from numpy emb = [(r["vec"]) for r in emb]
elif self.model_format == "sentence_transformers": elif self.model_format == "sentence_transformers":
# text is single string, sentence-transformers needs a list of strings # text is single string, sentence-transformers needs a list of strings
# get back list of numpy embedding vectors # get back list of numpy embedding vectors
res = self.embedding_model.encode(texts) # type: ignore emb = self.embedding_model.encode(texts) # type: ignore
emb = [list(r.astype('float64')) for r in res] #cast from numpy # cast to float64 as float32 can cause trouble when serializing for ES
emb = [(r.astype('float64')) for r in emb]
return emb return emb
def embed_queries(self, texts: List[str]) -> List[np.array]: def embed_queries(self, texts: List[str]) -> List[np.array]: