mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-05 20:33:58 +00:00
* Add basic tutorial for FAQ-based QA and switch to bach computation of embeddings * update readme & haystack version in tutorial
116 lines
5.2 KiB
Python
116 lines
5.2 KiB
Python
import logging
|
|
from typing import Type
|
|
from farm.infer import Inferencer
|
|
|
|
from haystack.database.base import Document, BaseDocumentStore
|
|
from haystack.retriever.base import BaseRetriever
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ElasticsearchRetriever(BaseRetriever):
|
|
def __init__(self, document_store: Type[BaseDocumentStore], custom_query: str = None):
|
|
"""
|
|
:param document_store: an instance of a DocumentStore to retrieve documents from.
|
|
:param custom_query: query string as per Elasticsearch DSL with a mandatory question placeholder($question).
|
|
|
|
Optionally, ES `filter` clause can be added where the values of `terms` are placeholders
|
|
that get substituted during runtime. The placeholder(${filter_name_1}, ${filter_name_2}..)
|
|
names must match with the filters dict supplied in self.retrieve().
|
|
|
|
An example custom_query:
|
|
{
|
|
"size": 10,
|
|
"query": {
|
|
"bool": {
|
|
"should": [{"multi_match": {
|
|
"query": "${question}", // mandatory $question placeholder
|
|
"type": "most_fields",
|
|
"fields": ["text", "title"]}}],
|
|
"filter": [ // optional custom filters
|
|
{"terms": {"year": "${years}"}},
|
|
{"terms": {"quarter": "${quarters}"}}],
|
|
}
|
|
},
|
|
}
|
|
|
|
For this custom_query, a sample retrieve() could be:
|
|
self.retrieve(query="Why did the revenue increase?",
|
|
filters={"years": ["2019"], "quarters": ["Q1", "Q2"]})
|
|
"""
|
|
self.document_store = document_store
|
|
self.custom_query = custom_query
|
|
|
|
def retrieve(self, query: str, filters: dict = None, top_k: int = 10) -> [Document]:
|
|
documents = self.document_store.query(query, filters, top_k, self.custom_query)
|
|
logger.info(f"Got {len(documents)} candidates from retriever")
|
|
|
|
return documents
|
|
|
|
|
|
class EmbeddingRetriever(BaseRetriever):
|
|
def __init__(
|
|
self,
|
|
document_store: Type[BaseDocumentStore],
|
|
embedding_model: str,
|
|
gpu: bool = True,
|
|
model_format: str = "farm",
|
|
pooling_strategy: str = "reduce_mean",
|
|
emb_extraction_layer: int = -1,
|
|
):
|
|
"""
|
|
TODO
|
|
:param document_store:
|
|
:param embedding_model:
|
|
:param gpu:
|
|
:param model_format:
|
|
"""
|
|
self.document_store = document_store
|
|
self.model_format = model_format
|
|
self.embedding_model = embedding_model
|
|
self.pooling_strategy = pooling_strategy
|
|
self.emb_extraction_layer = emb_extraction_layer
|
|
|
|
logger.info(f"Init retriever using embeddings of model {embedding_model}")
|
|
if model_format == "farm" or model_format == "transformers":
|
|
self.embedding_model = Inferencer.load(
|
|
embedding_model, task_type="embeddings", extraction_strategy=self.pooling_strategy,
|
|
extraction_layer=self.emb_extraction_layer, gpu=gpu, batch_size=4, max_seq_len=512, num_processes=0
|
|
)
|
|
|
|
elif model_format == "sentence_transformers":
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
# pretrained embedding models coming from: https://github.com/UKPLab/sentence-transformers#pretrained-models
|
|
# e.g. 'roberta-base-nli-stsb-mean-tokens'
|
|
self.embedding_model = SentenceTransformer(embedding_model)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def retrieve(self, query: str, candidate_doc_ids: [str] = None, top_k: int = 10) -> [Document]:
|
|
query_emb = self.create_embedding(texts=[query])
|
|
documents = self.document_store.query_by_embedding(query_emb[0], top_k, candidate_doc_ids)
|
|
|
|
return documents
|
|
|
|
def create_embedding(self, texts: [str]):
|
|
"""
|
|
Create embeddings for each text in a list of texts using the retrievers model (`self.embedding_model`)
|
|
:param texts: texts to embed
|
|
:return: list of embeddings (one per input text). Each embedding is a list of floats.
|
|
"""
|
|
|
|
# for backward compatibility: cast pure str input
|
|
if type(texts) == str:
|
|
texts = [texts]
|
|
assert type(texts) == list, "Expecting a list of texts, i.e. create_embeddings(texts=['text1',...])"
|
|
|
|
if self.model_format == "farm":
|
|
res = self.embedding_model.inference_from_dicts(dicts=[{"text": t} for t in texts])
|
|
emb = [list(r["vec"]) for r in res] #cast from numpy
|
|
elif self.model_format == "sentence_transformers":
|
|
# text is single string, sentence-transformers needs a list of strings
|
|
res = self.embedding_model.encode(texts) # get back list of numpy embedding vectors
|
|
emb = [list(r) for r in res] #cast from numpy
|
|
return emb
|