mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 04:56:45 +00:00
WIP: Add evaluation nodes for Pipelines (#904)
* Add main eval fns * WIP: make pipeline_eval.py run * Fix typo * Add support for no_answers * Add latest docstring and tutorial changes * Working pipeline eval * Add timing of nodes * Add latest docstring and tutorial changes * Refactor and clean * Update tutorial script * Set default params * Update tutorials * Fix indent * Add latest docstring and tutorial changes * Address mypy issues * Add test * Fix mypy error * Clear outputs * Add doc strings * Incorporate reviewer feedback * Add latest docstring and tutorial changes * Revert query counting * Fix typo Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
32050fdce3
commit
d77152c469
@ -56,7 +56,7 @@ Example: {"name": ["some", "more"], "category": ["only_one"]}
|
||||
#### add\_eval\_data
|
||||
|
||||
```python
|
||||
| add_eval_data(filename: str, doc_index: str = "eval_document", label_index: str = "label", batch_size: Optional[int] = None, preprocessor: Optional[PreProcessor] = None, max_docs: Union[int, bool] = None)
|
||||
| add_eval_data(filename: str, doc_index: str = "eval_document", label_index: str = "label", batch_size: Optional[int] = None, preprocessor: Optional[PreProcessor] = None, max_docs: Union[int, bool] = None, open_domain: bool = False)
|
||||
```
|
||||
|
||||
Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it.
|
||||
@ -76,6 +76,8 @@ Currently the PreProcessor does not support split_by sentence, cleaning nor spli
|
||||
When set to None (default) preprocessing is disabled.
|
||||
- `max_docs`: Optional number of documents that will be loaded.
|
||||
When set to None (default) all available eval documents are used.
|
||||
- `open_domain`: Set this to True if your file is an open domain dataset where two different answers to the
|
||||
same question might be found in different contexts.
|
||||
|
||||
<a name="elasticsearch"></a>
|
||||
# Module elasticsearch
|
||||
|
||||
@ -92,7 +92,7 @@ the parameters passed into PreProcessor.__init__(). Takes a single document as i
|
||||
#### eval\_data\_from\_json
|
||||
|
||||
```python
|
||||
eval_data_from_json(filename: str, max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None) -> Tuple[List[Document], List[Label]]
|
||||
eval_data_from_json(filename: str, max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None, open_domain: bool = False) -> Tuple[List[Document], List[Label]]
|
||||
```
|
||||
|
||||
Read Documents + Labels from a SQuAD-style file.
|
||||
@ -102,6 +102,7 @@ Document and Labels can then be indexed to the DocumentStore and be used for eva
|
||||
|
||||
- `filename`: Path to file in SQuAD format
|
||||
- `max_docs`: This sets the number of documents that will be loaded. By default, this is set to None, thus reading in all available eval documents.
|
||||
- `open_domain`: Set this to True if your file is an open domain dataset where two different answers to the same question might be found in different contexts.
|
||||
|
||||
**Returns**:
|
||||
|
||||
@ -111,7 +112,7 @@ Document and Labels can then be indexed to the DocumentStore and be used for eva
|
||||
#### eval\_data\_from\_jsonl
|
||||
|
||||
```python
|
||||
eval_data_from_jsonl(filename: str, batch_size: Optional[int] = None, max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None) -> Generator[Tuple[List[Document], List[Label]], None, None]
|
||||
eval_data_from_jsonl(filename: str, batch_size: Optional[int] = None, max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None, open_domain: bool = False) -> Generator[Tuple[List[Document], List[Label]], None, None]
|
||||
```
|
||||
|
||||
Read Documents + Labels from a SQuAD-style file in jsonl format, i.e. one document per line.
|
||||
@ -125,6 +126,7 @@ If batch_size is set to None, this method will yield all documents and labels.
|
||||
|
||||
- `filename`: Path to file in SQuAD format
|
||||
- `max_docs`: This sets the number of documents that will be loaded. By default, this is set to None, thus reading in all available eval documents.
|
||||
- `open_domain`: Set this to True if your file is an open domain dataset where two different answers to the same question might be found in different contexts.
|
||||
|
||||
**Returns**:
|
||||
|
||||
|
||||
@ -1,6 +1,22 @@
|
||||
<a name="base"></a>
|
||||
# Module base
|
||||
|
||||
<a name="base.BaseReader"></a>
|
||||
## BaseReader Objects
|
||||
|
||||
```python
|
||||
class BaseReader(BaseComponent)
|
||||
```
|
||||
|
||||
<a name="base.BaseReader.timing"></a>
|
||||
#### timing
|
||||
|
||||
```python
|
||||
| timing(fn, attr_name)
|
||||
```
|
||||
|
||||
Wrapper method used to time functions.
|
||||
|
||||
<a name="farm"></a>
|
||||
# Module farm
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ that are most relevant to the query.
|
||||
#### timing
|
||||
|
||||
```python
|
||||
| timing(fn)
|
||||
| timing(fn, attr_name)
|
||||
```
|
||||
|
||||
Wrapper method used to time functions.
|
||||
|
||||
@ -94,11 +94,35 @@ document_store = ElasticsearchDocumentStore(host="localhost", username="", passw
|
||||
|
||||
|
||||
```python
|
||||
from haystack.preprocessor import PreProcessor
|
||||
|
||||
# Add evaluation data to Elasticsearch Document Store
|
||||
# We first delete the custom tutorial indices to not have duplicate elements
|
||||
# and also split our documents into shorter passages using the PreProcessor
|
||||
preprocessor = PreProcessor(
|
||||
split_length=500,
|
||||
split_overlap=0,
|
||||
split_respect_sentence_boundary=False,
|
||||
clean_empty_lines=False,
|
||||
clean_whitespace=False
|
||||
)
|
||||
document_store.delete_all_documents(index=doc_index)
|
||||
document_store.delete_all_documents(index=label_index)
|
||||
document_store.add_eval_data(filename="../data/nq/nq_dev_subset_v2.json", doc_index=doc_index, label_index=label_index)
|
||||
document_store.add_eval_data(
|
||||
filename="../data/nq/nq_dev_subset_v2.json",
|
||||
doc_index=doc_index,
|
||||
label_index=label_index,
|
||||
preprocessor=preprocessor
|
||||
)
|
||||
|
||||
# Let's prepare the labels that we need for the retriever and the reader
|
||||
labels = document_store.get_all_labels_aggregated(index=label_index)
|
||||
q_to_l_dict = {
|
||||
l.question: {
|
||||
"retriever": l,
|
||||
"reader": l
|
||||
} for l in labels
|
||||
}
|
||||
```
|
||||
|
||||
## Initialize components of QA-System
|
||||
@ -129,18 +153,21 @@ retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
# Initialize Reader
|
||||
from haystack.reader.farm import FARMReader
|
||||
|
||||
reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4)
|
||||
reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4, return_no_answer=True)
|
||||
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# Initialize Finder which sticks together Reader and Retriever
|
||||
from haystack.finder import Finder
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
|
||||
finder = Finder(reader, retriever)
|
||||
# Here we initialize the nodes that perform evaluation
|
||||
eval_retriever = EvalRetriever()
|
||||
eval_reader = EvalReader()
|
||||
```
|
||||
|
||||
## Evaluation of Retriever
|
||||
Here we evaluate only the retriever, based on whether the gold_label document is retrieved.
|
||||
|
||||
|
||||
```python
|
||||
@ -154,6 +181,9 @@ print("Retriever Mean Avg Precision:", retriever_eval_results["map"])
|
||||
```
|
||||
|
||||
## Evaluation of Reader
|
||||
Here we evaluate only the reader in a closed domain fashion i.e. the reader is given one query
|
||||
and one document and metrics are calculated on whether the right position in this text is selected by
|
||||
the model as the answer span (i.e. SQuAD style)
|
||||
|
||||
|
||||
```python
|
||||
@ -170,16 +200,50 @@ print("Reader Exact Match:", reader_eval_results["EM"])
|
||||
print("Reader F1-Score:", reader_eval_results["f1"])
|
||||
```
|
||||
|
||||
## Evaluation of Finder
|
||||
## Evaluation of Retriever and Reader (Open Domain)
|
||||
Here we evaluate retriever and reader in open domain fashion i.e. a document is considered
|
||||
correctly retrieved if it contains the answer string within it. The reader is evaluated based purely on the
|
||||
predicted string, regardless of which document this came from and the position of the extracted span.
|
||||
|
||||
|
||||
```python
|
||||
# Evaluate combination of Reader and Retriever through Finder
|
||||
finder_eval_results = finder.eval(top_k_retriever=1, top_k_reader=10, label_index=label_index, doc_index=doc_index)
|
||||
finder.print_eval_results(finder_eval_results)
|
||||
from haystack import Pipeline
|
||||
|
||||
# Here is the pipeline definition
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
results = []
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# This is how to run the pipeline
|
||||
for q, l in q_to_l_dict.items():
|
||||
res = p.run(
|
||||
query=q,
|
||||
top_k_retriever=10,
|
||||
labels=l,
|
||||
top_k_reader=10,
|
||||
index=doc_index,
|
||||
)
|
||||
results.append(res)
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# When we have run evaluation using the pipeline, we can print the results
|
||||
n_queries = len(labels)
|
||||
eval_retriever.print()
|
||||
print()
|
||||
retriever.print_time()
|
||||
print()
|
||||
eval_reader.print(mode="reader")
|
||||
print()
|
||||
reader.print_time()
|
||||
print()
|
||||
eval_reader.print(mode="pipeline")
|
||||
|
||||
```
|
||||
|
||||
@ -144,7 +144,7 @@ class BaseDocumentStore(BaseComponent):
|
||||
|
||||
def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_index: str = "label",
|
||||
batch_size: Optional[int] = None, preprocessor: Optional[PreProcessor] = None,
|
||||
max_docs: Union[int, bool] = None):
|
||||
max_docs: Union[int, bool] = None, open_domain: bool = False):
|
||||
"""
|
||||
Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it.
|
||||
If a jsonl file and a batch_size is passed to the function, documents are loaded batchwise
|
||||
@ -161,6 +161,8 @@ class BaseDocumentStore(BaseComponent):
|
||||
When set to None (default) preprocessing is disabled.
|
||||
:param max_docs: Optional number of documents that will be loaded.
|
||||
When set to None (default) all available eval documents are used.
|
||||
:param open_domain: Set this to True if your file is an open domain dataset where two different answers to the
|
||||
same question might be found in different contexts.
|
||||
|
||||
"""
|
||||
# TODO improve support for PreProcessor when adding eval data
|
||||
@ -182,7 +184,7 @@ class BaseDocumentStore(BaseComponent):
|
||||
file_path = Path(filename)
|
||||
if file_path.suffix == ".json":
|
||||
if batch_size is None:
|
||||
docs, labels = eval_data_from_json(filename, max_docs=max_docs, preprocessor=preprocessor)
|
||||
docs, labels = eval_data_from_json(filename, max_docs=max_docs, preprocessor=preprocessor, open_domain=open_domain)
|
||||
self.write_documents(docs, index=doc_index)
|
||||
self.write_labels(labels, index=label_index)
|
||||
else:
|
||||
@ -190,10 +192,10 @@ class BaseDocumentStore(BaseComponent):
|
||||
logger.info(f"Adding evaluation data batch-wise is not compatible with json-formatted SQuAD files. "
|
||||
f"Converting json to jsonl to: {jsonl_filename}")
|
||||
squad_json_to_jsonl(filename, jsonl_filename)
|
||||
self.add_eval_data(jsonl_filename, doc_index, label_index, batch_size)
|
||||
self.add_eval_data(jsonl_filename, doc_index, label_index, batch_size, open_domain=open_domain)
|
||||
|
||||
elif file_path.suffix == ".jsonl":
|
||||
for docs, labels in eval_data_from_jsonl(filename, batch_size, max_docs=max_docs, preprocessor=preprocessor):
|
||||
for docs, labels in eval_data_from_jsonl(filename, batch_size, max_docs=max_docs, preprocessor=preprocessor, open_domain=open_domain):
|
||||
if docs:
|
||||
self.write_documents(docs, index=doc_index)
|
||||
if labels:
|
||||
|
||||
255
haystack/eval.py
255
haystack/eval.py
@ -1,7 +1,247 @@
|
||||
from typing import List, Tuple, Dict, Any
|
||||
import logging
|
||||
|
||||
from haystack import MultiLabel
|
||||
|
||||
from farm.evaluation.squad_evaluation import compute_f1 as calculate_f1_str
|
||||
from farm.evaluation.squad_evaluation import compute_exact as calculate_em_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvalRetriever:
|
||||
"""
|
||||
This is a pipeline node that should be placed after a Retriever in order to assess its performance. Performance
|
||||
metrics are stored in this class and updated as each sample passes through it. To view the results of the evaluation,
|
||||
call EvalRetriever.print()
|
||||
"""
|
||||
def __init__(self, debug=False, open_domain=True):
|
||||
"""
|
||||
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
|
||||
When False, correct retrieval is evaluated based on document_id.
|
||||
:type open_domain: bool
|
||||
:param debug: When True, a record of each sample and its evaluation will be stored in EvalRetriever.log
|
||||
:type debug: bool
|
||||
"""
|
||||
self.outgoing_edges = 1
|
||||
self.init_counts()
|
||||
self.no_answer_warning = False
|
||||
self.debug = debug
|
||||
self.log = []
|
||||
self.open_domain = open_domain
|
||||
|
||||
def init_counts(self):
|
||||
self.correct_retrieval_count = 0
|
||||
self.query_count = 0
|
||||
self.has_answer_count = 0
|
||||
self.has_answer_correct = 0
|
||||
self.has_answer_recall = 0
|
||||
self.no_answer_count = 0
|
||||
self.recall = 0.0
|
||||
|
||||
def run(self, documents, labels: dict, **kwargs):
|
||||
"""Run this node on one sample and its labels"""
|
||||
self.query_count += 1
|
||||
retriever_labels = labels["retriever"]
|
||||
# TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
|
||||
# If this sample is impossible to answer and expects a no_answer response
|
||||
if retriever_labels.no_answer:
|
||||
self.no_answer_count += 1
|
||||
correct_retrieval = 1
|
||||
if not self.no_answer_warning:
|
||||
self.no_answer_warning = True
|
||||
logger.warning("There seem to be empty string labels in the dataset suggesting that there "
|
||||
"are samples with is_impossible=True. "
|
||||
"Retrieval of these samples is always treated as correct.")
|
||||
# If there are answer span annotations in the labels
|
||||
else:
|
||||
self.has_answer_count += 1
|
||||
correct_retrieval = self.is_correctly_retrieved(retriever_labels, documents)
|
||||
self.has_answer_correct += int(correct_retrieval)
|
||||
|
||||
self.correct_retrieval_count += correct_retrieval
|
||||
self.recall = self.correct_retrieval_count / self.query_count
|
||||
self.has_answer_recall = self.has_answer_correct / self.has_answer_count
|
||||
if self.debug:
|
||||
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, **kwargs})
|
||||
return {"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, **kwargs}, "output_1"
|
||||
|
||||
def is_correctly_retrieved(self, retriever_labels, predictions):
|
||||
if self.open_domain:
|
||||
for label in retriever_labels.multiple_answers:
|
||||
for p in predictions:
|
||||
if label.lower() in p.text.lower():
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
prediction_ids = [p.id for p in predictions]
|
||||
label_ids = retriever_labels.multiple_document_ids
|
||||
for l in label_ids:
|
||||
if l in prediction_ids:
|
||||
return True
|
||||
return False
|
||||
|
||||
def print(self):
|
||||
"""Print the evaluation results"""
|
||||
print("Retriever")
|
||||
print("-----------------")
|
||||
if self.no_answer_count:
|
||||
print(
|
||||
f"has_answer recall: {self.has_answer_recall} ({self.has_answer_correct}/{self.has_answer_count})")
|
||||
print(
|
||||
f"no_answer recall: 1.00 ({self.no_answer_count}/{self.no_answer_count}) (no_answer samples are always treated as correctly retrieved)")
|
||||
print(f"recall: {self.recall} ({self.correct_retrieval_count} / {self.query_count})")
|
||||
|
||||
|
||||
class EvalReader:
|
||||
"""
|
||||
This is a pipeline node that should be placed after a Reader in order to assess the performance of the Reader
|
||||
individually or to assess the extractive QA performance of the whole pipeline. Performance metrics are stored in
|
||||
this class and updated as each sample passes through it. To view the results of the evaluation, call EvalReader.print()
|
||||
"""
|
||||
|
||||
def __init__(self, skip_incorrect_retrieval=True, open_domain=True, debug=False):
|
||||
"""
|
||||
:param skip_incorrect_retrieval: When set to True, this eval will ignore the cases where the retriever returned no correct documents
|
||||
:type skip_incorrect_retrieval: bool
|
||||
:param open_domain: When True, extracted answers are evaluated purely on string similarity rather than the position of the extracted answer
|
||||
:type open_domain: bool
|
||||
:param debug: When True, a record of each sample and its evaluation will be stored in EvalReader.log
|
||||
:type debug: bool
|
||||
"""
|
||||
self.outgoing_edges = 1
|
||||
self.init_counts()
|
||||
self.log = []
|
||||
self.debug = debug
|
||||
self.skip_incorrect_retrieval = skip_incorrect_retrieval
|
||||
self.open_domain = open_domain
|
||||
|
||||
def init_counts(self):
|
||||
self.query_count = 0
|
||||
self.correct_retrieval_count = 0
|
||||
self.no_answer_count = 0
|
||||
self.has_answer_count = 0
|
||||
self.top_1_no_answer_count = 0
|
||||
self.top_1_em_count = 0
|
||||
self.top_k_em_count = 0
|
||||
self.top_1_f1_sum = 0
|
||||
self.top_k_f1_sum = 0
|
||||
self.top_1_no_answer = 0
|
||||
self.top_1_em = 0.0
|
||||
self.top_k_em = 0.0
|
||||
self.top_1_f1 = 0.0
|
||||
self.top_k_f1 = 0.0
|
||||
|
||||
def run(self, labels, answers, **kwargs):
|
||||
"""Run this node on one sample and its labels"""
|
||||
self.query_count += 1
|
||||
predictions = answers
|
||||
skip = self.skip_incorrect_retrieval and not kwargs.get("correct_retrieval")
|
||||
if predictions and not skip:
|
||||
self.correct_retrieval_count += 1
|
||||
multi_labels = labels["reader"]
|
||||
# If this sample is impossible to answer and expects a no_answer response
|
||||
if multi_labels.no_answer:
|
||||
self.no_answer_count += 1
|
||||
if predictions[0]["answer"] is None:
|
||||
self.top_1_no_answer_count += 1
|
||||
if self.debug:
|
||||
self.log.append({"predictions": predictions,
|
||||
"gold_labels": multi_labels,
|
||||
"top_1_no_answer": int(predictions[0] == ""),
|
||||
})
|
||||
self.update_no_answer_metrics()
|
||||
# If there are answer span annotations in the labels
|
||||
else:
|
||||
self.has_answer_count += 1
|
||||
predictions = [p for p in predictions if p["answer"]]
|
||||
top_1_em, top_1_f1, top_k_em, top_k_f1 = self.evaluate_extraction(multi_labels, predictions)
|
||||
if self.debug:
|
||||
self.log.append({"predictions": predictions,
|
||||
"gold_labels": multi_labels,
|
||||
"top_k_f1": top_k_f1,
|
||||
"top_k_em": top_k_em
|
||||
})
|
||||
|
||||
self.top_1_em_count += top_1_em
|
||||
self.top_1_f1_sum += top_1_f1
|
||||
self.top_k_em_count += top_k_em
|
||||
self.top_k_f1_sum += top_k_f1
|
||||
self.update_has_answer_metrics()
|
||||
return {**kwargs}, "output_1"
|
||||
|
||||
def evaluate_extraction(self, gold_labels, predictions):
|
||||
if self.open_domain:
|
||||
gold_labels_list = gold_labels.multiple_answers
|
||||
predictions_str = [p["answer"] for p in predictions]
|
||||
top_1_em = calculate_em_str_multi(gold_labels_list, predictions_str[0])
|
||||
top_1_f1 = calculate_f1_str_multi(gold_labels_list, predictions_str[0])
|
||||
top_k_em = max([calculate_em_str_multi(gold_labels_list, p) for p in predictions_str])
|
||||
top_k_f1 = max([calculate_f1_str_multi(gold_labels_list, p) for p in predictions_str])
|
||||
else:
|
||||
logger.error("Closed Domain Reader Evaluation not yet implemented")
|
||||
return 0,0,0,0
|
||||
return top_1_em, top_1_f1, top_k_em, top_k_f1
|
||||
|
||||
def update_has_answer_metrics(self):
|
||||
self.top_1_em = self.top_1_em_count / self.has_answer_count
|
||||
self.top_k_em = self.top_k_em_count / self.has_answer_count
|
||||
self.top_1_f1 = self.top_1_f1_sum / self.has_answer_count
|
||||
self.top_k_f1 = self.top_k_f1_sum / self.has_answer_count
|
||||
|
||||
def update_no_answer_metrics(self):
|
||||
self.top_1_no_answer = self.top_1_no_answer_count / self.no_answer_count
|
||||
|
||||
def print(self, mode):
|
||||
"""Print the evaluation results"""
|
||||
if mode == "reader":
|
||||
print("Reader")
|
||||
print("-----------------")
|
||||
# print(f"answer in retrieved docs: {correct_retrieval}")
|
||||
print(f"has answer queries: {self.has_answer_count}")
|
||||
print(f"top 1 EM: {self.top_1_em}")
|
||||
print(f"top k EM: {self.top_k_em}")
|
||||
print(f"top 1 F1: {self.top_1_f1}")
|
||||
print(f"top k F1: {self.top_k_f1}")
|
||||
if self.no_answer_count:
|
||||
print()
|
||||
print(f"no_answer queries: {self.no_answer_count}")
|
||||
print(f"top 1 no_answer accuracy: {self.top_1_no_answer}")
|
||||
elif mode == "pipeline":
|
||||
print("Pipeline")
|
||||
print("-----------------")
|
||||
|
||||
pipeline_top_1_em = (self.top_1_em_count + self.top_1_no_answer_count) / self.query_count
|
||||
pipeline_top_k_em = (self.top_k_em_count + self.no_answer_count) / self.query_count
|
||||
pipeline_top_1_f1 = (self.top_1_f1_sum + self.top_1_no_answer_count) / self.query_count
|
||||
pipeline_top_k_f1 = (self.top_k_f1_sum + self.no_answer_count) / self.query_count
|
||||
|
||||
print(f"queries: {self.query_count}")
|
||||
print(f"top 1 EM: {pipeline_top_1_em}")
|
||||
print(f"top k EM: {pipeline_top_k_em}")
|
||||
print(f"top 1 F1: {pipeline_top_1_f1}")
|
||||
print(f"top k F1: {pipeline_top_k_f1}")
|
||||
if self.no_answer_count:
|
||||
print(
|
||||
"(top k results are likely inflated since the Reader always returns a no_answer prediction in its top k)"
|
||||
)
|
||||
|
||||
|
||||
def calculate_em_str_multi(gold_labels, prediction):
|
||||
for gold_label in gold_labels:
|
||||
result = calculate_em_str(gold_label, prediction)
|
||||
if result == 1.0:
|
||||
return 1.0
|
||||
return 0.0
|
||||
|
||||
|
||||
def calculate_f1_str_multi(gold_labels, prediction):
|
||||
results = []
|
||||
for gold_label in gold_labels:
|
||||
result = calculate_f1_str(gold_label, prediction)
|
||||
results.append(result)
|
||||
return max(results)
|
||||
|
||||
|
||||
def calculate_reader_metrics(metric_counts: Dict[str, float], correct_retrievals: int):
|
||||
number_of_has_answer = correct_retrievals - metric_counts["number_of_no_answer"]
|
||||
@ -212,13 +452,14 @@ def _count_exact_match(
|
||||
|
||||
if (gold_span["offset_start"] == predicted_span["offset_start"]) and \
|
||||
(gold_span["offset_end"] == predicted_span["offset_end"]):
|
||||
# top-1 answer
|
||||
if answer_idx == 0:
|
||||
metric_counts["exact_matches_top1"] += 1
|
||||
metric_counts["exact_matches_top1_has_answer"] += 1
|
||||
# top-k answers
|
||||
metric_counts["exact_matches_topk"] += 1
|
||||
metric_counts["exact_matches_topk_has_answer"] += 1
|
||||
if metric_counts:
|
||||
# top-1 answer
|
||||
if answer_idx == 0:
|
||||
metric_counts["exact_matches_top1"] += 1
|
||||
metric_counts["exact_matches_top1_has_answer"] += 1
|
||||
# top-k answers
|
||||
metric_counts["exact_matches_topk"] += 1
|
||||
metric_counts["exact_matches_topk_has_answer"] += 1
|
||||
found_em = True
|
||||
|
||||
return metric_counts, found_em
|
||||
|
||||
99
haystack/pipeline_eval.py
Normal file
99
haystack/pipeline_eval.py
Normal file
@ -0,0 +1,99 @@
|
||||
from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.preprocessor.utils import fetch_archive_from_http
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack import Pipeline
|
||||
from farm.utils import initialize_device_settings
|
||||
from haystack.preprocessor import PreProcessor
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LAUNCH_ELASTICSEARCH = True
|
||||
doc_index = "documents"
|
||||
label_index = "labels"
|
||||
top_k_retriever = 10
|
||||
open_domain = False
|
||||
|
||||
def launch_es():
|
||||
logger.info("Starting Elasticsearch ...")
|
||||
status = subprocess.run(
|
||||
['docker run -d -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.9.2'], shell=True
|
||||
)
|
||||
if status.returncode:
|
||||
logger.warning("Tried to start Elasticsearch through Docker but this failed. "
|
||||
"It is likely that there is already an existing Elasticsearch instance running. ")
|
||||
else:
|
||||
time.sleep(15)
|
||||
|
||||
def main():
|
||||
|
||||
launch_es()
|
||||
|
||||
document_store = ElasticsearchDocumentStore()
|
||||
es_retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
eval_retriever = EvalRetriever(open_domain=open_domain)
|
||||
reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4, num_processes=1, return_no_answer=True)
|
||||
eval_reader = EvalReader(debug=True, open_domain=open_domain)
|
||||
|
||||
# Download evaluation data, which is a subset of Natural Questions development set containing 50 documents
|
||||
doc_dir = "../data/nq"
|
||||
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/nq_dev_subset_v2.json.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
|
||||
|
||||
# Add evaluation data to Elasticsearch document store
|
||||
# We first delete the custom tutorial indices to not have duplicate elements
|
||||
preprocessor = PreProcessor(split_length=500, split_overlap=0, split_respect_sentence_boundary=False, clean_empty_lines=False, clean_whitespace=False)
|
||||
document_store.delete_all_documents(index=doc_index)
|
||||
document_store.delete_all_documents(index=label_index)
|
||||
document_store.add_eval_data(
|
||||
filename="../data/nq/nq_dev_subset_v2.json",
|
||||
doc_index=doc_index,
|
||||
label_index=label_index,
|
||||
preprocessor=preprocessor
|
||||
)
|
||||
labels = document_store.get_all_labels_aggregated(index=label_index)
|
||||
q_to_l_dict = {
|
||||
l.question: {
|
||||
"retriever": l,
|
||||
"reader": l
|
||||
} for l in labels
|
||||
}
|
||||
|
||||
# Here is the pipeline definition
|
||||
p = Pipeline()
|
||||
p.add_node(component=es_retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
|
||||
results = []
|
||||
for i, (q, l) in enumerate(q_to_l_dict.items()):
|
||||
res = p.run(query=q,
|
||||
top_k_retriever=top_k_retriever,
|
||||
labels=l,
|
||||
top_k_reader=10,
|
||||
index=doc_index,
|
||||
# skip_incorrect_retrieval=True
|
||||
)
|
||||
results.append(res)
|
||||
|
||||
eval_retriever.print()
|
||||
print()
|
||||
es_retriever.print_time()
|
||||
print()
|
||||
eval_reader.print(mode="reader")
|
||||
print()
|
||||
reader.print_time()
|
||||
print()
|
||||
eval_reader.print(mode="pipeline")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -22,13 +22,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def eval_data_from_json(filename: str, max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None) -> Tuple[List[Document], List[Label]]:
|
||||
def eval_data_from_json(filename: str, max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None, open_domain: bool =False) -> Tuple[List[Document], List[Label]]:
|
||||
"""
|
||||
Read Documents + Labels from a SQuAD-style file.
|
||||
Document and Labels can then be indexed to the DocumentStore and be used for evaluation.
|
||||
|
||||
:param filename: Path to file in SQuAD format
|
||||
:param max_docs: This sets the number of documents that will be loaded. By default, this is set to None, thus reading in all available eval documents.
|
||||
:param max_docs: This sets the number of documents that will be loaded. By default, this is set to None, thus reading in all available eval documents.
|
||||
:param open_domain: Set this to True if your file is an open domain dataset where two different answers to the same question might be found in different contexts.
|
||||
:return: (List of Documents, List of Labels)
|
||||
"""
|
||||
|
||||
@ -46,7 +47,10 @@ def eval_data_from_json(filename: str, max_docs: Union[int, bool] = None, prepro
|
||||
if len(docs) > max_docs:
|
||||
break
|
||||
# Extracting paragraphs and their labels from a SQuAD document dict
|
||||
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(document, preprocessor)
|
||||
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(
|
||||
document,
|
||||
preprocessor
|
||||
)
|
||||
docs.extend(cur_docs)
|
||||
labels.extend(cur_labels)
|
||||
problematic_ids.extend(cur_problematic_ids)
|
||||
@ -57,7 +61,8 @@ def eval_data_from_json(filename: str, max_docs: Union[int, bool] = None, prepro
|
||||
|
||||
|
||||
def eval_data_from_jsonl(filename: str, batch_size: Optional[int] = None,
|
||||
max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None) -> Generator[Tuple[List[Document], List[Label]], None, None]:
|
||||
max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None,
|
||||
open_domain: bool = False) -> Generator[Tuple[List[Document], List[Label]], None, None]:
|
||||
"""
|
||||
Read Documents + Labels from a SQuAD-style file in jsonl format, i.e. one document per line.
|
||||
Document and Labels can then be indexed to the DocumentStore and be used for evaluation.
|
||||
@ -68,6 +73,7 @@ def eval_data_from_jsonl(filename: str, batch_size: Optional[int] = None,
|
||||
|
||||
:param filename: Path to file in SQuAD format
|
||||
:param max_docs: This sets the number of documents that will be loaded. By default, this is set to None, thus reading in all available eval documents.
|
||||
:param open_domain: Set this to True if your file is an open domain dataset where two different answers to the same question might be found in different contexts.
|
||||
:return: (List of Documents, List of Labels)
|
||||
"""
|
||||
|
||||
@ -82,7 +88,7 @@ def eval_data_from_jsonl(filename: str, batch_size: Optional[int] = None,
|
||||
break
|
||||
# Extracting paragraphs and their labels from a SQuAD document dict
|
||||
document_dict = json.loads(document)
|
||||
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(document_dict, preprocessor)
|
||||
cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict(document_dict, preprocessor, open_domain)
|
||||
docs.extend(cur_docs)
|
||||
labels.extend(cur_labels)
|
||||
problematic_ids.extend(cur_problematic_ids)
|
||||
@ -100,7 +106,8 @@ def eval_data_from_jsonl(filename: str, batch_size: Optional[int] = None,
|
||||
yield docs, labels
|
||||
|
||||
|
||||
def _extract_docs_and_labels_from_dict(document_dict: Dict, preprocessor: PreProcessor = None):
|
||||
def _extract_docs_and_labels_from_dict(document_dict: Dict, preprocessor: PreProcessor = None, open_domain: bool=False):
|
||||
"""Set open_domain to True if you are trying to load open_domain labels (i.e. labels without doc id or start idx)"""
|
||||
docs = []
|
||||
labels = []
|
||||
problematic_ids = []
|
||||
@ -122,7 +129,7 @@ def _extract_docs_and_labels_from_dict(document_dict: Dict, preprocessor: PrePro
|
||||
splits_dicts = preprocessor.process(cur_doc.to_dict())
|
||||
# we need to pull in _split_id into the document id for unique reference in labels
|
||||
# todo: PreProcessor should work on Documents instead of dicts
|
||||
splits = []
|
||||
splits: List[Document] = []
|
||||
offset = 0
|
||||
for d in splits_dicts:
|
||||
id = f"{d['id']}-{d['meta']['_split_id']}"
|
||||
@ -148,25 +155,34 @@ def _extract_docs_and_labels_from_dict(document_dict: Dict, preprocessor: PrePro
|
||||
if not qa.get("is_impossible", False):
|
||||
for answer in qa["answers"]:
|
||||
ans = answer["text"]
|
||||
ans_position = cur_doc.text[answer["answer_start"]:answer["answer_start"]+len(ans)]
|
||||
if ans != ans_position:
|
||||
# do not use answer
|
||||
problematic_ids.append(qa.get("id","missing"))
|
||||
break
|
||||
# find corresponding document or split
|
||||
if len(splits) == 1:
|
||||
cur_id = splits[0].id
|
||||
cur_ans_start = answer["answer_start"]
|
||||
cur_ans_start = None
|
||||
# TODO The following block of code means that answer_start is never calculated
|
||||
# and cur_id is always None for open_domain
|
||||
# This can be rewritten so that this function could try to calculate offsets
|
||||
# and populate id in open_domain mode
|
||||
if open_domain:
|
||||
cur_ans_start = answer.get("answer_start", 0)
|
||||
cur_id = '0'
|
||||
else:
|
||||
for s in splits:
|
||||
# If answer start offset is contained in passage we assign the label to that passage
|
||||
if (answer["answer_start"] >= s.meta["_split_offset"]) and (answer["answer_start"] < (s.meta["_split_offset"] + len(s.text))):
|
||||
cur_id = s.id
|
||||
cur_ans_start = answer["answer_start"] - s.meta["_split_offset"]
|
||||
# If a document is splitting an answer we add the whole answer text to the document
|
||||
if s.text[cur_ans_start:cur_ans_start+len(ans)] != ans:
|
||||
s.text = s.text[:cur_ans_start] + ans
|
||||
break
|
||||
ans_position = cur_doc.text[answer["answer_start"]:answer["answer_start"]+len(ans)]
|
||||
if ans != ans_position:
|
||||
# do not use answer
|
||||
problematic_ids.append(qa.get("id","missing"))
|
||||
break
|
||||
# find corresponding document or split
|
||||
if len(splits) == 1:
|
||||
cur_id = splits[0].id
|
||||
cur_ans_start = answer["answer_start"]
|
||||
else:
|
||||
for s in splits:
|
||||
# If answer start offset is contained in passage we assign the label to that passage
|
||||
if (answer["answer_start"] >= s.meta["_split_offset"]) and (answer["answer_start"] < (s.meta["_split_offset"] + len(s.text))):
|
||||
cur_id = s.id
|
||||
cur_ans_start = answer["answer_start"] - s.meta["_split_offset"]
|
||||
# If a document is splitting an answer we add the whole answer text to the document
|
||||
if s.text[cur_ans_start:cur_ans_start+len(ans)] != ans:
|
||||
s.text = s.text[:cur_ans_start] + ans
|
||||
break
|
||||
label = Label(
|
||||
question=qa["question"],
|
||||
answer=ans,
|
||||
|
||||
@ -3,6 +3,9 @@ from scipy.special import expit
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Sequence
|
||||
from functools import wraps
|
||||
from time import perf_counter
|
||||
|
||||
|
||||
from haystack import Document, BaseComponent
|
||||
|
||||
@ -10,6 +13,8 @@ from haystack import Document, BaseComponent
|
||||
class BaseReader(BaseComponent):
|
||||
return_no_answers: bool
|
||||
outgoing_edges = 1
|
||||
query_count = 0
|
||||
query_time = 0
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
@ -47,8 +52,10 @@ class BaseReader(BaseComponent):
|
||||
return no_ans_prediction, max_no_ans_gap
|
||||
|
||||
def run(self, query: str, documents: List[Document], top_k_reader: Optional[int] = None, **kwargs): # type: ignore
|
||||
self.query_count += 1
|
||||
if documents:
|
||||
results = self.predict(query=query, documents=documents, top_k=top_k_reader)
|
||||
predict = self.timing(self.predict, "query_time")
|
||||
results = predict(query=query, documents=documents, top_k=top_k_reader)
|
||||
else:
|
||||
results = {"answers": [], "query": query}
|
||||
|
||||
@ -62,4 +69,27 @@ class BaseReader(BaseComponent):
|
||||
results.update(**kwargs)
|
||||
return results, "output_1"
|
||||
|
||||
def timing(self, fn, attr_name):
|
||||
"""Wrapper method used to time functions. """
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if attr_name not in self.__dict__:
|
||||
self.__dict__[attr_name] = 0
|
||||
tic = perf_counter()
|
||||
ret = fn(*args, **kwargs)
|
||||
toc = perf_counter()
|
||||
self.__dict__[attr_name] += toc - tic
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
def print_time(self):
|
||||
print("Reader (Speed)")
|
||||
print("---------------")
|
||||
if not self.query_count:
|
||||
print("No querying performed via Retriever.run()")
|
||||
else:
|
||||
print(f"Queries Performed: {self.query_count}")
|
||||
print(f"Query time: {self.query_time}s")
|
||||
print(f"{self.query_time / self.query_count} seconds per query")
|
||||
|
||||
|
||||
|
||||
@ -99,7 +99,8 @@ class FARMReader(BaseReader):
|
||||
self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
|
||||
task_type="question_answering", max_seq_len=max_seq_len,
|
||||
doc_stride=doc_stride, num_processes=num_processes, revision=model_version,
|
||||
disable_tqdm=not progress_bar)
|
||||
disable_tqdm=not progress_bar,
|
||||
strict=False)
|
||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||
self.inferencer.model.prediction_heads[0].no_ans_boost = no_ans_boost
|
||||
self.inferencer.model.prediction_heads[0].n_best = top_k_per_candidate + 1 # including possible no_answer
|
||||
|
||||
@ -14,6 +14,11 @@ logger = logging.getLogger(__name__)
|
||||
class BaseRetriever(BaseComponent):
|
||||
document_store: BaseDocumentStore
|
||||
outgoing_edges = 1
|
||||
query_count = 0
|
||||
index_count = 0
|
||||
query_time = 0.0
|
||||
index_time = 0.0
|
||||
retrieve_time = 0.0
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: Optional[int] = None, index: str = None) -> List[Document]:
|
||||
@ -28,16 +33,16 @@ class BaseRetriever(BaseComponent):
|
||||
"""
|
||||
pass
|
||||
|
||||
def timing(self, fn):
|
||||
def timing(self, fn, attr_name):
|
||||
"""Wrapper method used to time functions. """
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "retrieve_time" not in self.__dict__:
|
||||
self.retrieve_time = 0
|
||||
if attr_name not in self.__dict__:
|
||||
self.__dict__[attr_name] = 0
|
||||
tic = perf_counter()
|
||||
ret = fn(*args, **kwargs)
|
||||
toc = perf_counter()
|
||||
self.retrieve_time += toc - tic
|
||||
self.__dict__[attr_name] += toc - tic
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
@ -80,7 +85,7 @@ class BaseRetriever(BaseComponent):
|
||||
# Extract all questions for evaluation
|
||||
filters = {"origin": [label_origin]}
|
||||
|
||||
timed_retrieve = self.timing(self.retrieve)
|
||||
timed_retrieve = self.timing(self.retrieve, "retrieve_time")
|
||||
|
||||
labels = self.document_store.get_all_labels_aggregated(index=label_index, filters=filters)
|
||||
|
||||
@ -170,9 +175,13 @@ class BaseRetriever(BaseComponent):
|
||||
|
||||
def run(self, pipeline_type: str, **kwargs): # type: ignore
|
||||
if pipeline_type == "Query":
|
||||
output, stream = self.run_query(**kwargs)
|
||||
self.query_count += 1
|
||||
run_query_timed = self.timing(self.run_query, "query_time")
|
||||
output, stream = run_query_timed(**kwargs)
|
||||
elif pipeline_type == "Indexing":
|
||||
output, stream = self.run_indexing(**kwargs)
|
||||
self.index_count += len(kwargs["documents"])
|
||||
run_indexing = self.timing(self.run_indexing, "index_time")
|
||||
output, stream = run_indexing(**kwargs)
|
||||
else:
|
||||
raise Exception(f"Invalid pipeline_type '{pipeline_type}'.")
|
||||
return output, stream
|
||||
@ -184,7 +193,8 @@ class BaseRetriever(BaseComponent):
|
||||
top_k_retriever: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
documents = self.retrieve(query=query, filters=filters, top_k=top_k_retriever)
|
||||
index = kwargs.get("index", None)
|
||||
documents = self.retrieve(query=query, filters=filters, top_k=top_k_retriever, index=index)
|
||||
document_ids = [doc.id for doc in documents]
|
||||
logger.debug(f"Retrieved documents with IDs: {document_ids}")
|
||||
output = {
|
||||
@ -202,6 +212,21 @@ class BaseRetriever(BaseComponent):
|
||||
embeddings = self.embed_passages(document_objects) # type: ignore
|
||||
for doc, emb in zip(documents, embeddings):
|
||||
doc["embedding"] = emb
|
||||
|
||||
output = {**kwargs, "documents": documents}
|
||||
return output, "output_1"
|
||||
|
||||
def print_time(self):
|
||||
print("Retriever (Speed)")
|
||||
print("---------------")
|
||||
if not self.index_count:
|
||||
print("No indexing performed via Retriever.run()")
|
||||
else:
|
||||
print(f"Documents indexed: {self.index_count}")
|
||||
print(f"Index time: {self.index_time}s")
|
||||
print(f"{self.query_time / self.query_count} seconds per document")
|
||||
if not self.query_count:
|
||||
print("No querying performed via Retriever.run()")
|
||||
else:
|
||||
print(f"Queries Performed: {self.query_count}")
|
||||
print(f"Query time: {self.query_time}s")
|
||||
print(f"{self.query_time / self.query_count} seconds per query")
|
||||
|
||||
@ -6,11 +6,29 @@ import pprint
|
||||
import pandas as pd
|
||||
from typing import Dict, Any, List
|
||||
from haystack.document_store.sql import DocumentORM
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def launch_es():
|
||||
# Start an Elasticsearch server
|
||||
# You can start Elasticsearch on your local machine instance using Docker. If Docker is not readily available in
|
||||
# your environment (eg., in Colab notebooks), then you can manually download and execute Elasticsearch from source.
|
||||
|
||||
logger.info("Starting Elasticsearch ...")
|
||||
status = subprocess.run(
|
||||
['docker run -d -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.9.2'], shell=True
|
||||
)
|
||||
if status.returncode:
|
||||
logger.warning("Tried to start Elasticsearch through Docker but this failed. "
|
||||
"It is likely that there is already an existing Elasticsearch instance running. ")
|
||||
else:
|
||||
time.sleep(15)
|
||||
|
||||
|
||||
def print_answers(results: dict, details: str = "all"):
|
||||
answers = results["answers"]
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
@ -2,6 +2,9 @@ import pytest
|
||||
from haystack.document_store.base import BaseDocumentStore
|
||||
from haystack.preprocessor.preprocessor import PreProcessor
|
||||
from haystack.finder import Finder
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
from haystack import Pipeline
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [None, 20])
|
||||
@ -91,6 +94,48 @@ def test_eval_elastic_retriever(document_store: BaseDocumentStore, open_domain,
|
||||
assert results["map"] == 1.0
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever", ["elasticsearch"], indirect=True)
|
||||
def test_eval_pipeline(document_store: BaseDocumentStore, reader, retriever):
|
||||
# add eval data (SQUAD format)
|
||||
document_store.add_eval_data(
|
||||
filename="samples/squad/tiny.json",
|
||||
doc_index="haystack_test_eval_document",
|
||||
label_index="haystack_test_feedback",
|
||||
)
|
||||
|
||||
labels = document_store.get_all_labels_aggregated(index="haystack_test_feedback")
|
||||
q_to_l_dict = {
|
||||
l.question: {
|
||||
"retriever": l,
|
||||
"reader": l
|
||||
} for l in labels
|
||||
}
|
||||
|
||||
eval_retriever = EvalRetriever()
|
||||
eval_reader = EvalReader()
|
||||
|
||||
assert document_store.get_document_count(index="haystack_test_eval_document") == 2
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
for q, l in q_to_l_dict.items():
|
||||
res = p.run(
|
||||
query=q,
|
||||
top_k_retriever=10,
|
||||
labels=l,
|
||||
top_k_reader=10,
|
||||
index="haystack_test_eval_document",
|
||||
)
|
||||
assert eval_retriever.recall == 1.0
|
||||
assert eval_reader.top_k_f1 == 0.7
|
||||
assert eval_reader.top_k_em == 0.5
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("document_store", ["elasticsearch"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
|
||||
@ -201,17 +201,40 @@
|
||||
"id": "bRFsQUAJOhu_",
|
||||
"outputId": "56b84800-c524-4418-9664-e2720b66a1af",
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from haystack.preprocessor import PreProcessor\n",
|
||||
"\n",
|
||||
"# Add evaluation data to Elasticsearch Document Store\n",
|
||||
"# We first delete the custom tutorial indices to not have duplicate elements\n",
|
||||
"# and also split our documents into shorter passages using the PreProcessor\n",
|
||||
"preprocessor = PreProcessor(\n",
|
||||
" split_length=500,\n",
|
||||
" split_overlap=0,\n",
|
||||
" split_respect_sentence_boundary=False,\n",
|
||||
" clean_empty_lines=False,\n",
|
||||
" clean_whitespace=False\n",
|
||||
")\n",
|
||||
"document_store.delete_all_documents(index=doc_index)\n",
|
||||
"document_store.delete_all_documents(index=label_index)\n",
|
||||
"document_store.add_eval_data(filename=\"../data/nq/nq_dev_subset_v2.json\", doc_index=doc_index, label_index=label_index)"
|
||||
"document_store.add_eval_data(\n",
|
||||
" filename=\"../data/nq/nq_dev_subset_v2.json\",\n",
|
||||
" doc_index=doc_index,\n",
|
||||
" label_index=label_index,\n",
|
||||
" preprocessor=preprocessor\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Let's prepare the labels that we need for the retriever and the reader\n",
|
||||
"labels = document_store.get_all_labels_aggregated(index=label_index)\n",
|
||||
"q_to_l_dict = {\n",
|
||||
" l.question: {\n",
|
||||
" \"retriever\": l,\n",
|
||||
" \"reader\": l\n",
|
||||
" } for l in labels\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -321,7 +344,6 @@
|
||||
"id": "cW3Ypn_gOhvK",
|
||||
"outputId": "89ad5598-1017-499f-c986-72bba2a3a6cb",
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
@ -330,28 +352,26 @@
|
||||
"# Initialize Reader\n",
|
||||
"from haystack.reader.farm import FARMReader\n",
|
||||
"\n",
|
||||
"reader = FARMReader(\"deepset/roberta-base-squad2\", top_k_per_candidate=4)"
|
||||
"reader = FARMReader(\"deepset/roberta-base-squad2\", top_k_per_candidate=4, return_no_answer=True)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "gOs7qy4xOhvO",
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Initialize Finder which sticks together Reader and Retriever\n",
|
||||
"from haystack.finder import Finder\n",
|
||||
"from haystack.eval import EvalReader, EvalRetriever\n",
|
||||
"\n",
|
||||
"finder = Finder(reader, retriever)"
|
||||
]
|
||||
"# Here we initialize the nodes that perform evaluation\n",
|
||||
"eval_retriever = EvalRetriever()\n",
|
||||
"eval_reader = EvalReader()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
@ -363,7 +383,8 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Evaluation of Retriever"
|
||||
"## Evaluation of Retriever\n",
|
||||
"Here we evaluate only the retriever, based on whether the gold_label document is retrieved."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -378,7 +399,6 @@
|
||||
"id": "YzvLhnx3OhvS",
|
||||
"outputId": "1d45f072-0ae0-4864-8ccc-aa12303a8d04",
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
@ -403,7 +423,10 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Evaluation of Reader"
|
||||
"## Evaluation of Reader\n",
|
||||
"Here we evaluate only the reader in a closed domain fashion i.e. the reader is given one query\n",
|
||||
"and one document and metrics are calculated on whether the right position in this text is selected by\n",
|
||||
"the model as the answer span (i.e. SQuAD style)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -418,7 +441,6 @@
|
||||
"id": "Lgsgf4KaOhvY",
|
||||
"outputId": "24d3755e-bf2e-4396-f1a2-59c925cc54d3",
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
@ -447,7 +469,10 @@
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Evaluation of Finder"
|
||||
"## Evaluation of Retriever and Reader (Open Domain)\n",
|
||||
"Here we evaluate retriever and reader in open domain fashion i.e. a document is considered\n",
|
||||
"correctly retrieved if it contains the answer string within it. The reader is evaluated based purely on the\n",
|
||||
"predicted string, regardless of which document this came from and the position of the extracted span."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -462,31 +487,68 @@
|
||||
"id": "yLpMHAexOhvd",
|
||||
"outputId": "fd74be7d-5c8e-4eb9-a653-062427b74347",
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Evaluate combination of Reader and Retriever through Finder\n",
|
||||
"finder_eval_results = finder.eval(top_k_retriever=1, top_k_reader=10, label_index=label_index, doc_index=doc_index)\n",
|
||||
"finder.print_eval_results(finder_eval_results)"
|
||||
"from haystack import Pipeline\n",
|
||||
"\n",
|
||||
"# Here is the pipeline definition\n",
|
||||
"p = Pipeline()\n",
|
||||
"p.add_node(component=retriever, name=\"ESRetriever\", inputs=[\"Query\"])\n",
|
||||
"p.add_node(component=eval_retriever, name=\"EvalRetriever\", inputs=[\"ESRetriever\"])\n",
|
||||
"p.add_node(component=reader, name=\"QAReader\", inputs=[\"EvalRetriever\"])\n",
|
||||
"p.add_node(component=eval_reader, name=\"EvalReader\", inputs=[\"QAReader\"])\n",
|
||||
"results = []"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# This is how to run the pipeline\n",
|
||||
"for q, l in q_to_l_dict.items():\n",
|
||||
" res = p.run(\n",
|
||||
" query=q,\n",
|
||||
" top_k_retriever=10,\n",
|
||||
" labels=l,\n",
|
||||
" top_k_reader=10,\n",
|
||||
" index=doc_index,\n",
|
||||
" )\n",
|
||||
" results.append(res)"
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "DD57b_LkOhvg",
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"is_executing": true,
|
||||
"name": "#%%\n"
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"# When we have run evaluation using the pipeline, we can print the results\n",
|
||||
"n_queries = len(labels)\n",
|
||||
"eval_retriever.print()\n",
|
||||
"print()\n",
|
||||
"retriever.print_time()\n",
|
||||
"print()\n",
|
||||
"eval_reader.print(mode=\"reader\")\n",
|
||||
"print()\n",
|
||||
"reader.print_time()\n",
|
||||
"print()\n",
|
||||
"eval_reader.print(mode=\"pipeline\")\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
@ -2,46 +2,44 @@ from haystack.document_store.elasticsearch import ElasticsearchDocumentStore
|
||||
from haystack.preprocessor.utils import fetch_archive_from_http
|
||||
from haystack.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.retriever.dense import DensePassageRetriever
|
||||
from haystack.eval import EvalReader, EvalRetriever
|
||||
from haystack.reader.farm import FARMReader
|
||||
from haystack.finder import Finder
|
||||
from haystack.preprocessor import PreProcessor
|
||||
from haystack.utils import launch_es
|
||||
from haystack import Pipeline
|
||||
|
||||
from farm.utils import initialize_device_settings
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tutorial5_evaluation():
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
##############################################
|
||||
# Settings
|
||||
##############################################
|
||||
LAUNCH_ELASTICSEARCH = True
|
||||
|
||||
eval_retriever_only = True
|
||||
eval_reader_only = False
|
||||
eval_both = False
|
||||
# Choose from Evaluation style from ['retriever_closed', 'reader_closed', 'retriever_reader_open']
|
||||
# 'retriever_closed' - evaluates only the retriever, based on whether the gold_label document is retrieved.
|
||||
# 'reader_closed' - evaluates only the reader in a closed domain fashion i.e. the reader is given one query
|
||||
# and one document and metrics are calculated on whether the right position in this text is selected by
|
||||
# the model as the answer span (i.e. SQuAD style)
|
||||
# 'retriever_reader_open' - evaluates retriever and reader in open domain fashion i.e. a document is considered
|
||||
# correctly retrieved if it contains the answer string within it. The reader is evaluated based purely on the
|
||||
# predicted string, regardless of which document this came from and the position of the extracted span.
|
||||
style = "retriever_reader_open"
|
||||
|
||||
# make sure these indices do not collide with existing ones, the indices will be wiped clean before data is inserted
|
||||
doc_index = "tutorial5_docs"
|
||||
label_index = "tutorial5_labels"
|
||||
|
||||
##############################################
|
||||
# Code
|
||||
##############################################
|
||||
launch_es()
|
||||
device, n_gpu = initialize_device_settings(use_cuda=True)
|
||||
# Start an Elasticsearch server
|
||||
# You can start Elasticsearch on your local machine instance using Docker. If Docker is not readily available in
|
||||
# your environment (eg., in Colab notebooks), then you can manually download and execute Elasticsearch from source.
|
||||
if LAUNCH_ELASTICSEARCH:
|
||||
logging.info("Starting Elasticsearch ...")
|
||||
status = subprocess.run(
|
||||
['docker run -d -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.9.2'], shell=True
|
||||
)
|
||||
if status.returncode:
|
||||
raise Exception("Failed to launch Elasticsearch. If you want to connect to an existing Elasticsearch instance"
|
||||
"then set LAUNCH_ELASTICSEARCH in the script to False.")
|
||||
time.sleep(30)
|
||||
|
||||
# Download evaluation data, which is a subset of Natural Questions development set containing 50 documents
|
||||
doc_dir = "../data/nq"
|
||||
@ -49,16 +47,39 @@ def tutorial5_evaluation():
|
||||
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
|
||||
|
||||
# Connect to Elasticsearch
|
||||
document_store = ElasticsearchDocumentStore(host="localhost", username="", password="", index="document",
|
||||
create_index=False, embedding_field="emb",
|
||||
embedding_dim=768, excluded_meta_data=["emb"])
|
||||
|
||||
document_store = ElasticsearchDocumentStore(
|
||||
host="localhost", username="", password="", index="document",
|
||||
create_index=False, embedding_field="emb",
|
||||
embedding_dim=768, excluded_meta_data=["emb"]
|
||||
)
|
||||
|
||||
# Add evaluation data to Elasticsearch document store
|
||||
# We first delete the custom tutorial indices to not have duplicate elements
|
||||
# and also split our documents into shorter passages using the PreProcessor
|
||||
preprocessor = PreProcessor(
|
||||
split_length=500,
|
||||
split_overlap=0,
|
||||
split_respect_sentence_boundary=False,
|
||||
clean_empty_lines=False,
|
||||
clean_whitespace=False
|
||||
)
|
||||
document_store.delete_all_documents(index=doc_index)
|
||||
document_store.delete_all_documents(index=label_index)
|
||||
document_store.add_eval_data(filename="../data/nq/nq_dev_subset_v2.json", doc_index=doc_index, label_index=label_index)
|
||||
document_store.add_eval_data(
|
||||
filename="../data/nq/nq_dev_subset_v2.json",
|
||||
doc_index=doc_index,
|
||||
label_index=label_index,
|
||||
preprocessor=preprocessor
|
||||
)
|
||||
|
||||
# Let's prepare the labels that we need for the retriever and the reader
|
||||
labels = document_store.get_all_labels_aggregated(index=label_index)
|
||||
q_to_l_dict = {
|
||||
l.question: {
|
||||
"retriever": l,
|
||||
"reader": l
|
||||
} for l in labels
|
||||
}
|
||||
|
||||
# Initialize Retriever
|
||||
retriever = ElasticsearchRetriever(document_store=document_store)
|
||||
@ -68,23 +89,27 @@ def tutorial5_evaluation():
|
||||
# Here, for nq_dev_subset_v2.json we have avg. num of tokens = 5220(!).
|
||||
# DPR still outperforms Elastic's BM25 by a small margin here.
|
||||
# retriever = DensePassageRetriever(document_store=document_store,
|
||||
# query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
||||
# query_embedding_model="facebook/dpr-qPreProuestion_encoder-single-nq-base",
|
||||
# passage_embedding_model="facebook/dpr-ctx_encoder-single-nq-base",
|
||||
# use_gpu=True,
|
||||
# embed_title=True,
|
||||
# remove_sep_tok_from_untitled_passages=True)
|
||||
# document_store.update_embeddings(retriever, index=doc_index)
|
||||
|
||||
|
||||
# Initialize Reader
|
||||
reader = FARMReader("deepset/roberta-base-squad2", top_k_per_candidate=4)
|
||||
reader = FARMReader(
|
||||
model_name_or_path="deepset/roberta-base-squad2",
|
||||
top_k_per_candidate=4,
|
||||
return_no_answer=True
|
||||
)
|
||||
|
||||
# Initialize Finder which sticks together Reader and Retriever
|
||||
finder = Finder(reader, retriever)
|
||||
# Here we initialize the nodes that perform evaluation
|
||||
eval_retriever = EvalRetriever()
|
||||
eval_reader = EvalReader()
|
||||
|
||||
|
||||
## Evaluate Retriever on its own
|
||||
if eval_retriever_only:
|
||||
## Evaluate Retriever on its own in closed domain fashion
|
||||
if style == "retriever_closed":
|
||||
retriever_eval_results = retriever.eval(top_k=10, label_index=label_index, doc_index=doc_index)
|
||||
## Retriever Recall is the proportion of questions for which the correct document containing the answer is
|
||||
## among the correct documents
|
||||
@ -92,8 +117,8 @@ def tutorial5_evaluation():
|
||||
## Retriever Mean Avg Precision rewards retrievers that give relevant documents a higher rank
|
||||
print("Retriever Mean Avg Precision:", retriever_eval_results["map"])
|
||||
|
||||
# Evaluate Reader on its own
|
||||
if eval_reader_only:
|
||||
# Evaluate Reader on its own in closed domain fashion (i.e. SQuAD style)
|
||||
elif style == "reader_closed":
|
||||
reader_eval_results = reader.eval(document_store=document_store, device=device, label_index=label_index, doc_index=doc_index)
|
||||
# Evaluation of Reader can also be done directly on a SQuAD-formatted file without passing the data to Elasticsearch
|
||||
#reader_eval_results = reader.eval_on_file("../data/nq", "nq_dev_subset_v2.json", device=device)
|
||||
@ -106,10 +131,37 @@ def tutorial5_evaluation():
|
||||
print("Reader F1-Score:", reader_eval_results["f1"])
|
||||
|
||||
|
||||
# Evaluate combination of Reader and Retriever through Finder
|
||||
if eval_both:
|
||||
finder_eval_results = finder.eval(top_k_retriever=1, top_k_reader=10, label_index=label_index, doc_index=doc_index)
|
||||
finder.print_eval_results(finder_eval_results)
|
||||
# Evaluate combination of Reader and Retriever in open domain fashion
|
||||
elif style == "retriever_reader_open":
|
||||
|
||||
# Here is the pipeline definition
|
||||
p = Pipeline()
|
||||
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
|
||||
p.add_node(component=eval_retriever, name="EvalRetriever", inputs=["ESRetriever"])
|
||||
p.add_node(component=reader, name="QAReader", inputs=["EvalRetriever"])
|
||||
p.add_node(component=eval_reader, name="EvalReader", inputs=["QAReader"])
|
||||
results = []
|
||||
|
||||
for q, l in q_to_l_dict.items():
|
||||
res = p.run(
|
||||
query=q,
|
||||
top_k_retriever=10,
|
||||
labels=l,
|
||||
top_k_reader=10,
|
||||
index=doc_index,
|
||||
)
|
||||
results.append(res)
|
||||
|
||||
n_queries = len(labels)
|
||||
eval_retriever.print()
|
||||
print()
|
||||
retriever.print_time()
|
||||
print()
|
||||
eval_reader.print(mode="reader")
|
||||
print()
|
||||
reader.print_time()
|
||||
print()
|
||||
eval_reader.print(mode="pipeline")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user