From 84c34295a1b40449f4eca5e531832aa586bf05dc Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Mon, 31 May 2021 15:31:36 +0200 Subject: [PATCH] Re-ranking component for document search without QA (#1025) * Adding ranker similar to retriever and reader * Sort documents according to query-document similarity scores * Reranking and model training runs for small example * Added EvalRanker node * Calculate recall@k in EvalRetriever and EvalRanker nodes * Renaming EvalRetriever to EvalDocuments and EvalReader to EvalAnswers * Added mean reciprocal rank as metric for EvalDocuments * Fix bug that appeared when ranking documents with same score * Remove commented code for unimplmented eval() of Ranker node * Add documentation of k parameter in EvalDocuments * Add Ranker docu and renaming top_k param --- README.md | 11 +- docs/_src/tutorials/tutorials/5.md | 9 +- docs/_src/usage/usage/ranker.md | 54 +++++ docs/v0.8.0/_src/tutorials/tutorials/5.md | 9 +- haystack/eval.py | 74 ++++-- haystack/pipeline_eval.py | 12 +- haystack/ranker/__init__.py | 1 + haystack/ranker/base.py | 105 +++++++++ haystack/ranker/farm.py | 268 ++++++++++++++++++++++ test/test_eval.py | 12 +- tutorials/Tutorial5_Evaluation.py | 12 +- 11 files changed, 509 insertions(+), 58 deletions(-) create mode 100644 docs/_src/usage/usage/ranker.md create mode 100644 haystack/ranker/__init__.py create mode 100644 haystack/ranker/base.py create mode 100644 haystack/ranker/farm.py diff --git a/README.md b/README.md index 86eed58ba..8afc24230 100644 --- a/README.md +++ b/README.md @@ -164,12 +164,13 @@ We recommend Elasticsearch or FAISS but also have more light-weight options for Retrievers narrow down the search space significantly and are therefore crucial for scalable QA. Haystack supports sparse methods (TF-IDF, BM25, custom Elasticsearch queries) and state of the art dense methods (e.g., sentence-transformers and Dense Passage Retrieval) -5. **Reader**: Neural network (e.g., BERT or RoBERTA) that reads through texts in detail +5. **Ranker**: Neural network (e.g., BERT or RoBERTA) that re-ranks top-k retrieved documents. The Ranker is an optional component and uses a TextPairClassification model under the hood. This model calculates semantic similarity of each of the top-k retrieved documents with the query. +6. **Reader**: Neural network (e.g., BERT or RoBERTA) that reads through texts in detail to find an answer. The Reader takes multiple passages of text as input and returns top-n answers. Models are trained via [FARM](https://github.com/deepset-ai/FARM) or [Transformers](https://github.com/huggingface/transformers) on SQuAD like tasks. You can load a pre-trained model from [Hugging Face's model hub](https://huggingface.co/models) or fine-tune it on your domain data. -6. **Generator**: Neural network (e.g., RAG) that *generates* an answer for a given question conditioned on the retrieved documents from the retriever. -6. **Pipeline**: Stick building blocks together to highly custom pipelines that are represented as Directed Acyclic Graphs (DAG). Think of it as "Apache Airflow for search". -7. **REST API**: Exposes a simple API based on fastAPI for running QA search, uploading files, and collecting user feedback for continuous learning. -8. **Haystack Annotate**: Create custom QA labels to improve the performance of your domain-specific models. [Hosted version](https://annotate.deepset.ai/login) or [Docker images](https://github.com/deepset-ai/haystack/tree/master/annotation_tool). +7. **Generator**: Neural network (e.g., RAG) that *generates* an answer for a given question conditioned on the retrieved documents from the retriever. +8. **Pipeline**: Stick building blocks together to highly custom pipelines that are represented as Directed Acyclic Graphs (DAG). Think of it as "Apache Airflow for search". +9. **REST API**: Exposes a simple API based on fastAPI for running QA search, uploading files, and collecting user feedback for continuous learning. +10. **Haystack Annotate**: Create custom QA labels to improve the performance of your domain-specific models. [Hosted version](https://annotate.deepset.ai/login) or [Docker images](https://github.com/deepset-ai/haystack/tree/master/annotation_tool). It's quite simple to begin experimenting with Haystack. We'd recommend going through the [Tutorials](https://github.com/deepset-ai/haystack/#tutorials) section below, but here's an example code structure describing how to approach Haystack with the DocumentStore based on Elasticsearch. diff --git a/docs/_src/tutorials/tutorials/5.md b/docs/_src/tutorials/tutorials/5.md index 97668294a..9ddcf525a 100644 --- a/docs/_src/tutorials/tutorials/5.md +++ b/docs/_src/tutorials/tutorials/5.md @@ -157,9 +157,8 @@ reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4, return ``` - ```python -from haystack.eval import EvalReader, EvalRetriever +from haystack.eval import EvalAnswers, EvalDocuments # Here we initialize the nodes that perform evaluation eval_retriever = EvalRetriever() @@ -212,9 +211,9 @@ from haystack import Pipeline # Here is the pipeline definition p = Pipeline() p.add_node(component=retriever, name="ESRetriever", inputs=["Query"]) -p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"]) -p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"]) -p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"]) +p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"]) +p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"]) +p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"]) results = [] ``` diff --git a/docs/_src/usage/usage/ranker.md b/docs/_src/usage/usage/ranker.md new file mode 100644 index 000000000..d21dc21ee --- /dev/null +++ b/docs/_src/usage/usage/ranker.md @@ -0,0 +1,54 @@ + + +# Ranker + +There are pure "semantic document search" use cases that do not need question answering functionality but only document ranking. +While the [Retriever](/docs/latest/retrievermd) is a perfect fit for document retrieval, we can further improve its results with the Ranker. +For example, BM25 (sparse retriever) does not take into account semantics of the documents and the query but only their keywords. +The Ranker can re-rank the results of the retriever step by taking semantics into account. +Similar to the Reader, it is based on the latest language models. +Instead of returning answers, it returns documents in re-ranked order. + +Without a Ranker and its re-ranking step, the querying process is faster but the query results might be of lower quality. +If you want to do "semantic document search" instead of a question answering, try first with a Retriever only. +In case the semantic similarity of the query and the resulting documents is low, add a Ranker. + +Note that a Ranker needs to be initialised with a model trained on a text pair classification task. +You can train the model also with the train() method of the Ranker. +Alternatively, [this example](https://github.com/deepset-ai/FARM/blob/master/examples/text_pair_classification.py) shows how to train a text pair classification model in FARM. + + +## FARMRanker + +### Description + +The FARMRanker consists of a Transformer-based model for document re-ranking using the TextPairClassifier of [FARM](https://github.com/deepset-ai/FARM). +While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interface remains the same. +With a FARMRanker, you can: +* Directly get predictions (re-ranked version of the supplied list of Document) via predict() if supplying a pre-trained model +* Take a plain language model (e.g. `bert-base-cased`) and train it for TextPairClassification via train() + +### Initialisation + +```python +from haystack.document_store import ElasticsearchDocumentStore +from haystack.retriever import ElasticsearchRetriever +from haystack.ranker import FARMRanker +from haystack import Pipeline + +document_store = ElasticsearchDocumentStore() +... +retriever = ElasticsearchRetriever(document_store) +ranker = FARMRanker(model_name_or_path="saved_models/roberta-base-asnq-binary") +... +p = Pipeline() +p.add_node(component=retriever, name="ESRetriever", inputs=["Query"]) +p.add_node(component=ranker, name="Ranker", inputs=["ESRetriever"]) +``` diff --git a/docs/v0.8.0/_src/tutorials/tutorials/5.md b/docs/v0.8.0/_src/tutorials/tutorials/5.md index 97668294a..9ddcf525a 100644 --- a/docs/v0.8.0/_src/tutorials/tutorials/5.md +++ b/docs/v0.8.0/_src/tutorials/tutorials/5.md @@ -157,9 +157,8 @@ reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4, return ``` - ```python -from haystack.eval import EvalReader, EvalRetriever +from haystack.eval import EvalAnswers, EvalDocuments # Here we initialize the nodes that perform evaluation eval_retriever = EvalRetriever() @@ -212,9 +211,9 @@ from haystack import Pipeline # Here is the pipeline definition p = Pipeline() p.add_node(component=retriever, name="ESRetriever", inputs=["Query"]) -p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"]) -p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"]) -p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"]) +p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"]) +p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"]) +p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"]) results = [] ``` diff --git a/haystack/eval.py b/haystack/eval.py index 2aaf26321..3b7ce5755 100644 --- a/haystack/eval.py +++ b/haystack/eval.py @@ -9,19 +9,21 @@ from farm.evaluation.squad_evaluation import compute_exact as calculate_em_str logger = logging.getLogger(__name__) -class EvalRetriever: +class EvalDocuments: """ - This is a pipeline node that should be placed after a Retriever in order to assess its performance. Performance - metrics are stored in this class and updated as each sample passes through it. To view the results of the evaluation, - call EvalRetriever.print(). Note that results from this Node may differ from that when calling Retriever.eval() - since that is a closed domain evaluation. Have a look at our evaluation tutorial for more info about - open vs closed domain eval (https://haystack.deepset.ai/docs/latest/tutorial5md). + This is a pipeline node that should be placed after a node that returns a List of Document, e.g., Retriever or + Ranker, in order to assess its performance. Performance metrics are stored in this class and updated as each + sample passes through it. To view the results of the evaluation, call EvalDocuments.print(). Note that results + from this Node may differ from that when calling Retriever.eval() since that is a closed domain evaluation. Have + a look at our evaluation tutorial for more info about open vs closed domain eval ( + https://haystack.deepset.ai/docs/latest/tutorial5md). """ - def __init__(self, debug: bool=False, open_domain: bool=True): + def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10, name="EvalDocuments"): """ :param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it. When False, correct retrieval is evaluated based on document_id. - :param debug: When True, a record of each sample and its evaluation will be stored in EvalRetriever.log + :param debug: When True, a record of each sample and its evaluation will be stored in EvalDocuments.log + :param top_k: calculate eval metrics for top k results, e.g., recall@k """ self.outgoing_edges = 1 self.init_counts() @@ -29,6 +31,8 @@ class EvalRetriever: self.debug = debug self.log: List = [] self.open_domain = open_domain + self.top_k = top_k + self.name = name def init_counts(self): self.correct_retrieval_count = 0 @@ -38,6 +42,10 @@ class EvalRetriever: self.has_answer_recall = 0 self.no_answer_count = 0 self.recall = 0.0 + self.mean_reciprocal_rank = 0.0 + self.has_answer_mean_reciprocal_rank = 0.0 + self.reciprocal_rank_sum = 0.0 + self.has_answer_reciprocal_rank_sum = 0.0 def run(self, documents, labels: dict, **kwargs): """Run this node on one sample and its labels""" @@ -48,6 +56,8 @@ class EvalRetriever: if retriever_labels.no_answer: self.no_answer_count += 1 correct_retrieval = 1 + retrieved_reciprocal_rank = 1 + self.reciprocal_rank_sum += 1 if not self.no_answer_warning: self.no_answer_warning = True logger.warning("There seem to be empty string labels in the dataset suggesting that there " @@ -56,48 +66,62 @@ class EvalRetriever: # If there are answer span annotations in the labels else: self.has_answer_count += 1 - correct_retrieval = self.is_correctly_retrieved(retriever_labels, documents) + retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents) + self.reciprocal_rank_sum += retrieved_reciprocal_rank + correct_retrieval = True if retrieved_reciprocal_rank > 0 else False self.has_answer_correct += int(correct_retrieval) + self.has_answer_reciprocal_rank_sum += retrieved_reciprocal_rank self.has_answer_recall = self.has_answer_correct / self.has_answer_count + self.has_answer_mean_reciprocal_rank = self.has_answer_reciprocal_rank_sum / self.has_answer_count self.correct_retrieval_count += correct_retrieval self.recall = self.correct_retrieval_count / self.query_count + self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count + if self.debug: - self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, **kwargs}) - return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, **kwargs}, "output_1" + self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}) + return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs}, "output_1" def is_correctly_retrieved(self, retriever_labels, predictions): + return self.reciprocal_rank_retrieved(retriever_labels, predictions) > 0 + + def reciprocal_rank_retrieved(self, retriever_labels, predictions): if self.open_domain: for label in retriever_labels.multiple_answers: - for p in predictions: + for rank, p in enumerate(predictions[:self.top_k]): if label.lower() in p.text.lower(): - return True + return 1/(rank+1) return False else: - prediction_ids = [p.id for p in predictions] + prediction_ids = [p.id for p in predictions[:self.top_k]] label_ids = retriever_labels.multiple_document_ids - for l in label_ids: - if l in prediction_ids: - return True - return False + for rank, p in enumerate(prediction_ids): + if p in label_ids: + return 1/(rank+1) + return 0 def print(self): """Print the evaluation results""" - print("Retriever") + print(self.name) print("-----------------") if self.no_answer_count: print( - f"has_answer recall: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})") + f"has_answer recall@{self.top_k}: {self.has_answer_recall:.4f} ({self.has_answer_correct}/{self.has_answer_count})") print( - f"no_answer recall: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)") - print(f"recall: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})") + f"no_answer recall@{self.top_k}: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)") + print( + f"has_answer mean_reciprocal_rank@{self.top_k}: {self.has_answer_mean_reciprocal_rank:.4f}") + print( + f"no_answer mean_reciprocal_rank@{self.top_k}: 1.0000 (no_answer samples are always treated as correctly retrieved at rank 1)") + print(f"recall@{self.top_k}: {self.recall:.4f} ({self.correct_retrieval_count} / {self.query_count})") + print(f"mean_reciprocal_rank@{self.top_k}: {self.mean_reciprocal_rank:.4f}") -class EvalReader: +class EvalAnswers: """ This is a pipeline node that should be placed after a Reader in order to assess the performance of the Reader individually or to assess the extractive QA performance of the whole pipeline. Performance metrics are stored in - this class and updated as each sample passes through it. To view the results of the evaluation, call EvalReader.print(). + this class and updated as each sample passes through it. To view the results of the evaluation, call EvalAnswers.print(). Note that results from this Node may differ from that when calling Reader.eval() since that is a closed domain evaluation. Have a look at our evaluation tutorial for more info about open vs closed domain eval (https://haystack.deepset.ai/docs/latest/tutorial5md). @@ -107,7 +131,7 @@ class EvalReader: """ :param skip_incorrect_retrieval: When set to True, this eval will ignore the cases where the retriever returned no correct documents :param open_domain: When True, extracted answers are evaluated purely on string similarity rather than the position of the extracted answer - :param debug: When True, a record of each sample and its evaluation will be stored in EvalReader.log + :param debug: When True, a record of each sample and its evaluation will be stored in EvalAnswers.log """ self.outgoing_edges = 1 self.init_counts() diff --git a/haystack/pipeline_eval.py b/haystack/pipeline_eval.py index c8926ae90..1cec75ee2 100644 --- a/haystack/pipeline_eval.py +++ b/haystack/pipeline_eval.py @@ -6,7 +6,7 @@ from haystack.reader.farm import FARMReader from haystack import Pipeline from farm.utils import initialize_device_settings from haystack.preprocessor import PreProcessor -from haystack.eval import EvalReader, EvalRetriever +from haystack.eval import EvalAnswers, EvalDocuments import logging import subprocess @@ -37,9 +37,9 @@ def main(): document_store = ElasticsearchDocumentStore() es_retriever = ElasticsearchRetriever(document_store=document_store) - eval_retriever = EvalRetriever(open_domain=open_domain) + eval_retriever = EvalDocuments(open_domain=open_domain) reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4, num_processes=1, return_no_answer=True) - eval_reader = EvalReader(debug=True, open_domain=open_domain) + eval_reader = EvalAnswers(debug=True, open_domain=open_domain) # Download evaluation data, which is a subset of Natural Questions development set containing 50 documents doc_dir = "../data/nq" @@ -68,9 +68,9 @@ def main(): # Here is the pipeline definition p = Pipeline() p.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"]) - p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"]) - p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"]) - p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"]) + p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"]) + p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"]) + p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"]) results = [] for i, (q, l) in enumerate(q_to_l_dict.items()): diff --git a/haystack/ranker/__init__.py b/haystack/ranker/__init__.py new file mode 100644 index 000000000..c203ea562 --- /dev/null +++ b/haystack/ranker/__init__.py @@ -0,0 +1 @@ +from haystack.ranker.farm import FARMRanker diff --git a/haystack/ranker/base.py b/haystack/ranker/base.py new file mode 100644 index 000000000..8c24e28bc --- /dev/null +++ b/haystack/ranker/base.py @@ -0,0 +1,105 @@ +import logging +from abc import abstractmethod +from copy import deepcopy +from typing import List, Optional +from functools import wraps +from time import perf_counter + +from tqdm import tqdm + +from haystack import Document, BaseComponent + +logger = logging.getLogger(__name__) + + +class BaseRanker(BaseComponent): + return_no_answers: bool + outgoing_edges = 1 + query_count = 0 + query_time = 0 + + @abstractmethod + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None): + pass + + @abstractmethod + def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None): + pass + + def run(self, query: str, documents: List[Document], top_k_ranker: Optional[int] = None, **kwargs): # type: ignore + self.query_count += 1 + if documents: + predict = self.timing(self.predict, "query_time") + results = predict(query=query, documents=documents, top_k=top_k_ranker) + else: + results = [] + + document_ids = [doc.id for doc in results] + logger.debug(f"Retrieved documents with IDs: {document_ids}") + output = { + "query": query, + "documents": results, + **kwargs + } + + return output, "output_1" + + def timing(self, fn, attr_name): + """Wrapper method used to time functions. """ + @wraps(fn) + def wrapper(*args, **kwargs): + if attr_name not in self.__dict__: + self.__dict__[attr_name] = 0 + tic = perf_counter() + ret = fn(*args, **kwargs) + toc = perf_counter() + self.__dict__[attr_name] += toc - tic + return ret + return wrapper + + def print_time(self): + print("Ranker (Speed)") + print("---------------") + if not self.query_count: + print("No querying performed via Retriever.run()") + else: + print(f"Queries Performed: {self.query_count}") + print(f"Query time: {self.query_time}s") + print(f"{self.query_time / self.query_count} seconds per query") + + def eval( + self, + label_index: str = "label", + doc_index: str = "eval_document", + label_origin: str = "gold_label", + top_k: int = 10, + open_domain: bool = False, + return_preds: bool = False, + ) -> dict: + """ + Performs evaluation of the Ranker. + Ranker is evaluated in the same way as a Retriever based on whether it finds the correct document given the query string and at which + position in the ranking of documents the correct document is. + + | Returns a dict containing the following metrics: + + - "recall": Proportion of questions for which correct document is among retrieved documents + - "mrr": Mean of reciprocal rank. Rewards retrievers that give relevant documents a higher rank. + Only considers the highest ranked relevant document. + - "map": Mean of average precision for each question. Rewards retrievers that give relevant + documents a higher rank. Considers all retrieved relevant documents. If ``open_domain=True``, + average precision is normalized by the number of retrieved relevant documents per query. + If ``open_domain=False``, average precision is normalized by the number of all relevant documents + per query. + + :param label_index: Index/Table in DocumentStore where labeled questions are stored + :param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored + :param top_k: How many documents to return per query + :param open_domain: If ``True``, retrieval will be evaluated by checking if the answer string to a question is + contained in the retrieved docs (common approach in open-domain QA). + If ``False``, retrieval uses a stricter evaluation that checks if the retrieved document ids + are within ids explicitly stated in the labels. + :param return_preds: Whether to add predictions in the returned dictionary. If True, the returned dictionary + contains the keys "predictions" and "metrics". + """ + raise NotImplementedError diff --git a/haystack/ranker/farm.py b/haystack/ranker/farm.py new file mode 100644 index 000000000..68a4c67ca --- /dev/null +++ b/haystack/ranker/farm.py @@ -0,0 +1,268 @@ +import logging +import multiprocessing +from pathlib import Path +from typing import List, Optional, Union + +from farm.data_handler.data_silo import DataSilo +from farm.data_handler.processor import TextPairClassificationProcessor +from farm.infer import Inferencer +from farm.modeling.optimization import initialize_optimizer +from farm.modeling.adaptive_model import BaseAdaptiveModel +from farm.train import Trainer +from farm.utils import set_all_seeds, initialize_device_settings +import shutil + +from haystack import Document +from haystack.ranker.base import BaseRanker + +logger = logging.getLogger(__name__) + + +class FARMRanker(BaseRanker): + """ + Transformer based model for Document Re-ranking using the TextPairClassifier of FARM framework (https://github.com/deepset-ai/FARM). + While the underlying model can vary (BERT, Roberta, DistilBERT, ...), the interface remains the same. + + | With a FARMRanker, you can: + + - directly get predictions via predict() + - fine-tune the model on TextPair data via train() + """ + + def __init__( + self, + model_name_or_path: Union[str, Path], + model_version: Optional[str] = None, + batch_size: int = 50, + use_gpu: bool = True, + top_k: int = 10, + num_processes: Optional[int] = None, + max_seq_len: int = 256, + progress_bar: bool = True + ): + + """ + :param model_name_or_path: Directory of a saved model or the name of a public model e.g. 'bert-base-cased', + 'deepset/bert-base-cased-squad2', 'deepset/bert-base-cased-squad2', 'distilbert-base-uncased-distilled-squad'. + See https://huggingface.co/models for full list of available models. + :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. + :param batch_size: Number of samples the model receives in one batch for inference. + Memory consumption is much lower in inference mode. Recommendation: Increase the batch size + to a value so only a single batch is used. + :param use_gpu: Whether to use GPU (if available) + :param top_k: The maximum number of documents to return + :param num_processes: The number of processes for `multiprocessing.Pool`. Set to value of 0 to disable + multiprocessing. Set to None to let Inferencer determine optimum number. If you + want to debug the Language Model, you might need to disable multiprocessing! + :param max_seq_len: Max sequence length of one input text for the model + :param progress_bar: Whether to show a tqdm progress bar or not. + Can be helpful to disable in production deployments to keep the logs clean. + """ + + # save init parameters to enable export of component config as YAML + self.set_config( + model_name_or_path=model_name_or_path, model_version=model_version, + batch_size=batch_size, use_gpu=use_gpu, top_k=top_k, + num_processes=num_processes, max_seq_len=max_seq_len, progress_bar=progress_bar, + ) + + self.top_k = top_k + + self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu, + task_type="text_classification", max_seq_len=max_seq_len, + num_processes=num_processes, revision=model_version, + disable_tqdm=not progress_bar, + strict=False) + + self.max_seq_len = max_seq_len + self.use_gpu = use_gpu + self.progress_bar = progress_bar + + def train( + self, + data_dir: str, + train_filename: str, + dev_filename: Optional[str] = None, + test_filename: Optional[str] = None, + use_gpu: Optional[bool] = None, + batch_size: int = 10, + n_epochs: int = 2, + learning_rate: float = 1e-5, + max_seq_len: Optional[int] = None, + warmup_proportion: float = 0.2, + dev_split: float = 0, + evaluate_every: int = 300, + save_dir: Optional[str] = None, + num_processes: Optional[int] = None, + use_amp: str = None, + ): + """ + Fine-tune a model on a TextPairClassification dataset. Options: + + - Take a plain language model (e.g. `bert-base-cased`) and train it for TextPairClassification + - Take a TextPairClassification model and fine-tune it for your domain + + :param data_dir: Path to directory containing your training data in SQuAD style + :param train_filename: Filename of training data + :param dev_filename: Filename of dev / eval data + :param test_filename: Filename of test data + :param dev_split: Instead of specifying a dev_filename, you can also specify a ratio (e.g. 0.1) here + that gets split off from training data for eval. + :param use_gpu: Whether to use GPU (if available) + :param batch_size: Number of samples the model receives in one batch for training + :param n_epochs: Number of iterations on the whole training data set + :param learning_rate: Learning rate of the optimizer + :param max_seq_len: Maximum text length (in tokens). Everything longer gets cut down. + :param warmup_proportion: Proportion of training steps until maximum learning rate is reached. + Until that point LR is increasing linearly. After that it's decreasing again linearly. + Options for different schedules are available in FARM. + :param evaluate_every: Evaluate the model every X steps on the hold-out eval dataset + :param save_dir: Path to store the final model + :param num_processes: The number of processes for `multiprocessing.Pool` during preprocessing. + Set to value of 1 to disable multiprocessing. When set to 1, you cannot split away a dev set from train set. + Set to None to use all CPU cores minus one. + :param use_amp: Optimization level of NVIDIA's automatic mixed precision (AMP). The higher the level, the faster the model. + Available options: + None (Don't use AMP) + "O0" (Normal FP32 training) + "O1" (Mixed Precision => Recommended) + "O2" (Almost FP16) + "O3" (Pure FP16). + See details on: https://nvidia.github.io/apex/amp.html + :return: None + """ + + if dev_filename: + dev_split = 0 + + if num_processes is None: + num_processes = multiprocessing.cpu_count() - 1 or 1 + + set_all_seeds(seed=42) + + # For these variables, by default, we use the value set when initializing the FARMRanker. + # These can also be manually set when train() is called if you want a different value at train vs inference + if use_gpu is None: + use_gpu = self.use_gpu + if max_seq_len is None: + max_seq_len = self.max_seq_len + + device, n_gpu = initialize_device_settings(use_cuda=use_gpu, use_amp=use_amp) + + if not save_dir: + save_dir = f"saved_models/{self.inferencer.model.language_model.name}" + + # 1. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset + label_list = ["0", "1"] + metric = "f1_macro" + processor = TextPairClassificationProcessor( + tokenizer=self.inferencer.processor.tokenizer, + max_seq_len=max_seq_len, + label_list=label_list, + metric=metric, + train_filename=train_filename, + dev_filename=dev_filename, + dev_split=dev_split, + test_filename=test_filename, + data_dir=Path(data_dir), + delimiter="\t" + ) + + # 2. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them + # and calculates a few descriptive statistics of our datasets + data_silo = DataSilo(processor=processor, batch_size=batch_size, distributed=False, max_processes=num_processes) + + # Quick-fix until this is fixed upstream in FARM: + # We must avoid applying DataParallel twice (once when loading the inferencer, + # once when calling initalize_optimizer) + self.inferencer.model.save("tmp_model") + model = BaseAdaptiveModel.load(load_dir="tmp_model", device=device, strict=True) + shutil.rmtree('tmp_model') + + # 3. Create an optimizer and pass the already initialized model + model, optimizer, lr_schedule = initialize_optimizer( + model=model, + learning_rate=learning_rate, + schedule_opts={"name": "LinearWarmup", "warmup_proportion": warmup_proportion}, + n_batches=len(data_silo.loaders["train"]), + n_epochs=n_epochs, + device=device, + use_amp=use_amp, + ) + # 4. Feed everything to the Trainer, which keeps care of growing our model and evaluates it from time to time + trainer = Trainer( + model=model, + optimizer=optimizer, + data_silo=data_silo, + epochs=n_epochs, + n_gpu=n_gpu, + lr_schedule=lr_schedule, + evaluate_every=evaluate_every, + device=device, + use_amp=use_amp, + disable_tqdm=not self.progress_bar + ) + + # 5. Let it grow! + self.inferencer.model = trainer.train() + self.save(Path(save_dir)) + + def update_parameters( + self, + max_seq_len: Optional[int] = None, + ): + """ + Hot update parameters of a loaded Ranker. It may not to be safe when processing concurrent requests. + """ + if max_seq_len is not None: + self.inferencer.processor.max_seq_len = max_seq_len + self.max_seq_len = max_seq_len + + def save(self, directory: Path): + """ + Saves the Ranker model so that it can be reused at a later point in time. + + :param directory: Directory where the Ranker model should be saved + """ + logger.info(f"Saving ranker model to {directory}") + self.inferencer.model.save(directory) + self.inferencer.processor.save(directory) + + def predict_batch(self, query_doc_list: List[dict], top_k: int = None, batch_size: int = None): + """ + Use loaded Ranker model to, for a list of queries, rank each query's supplied list of Document. + + Returns list of dictionary of query and list of document sorted by (desc.) similarity with query + + :param query_doc_list: List of dictionaries containing queries with their retrieved documents + :param top_k: The maximum number of answers to return for each query + :param batch_size: Number of samples the model receives in one batch for inference + :return: List of dictionaries containing query and ranked list of Document + """ + raise NotImplementedError + + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Use loaded ranker model to re-rank the supplied list of Document. + + Returns list of Document sorted by (desc.) TextPairClassification similarity with the query. + + :param query: Query string + :param documents: List of Document to be re-ranked + :param top_k: The maximum number of documents to return + :return: List of Document + """ + if top_k is None: + top_k = self.top_k + + # calculate similarity of query and each document + query_and_docs = [{"text": (query, doc.text)} for doc in documents] + result = self.inferencer.inference_from_dicts(dicts=query_and_docs) + similarity_scores = [pred["probability"] for preds in result for pred in preds["predictions"]] + + # rank documents according to scores + sorted_scores_and_documents = sorted(zip(similarity_scores, documents), + key=lambda similarity_document_tuple: similarity_document_tuple[0], + reverse=True) + sorted_documents = [doc for _, doc in sorted_scores_and_documents] + return sorted_documents[:top_k] diff --git a/test/test_eval.py b/test/test_eval.py index 211215f10..e363df48c 100644 --- a/test/test_eval.py +++ b/test/test_eval.py @@ -2,7 +2,7 @@ import pytest from haystack.document_store.base import BaseDocumentStore from haystack.preprocessor.preprocessor import PreProcessor from haystack.finder import Finder -from haystack.eval import EvalReader, EvalRetriever +from haystack.eval import EvalAnswers, EvalDocuments from haystack import Pipeline @@ -114,15 +114,15 @@ def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever): } for l in labels } - eval_retriever = EvalRetriever() - eval_reader = EvalReader() + eval_retriever = EvalDocuments() + eval_reader = EvalAnswers() assert document_store.get_document_count(index="haystack_test_eval_document") == 2 p = Pipeline() p.add_node(component=retriever, name="ESRetriever", inputs=["Query"]) - p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"]) - p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"]) - p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"]) + p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"]) + p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"]) + p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"]) for q, l in q_to_l_dict.items(): res = p.run( query=q, diff --git a/tutorials/Tutorial5_Evaluation.py b/tutorials/Tutorial5_Evaluation.py index 8af355414..8a2c0cb41 100644 --- a/tutorials/Tutorial5_Evaluation.py +++ b/tutorials/Tutorial5_Evaluation.py @@ -2,7 +2,7 @@ from haystack.document_store.elasticsearch import ElasticsearchDocumentStore from haystack.preprocessor.utils import fetch_archive_from_http from haystack.retriever.sparse import ElasticsearchRetriever from haystack.retriever.dense import DensePassageRetriever -from haystack.eval import EvalReader, EvalRetriever +from haystack.eval import EvalAnswers, EvalDocuments from haystack.reader.farm import FARMReader from haystack.preprocessor import PreProcessor from haystack.utils import launch_es @@ -104,8 +104,8 @@ def tutorial5_evaluation(): ) # Here we initialize the nodes that perform evaluation - eval_retriever = EvalRetriever() - eval_reader = EvalReader() + eval_retriever = EvalDocuments() + eval_reader = EvalAnswers() ## Evaluate Retriever on its own in closed domain fashion @@ -137,9 +137,9 @@ def tutorial5_evaluation(): # Here is the pipeline definition p = Pipeline() p.add_node(component=retriever, name="ESRetriever", inputs=["Query"]) - p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"]) - p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"]) - p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"]) + p.add_node(component=eval_retriever, name="EvalDocuments", inputs=["ESRetriever"]) + p.add_node(component=reader, name="QAReader", inputs=["EvalDocuments"]) + p.add_node(component=eval_reader, name="EvalAnswers", inputs=["QAReader"]) results = [] for q, l in q_to_l_dict.items():