mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-21 06:58:27 +00:00
Upgrade to new FARM / Transformers / PyTorch versions (#212)
This commit is contained in:
parent
17c1b84c21
commit
99a6a34047
@ -6,8 +6,10 @@ import numpy as np
|
|||||||
from farm.data_handler.data_silo import DataSilo
|
from farm.data_handler.data_silo import DataSilo
|
||||||
from farm.data_handler.processor import SquadProcessor
|
from farm.data_handler.processor import SquadProcessor
|
||||||
from farm.data_handler.dataloader import NamedDataLoader
|
from farm.data_handler.dataloader import NamedDataLoader
|
||||||
from farm.infer import Inferencer
|
from farm.data_handler.inputs import QAInput, Question
|
||||||
|
from farm.infer import QAInferencer
|
||||||
from farm.modeling.optimization import initialize_optimizer
|
from farm.modeling.optimization import initialize_optimizer
|
||||||
|
from farm.modeling.predictions import QAPred, QACandidate
|
||||||
from farm.train import Trainer
|
from farm.train import Trainer
|
||||||
from farm.eval import Evaluator
|
from farm.eval import Evaluator
|
||||||
from farm.utils import set_all_seeds, initialize_device_settings
|
from farm.utils import set_all_seeds, initialize_device_settings
|
||||||
@ -85,7 +87,7 @@ class FARMReader(BaseReader):
|
|||||||
else:
|
else:
|
||||||
self.return_no_answers = True
|
self.return_no_answers = True
|
||||||
self.top_k_per_candidate = top_k_per_candidate
|
self.top_k_per_candidate = top_k_per_candidate
|
||||||
self.inferencer = Inferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
|
self.inferencer = QAInferencer.load(model_name_or_path, batch_size=batch_size, gpu=use_gpu,
|
||||||
task_type="question_answering", max_seq_len=max_seq_len,
|
task_type="question_answering", max_seq_len=max_seq_len,
|
||||||
doc_stride=doc_stride, num_processes=num_processes)
|
doc_stride=doc_stride, num_processes=num_processes)
|
||||||
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
self.inferencer.model.prediction_heads[0].context_window_size = context_window_size
|
||||||
@ -231,18 +233,16 @@ class FARMReader(BaseReader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# convert input to FARM format
|
# convert input to FARM format
|
||||||
input_dicts = []
|
inputs = []
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
cur = {
|
cur = QAInput(doc_text=doc.text,
|
||||||
"text": doc.text,
|
questions=Question(text=question,
|
||||||
"questions": [question],
|
uid=doc.id))
|
||||||
"document_id": doc.id
|
inputs.append(cur)
|
||||||
}
|
|
||||||
input_dicts.append(cur)
|
|
||||||
|
|
||||||
# get answers from QA model
|
# get answers from QA model
|
||||||
predictions = self.inferencer.inference_from_dicts(
|
predictions = self.inferencer.inference_from_objects(
|
||||||
dicts=input_dicts, return_json=True, multiprocessing_chunksize=1
|
objects=inputs, return_json=False, multiprocessing_chunksize=1
|
||||||
)
|
)
|
||||||
# assemble answers from all the different documents & format them.
|
# assemble answers from all the different documents & format them.
|
||||||
# For the "no answer" option, we collect all no_ans_gaps and decide how likely
|
# For the "no answer" option, we collect all no_ans_gaps and decide how likely
|
||||||
@ -250,29 +250,28 @@ class FARMReader(BaseReader):
|
|||||||
answers = []
|
answers = []
|
||||||
no_ans_gaps = []
|
no_ans_gaps = []
|
||||||
best_score_answer = 0
|
best_score_answer = 0
|
||||||
# TODO once FARM returns doc ids again we can revert to using them inside the preds and remove
|
for pred, inp in zip(predictions, inputs):
|
||||||
for pred, inp in zip(predictions, input_dicts):
|
|
||||||
answers_per_document = []
|
answers_per_document = []
|
||||||
no_ans_gaps.append(pred["predictions"][0]["no_ans_gap"])
|
no_ans_gaps.append(pred.no_answer_gap)
|
||||||
for ans in pred["predictions"][0]["answers"]:
|
for ans in pred.prediction:
|
||||||
# skip "no answers" here
|
# skip "no answers" here
|
||||||
if self._check_no_answer(ans):
|
if self._check_no_answer(ans):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
cur = {"answer": ans["answer"],
|
cur = {"answer": ans.answer,
|
||||||
"score": ans["score"],
|
"score": ans.score,
|
||||||
# just a pseudo prob for now
|
# just a pseudo prob for now
|
||||||
"probability": float(expit(np.asarray([ans["score"]]) / 8)), # type: ignore
|
"probability": float(expit(np.asarray([ans.score]) / 8)), # type: ignore
|
||||||
"context": ans["context"],
|
"context": ans.context_window,
|
||||||
"offset_start": ans["offset_answer_start"] - ans["offset_context_start"],
|
"offset_start": ans.offset_answer_start - ans.offset_context_window_start,
|
||||||
"offset_end": ans["offset_answer_end"] - ans["offset_context_start"],
|
"offset_end": ans.offset_answer_end - ans.offset_context_window_start,
|
||||||
"offset_start_in_doc": ans["offset_answer_start"],
|
"offset_start_in_doc": ans.offset_answer_start,
|
||||||
"offset_end_in_doc": ans["offset_answer_end"],
|
"offset_end_in_doc": ans.offset_answer_end,
|
||||||
"document_id": inp["document_id"]} #TODO revert to ans["docid"] once it is populated
|
"document_id": pred.id}
|
||||||
answers_per_document.append(cur)
|
answers_per_document.append(cur)
|
||||||
|
|
||||||
if ans["score"] > best_score_answer:
|
if ans.score > best_score_answer:
|
||||||
best_score_answer = ans["score"]
|
best_score_answer = ans.score
|
||||||
# only take n best candidates. Answers coming back from FARM are sorted with decreasing relevance.
|
# only take n best candidates. Answers coming back from FARM are sorted with decreasing relevance.
|
||||||
answers += answers_per_document[:self.top_k_per_candidate]
|
answers += answers_per_document[:self.top_k_per_candidate]
|
||||||
|
|
||||||
@ -299,7 +298,7 @@ class FARMReader(BaseReader):
|
|||||||
Returns a dict containing the following metrics:
|
Returns a dict containing the following metrics:
|
||||||
- "EM": exact match score
|
- "EM": exact match score
|
||||||
- "f1": F1-Score
|
- "f1": F1-Score
|
||||||
- "top_n_recall": Proportion of predicted answers that overlap with correct answer
|
- "top_n_accuracy": Proportion of predicted answers that match with correct answer
|
||||||
|
|
||||||
:param data_dir: The directory in which the test set can be found
|
:param data_dir: The directory in which the test set can be found
|
||||||
:type data_dir: Path or str
|
:type data_dir: Path or str
|
||||||
@ -329,7 +328,7 @@ class FARMReader(BaseReader):
|
|||||||
results = {
|
results = {
|
||||||
"EM": eval_results[0]["EM"],
|
"EM": eval_results[0]["EM"],
|
||||||
"f1": eval_results[0]["f1"],
|
"f1": eval_results[0]["f1"],
|
||||||
"top_n_recall": eval_results[0]["top_n_recall"]
|
"top_n_accuracy": eval_results[0]["top_n_accuracy"]
|
||||||
}
|
}
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -347,7 +346,7 @@ class FARMReader(BaseReader):
|
|||||||
Returns a dict containing the following metrics:
|
Returns a dict containing the following metrics:
|
||||||
- "EM": Proportion of exact matches of predicted answers with their corresponding correct answers
|
- "EM": Proportion of exact matches of predicted answers with their corresponding correct answers
|
||||||
- "f1": Average overlap between predicted answers and their corresponding correct answers
|
- "f1": Average overlap between predicted answers and their corresponding correct answers
|
||||||
- "top_n_recall": Proportion of predicted answers that overlap with correct answer
|
- "top_n_accuracy": Proportion of predicted answers that match with correct answer
|
||||||
|
|
||||||
:param document_store: The ElasticsearchDocumentStore containing the evaluation documents
|
:param document_store: The ElasticsearchDocumentStore containing the evaluation documents
|
||||||
:type document_store: ElasticsearchDocumentStore
|
:type document_store: ElasticsearchDocumentStore
|
||||||
@ -404,23 +403,23 @@ class FARMReader(BaseReader):
|
|||||||
results = {
|
results = {
|
||||||
"EM": eval_results[0]["EM"],
|
"EM": eval_results[0]["EM"],
|
||||||
"f1": eval_results[0]["f1"],
|
"f1": eval_results[0]["f1"],
|
||||||
"top_n_recall": eval_results[0]["top_n_recall"]
|
"top_n_accuracy": eval_results[0]["top_n_accuracy"]
|
||||||
}
|
}
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_no_answer(d: dict):
|
def _check_no_answer(c: QACandidate):
|
||||||
# check for correct value in "answer"
|
# check for correct value in "answer"
|
||||||
if d["offset_answer_start"] == 0 and d["offset_answer_end"] == 0:
|
if c.offset_answer_start == 0 and c.offset_answer_end == 0:
|
||||||
assert d["answer"] == "is_impossible", f"Check for no answer is not working"
|
if c.answer != "no_answer":
|
||||||
|
logger.error("Invalid 'no_answer': Got a prediction for position 0, but answer string is not 'no_answer'")
|
||||||
# check weather the model thinks there is no answer
|
if c.answer == "no_answer":
|
||||||
if d["answer"] == "is_impossible":
|
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _calc_no_answer(no_ans_gaps: List[float], best_score_answer: float):
|
def _calc_no_answer(no_ans_gaps: List[float], best_score_answer: float):
|
||||||
# "no answer" scores and positive answers scores are difficult to compare, because
|
# "no answer" scores and positive answers scores are difficult to compare, because
|
||||||
@ -476,5 +475,5 @@ class FARMReader(BaseReader):
|
|||||||
are "gpu_tensor_core" (GPUs with tensor core like V100 or T4),
|
are "gpu_tensor_core" (GPUs with tensor core like V100 or T4),
|
||||||
"gpu_without_tensor_core" (most other GPUs), and "cpu".
|
"gpu_without_tensor_core" (most other GPUs), and "cpu".
|
||||||
"""
|
"""
|
||||||
inferencer = Inferencer.load(model_name_or_path, task_type="question_answering")
|
inferencer = QAInferencer.load(model_name_or_path, task_type="question_answering")
|
||||||
inferencer.model.convert_to_onnx(output_path=Path("onnx-export"), opset_version=opset_version, optimize_for=optimize_for)
|
inferencer.model.convert_to_onnx(output_path=Path("onnx-export"), opset_version=opset_version, optimize_for=optimize_for)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import pipeline
|
from haystack.reader.transformers_utils import pipeline
|
||||||
|
|
||||||
from haystack.database.base import Document
|
from haystack.database.base import Document
|
||||||
from haystack.reader.base import BaseReader
|
from haystack.reader.base import BaseReader
|
||||||
@ -40,10 +40,11 @@ class TransformersReader(BaseReader):
|
|||||||
:param use_gpu: < 0 -> use cpu
|
:param use_gpu: < 0 -> use cpu
|
||||||
>= 0 -> ordinal of the gpu to use
|
>= 0 -> ordinal of the gpu to use
|
||||||
"""
|
"""
|
||||||
self.model = pipeline("question-answering", model=model, tokenizer=tokenizer, device=use_gpu)
|
self.model = pipeline('question-answering', model=model, tokenizer=tokenizer, device=use_gpu)
|
||||||
self.context_window_size = context_window_size
|
self.context_window_size = context_window_size
|
||||||
self.n_best_per_passage = n_best_per_passage
|
self.n_best_per_passage = n_best_per_passage
|
||||||
#TODO param to modify bias for no_answer
|
#TODO param to modify bias for no_answer
|
||||||
|
# TODO context_window_size behaviour different from behavior in FARMReader
|
||||||
|
|
||||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
@ -76,6 +77,9 @@ class TransformersReader(BaseReader):
|
|||||||
for doc in documents:
|
for doc in documents:
|
||||||
query = {"context": doc.text, "question": question}
|
query = {"context": doc.text, "question": question}
|
||||||
predictions = self.model(query, topk=self.n_best_per_passage)
|
predictions = self.model(query, topk=self.n_best_per_passage)
|
||||||
|
# for single preds (e.g. via top_k=1) transformers returns a dict instead of a list
|
||||||
|
if type(predictions) == dict:
|
||||||
|
predictions = [predictions]
|
||||||
# assemble and format all answers
|
# assemble and format all answers
|
||||||
for pred in predictions:
|
for pred in predictions:
|
||||||
if pred["answer"]:
|
if pred["answer"]:
|
||||||
|
1912
haystack/reader/transformers_utils.py
Normal file
1912
haystack/reader/transformers_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
|||||||
farm==0.4.5
|
farm==0.4.6
|
||||||
--find-links=https://download.pytorch.org/whl/torch_stable.html
|
--find-links=https://download.pytorch.org/whl/torch_stable.html
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
import tarfile
|
import tarfile
|
||||||
import time
|
import time
|
||||||
import urllib.request
|
import urllib.request
|
||||||
@ -10,6 +11,7 @@ from elasticsearch import Elasticsearch
|
|||||||
from haystack.reader.farm import FARMReader
|
from haystack.reader.farm import FARMReader
|
||||||
from haystack.reader.transformers import TransformersReader
|
from haystack.reader.transformers import TransformersReader
|
||||||
|
|
||||||
|
from haystack.database.base import Document
|
||||||
from haystack.database.sql import SQLDocumentStore
|
from haystack.database.sql import SQLDocumentStore
|
||||||
from haystack.database.memory import InMemoryDocumentStore
|
from haystack.database.memory import InMemoryDocumentStore
|
||||||
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
from haystack.database.elasticsearch import ElasticsearchDocumentStore
|
||||||
@ -72,6 +74,39 @@ def reader(request):
|
|||||||
use_gpu=-1)
|
use_gpu=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Fix bug in test_no_answer_output when using
|
||||||
|
# @pytest.fixture(params=["farm", "transformers"])
|
||||||
|
@pytest.fixture(params=["farm"])
|
||||||
|
def no_answer_reader(request):
|
||||||
|
if request.param == "farm":
|
||||||
|
return FARMReader(model_name_or_path="deepset/roberta-base-squad2",
|
||||||
|
use_gpu=False, top_k_per_sample=5, no_ans_boost=0, num_processes=0)
|
||||||
|
if request.param == "transformers":
|
||||||
|
return TransformersReader(model="deepset/roberta-base-squad2",
|
||||||
|
tokenizer="deepset/roberta-base-squad2",
|
||||||
|
use_gpu=-1, n_best_per_passage=5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def prediction(reader, test_docs_xs):
|
||||||
|
docs = []
|
||||||
|
for d in test_docs_xs:
|
||||||
|
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
||||||
|
docs.append(doc)
|
||||||
|
prediction = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def no_answer_prediction(no_answer_reader, test_docs_xs):
|
||||||
|
docs = []
|
||||||
|
for d in test_docs_xs:
|
||||||
|
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
||||||
|
docs.append(doc)
|
||||||
|
prediction = no_answer_reader.predict(question="What is the meaning of life?", documents=docs, top_k=5)
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["sql", "memory", "elasticsearch"])
|
@pytest.fixture(params=["sql", "memory", "elasticsearch"])
|
||||||
def document_store_with_docs(request, test_docs_xs, elasticsearch_fixture):
|
def document_store_with_docs(request, test_docs_xs, elasticsearch_fixture):
|
||||||
if request.param == "sql":
|
if request.param == "sql":
|
||||||
|
@ -28,17 +28,17 @@ def test_finder_offsets(reader, document_store_with_docs):
|
|||||||
top_k_reader=5)
|
top_k_reader=5)
|
||||||
|
|
||||||
assert prediction["answers"][0]["offset_start"] == 11
|
assert prediction["answers"][0]["offset_start"] == 11
|
||||||
#TODO enable again when FARM is upgraded incl. the new offset calc
|
assert prediction["answers"][0]["offset_end"] == 16
|
||||||
# assert prediction["answers"][0]["offset_end"] == 16
|
|
||||||
start = prediction["answers"][0]["offset_start"]
|
start = prediction["answers"][0]["offset_start"]
|
||||||
end = prediction["answers"][0]["offset_end"]
|
end = prediction["answers"][0]["offset_end"]
|
||||||
#assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
||||||
|
|
||||||
|
|
||||||
def test_finder_get_answers_single_result(reader, document_store_with_docs):
|
def test_finder_get_answers_single_result(reader, document_store_with_docs):
|
||||||
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
retriever = TfidfRetriever(document_store=document_store_with_docs)
|
||||||
finder = Finder(reader, retriever)
|
finder = Finder(reader, retriever)
|
||||||
prediction = finder.get_answers(question="testing finder", top_k_retriever=1,
|
query = "testing finder"
|
||||||
|
prediction = finder.get_answers(question=query, top_k_retriever=1,
|
||||||
top_k_reader=1)
|
top_k_reader=1)
|
||||||
assert prediction is not None
|
assert prediction is not None
|
||||||
assert len(prediction["answers"]) == 1
|
assert len(prediction["answers"]) == 1
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
import pytest
|
import math
|
||||||
|
|
||||||
from haystack.reader.base import BaseReader
|
|
||||||
from haystack.database.base import Document
|
from haystack.database.base import Document
|
||||||
|
from haystack.reader.base import BaseReader
|
||||||
|
from haystack.reader.farm import FARMReader
|
||||||
|
from haystack.reader.transformers import TransformersReader
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_reader_basic(reader):
|
def test_reader_basic(reader):
|
||||||
@ -9,20 +12,89 @@ def test_reader_basic(reader):
|
|||||||
assert isinstance(reader, BaseReader)
|
assert isinstance(reader, BaseReader)
|
||||||
|
|
||||||
|
|
||||||
def test_output(reader, test_docs_xs):
|
def test_output(prediction):
|
||||||
|
assert prediction is not None
|
||||||
|
assert prediction["question"] == "Who lives in Berlin?"
|
||||||
|
assert prediction["answers"][0]["answer"] == "Carla"
|
||||||
|
assert prediction["answers"][0]["offset_start"] == 11
|
||||||
|
assert prediction["answers"][0]["offset_end"] == 16
|
||||||
|
assert prediction["answers"][0]["probability"] <= 1
|
||||||
|
assert prediction["answers"][0]["probability"] >= 0
|
||||||
|
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||||
|
assert prediction["answers"][0]["document_id"] == "filename1"
|
||||||
|
assert len(prediction["answers"]) == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_answer_output(no_answer_prediction):
|
||||||
|
assert no_answer_prediction is not None
|
||||||
|
assert no_answer_prediction["question"] == "What is the meaning of life?"
|
||||||
|
assert math.isclose(no_answer_prediction["no_ans_gap"], -14.4729533, rel_tol=0.0001)
|
||||||
|
assert no_answer_prediction["answers"][0]["answer"] is None
|
||||||
|
assert no_answer_prediction["answers"][0]["offset_start"] == 0
|
||||||
|
assert no_answer_prediction["answers"][0]["offset_end"] == 0
|
||||||
|
assert no_answer_prediction["answers"][0]["probability"] <= 1
|
||||||
|
assert no_answer_prediction["answers"][0]["probability"] >= 0
|
||||||
|
assert no_answer_prediction["answers"][0]["context"] == None
|
||||||
|
assert no_answer_prediction["answers"][0]["document_id"] == None
|
||||||
|
answers = [x["answer"] for x in no_answer_prediction["answers"]]
|
||||||
|
assert answers.count(None) == 1
|
||||||
|
assert len(no_answer_prediction["answers"]) == 5
|
||||||
|
|
||||||
|
# TODO Directly compare farm and transformers reader outputs
|
||||||
|
# TODO checks to see that model is responsive to input arguments e.g. context_window_size - topk
|
||||||
|
|
||||||
|
def test_prediction_attributes(prediction):
|
||||||
|
# TODO FARM's prediction also has no_ans_gap
|
||||||
|
attributes_gold = ["question", "answers"]
|
||||||
|
for ag in attributes_gold:
|
||||||
|
assert ag in prediction
|
||||||
|
|
||||||
|
|
||||||
|
def test_answer_attributes(prediction):
|
||||||
|
# TODO Transformers answer also has meta key
|
||||||
|
# TODO FARM answer has offset_start_in_doc, offset_end_in_doc
|
||||||
|
answer = prediction["answers"][0]
|
||||||
|
attributes_gold = ['answer', 'score', 'probability', 'context', 'offset_start', 'offset_end', 'document_id']
|
||||||
|
for ag in attributes_gold:
|
||||||
|
assert ag in answer
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_window_size(test_docs_xs):
|
||||||
|
# TODO parametrize window_size and farm/transformers reader using pytest
|
||||||
docs = []
|
docs = []
|
||||||
for d in test_docs_xs:
|
for d in test_docs_xs:
|
||||||
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
||||||
docs.append(doc)
|
docs.append(doc)
|
||||||
results = reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
|
for window_size in [10, 15, 20]:
|
||||||
assert results is not None
|
farm_reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||||
assert results["question"] == "Who lives in Berlin?"
|
use_gpu=False, top_k_per_sample=5, no_ans_boost=None, context_window_size=window_size)
|
||||||
assert results["answers"][0]["answer"] == "Carla"
|
prediction = farm_reader.predict(question="Who lives in Berlin?", documents=docs, top_k=5)
|
||||||
assert results["answers"][0]["offset_start"] == 11
|
for answer in prediction["answers"]:
|
||||||
#TODO enable again when FARM is upgraded incl. the new offset calc
|
# If the extracted answer is larger than the context window, the context window is expanded.
|
||||||
# assert results["answers"][0]["offset_end"] == 16
|
# If the extracted answer is odd in length, the resulting context window is one less than context_window_size
|
||||||
assert results["answers"][0]["probability"] <= 1
|
# due to rounding (See FARM's QACandidate)
|
||||||
assert results["answers"][0]["probability"] >= 0
|
# TODO Currently the behaviour of context_window_size in FARMReader and TransformerReader is different
|
||||||
assert results["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
if len(answer["answer"]) <= window_size:
|
||||||
assert results["answers"][0]["document_id"] == "filename1"
|
assert len(answer["context"]) in [window_size, window_size-1]
|
||||||
assert len(results["answers"]) == 5
|
else:
|
||||||
|
assert len(answer["answer"]) == len(answer["context"])
|
||||||
|
|
||||||
|
# TODO Need to test transformers reader
|
||||||
|
# TODO Currently the behaviour of context_window_size in FARMReader and TransformerReader is different
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_k(test_docs_xs):
|
||||||
|
# TODO parametrize top_k and farm/transformers reader using pytest
|
||||||
|
# TODO transformers reader was crashing when tested on this
|
||||||
|
docs = []
|
||||||
|
for d in test_docs_xs:
|
||||||
|
doc = Document(id=d["meta"]["name"], text=d["text"], meta=d["meta"])
|
||||||
|
docs.append(doc)
|
||||||
|
farm_reader = FARMReader(model_name_or_path="distilbert-base-uncased-distilled-squad",
|
||||||
|
use_gpu=False, top_k_per_sample=4, no_ans_boost=None, top_k_per_candidate=4)
|
||||||
|
for top_k in [2, 5, 10]:
|
||||||
|
prediction = farm_reader.predict(question="Who lives in Berlin?", documents=docs, top_k=top_k)
|
||||||
|
assert len(prediction["answers"]) == top_k
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,8 +76,8 @@ if eval_reader_only:
|
|||||||
# Evaluation of Reader can also be done directly on a SQuAD-formatted file without passing the data to Elasticsearch
|
# Evaluation of Reader can also be done directly on a SQuAD-formatted file without passing the data to Elasticsearch
|
||||||
#reader_eval_results = reader.eval_on_file("../data/natural_questions", "dev_subset.json", device=device)
|
#reader_eval_results = reader.eval_on_file("../data/natural_questions", "dev_subset.json", device=device)
|
||||||
|
|
||||||
## Reader Top-N-Recall is the proportion of predicted answers that overlap with their corresponding correct answer
|
## Reader Top-N-Accuracy is the proportion of predicted answers that match with their corresponding correct answer
|
||||||
print("Reader Top-N-Recall:", reader_eval_results["top_n_recall"])
|
print("Reader Top-N-Accuracy:", reader_eval_results["top_n_accuracy"])
|
||||||
## Reader Exact Match is the proportion of questions where the predicted answer is exactly the same as the correct answer
|
## Reader Exact Match is the proportion of questions where the predicted answer is exactly the same as the correct answer
|
||||||
print("Reader Exact Match:", reader_eval_results["EM"])
|
print("Reader Exact Match:", reader_eval_results["EM"])
|
||||||
## Reader F1-Score is the average overlap between the predicted answers and the correct answers
|
## Reader F1-Score is the average overlap between the predicted answers and the correct answers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user