mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-26 14:38:36 +00:00
Re-ranking component for document search without QA (#1025)
* Adding ranker similar to retriever and reader * Sort documents according to query-document similarity scores * Reranking and model training runs for small example * Added EvalRanker node * Calculate recall@k in EvalRetriever and EvalRanker nodes * Renaming EvalRetriever to EvalDocuments and EvalReader to EvalAnswers * Added mean reciprocal rank as metric for EvalDocuments * Fix bug that appeared when ranking documents with same score * Remove commented code for unimplmented eval() of Ranker node * Add documentation of k parameter in EvalDocuments * Add Ranker docu and renaming top_k param
This commit is contained in:
parent
b5cae20ddb
commit
84c34295a1
11
README.md
11
README.md
@ -164,12 +164,13 @@ We recommend Elasticsearch or FAISS but also have more light-weight options for
|
||||
Retrievers narrow down the search space significantly and are therefore crucial for scalable QA.
|
||||
Haystack supports sparse methods (TF-IDF, BM25, custom Elasticsearch queries)
|
||||
and state of the art dense methods (e.g., sentence-transformers and Dense Passage Retrieval)
|
||||
5. **Reader**: Neural network (e.g., BERT or RoBERTA) that reads through texts in detail
|
||||
5. **Ranker**: Neural network (e.g., BERT or RoBERTA) that re-ranks top-k retrieved documents. The Ranker is an optional component and uses a TextPairClassification model under the hood. This model calculates semantic similarity of each of the top-k retrieved documents with the query.
|
||||
6. **Reader**: Neural network (e.g., BERT or RoBERTA) that reads through texts in detail
|
||||
to find an answer. The Reader takes multiple passages of text as input and returns top-n answers. Models are trained via [FARM](https://github.com/deepset-ai/FARM) or [Transformers](https://github.com/huggingface/transformers) on SQuAD like tasks. You can load a pre-trained model from [Hugging Face's model hub](https://huggingface.co/models) or fine-tune it on your domain data.
|
||||
6. **Generator**: Neural network (e.g., RAG) that *generates* an answer for a given question conditioned on the retrieved documents from the retriever.
|
||||
6. **Pipeline**: Stick building blocks together to highly custom pipelines that are represented as Directed Acyclic Graphs (DAG). Think of it as "Apache Airflow for search".
|
||||
7. **REST API**: Exposes a simple API based on fastAPI for running QA search, uploading files, and collecting user feedback for continuous learning.
|
||||
8. **Haystack Annotate**: Create custom QA labels to improve the performance of your domain-specific models. [Hosted version](https://annotate.deepset.ai/login) or [Docker images](https://github.com/deepset-ai/haystack/tree/master/annotation_tool).
|
||||
7. **Generator**: Neural network (e.g., RAG) that *generates* an answer for a given question conditioned on the retrieved documents from the retriever.
|
||||
8. **Pipeline**: Stick building blocks together to highly custom pipelines that are represented as Directed Acyclic Graphs (DAG). Think of it as "Apache Airflow for search".
|
||||
9. **REST API**: Exposes a simple API based on fastAPI for running QA search, uploading files, and collecting user feedback for continuous learning.
|
||||
10. **Haystack Annotate**: Create custom QA labels to improve the performance of your domain-specific models. [Hosted version](https://annotate.deepset.ai/login) or [Docker images](https://github.com/deepset-ai/haystack/tree/master/annotation_tool).
|
||||
|
||||
It's quite simple to begin experimenting with Haystack. We'd recommend going through the [Tutorials](https://github.com/deepset-ai/haystack/#tutorials) section below, but here's an example code structure describing how to approach Haystack with the DocumentStore based on Elasticsearch.
|
||||
|
||||
|
||||
@ -157,9 +157,8 @@ reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4, return
|
||||
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
from haystack.eval import EvalAnswers, EvalDocuments
|
||||
|
||||
# Here we initialize the nodes that perform evaluation
|
||||
eval_retriever = EvalRetriever()
|
||||
@ -212,9 +211,9 @@ from haystack import Pipeline
|
||||
# Here is the pipeline definition
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"])
|
||||
p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"])
|
||||
results = []
|
||||
```
|
||||
|
||||
|
||||
54
docs/_src/usage/usage/ranker.md
Normal file
54
docs/_src/usage/usage/ranker.md
Normal file
@ -0,0 +1,54 @@
|
||||
<!---
|
||||
title: "Ranker"
|
||||
metaTitle: "Ranker"
|
||||
metaDescription: ""
|
||||
slug: "/docs/ranker"
|
||||
date: "2020-05-26"
|
||||
id: "rankermd"
|
||||
--->
|
||||
|
||||
# Ranker
|
||||
|
||||
There are pure "semantic document search" use cases that do not need question answering functionality but only document ranking.
|
||||
While the [Retriever](/docs/latest/retrievermd) is a perfect fit for document retrieval, we can further improve its results with the Ranker.
|
||||
For example, BM25 (sparse retriever) does not take into account semantics of the documents and the query but only their keywords.
|
||||
The Ranker can re-rank the results of the retriever step by taking semantics into account.
|
||||
Similar to the Reader, it is based on the latest language models.
|
||||
Instead of returning answers, it returns documents in re-ranked order.
|
||||
|
||||
Without a Ranker and its re-ranking step, the querying process is faster but the query results might be of lower quality.
|
||||
If you want to do "semantic document search" instead of a question answering, try first with a Retriever only.
|
||||
In case the semantic similarity of the query and the resulting documents is low, add a Ranker.
|
||||
|
||||
Note that a Ranker needs to be initialised with a model trained on a text pair classification task.
|
||||
You can train the model also with the train() method of the Ranker.
|
||||
Alternatively, [this example](https://github.com/deepset-ai/FARM/blob/master/examples/text_pair_classification.py) shows how to train a text pair classification model in FARM.
|
||||
|
||||
|
||||
## FARMRanker
|
||||
|
||||
### Description
|
||||
|
||||
The FARMRanker consists of a Transformer-based model for document re-ranking using the TextPairClassifier of [FARM](https://github.com/deepset-ai/FARM).
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interface remains the same.
|
||||
With a FARMRanker, you can:
|
||||
* Directly get predictions (re-ranked version of the supplied list of Document) via predict() if supplying a pre-trained model
|
||||
* Take a plain language model (e.g. `bert-base-cased`) and train it for TextPairClassification via train()
|
||||
|
||||
### Initialisation
|
||||
|
||||
```python
|
||||
from haystack.document_store import ElasticsearchDocumentStore
|
||||
from haystack.retriever import ElasticsearchRetriever
|
||||
from haystack.ranker import FARMRanker
|
||||
from haystack import Pipeline
|
||||
|
||||
document_store = ElasticsearchDocumentStore()
|
||||
...
|
||||
retriever = ElasticsearchRetriever(document_store)
|
||||
ranker = FARMRanker(model_name_or_path="saved_models/roberta-base-asnq-binary")
|
||||
...
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=ranker, name="Ranker", inputs=["ESRetriever"])
|
||||
```
|
||||
@ -157,9 +157,8 @@ reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4, return
|
||||
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
from haystack.eval import EvalAnswers, EvalDocuments
|
||||
|
||||
# Here we initialize the nodes that perform evaluation
|
||||
eval_retriever = EvalRetriever()
|
||||
@ -212,9 +211,9 @@ from haystack import Pipeline
|
||||
# Here is the pipeline definition
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"])
|
||||
p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"])
|
||||
results = []
|
||||
```
|
||||
|
||||
|
||||
@ -9,19 +9,21 @@ from farm.evaluation.squad_evaluation import compute_exact as calculate_em_str
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvalRetriever:
|
||||
class EvalDocuments:
|
||||
"""
|
||||
This is a pipeline node that should be placed after a Retriever in order to assess its performance. Performance
|
||||
metrics are stored in this class and updated as each sample passes through it. To view the results of the evaluation,
|
||||
call EvalRetriever.print(). Note that results from this Node may differ from that when calling Retriever.eval()
|
||||
since that is a closed domain evaluation. Have a look at our evaluation tutorial for more info about
|
||||
open vs closed domain eval (https://haystack.deepset.ai/docs/latest/tutorial5md).
|
||||
This is a pipeline node that should be placed after a node that returns a List of Document, e.g., Retriever or
|
||||
Ranker, in order to assess its performance. Performance metrics are stored in this class and updated as each
|
||||
sample passes through it. To view the results of the evaluation, call EvalDocuments.print(). Note that results
|
||||
from this Node may differ from that when calling Retriever.eval() since that is a closed domain evaluation. Have
|
||||
a look at our evaluation tutorial for more info about open vs closed domain eval (
|
||||
https://haystack.deepset.ai/docs/latest/tutorial5md).
|
||||
"""
|
||||
def __init__(self, debug: bool=False, open_domain: bool=True):
|
||||
def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10, name="EvalDocuments"):
|
||||
"""
|
||||
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
|
||||
When False, correct retrieval is evaluated based on document_id.
|
||||
:param debug: When True, a record of each sample and its evaluation will be stored in EvalRetriever.log
|
||||
:param debug: When True, a record of each sample and its evaluation will be stored in EvalDocuments.log
|
||||
:param top_k: calculate eval metrics for top k results, e.g., recall@k
|
||||
"""
|
||||
self.outgoing_edges = 1
|
||||
self.init_counts()
|
||||
@ -29,6 +31,8 @@ class EvalRetriever:
|
||||
self.debug = debug
|
||||
self.log: List = []
|
||||
self.open_domain = open_domain
|
||||
self.top_k = top_k
|
||||
self.name = name
|
||||
|
||||
def init_counts(self):
|
||||
self.correct_retrieval_count = 0
|
||||
@ -38,6 +42,10 @@ class EvalRetriever:
|
||||
self.has_answer_recall = 0
|
||||
self.no_answer_count = 0
|
||||
self.recall = 0.0
|
||||
self.mean_reciprocal_rank = 0.0
|
||||
self.has_answer_mean_reciprocal_rank = 0.0
|
||||
self.reciprocal_rank_sum = 0.0
|
||||
self.has_answer_reciprocal_rank_sum = 0.0
|
||||
|
||||
def run(self, documents, labels: dict, **kwargs):
|
||||
"""Run this node on one sample and its labels"""
|
||||
@ -48,6 +56,8 @@ class EvalRetriever:
|
||||
if retriever_labels.no_answer:
|
||||
self.no_answer_count += 1
|
||||
correct_retrieval = 1
|
||||
retrieved_reciprocal_rank = 1
|
||||
self.reciprocal_rank_sum += 1
|
||||
if not self.no_answer_warning:
|
||||
self.no_answer_warning = True
|
||||
logger.warning("There seem to be empty string labels in the dataset suggesting that there "
|
||||
@ -56,48 +66,62 @@ class EvalRetriever:
|
||||
# If there are answer span annotations in the labels
|
||||
else:
|
||||
self.has_answer_count += 1
|
||||
correct_retrieval = self.is_correctly_retrieved(retriever_labels, documents)
|
||||
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents)
|
||||
self.reciprocal_rank_sum += retrieved_reciprocal_rank
|
||||
correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
|
||||
self.has_answer_correct += int(correct_retrieval)
|
||||
self.has_answer_reciprocal_rank_sum += retrieved_reciprocal_rank
|
||||
self.has_answer_recall = self.has_answer_correct / self.has_answer_count
|
||||
self.has_answer_mean_reciprocal_rank = self.has_answer_reciprocal_rank_sum / self.has_answer_count
|
||||
|
||||
self.correct_retrieval_count += correct_retrieval
|
||||
self.recall = self.correct_retrieval_count / self.query_count
|
||||
self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count
|
||||
|
||||
if self.debug:
|
||||
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, **kwargs})
|
||||
return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, **kwargs}, "output_1"
|
||||
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs})
|
||||
return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}, "output_1"
|
||||
|
||||
def is_correctly_retrieved(self, retriever_labels, predictions):
|
||||
return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0
|
||||
|
||||
def reciprocal_rank_retrieved(self, retriever_labels, predictions):
|
||||
if self.open_domain:
|
||||
for label in retriever_labels.multiple_answers:
|
||||
for p in predictions:
|
||||
for rank, p in enumerate(predictions[:self.top_k]):
|
||||
if label.lower() in p.text.lower():
|
||||
return True
|
||||
return 1/(rank+1)
|
||||
return False
|
||||
else:
|
||||
prediction_ids = [p.id for p in predictions]
|
||||
prediction_ids = [p.id for p in predictions[:self.top_k]]
|
||||
label_ids = retriever_labels.multiple_document_ids
|
||||
for l in label_ids:
|
||||
if l in prediction_ids:
|
||||
return True
|
||||
return False
|
||||
for rank, p in enumerate(prediction_ids):
|
||||
if p in label_ids:
|
||||
return 1/(rank+1)
|
||||
return 0
|
||||
|
||||
def print(self):
|
||||
"""Print the evaluation results"""
|
||||
print("Retriever")
|
||||
print(self.name)
|
||||
print("-----------------")
|
||||
if self.no_answer_count:
|
||||
print(
|
||||
f"has_answer recall: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
|
||||
f"has_answer recall@{self.top_k}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})")
|
||||
print(
|
||||
f"no_answer recall: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
|
||||
print(f"recall: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
|
||||
f"no_answer recall@{self.top_k}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
|
||||
print(
|
||||
f"has_answer mean_reciprocal_rank@{self.top_k}: {self.has_answer_mean_reciprocal_rank:.4f}")
|
||||
print(
|
||||
f"no_answer mean_reciprocal_rank@{self.top_k}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)")
|
||||
print(f"recall@{self.top_k}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})")
|
||||
print(f"mean_reciprocal_rank@{self.top_k}: {self.mean_reciprocal_rank:.4f}")
|
||||
|
||||
|
||||
class EvalReader:
|
||||
class EvalAnswers:
|
||||
"""
|
||||
This is a pipeline node that should be placed after a Reader in order to assess the performance of the Reader
|
||||
individually or to assess the extractive QA performance of the whole pipeline. Performance metrics are stored in
|
||||
this class and updated as each sample passes through it. To view the results of the evaluation, call EvalReader.print().
|
||||
this class and updated as each sample passes through it. To view the results of the evaluation, call EvalAnswers.print().
|
||||
Note that results from this Node may differ from that when calling Reader.eval()
|
||||
since that is a closed domain evaluation. Have a look at our evaluation tutorial for more info about
|
||||
open vs closed domain eval (https://haystack.deepset.ai/docs/latest/tutorial5md).
|
||||
@ -107,7 +131,7 @@ class EvalReader:
|
||||
"""
|
||||
:param skip_incorrect_retrieval: When set to True, this eval will ignore the cases where the retriever returned no correct documents
|
||||
:param open_domain: When True, extracted answers are evaluated purely on string similarity rather than the position of the extracted answer
|
||||
:param debug: When True, a record of each sample and its evaluation will be stored in EvalReader.log
|
||||
:param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log
|
||||
"""
|
||||
self.outgoing_edges = 1
|
||||
self.init_counts()
|
||||
|
||||
@ -6,7 +6,7 @@ from haystack.reader.farm import FARMReader
|
||||
from haystack import Pipeline
|
||||
from farm.utils import initialize_device_settings
|
||||
from haystack.preprocessor import PreProcessor
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
from haystack.eval import EvalAnswers, EvalDocuments
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
@ -37,9 +37,9 @@ def main():
|
||||
|
||||
document_store = ElasticsearchDocumentStore()
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
eval_retriever = EvalRetriever(open_domain=open_domain)
|
||||
eval_retriever = EvalDocuments(open_domain=open_domain)
|
||||
reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4, num_processes=1, return_no_answer=True)
|
||||
eval_reader = EvalReader(debug=True, open_domain=open_domain)
|
||||
eval_reader = EvalAnswers(debug=True, open_domain=open_domain)
|
||||
|
||||
# Download evaluation data, which is a subset of Natural Questions development set containing 50 documents
|
||||
doc_dir = "../data/nq"
|
||||
@ -68,9 +68,9 @@ def main():
|
||||
# Here is the pipeline definition
|
||||
p = Pipeline()
|
||||
p.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"])
|
||||
p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"])
|
||||
|
||||
results = []
|
||||
for i, (q, l) in enumerate(q_to_l_dict.items()):
|
||||
|
||||
1
haystack/ranker/__init__.py
Normal file
1
haystack/ranker/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from haystack.ranker.farm import FARMRanker
|
||||
105
haystack/ranker/base.py
Normal file
105
haystack/ranker/base.py
Normal file
@ -0,0 +1,105 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from functools import wraps
|
||||
from time import perf_counter
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from haystack import Document, BaseComponent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRanker(BaseComponent):
|
||||
return_no_answers: bool
|
||||
outgoing_edges = 1
|
||||
query_count = 0
|
||||
query_time = 0
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
|
||||
pass
|
||||
|
||||
def run(self, query: str, documents: List[Document], top_k_ranker: Optional[int] = None, **kwargs): # type: ignore
|
||||
self.query_count += 1
|
||||
if documents:
|
||||
predict = self.timing(self.predict, "query_time")
|
||||
results = predict(query=query, documents=documents, top_k=top_k_ranker)
|
||||
else:
|
||||
results = []
|
||||
|
||||
document_ids = [doc.id for doc in results]
|
||||
logger.debug(f"Retrieved documents with IDs: {document_ids}")
|
||||
output = {
|
||||
"query": query,
|
||||
"documents": results,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
return output, "output_1"
|
||||
|
||||
def timing(self, fn, attr_name):
|
||||
"""Wrapper method used to time functions. """
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if attr_name not in self.__dict__:
|
||||
self.__dict__[attr_name] = 0
|
||||
tic = perf_counter()
|
||||
ret = fn(*args, **kwargs)
|
||||
toc = perf_counter()
|
||||
self.__dict__[attr_name] += toc - tic
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
def print_time(self):
|
||||
print("Ranker (Speed)")
|
||||
print("---------------")
|
||||
if not self.query_count:
|
||||
print("No querying performed via Retriever.run()")
|
||||
else:
|
||||
print(f"Queries Performed: {self.query_count}")
|
||||
print(f"Query time: {self.query_time}s")
|
||||
print(f"{self.query_time / self.query_count} seconds per query")
|
||||
|
||||
def eval(
|
||||
self,
|
||||
label_index: str = "label",
|
||||
doc_index: str = "eval_document",
|
||||
label_origin: str = "gold_label",
|
||||
top_k: int = 10,
|
||||
open_domain: bool = False,
|
||||
return_preds: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Performs evaluation of the Ranker.
|
||||
Ranker is evaluated in the same way as a Retriever based on whether it finds the correct document given the query string and at which
|
||||
position in the ranking of documents the correct document is.
|
||||
|
||||
| Returns a dict containing the following metrics:
|
||||
|
||||
- "recall": Proportion of questions for which correct document is among retrieved documents
|
||||
- "mrr": Mean of reciprocal rank. Rewards retrievers that give relevant documents a higher rank.
|
||||
Only considers the highest ranked relevant document.
|
||||
- "map": Mean of average precision for each question. Rewards retrievers that give relevant
|
||||
documents a higher rank. Considers all retrieved relevant documents. If ``open_domain=True``,
|
||||
average precision is normalized by the number of retrieved relevant documents per query.
|
||||
If ``open_domain=False``, average precision is normalized by the number of all relevant documents
|
||||
per query.
|
||||
|
||||
:param label_index: Index/Table in DocumentStore where labeled questions are stored
|
||||
:param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored
|
||||
:param top_k: How many documents to return per query
|
||||
:param open_domain: If ``True``, retrieval will be evaluated by checking if the answer string to a question is
|
||||
contained in the retrieved docs (common approach in open-domain QA).
|
||||
If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids
|
||||
are within ids explicitly stated in the labels.
|
||||
:param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary
|
||||
contains the keys "predictions" and "metrics".
|
||||
"""
|
||||
raise NotImplementedError
|
||||
268
haystack/ranker/farm.py
Normal file
268
haystack/ranker/farm.py
Normal file
@ -0,0 +1,268 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from farm.data_handler.data_silo import DataSilo
|
||||
from farm.data_handler.processor import TextPairClassificationProcessor
|
||||
from farm.infer import Inferencer
|
||||
from farm.modeling.optimization import initialize_optimizer
|
||||
from farm.modeling.adaptive_model import BaseAdaptiveModel
|
||||
from farm.train import Trainer
|
||||
from farm.utils import set_all_seeds, initialize_device_settings
|
||||
import shutil
|
||||
|
||||
from haystack import Document
|
||||
from haystack.ranker.base import BaseRanker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FARMRanker(BaseRanker):
|
||||
"""
|
||||
Transformer based model for Document Re-ranking using the TextPairClassifier of FARM framework (https://github.com/deepset-ai/FARM).
|
||||
While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interface remains the same.
|
||||
|
||||
| With a FARMRanker, you can:
|
||||
|
||||
- directly get predictions via predict()
|
||||
- fine-tune the model on TextPair data via train()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: Union[str, Path],
|
||||
model_version: Optional[str] = None,
|
||||
batch_size: int = 50,
|
||||
use_gpu: bool = True,
|
||||
top_k: int = 10,
|
||||
num_processes: Optional[int] = None,
|
||||
max_seq_len: int = 256,
|
||||
progress_bar: bool = True
|
||||
):
|
||||
|
||||
"""
|
||||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'bert-base-cased',
|
||||
'deepset/bert-base-cased-squad2', 'deepset/bert-base-cased-squad2', 'distilbert-base-uncased-distilled-squad'.
|
||||
See https://huggingface.co/models for full list of available models.
|
||||
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
|
||||
:param batch_size: Number of samples the model receives in one batch for inference.
|
||||
Memory consumption is much lower in inference mode. Recommendation: Increase the batch size
|
||||
to a value so only a single batch is used.
|
||||
:param use_gpu: Whether to use GPU (if available)
|
||||
:param top_k: The maximum number of documents to return
|
||||
:param num_processes: The number of processes for `multiprocessing.Pool`. Set to value of 0 to disable
|
||||
multiprocessing. Set to None to let Inferencer determine optimum number. If you
|
||||
want to debug the Language Model, you might need to disable multiprocessing!
|
||||
:param max_seq_len: Max sequence length of one input text for the model
|
||||
:param progress_bar: Whether to show a tqdm progress bar or not.
|
||||
Can be helpful to disable in production deployments to keep the logs clean.
|
||||
"""
|
||||
|
||||
# save init parameters to enable export of component config as YAML
|
||||
self.set_config(
|
||||
model_name_or_path=model_name_or_path, model_version=model_version,
|
||||
batch_size=batch_size, use_gpu=use_gpu, top_k=top_k,
|
||||
num_processes=num_processes, max_seq_len=max_seq_len, progress_bar=progress_bar,
|
||||
)
|
||||
|
||||
self.top_k = top_k
|
||||
|
||||
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
|
||||
task_type="text_classification", max_seq_len=max_seq_len,
|
||||
num_processes=num_processes, revision=model_version,
|
||||
disable_tqdm=not progress_bar,
|
||||
strict=False)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.use_gpu = use_gpu
|
||||
self.progress_bar = progress_bar
|
||||
|
||||
def train(
|
||||
self,
|
||||
data_dir: str,
|
||||
train_filename: str,
|
||||
dev_filename: Optional[str] = None,
|
||||
test_filename: Optional[str] = None,
|
||||
use_gpu: Optional[bool] = None,
|
||||
batch_size: int = 10,
|
||||
n_epochs: int = 2,
|
||||
learning_rate: float = 1e-5,
|
||||
max_seq_len: Optional[int] = None,
|
||||
warmup_proportion: float = 0.2,
|
||||
dev_split: float = 0,
|
||||
evaluate_every: int = 300,
|
||||
save_dir: Optional[str] = None,
|
||||
num_processes: Optional[int] = None,
|
||||
use_amp: str = None,
|
||||
):
|
||||
"""
|
||||
Fine-tune a model on a TextPairClassification dataset. Options:
|
||||
|
||||
- Take a plain language model (e.g. `bert-base-cased`) and train it for TextPairClassification
|
||||
- Take a TextPairClassification model and fine-tune it for your domain
|
||||
|
||||
:param data_dir: Path to directory containing your training data in SQuAD style
|
||||
:param train_filename: Filename of training data
|
||||
:param dev_filename: Filename of dev / eval data
|
||||
:param test_filename: Filename of test data
|
||||
:param dev_split: Instead of specifying a dev_filename, you can also specify a ratio (e.g. 0.1) here
|
||||
that gets split off from training data for eval.
|
||||
:param use_gpu: Whether to use GPU (if available)
|
||||
:param batch_size: Number of samples the model receives in one batch for training
|
||||
:param n_epochs: Number of iterations on the whole training data set
|
||||
:param learning_rate: Learning rate of the optimizer
|
||||
:param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down.
|
||||
:param warmup_proportion: Proportion of training steps until maximum learning rate is reached.
|
||||
Until that point LR is increasing linearly. After that it's decreasing again linearly.
|
||||
Options for different schedules are available in FARM.
|
||||
:param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset
|
||||
:param save_dir: Path to store the final model
|
||||
:param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing.
|
||||
Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set.
|
||||
Set to None to use all CPU cores minus one.
|
||||
:param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model.
|
||||
Available options:
|
||||
None (Don't use AMP)
|
||||
"O0" (Normal FP32 training)
|
||||
"O1" (Mixed Precision => Recommended)
|
||||
"O2" (Almost FP16)
|
||||
"O3" (Pure FP16).
|
||||
See details on: https://nvidia.github.io/apex/amp.html
|
||||
:return: None
|
||||
"""
|
||||
|
||||
if dev_filename:
|
||||
dev_split = 0
|
||||
|
||||
if num_processes is None:
|
||||
num_processes = multiprocessing.cpu_count() - 1 or 1
|
||||
|
||||
set_all_seeds(seed=42)
|
||||
|
||||
# For these variables, by default, we use the value set when initializing the FARMRanker.
|
||||
# These can also be manually set when train() is called if you want a different value at train vs inference
|
||||
if use_gpu is None:
|
||||
use_gpu = self.use_gpu
|
||||
if max_seq_len is None:
|
||||
max_seq_len = self.max_seq_len
|
||||
|
||||
device, n_gpu = initialize_device_settings(use_cuda=use_gpu, use_amp=use_amp)
|
||||
|
||||
if not save_dir:
|
||||
save_dir = f"saved_models/{self.inferencer.model.language_model.name}"
|
||||
|
||||
# 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
|
||||
label_list = ["0", "1"]
|
||||
metric = "f1_macro"
|
||||
processor = TextPairClassificationProcessor(
|
||||
tokenizer=self.inferencer.processor.tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
label_list=label_list,
|
||||
metric=metric,
|
||||
train_filename=train_filename,
|
||||
dev_filename=dev_filename,
|
||||
dev_split=dev_split,
|
||||
test_filename=test_filename,
|
||||
data_dir=Path(data_dir),
|
||||
delimiter="\t"
|
||||
)
|
||||
|
||||
# 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them
|
||||
# and calculates a few descriptive statistics of our datasets
|
||||
data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes)
|
||||
|
||||
# Quick-fix until this is fixed upstream in FARM:
|
||||
# We must avoid applying DataParallel twice (once when loading the inferencer,
|
||||
# once when calling initalize_optimizer)
|
||||
self.inferencer.model.save("tmp_model")
|
||||
model = BaseAdaptiveModel.load(load_dir="tmp_model", device=device, strict=True)
|
||||
shutil.rmtree('tmp_model')
|
||||
|
||||
# 3. Create an optimizer and pass the already initialized model
|
||||
model, optimizer, lr_schedule = initialize_optimizer(
|
||||
model=model,
|
||||
learning_rate=learning_rate,
|
||||
schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion},
|
||||
n_batches=len(data_silo.loaders["train"]),
|
||||
n_epochs=n_epochs,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
)
|
||||
# 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
data_silo=data_silo,
|
||||
epochs=n_epochs,
|
||||
n_gpu=n_gpu,
|
||||
lr_schedule=lr_schedule,
|
||||
evaluate_every=evaluate_every,
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
disable_tqdm=not self.progress_bar
|
||||
)
|
||||
|
||||
# 5. Let it grow!
|
||||
self.inferencer.model = trainer.train()
|
||||
self.save(Path(save_dir))
|
||||
|
||||
def update_parameters(
|
||||
self,
|
||||
max_seq_len: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Hot update parameters of a loaded Ranker. It may not to be safe when processing concurrent requests.
|
||||
"""
|
||||
if max_seq_len is not None:
|
||||
self.inferencer.processor.max_seq_len = max_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def save(self, directory: Path):
|
||||
"""
|
||||
Saves the Ranker model so that it can be reused at a later point in time.
|
||||
|
||||
:param directory: Directory where the Ranker model should be saved
|
||||
"""
|
||||
logger.info(f"Saving ranker model to {directory}")
|
||||
self.inferencer.model.save(directory)
|
||||
self.inferencer.processor.save(directory)
|
||||
|
||||
def predict_batch(self, query_doc_list: List[dict], top_k: int = None, batch_size: int = None):
|
||||
"""
|
||||
Use loaded Ranker model to, for a list of queries, rank each query's supplied list of Document.
|
||||
|
||||
Returns list of dictionary of query and list of document sorted by (desc.) similarity with query
|
||||
|
||||
:param query_doc_list: List of dictionaries containing queries with their retrieved documents
|
||||
:param top_k: The maximum number of answers to return for each query
|
||||
:param batch_size: Number of samples the model receives in one batch for inference
|
||||
:return: List of dictionaries containing query and ranked list of Document
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
"""
|
||||
Use loaded ranker model to re-rank the supplied list of Document.
|
||||
|
||||
Returns list of Document sorted by (desc.) TextPairClassification similarity with the query.
|
||||
|
||||
:param query: Query string
|
||||
:param documents: List of Document to be re-ranked
|
||||
:param top_k: The maximum number of documents to return
|
||||
:return: List of Document
|
||||
"""
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
|
||||
# calculate similarity of query and each document
|
||||
query_and_docs = [{"text": (query, doc.text)} for doc in documents]
|
||||
result = self.inferencer.inference_from_dicts(dicts=query_and_docs)
|
||||
similarity_scores = [pred["probability"] for preds in result for pred in preds["predictions"]]
|
||||
|
||||
# rank documents according to scores
|
||||
sorted_scores_and_documents = sorted(zip(similarity_scores, documents),
|
||||
key=lambda similarity_document_tuple: similarity_document_tuple[0],
|
||||
reverse=True)
|
||||
sorted_documents = [doc for _, doc in sorted_scores_and_documents]
|
||||
return sorted_documents[:top_k]
|
||||
@ -2,7 +2,7 @@ import pytest
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack.preprocessor.preprocessor import PreProcessor
|
||||
from haystack.finder import Finder
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
from haystack.eval import EvalAnswers, EvalDocuments
|
||||
from haystack import Pipeline
|
||||
|
||||
|
||||
@ -114,15 +114,15 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever):
|
||||
} for l in labels
|
||||
}
|
||||
|
||||
eval_retriever = EvalRetriever()
|
||||
eval_reader = EvalReader()
|
||||
eval_retriever = EvalDocuments()
|
||||
eval_reader = EvalAnswers()
|
||||
|
||||
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"])
|
||||
p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"])
|
||||
for q, l in q_to_l_dict.items():
|
||||
res = p.run(
|
||||
query=q,
|
||||
|
||||
@ -2,7 +2,7 @@ from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.preprocessor.utils import fetch_archive_from_http
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
from haystack.eval import EvalAnswers, EvalDocuments
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.preprocessor import PreProcessor
|
||||
from haystack.utils import launch_es
|
||||
@ -104,8 +104,8 @@ def tutorial5_evaluation():
|
||||
)
|
||||
|
||||
# Here we initialize the nodes that perform evaluation
|
||||
eval_retriever = EvalRetriever()
|
||||
eval_reader = EvalReader()
|
||||
eval_retriever = EvalDocuments()
|
||||
eval_reader = EvalAnswers()
|
||||
|
||||
|
||||
## Evaluate Retriever on its own in closed domain fashion
|
||||
@ -137,9 +137,9 @@ def tutorial5_evaluation():
|
||||
# Here is the pipeline definition
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"])
|
||||
p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"])
|
||||
results = []
|
||||
|
||||
for q, l in q_to_l_dict.items():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user