Add Evaluation of Reader, Retriever and Finder (#92)

This commit is contained in:
bogdankostic 2020-05-29 15:57:07 +02:00 committed by GitHub
parent ca6778d934
commit bbfccf5cf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1801 additions and 6 deletions

View File

@ -55,10 +55,13 @@ Components
Resources
=========
- Tutorial 1 - Basic QA Pipeline: `Jupyter notebook <https://github.com/deepset-ai/haystack/blob/master/tutorials/Tutorial1_Basic_QA_Pipeline.ipynb>`__ or `Colab <https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial1_Basic_QA_Pipeline.ipynb>`__
- Tutorial 2 - Fine-tuning a model on own data: `Jupyter notebook <https://github.com/deepset-ai/haystack/blob/master/tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb>`__ or `Colab <https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb>`__
- Tutorial 3 - Basic QA Pipeline without Elasticsearch: `Jupyter notebook <https://github.com/deepset-ai/haystack/blob/master/tutorials/Tutorial3_Basic_QA_Pipeline_without_Elasticsearch.ipynb>`__ or `Colab <https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial3_Basic_QA_Pipeline_without_Elasticsearch.ipynb>`__
- Tutorial 1 - Basic QA Pipeline: `Jupyter notebook <https://github.com/deepset-ai/haystack/blob/master/tutorials/Tutorial1_Basic_QA_Pipeline.ipynb>`_ or `Colab <https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial1_Basic_QA_Pipeline.ipynb>`_
- Tutorial 2 - Fine-tuning a model on own data: `Jupyter notebook <https://github.com/deepset-ai/haystack/blob/master/tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb>`_ or `Colab <https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial2_Finetune_a_model_on_your_data.ipynb>`_
- Tutorial 3 - Basic QA Pipeline without Elasticsearch: `Jupyter notebook <https://github.com/deepset-ai/haystack/blob/master/tutorials/Tutorial3_Basic_QA_Pipeline_without_Elasticsearch.py>`_ or `Colab <https://colab.research.google.com/github/deepset-ai/haystack/blob/update-tutorials/tutorials/Tutorial3_Basic_QA_Pipeline_without_Elasticsearch.ipynb>`_
- Tutorial 4 - FAQ-style QA: `Jupyter notebook <https://github.com/deepset-ai/haystack/blob/master/tutorials/Tutorial4_FAQ_style_QA.ipynb>`__ or `Colab <https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial4_FAQ_style_QA.ipynb>`__
- Tutorial 5 - Evaluation of the whole QA-Pipeline: `Jupyter noteboook <https://github.com/deepset-ai/haystack/blob/master/tutorials/Tutorial5_Evaluation.ipynb>`_ or `Colab <https://colab.research.google.com/github/deepset-ai/haystack/blob/master/tutorials/Tutorial5_Evaluation.ipynb>`_
Quick Start
===========

View File

@ -108,8 +108,12 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
filters: dict = None,
top_k: int = 10,
custom_query: str = None,
index: str = None,
) -> [Document]:
if index is None:
index = self.index
if custom_query: # substitute placeholder for question and filters for the custom_query template string
template = Template(custom_query)
@ -145,7 +149,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
body["_source"] = {"excludes": self.excluded_meta_data}
logger.debug(f"Retriever query: {body}")
result = self.client.search(index=self.index, body=body)["hits"]["hits"]
result = self.client.search(index=index, body=body)["hits"]["hits"]
documents = [self._convert_es_hit_to_document(hit) for hit in result]
return documents
@ -199,3 +203,76 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
query_score=hit["_score"] + score_adjustment if hit["_score"] else None,
)
return document
def add_eval_data(self, filename: str, doc_index: str = "eval_document", label_index: str = "feedback"):
"""
Adds a SQuAD-formatted file to the DocumentStore in order to be able to perform evaluation on it.
:param filename: Name of the file containing evaluation data
:type filename: str
:param doc_index: Elasticsearch index where evaluation documents should be stored
:type doc_index: str
:param label_index: Elasticsearch index where labeled questions should be stored
:type label_index: str
"""
eval_docs_to_index = []
questions_to_index = []
with open(filename, "r") as file:
data = json.load(file)
for document in data["data"]:
for paragraph in document["paragraphs"]:
doc_to_index= {}
id = hash(paragraph["context"])
for fieldname, value in paragraph.items():
# write docs to doc_index
if fieldname == "context":
doc_to_index[self.text_field] = value
doc_to_index["doc_id"] = str(id)
doc_to_index["_op_type"] = "create"
doc_to_index["_index"] = doc_index
# write questions to label_index
elif fieldname == "qas":
for qa in value:
question_to_index = {
"question": qa["question"],
"answers": qa["answers"],
"doc_id": str(id),
"origin": "gold_label",
"index_name": doc_index,
"_op_type": "create",
"_index": label_index
}
questions_to_index.append(question_to_index)
# additional fields for docs
else:
doc_to_index[fieldname] = value
for key, value in document.items():
if key == "title":
doc_to_index[self.name_field] = value
elif key != "paragraphs":
doc_to_index[key] = value
eval_docs_to_index.append(doc_to_index)
bulk(self.client, eval_docs_to_index)
bulk(self.client, questions_to_index)
def get_all_documents_in_index(self, index, filters=None):
body = {
"query": {
"bool": {
"must": {
"match_all" : {}
}
}
}
}
if filters:
body["query"]["bool"]["filter"] = {"term": filters}
result = scan(self.client, query=body, index=index)
return result

