mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-16 01:54:35 +00:00
Fix return type of EmbeddingRetriever to numpy array (#245)
This commit is contained in:
parent
4da480aa15
commit
355be293b6
@ -4,6 +4,7 @@ from string import Template
|
|||||||
from typing import List, Optional, Union, Dict, Any
|
from typing import List, Optional, Union, Dict, Any
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
from elasticsearch.helpers import bulk, scan
|
from elasticsearch.helpers import bulk, scan
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from haystack.database.base import BaseDocumentStore, Document
|
from haystack.database.base import BaseDocumentStore, Document
|
||||||
|
|
||||||
@ -236,7 +237,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
return documents
|
return documents
|
||||||
|
|
||||||
def query_by_embedding(self,
|
def query_by_embedding(self,
|
||||||
query_emb: List[float],
|
query_emb: np.array,
|
||||||
filters: Optional[dict] = None,
|
filters: Optional[dict] = None,
|
||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
index: Optional[str] = None) -> List[Document]:
|
index: Optional[str] = None) -> List[Document]:
|
||||||
@ -255,7 +256,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||||||
"script": {
|
"script": {
|
||||||
"source": f"cosineSimilarity(params.query_vector,doc['{self.embedding_field}']) + 1.0",
|
"source": f"cosineSimilarity(params.query_vector,doc['{self.embedding_field}']) + 1.0",
|
||||||
"params": {
|
"params": {
|
||||||
"query_vector": query_emb
|
"query_vector": query_emb.tolist()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -246,7 +246,7 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
top_k=top_k, index=index)
|
top_k=top_k, index=index)
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
def embed(self, texts: Union[List[str], str]) -> List[List[float]]:
|
def embed(self, texts: Union[List[str], str]) -> List[np.array]:
|
||||||
"""
|
"""
|
||||||
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
|
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
|
||||||
:param texts: texts to embed
|
:param texts: texts to embed
|
||||||
@ -259,13 +259,14 @@ class EmbeddingRetriever(BaseRetriever):
|
|||||||
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
||||||
|
|
||||||
if self.model_format == "farm" or self.model_format == "transformers":
|
if self.model_format == "farm" or self.model_format == "transformers":
|
||||||
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
|
emb = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts]) # type: ignore
|
||||||
emb = [list(r["vec"]) for r in res] #cast from numpy
|
emb = [(r["vec"]) for r in emb]
|
||||||
elif self.model_format == "sentence_transformers":
|
elif self.model_format == "sentence_transformers":
|
||||||
# text is single string, sentence-transformers needs a list of strings
|
# text is single string, sentence-transformers needs a list of strings
|
||||||
# get back list of numpy embedding vectors
|
# get back list of numpy embedding vectors
|
||||||
res = self.embedding_model.encode(texts) # type: ignore
|
emb = self.embedding_model.encode(texts) # type: ignore
|
||||||
emb = [list(r.astype('float64')) for r in res] #cast from numpy
|
# cast to float64 as float32 can cause trouble when serializing for ES
|
||||||
|
emb = [(r.astype('float64')) for r in emb]
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
def embed_queries(self, texts: List[str]) -> List[np.array]:
|
def embed_queries(self, texts: List[str]) -> List[np.array]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user