mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 02:39:51 +00:00
Fix return type of EmbeddingRetriever to numpy array (#245)
This commit is contained in:
parent
4da480aa15
commit
355be293b6
@ -4,6 +4,7 @@ from string import Template
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
from elasticsearch import Elasticsearch
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
import numpy as np
|
||||
|
||||
from haystack.database.base import BaseDocumentStore, Document
|
||||
|
||||
@ -236,7 +237,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
return documents
|
||||
|
||||
def query_by_embedding(self,
|
||||
query_emb: List[float],
|
||||
query_emb: np.array,
|
||||
filters: Optional[dict] = None,
|
||||
top_k: int = 10,
|
||||
index: Optional[str] = None) -> List[Document]:
|
||||
@ -255,7 +256,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
"script": {
|
||||
"source": f"cosineSimilarity(params.query_vector,doc['{self.embedding_field}']) + 1.0",
|
||||
"params": {
|
||||
"query_vector": query_emb
|
||||
"query_vector": query_emb.tolist()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -246,7 +246,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
top_k=top_k, index=index)
|
||||
return documents
|
||||
|
||||
def embed(self, texts: Union[List[str], str]) -> List[List[float]]:
|
||||
def embed(self, texts: Union[List[str], str]) -> List[np.array]:
|
||||
"""
|
||||
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
|
||||
:param texts: texts to embed
|
||||
@ -259,13 +259,14 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||||
|
||||
if self.model_format == "farm" or self.model_format == "transformers":
|
||||
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
|
||||
emb = [list(r["vec"]) for r in res] #cast from numpy
|
||||
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
|
||||
emb = [(r["vec"]) for r in emb]
|
||||
elif self.model_format == "sentence_transformers":
|
||||
# text is single string, sentence-transformers needs a list of strings
|
||||
# get back list of numpy embedding vectors
|
||||
res = self.embedding_model.encode(texts) # type: ignore
|
||||
emb = [list(r.astype('float64')) for r in res] #cast from numpy
|
||||
emb = self.embedding_model.encode(texts) # type: ignore
|
||||
# cast to float64 as float32 can cause trouble when serializing for ES
|
||||
emb = [(r.astype('float64')) for r in emb]
|
||||
return emb
|
||||
|
||||
def embed_queries(self, texts: List[str]) -> List[np.array]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user