mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 20:17:14 +00:00
fix: Dynamic max_answers for SquadProcessor (fixes IndexError when max_answers is less than the number of answers in the dataset) (#4817)
* #4320 implemented dynamic max_answers for SquadProcessor, fixed IndexError when max_answers is less than the number of answers in the dataset * #4320 added two unit tests for dataset_from_dicts testing default and manual max_answers * apply suggestions from code review Co-authored-by: bogdankostic <bogdankostic@web.de> * simplify comment, fix mypy & pylint errors, fix old test * adjust max_answers to each dataset individually --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
8fbfca9ebb
commit
099d0deb86
@ -35,7 +35,6 @@ from haystack.modeling.data_handler.samples import (
|
||||
from haystack.modeling.data_handler.input_features import sample_to_features_text
|
||||
from haystack.utils.experiment_tracking import Tracker as tracker
|
||||
|
||||
|
||||
DOWNSTREAM_TASK_MAP = {
|
||||
"squad20": "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/squad20.tar.gz",
|
||||
"covidqa": "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/covidqa.tar.gz",
|
||||
@ -381,7 +380,7 @@ class SquadProcessor(Processor):
|
||||
doc_stride: int = 128,
|
||||
max_query_length: int = 64,
|
||||
proxies: Optional[dict] = None,
|
||||
max_answers: int = 6,
|
||||
max_answers: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -403,7 +402,9 @@ class SquadProcessor(Processor):
|
||||
:param max_query_length: Maximum length of the question (in number of subword tokens)
|
||||
:param proxies: proxy configuration to allow downloads of remote datasets.
|
||||
Format as in "requests" library: https://2.python-requests.org//en/latest/user/advanced/#proxies
|
||||
:param max_answers: number of answers to be converted. QA dev or train sets can contain multi-way annotations, which are converted to arrays of max_answer length
|
||||
:param max_answers: Number of answers to be converted. QA sets can contain multi-way annotations, which are converted to arrays of max_answer length.
|
||||
Adjusts to maximum number of answers in the first processed datasets if not set.
|
||||
Truncates or pads to max_answer length if set.
|
||||
:param kwargs: placeholder for passing generic parameters
|
||||
"""
|
||||
self.ph_output_type = "per_token_squad"
|
||||
@ -469,12 +470,19 @@ class SquadProcessor(Processor):
|
||||
# Split documents into smaller passages to fit max_seq_len
|
||||
baskets = self._split_docs_into_passages(baskets)
|
||||
|
||||
# Determine max_answers if not set
|
||||
max_answers = (
|
||||
self.max_answers
|
||||
if self.max_answers is not None
|
||||
else max(max(len(basket.raw["answers"]) for basket in baskets), 1)
|
||||
)
|
||||
|
||||
# Convert answers from string to token space, skip this step for inference
|
||||
if not return_baskets:
|
||||
baskets = self._convert_answers(baskets)
|
||||
baskets = self._convert_answers(baskets, max_answers)
|
||||
|
||||
# Convert internal representation (nested baskets + samples with mixed types) to pytorch features (arrays of numbers)
|
||||
baskets = self._passages_to_pytorch_features(baskets, return_baskets)
|
||||
baskets = self._passages_to_pytorch_features(baskets, return_baskets, max_answers)
|
||||
|
||||
# Convert features into pytorch dataset, this step also removes potential errors during preprocessing
|
||||
dataset, tensor_names, baskets = self._create_dataset(baskets)
|
||||
@ -607,7 +615,7 @@ class SquadProcessor(Processor):
|
||||
|
||||
return baskets
|
||||
|
||||
def _convert_answers(self, baskets: List[SampleBasket]):
|
||||
def _convert_answers(self, baskets: List[SampleBasket], max_answers: int):
|
||||
"""
|
||||
Converts answers that are pure strings into the token based representation with start and end token offset.
|
||||
Can handle multiple answers per question document pair as is common for development/text sets
|
||||
@ -617,7 +625,7 @@ class SquadProcessor(Processor):
|
||||
for sample in basket.samples: # type: ignore
|
||||
# Dealing with potentially multiple answers (e.g. Squad dev set)
|
||||
# Initializing a numpy array of shape (max_answers, 2), filled with -1 for missing values
|
||||
label_idxs = np.full((self.max_answers, 2), fill_value=-1)
|
||||
label_idxs = np.full((max_answers, 2), fill_value=-1)
|
||||
|
||||
if error_in_answer or (len(basket.raw["answers"]) == 0):
|
||||
# If there are no answers we set
|
||||
@ -625,6 +633,14 @@ class SquadProcessor(Processor):
|
||||
else:
|
||||
# For all other cases we use start and end token indices, that are relative to the passage
|
||||
for i, answer in enumerate(basket.raw["answers"]):
|
||||
if i >= max_answers:
|
||||
logger.warning(
|
||||
"Found a sample with more answers (%d) than "
|
||||
"max_answers (%d). These will be ignored.",
|
||||
len(basket.raw["answers"]),
|
||||
max_answers,
|
||||
)
|
||||
break
|
||||
# Calculate start and end relative to document
|
||||
answer_len_c = len(answer["text"])
|
||||
answer_start_c = answer["answer_start"]
|
||||
@ -691,7 +707,7 @@ class SquadProcessor(Processor):
|
||||
|
||||
return baskets
|
||||
|
||||
def _passages_to_pytorch_features(self, baskets: List[SampleBasket], return_baskets: bool):
|
||||
def _passages_to_pytorch_features(self, baskets: List[SampleBasket], return_baskets: bool, max_answers: int):
|
||||
"""
|
||||
Convert internal representation (nested baskets + samples with mixed types) to python features (arrays of numbers).
|
||||
We first join question and passages into one large vector.
|
||||
@ -769,7 +785,7 @@ class SquadProcessor(Processor):
|
||||
len(input_ids) == len(padding_mask) == len(segment_ids) == len(start_of_word) == len(span_mask)
|
||||
)
|
||||
id_check = len(sample_id) == 3
|
||||
label_check = return_baskets or len(sample.tokenized.get("labels", [])) == self.max_answers # type: ignore
|
||||
label_check = return_baskets or len(sample.tokenized.get("labels", [])) == max_answers # type: ignore
|
||||
# labels are set to -100 when answer cannot be found
|
||||
label_check2 = return_baskets or np.all(sample.tokenized["labels"] > -99) # type: ignore
|
||||
if len_check and id_check and label_check and label_check2:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from haystack.modeling.data_handler.processor import SquadProcessor
|
||||
@ -233,7 +235,7 @@ def test_batch_encoding_flatten_rename():
|
||||
pass
|
||||
|
||||
|
||||
def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None):
|
||||
def test_dataset_from_dicts_qa_label_conversion(samples_path, caplog=None):
|
||||
if caplog:
|
||||
caplog.set_level(logging.CRITICAL)
|
||||
|
||||
@ -248,7 +250,7 @@ def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None):
|
||||
|
||||
for model in models:
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model)
|
||||
processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None)
|
||||
processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None, max_answers=6)
|
||||
|
||||
for sample_type in sample_types:
|
||||
dicts = processor.file_to_dicts(samples_path / "qa" / f"{sample_type}.json")
|
||||
@ -296,3 +298,36 @@ def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None):
|
||||
12,
|
||||
12,
|
||||
], f"Processing labels for {model} has changed."
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_dataset_from_dicts_auto_determine_max_answers(samples_path, caplog=None):
|
||||
"""
|
||||
SquadProcessor should determine the number of answers for the pytorch dataset based on
|
||||
the maximum number of answers for each question. Vanilla.json has one question with two answers,
|
||||
so the number of answers should be two.
|
||||
"""
|
||||
model = "deepset/roberta-base-squad2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model)
|
||||
processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None)
|
||||
dicts = processor.file_to_dicts(samples_path / "qa" / "vanilla.json")
|
||||
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts, indices=[1])
|
||||
assert len(dataset[0][tensor_names.index("labels")]) == 2
|
||||
# check that a max_answers will be adjusted when processing a different dataset with the same SquadProcessor
|
||||
dicts_more_answers = copy.deepcopy(dicts)
|
||||
dicts_more_answers[0]["qas"][0]["answers"] = dicts_more_answers[0]["qas"][0]["answers"] * 3
|
||||
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts_more_answers, indices=[1])
|
||||
assert len(dataset[0][tensor_names.index("labels")]) == 6
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_dataset_from_dicts_truncate_max_answers(samples_path, caplog=None):
|
||||
"""
|
||||
Test that it is possible to manually set the number of answers, truncating the answers in the data.
|
||||
"""
|
||||
model = "deepset/roberta-base-squad2"
|
||||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model)
|
||||
processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None, max_answers=1)
|
||||
dicts = processor.file_to_dicts(samples_path / "qa" / "vanilla.json")
|
||||
dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts, indices=[1])
|
||||
assert len(dataset[0][tensor_names.index("labels")]) == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user