View File

@ -2,6 +2,8 @@ import logging
import numpy as np
from scipy.special import expit
import time
from statistics import mean
logger = logging.getLogger(__name__)
@ -96,3 +98,239 @@ class Finder:
results["answers"].append(cur_answer)
return results
def eval(self, label_index: str = "feedback", doc_index: str = "eval_document", label_origin: str = "gold_label",
top_k_retriever: int = 10, top_k_reader: int = 10):
"""
Evaluation of the whole pipeline by first evaluating the Retriever and then evaluating the Reader on the result
of the Retriever.
Returns a dict containing the following metrics:
- "retriever_recall": Proportion of questions for which correct document is among retrieved documents
- "retriever_map": Mean of average precision for each question. Rewards retrievers that give relevant
documents a higher rank.
- "reader_top1_accuracy": Proportion of highest ranked predicted answers that overlap with corresponding correct answer
- "reader_top1_accuracy_has_answer": Proportion of highest ranked predicted answers that overlap
with corresponding correct answer for answerable questions
- "reader_top_k_accuracy": Proportion of predicted answers that overlap with corresponding correct answer
- "reader_topk_accuracy_has_answer": Proportion of predicted answers that overlap with corresponding correct answer
for answerable questions
- "reader_top1_em": Proportion of exact matches of highest ranked predicted answers with their corresponding
correct answers
- "reader_top1_em_has_answer": Proportion of exact matches of highest ranked predicted answers with their corresponding
correct answers for answerable questions
- "reader_topk_em": Proportion of exact matches of predicted answers with their corresponding correct answers
- "reader_topk_em_has_answer": Proportion of exact matches of predicted answers with their corresponding
correct answers for answerable questions
- "reader_top1_f1": Average overlap between highest ranked predicted answers and their corresponding correct answers
- "reader_top1_f1_has_answer": Average overlap between highest ranked predicted answers and their corresponding
correct answers for answerable questions
- "reader_topk_f1": Average overlap between predicted answers and their corresponding correct answers
- "reader_topk_f1_has_answer": Average overlap between predicted answers and their corresponding correct answers
for answerable questions
- "reader_top1_no_answer_accuracy": Proportion of correct predicting unanswerable question at highest ranked prediction
- "reader_topk_no_answer_accuracy": Proportion of correct predicting unanswerable question among all predictions
- "total_retrieve_time": Time retriever needed to retrieve documents for all questions
- "avg_retrieve_time": Average time needed to retrieve documents for one question
- "total_reader_time": Time reader needed to extract answer out of retrieved documents for all questions
where the correct document is among the retrieved ones
- "avg_reader_time": Average time needed to extract answer out of retrieved documents for one question
- "total_finder_time": Total time for whole pipeline
:param label_index: Elasticsearch index where labeled questions are stored
:type label_index: str
:param doc_index: Elasticsearch index where documents that are used for evaluation are stored
:type doc_index: str
:param top_k_retriever: How many documents per question to return and pass to reader
:type top_k_retriever: int
:param top_k_reader: How many answers to return per question
:type top_k_reader: int
"""
finder_start_time = time.time()
# extract all questions for evaluation
filter = {"origin": label_origin}
questions = self.retriever.document_store.get_all_documents_in_index(index=label_index, filters=filter)
correct_retrievals = 0
summed_avg_precision_retriever = 0
retrieve_times = []
correct_readings_top1 = 0
correct_readings_topk = 0
correct_readings_top1_has_answer = 0
correct_readings_topk_has_answer = 0
exact_matches_top1 = 0
exact_matches_topk = 0
exact_matches_top1_has_answer = 0
exact_matches_topk_has_answer = 0
summed_f1_top1 = 0
summed_f1_topk = 0
summed_f1_top1_has_answer = 0
summed_f1_topk_has_answer = 0
correct_no_answers_top1 = 0
correct_no_answers_topk = 0
read_times = []
# retrieve documents
questions_with_docs = []
retriever_start_time = time.time()
for q_idx, question in enumerate(questions):
question_string = question["_source"]["question"]
single_retrieve_start = time.time()
retrieved_docs = self.retriever.retrieve(question_string, top_k=top_k_retriever, index=doc_index)
retrieve_times.append(time.time() - single_retrieve_start)
for doc_idx, doc in enumerate(retrieved_docs):
# check if correct doc among retrieved docs
if doc.meta["doc_id"] == question["_source"]["doc_id"]:
correct_retrievals += 1
summed_avg_precision_retriever += 1 / (doc_idx + 1)
questions_with_docs.append({
"question": question,
"docs": retrieved_docs,
"correct_es_doc_id": doc.id})
break
retriever_total_time = time.time() - retriever_start_time
number_of_questions = q_idx + 1
number_of_no_answer = 0
previous_return_no_answers = self.reader.return_no_answers
self.reader.return_no_answers = True
# extract answers
reader_start_time = time.time()
for q_idx, question in enumerate(questions_with_docs):
if (q_idx + 1) % 100 == 0:
print(f"Processed {q_idx+1} questions.")
question_string = question["question"]["_source"]["question"]
docs = question["docs"]
single_reader_start = time.time()
predicted_answers = self.reader.predict(question_string, docs, top_k_reader)
read_times.append(time.time() - single_reader_start)
# check if question is answerable
if question["question"]["_source"]["answers"]:
for answer_idx, answer in enumerate(predicted_answers["answers"]):
found_answer = False
found_em = False
best_f1 = 0
# check if correct document
if answer["document_id"] == question["correct_es_doc_id"]:
gold_spans = [(gold_answer["answer_start"], gold_answer["answer_start"] + len(gold_answer["text"]) + 1)
for gold_answer in question["question"]["_source"]["answers"]]
predicted_span = (answer["offset_start_in_doc"], answer["offset_end_in_doc"])
for gold_span in gold_spans:
# check if overlap between gold answer and predicted answer
# top-1 answer
if not found_answer:
if (gold_span[0] <= predicted_span[1]) and (predicted_span[0] <= gold_span[1]):
# top-1 answer
if answer_idx == 0:
correct_readings_top1 += 1
correct_readings_top1_has_answer += 1
# top-k answers
correct_readings_topk += 1
correct_readings_topk_has_answer += 1
found_answer = True
# check for exact match
if not found_em:
if (gold_span[0] == predicted_span[0]) and (gold_span[1] == predicted_span[1]):
# top-1-answer
if answer_idx == 0:
exact_matches_top1 += 1
exact_matches_top1_has_answer += 1
# top-k answers
exact_matches_topk += 1
exact_matches_topk_has_answer += 1
found_em = True
# calculate f1
pred_indices = list(range(predicted_span[0], predicted_span[1] + 1))
gold_indices = list(range(gold_span[0], gold_span[1] + 1))
n_overlap = len([x for x in pred_indices if x in gold_indices])
if pred_indices and gold_indices and n_overlap:
precision = n_overlap / len(pred_indices)
recall = n_overlap / len(gold_indices)
current_f1 = (2 * precision * recall) / (precision + recall)
# top-1 answer
if answer_idx == 0:
summed_f1_top1 += current_f1
summed_f1_top1_has_answer += current_f1
if current_f1 > best_f1:
best_f1 = current_f1
# top-k answers: use best f1-score
summed_f1_topk += best_f1
summed_f1_topk_has_answer += best_f1
if found_answer and found_em:
break
# question not answerable
else:
number_of_no_answer += 1
# As question is not answerable, it is not clear how to compute average precision for this question.
# For now, we decided to calculate average precision based on the rank of 'no answer'.
for answer_idx, answer in enumerate(predicted_answers["answers"]):
# check if 'no answer'
if answer["answer"] is None:
if answer_idx == 0:
correct_no_answers_top1 += 1
correct_readings_top1 += 1
exact_matches_top1 += 1
summed_f1_top1 += 1
correct_no_answers_topk += 1
correct_readings_topk += 1
exact_matches_topk += 1
summed_f1_topk += 1
break
number_of_has_answer = correct_retrievals - number_of_no_answer
reader_total_time = time.time() - reader_start_time
finder_total_time = time.time() - finder_start_time
retriever_recall = correct_retrievals / number_of_questions
retriever_map = summed_avg_precision_retriever / number_of_questions
reader_top1_accuracy = correct_readings_top1 / correct_retrievals
reader_top1_accuracy_has_answer = correct_readings_top1_has_answer / number_of_has_answer
reader_top_k_accuracy = correct_readings_topk / correct_retrievals
reader_topk_accuracy_has_answer = correct_readings_topk_has_answer / number_of_has_answer
reader_top1_em = exact_matches_top1 / correct_retrievals
reader_top1_em_has_answer = exact_matches_top1_has_answer / number_of_has_answer
reader_topk_em = exact_matches_topk / correct_retrievals
reader_topk_em_has_answer = exact_matches_topk_has_answer / number_of_has_answer
reader_top1_f1 = summed_f1_top1 / correct_retrievals
reader_top1_f1_has_answer = summed_f1_top1_has_answer / number_of_has_answer
reader_topk_f1 = summed_f1_topk / correct_retrievals
reader_topk_f1_has_answer = summed_f1_topk_has_answer / number_of_has_answer
reader_top1_no_answer_accuracy = correct_no_answers_top1 / number_of_no_answer
reader_topk_no_answer_accuracy = correct_no_answers_topk / number_of_no_answer
self.reader.return_no_answers = previous_return_no_answers
logger.info((f"{correct_readings_topk} out of {number_of_questions} questions were correctly answered "
f"({(correct_readings_topk/number_of_questions):.2%})."))
logger.info(f"{number_of_questions-correct_retrievals} questions could not be answered due to the retriever.")
logger.info(f"{correct_retrievals-correct_readings_topk} questions could not be answered due to the reader.")
results = {
"retriever_recall": retriever_recall,
"retriever_map": retriever_map,
"reader_top1_accuracy": reader_top1_accuracy,
"reader_top1_accuracy_has_answer": reader_top1_accuracy_has_answer,
"reader_top_k_accuracy": reader_top_k_accuracy,
"reader_topk_accuracy_has_answer": reader_topk_accuracy_has_answer,
"reader_top1_em": reader_top1_em,
"reader_top1_em_has_answer": reader_top1_em_has_answer,
"reader_topk_em": reader_topk_em,
"reader_topk_em_has_answer": reader_topk_em_has_answer,
"reader_top1_f1": reader_top1_f1,
"reader_top1_f1_has_answer": reader_top1_f1_has_answer,
"reader_topk_f1": reader_topk_f1,
"reader_topk_f1_has_answer": reader_topk_f1_has_answer,
"reader_top1_no_answer_accuracy": reader_top1_no_answer_accuracy,
"reader_topk_no_answer_accuracy": reader_topk_no_answer_accuracy,
"total_retrieve_time": retriever_total_time,
"avg_retrieve_time": mean(retrieve_times),
"total_reader_time": reader_total_time,
"avg_reader_time": mean(read_times),
"total_finder_time": finder_total_time
}
return results

View File

@ -4,13 +4,16 @@ from pathlib import Path
import numpy as np
from farm.data_handler.data_silo import DataSilo
from farm.data_handler.processor import SquadProcessor
from farm.data_handler.dataloader import NamedDataLoader
from farm.infer import Inferencer
from farm.modeling.optimization import initialize_optimizer
from farm.train import Trainer
from farm.eval import Evaluator
from farm.utils import set_all_seeds, initialize_device_settings
from scipy.special import expit
from haystack.database.base import Document
from haystack.database.elasticsearch import ElasticsearchDocumentStore
logger = logging.getLogger(__name__)
@ -272,6 +275,116 @@ class FARMReader:
return result
def eval_on_file(self, data_dir: str, test_filename: str, device: str):
"""
Performs evaluation on a SQuAD-formatted file.
Returns a dict containing the following metrics:
- "EM": exact match score
- "f1": F1-Score
- "top_n_recall": Proportion of predicted answers that overlap with correct answer
:param data_dir: The directory in which the test set can be found
:type data_dir: Path or str
:param test_filename: The name of the file containing the test data in SQuAD format.
:type test_filename: str
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda".
:type device: str
"""
eval_processor = SquadProcessor(
tokenizer=self.inferencer.processor.tokenizer,
max_seq_len=self.inferencer.processor.max_seq_len,
label_list=self.inferencer.processor.tasks["question_answering"]["label_list"],
metric=self.inferencer.processor.tasks["question_answering"]["metric"],
train_filename=None,
dev_filename=None,
dev_split=0,
test_filename=test_filename,
data_dir=Path(data_dir),
)
data_silo = DataSilo(processor=eval_processor, batch_size=self.inferencer.batch_size, distributed=False)
data_loader = data_silo.get_data_loader("test")
evaluator = Evaluator(data_loader=data_loader, tasks=eval_processor.tasks, device=device)
eval_results = evaluator.eval(self.inferencer.model)
results = {
"EM": eval_results[0]["EM"],
"f1": eval_results[0]["f1"],
"top_n_recall": eval_results[0]["top_n_recall"]
}
return results
def eval(self, document_store: ElasticsearchDocumentStore, device: str, label_index: str = "feedback",
doc_index: str = "eval_document", label_origin: str = "gold_label"):
"""
Performs evaluation on evaluation documents in Elasticsearch DocumentStore.
Returns a dict containing the following metrics:
- "EM": Proportion of exact matches of predicted answers with their corresponding correct answers
- "f1": Average overlap between predicted answers and their corresponding correct answers
- "top_n_recall": Proportion of predicted answers that overlap with correct answer
:param document_store: The ElasticsearchDocumentStore containing the evaluation documents
:type document_store: ElasticsearchDocumentStore
:param device: The device on which the tensors should be processed. Choose from "cpu" and "cuda".
:type device: str
:param label_index: Elasticsearch index where labeled questions are stored
:type label_index: str
:param doc_index: Elasticsearch index where documents that are used for evaluation are stored
:type doc_index: str
"""
# extract all questions for evaluation
filter = {"origin": label_origin}
questions = document_store.get_all_documents_in_index(index=label_index, filters=filter)
# mapping from doc_id to questions
doc_questions_dict = {}
id = 0
for question in questions:
doc_id = question["_source"]["doc_id"]
if doc_id not in doc_questions_dict:
doc_questions_dict[doc_id] = [{
"id": id,
"question" : question["_source"]["question"],
"answers" : question["_source"]["answers"],
"is_impossible" : False if question["_source"]["answers"] else True
}]
else:
doc_questions_dict[doc_id].append({
"id": id,
"question" : question["_source"]["question"],
"answers" : question["_source"]["answers"],
"is_impossible" : False if question["_source"]["answers"] else True
})
id += 1
# extract eval documents and convert data back to SQuAD-like format
documents = document_store.get_all_documents_in_index(index=doc_index)
dicts = []
for document in documents:
doc_id = document["_source"]["doc_id"]
text = document["_source"]["text"]
questions = doc_questions_dict[doc_id]
dicts.append({"qas" : questions, "context" : text})
# Create DataLoader that can be passed to the Evaluator
indices = range(len(dicts))
dataset, tensor_names = self.inferencer.processor.dataset_from_dicts(dicts, indices=indices)
data_loader = NamedDataLoader(dataset=dataset, batch_size=self.inferencer.batch_size, tensor_names=tensor_names)
evaluator = Evaluator(data_loader=data_loader, tasks=self.inferencer.processor.tasks, device=device)
eval_results = evaluator.eval(self.inferencer.model)
results = {
"EM": eval_results[0]["EM"],
"f1": eval_results[0]["f1"],
"top_n_recall": eval_results[0]["top_n_recall"]
}
return results
@staticmethod
def _calc_no_answer(no_ans_gaps,best_score_answer):
# "no answer" scores and positive answers scores are difficult to compare, because

View File

@ -41,12 +41,58 @@ class ElasticsearchRetriever(BaseRetriever):
self.document_store = document_store
self.custom_query = custom_query
def retrieve(self, query: str, filters: dict = None, top_k: int = 10) -> [Document]:
documents = self.document_store.query(query, filters, top_k, self.custom_query)
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> [Document]:
if index is None:
index = self.document_store.index
documents = self.document_store.query(query, filters, top_k, self.custom_query, index)
logger.info(f"Got {len(documents)} candidates from retriever")
return documents
def eval(self, label_index: str = "feedback", doc_index: str = "eval_document", label_origin: str = "gold_label",
top_k: int = 10) -> dict:
"""
Performs evaluation on the Retriever.
Retriever is evaluated based on whether it finds the correct document given the question string and at which
position in the ranking of documents the correct document is.
Returns a dict containing the following metrics:
- "recall": Proportion of questions for which correct document is among retrieved documents
- "mean avg precision": Mean of average precision for each question. Rewards retrievers that give relevant
documents a higher rank.
:param label_index: Index/Table in DocumentStore where labeled questions are stored
:param doc_index: Index/Table in DocumentStore where documents that are used for evaluation are stored
:param top_k: How many documents to return per question
"""
# extract all questions for evaluation
filter = {"origin": label_origin}
questions = self.document_store.get_all_docs_in_index(index=label_index, filters=filter)
# calculate recall and mean-average-precision
correct_retrievals = 0
summed_avg_precision = 0
for q_idx, question in enumerate(questions):
question_string = question["_source"]["question"]
retrieved_docs = self.retrieve(question_string, top_k=top_k, index=doc_index)
# check if correct doc in retrieved docs
for doc_idx, doc in enumerate(retrieved_docs):
if doc.meta["doc_id"] == question["_source"]["doc_id"]:
correct_retrievals += 1
summed_avg_precision += 1 / (doc_idx + 1)
break
number_of_questions = q_idx + 1
recall = correct_retrievals / number_of_questions
mean_avg_precision = summed_avg_precision / number_of_questions
logger.info((f"For {correct_retrievals} out of {number_of_questions} questions ({recall:.2%}), the answer was in"
f" the top-{top_k} candidate passages selected by the retriever."))
return {"recall": recall, "map": mean_avg_precision}
class EmbeddingRetriever(BaseRetriever):
def __init__(

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,80 @@
from haystack.database.elasticsearch import ElasticsearchDocumentStore
from haystack.indexing.io import fetch_archive_from_http
from haystack.retriever.elasticsearch import ElasticsearchRetriever
from haystack.reader.farm import FARMReader
from haystack.finder import Finder
from farm.utils import initialize_device_settings
import logging
import subprocess
import time
LAUNCH_ELASTICSEARCH = False
device, n_gpu = initialize_device_settings(use_cuda=True)
# Start an Elasticsearch server
# You can start Elasticsearch on your local machine instance using Docker. If Docker is not readily available in
# your environment (eg., in Colab notebooks), then you can manually download and execute Elasticsearch from source.
if LAUNCH_ELASTICSEARCH:
logging.info("Starting Elasticsearch ...")
status = subprocess.run(
['docker run -d -p 9200:9200 -e "discovery.type=single-node" elasticsearch:7.6.2'], shell=True
)
if status.returncode:
raise Exception("Failed to launch Elasticsearch. If you want to connect to an existing Elasticsearch instance"
"then set LAUNCH_ELASTICSEARCH in the script to False.")
time.sleep(30)
# Download evaluation data, which is a subset of Natural Questions development set containing 50 documents
doc_dir = "../data/nq"
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/nq_dev_subset.json.zip"
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
# Connect to Elasticsearch
document_store = ElasticsearchDocumentStore(host="localhost", username="", password="", index="document", create_index=False)
# Add evaluation data to Elasticsearch database
document_store.add_eval_data("../data/nq/nq_dev_subset.json")
# Initialize Retriever
retriever = ElasticsearchRetriever(document_store=document_store)
# Initialize Reader
reader = FARMReader("deepset/roberta-base-squad2")
# Initialize Finder which sticks together Reader and Retriever
finder = Finder(reader, retriever)
# Evaluate Retriever on its own
retriever_eval_results = retriever.eval()
## Retriever Recall is the proportion of questions for which the correct document containing the answer is
## among the correct documents
print("Retriever Recall:", retriever_eval_results["recall"])
## Retriever Mean Avg Precision rewards retrievers that give relevant documents a higher rank
print("Retriever Mean Avg Precision:", retriever_eval_results["mean avg precision"])
# Evaluate Reader on its own
reader_start = time.time()
reader_eval_results = reader.eval(document_store=document_store, device=device)
reader_total = time.time() - reader_start
# Evaluation of Reader can also be done directly on a SQuAD-formatted file without passing the data to Elasticsearch
#reader_eval_results = reader.eval_on_file("../data/natural_questions", "dev_subset.json", device=device)
## Reader Top-N-Recall is the proportion of predicted answers that overlap with their corresponding correct answer
print("Reader Top-N-Recall:", reader_eval_results["top_n_recall"])
## Reader Exact Match is the proportion of questions where the predicted answer is exactly the same as the correct answer
print("Reader Exact Match:", reader_eval_results["EM"])
## Reader F1-Score is the average overlap between the predicted answers and the correct answers
print("Reader F1-Score:", reader_eval_results["f1"])
# Evaluate combination of Reader and Retriever through Finder
finder_eval_results = finder.eval()
print("Retriever Recall in Finder:", finder_eval_results["retriever_recall"])
print("Retriever Mean Avg Precision in Finder:", finder_eval_results["retriever_map"])
# Reader is only evaluated with those questions, where the correct document is among the retrieved ones
print("Reader Recall in Finder:", finder_eval_results["reader_recall"])
print("Reader Mean Avg Precision in Finder:", finder_eval_results["reader_map"])
print("Reader Exact Match in Finder:", finder_eval_results["reader_em"])
print("Reader F1-Score in Finder:", finder_eval_results["reader_f1"])
print(f"Finder time: {finder_eval_results['total_finder_time']}s")