From 099d0deb86b1e004d1a4d031bb112cff5b677683 Mon Sep 17 00:00:00 2001 From: Ben Heckmann <79015931+benheckmann@users.noreply.github.com> Date: Mon, 15 May 2023 14:34:23 +0200 Subject: [PATCH] fix: Dynamic `max_answers` for SquadProcessor (fixes IndexError when max_answers is less than the number of answers in the dataset) (#4817) * #4320 implemented dynamic max_answers for SquadProcessor, fixed IndexError when max_answers is less than the number of answers in the dataset * #4320 added two unit tests for dataset_from_dicts testing default and manual max_answers * apply suggestions from code review Co-authored-by: bogdankostic * simplify comment, fix mypy & pylint errors, fix old test * adjust max_answers to each dataset individually --------- Co-authored-by: bogdankostic --- haystack/modeling/data_handler/processor.py | 34 +++++++++++++----- test/modeling/test_processor.py | 39 +++++++++++++++++++-- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/haystack/modeling/data_handler/processor.py b/haystack/modeling/data_handler/processor.py index 9055180b7..08e6a7af9 100644 --- a/haystack/modeling/data_handler/processor.py +++ b/haystack/modeling/data_handler/processor.py @@ -35,7 +35,6 @@ from haystack.modeling.data_handler.samples import ( from haystack.modeling.data_handler.input_features import sample_to_features_text from haystack.utils.experiment_tracking import Tracker as tracker - DOWNSTREAM_TASK_MAP = { "squad20": "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/squad20.tar.gz", "covidqa": "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-downstream/covidqa.tar.gz", @@ -381,7 +380,7 @@ class SquadProcessor(Processor): doc_stride: int = 128, max_query_length: int = 64, proxies: Optional[dict] = None, - max_answers: int = 6, + max_answers: Optional[int] = None, **kwargs, ): """ @@ -403,7 +402,9 @@ class SquadProcessor(Processor): :param max_query_length: Maximum length of the question (in number of subword tokens) :param proxies: proxy configuration to allow downloads of remote datasets. Format as in "requests" library: https://2.python-requests.org//en/latest/user/advanced/#proxies - :param max_answers: number of answers to be converted. QA dev or train sets can contain multi-way annotations, which are converted to arrays of max_answer length + :param max_answers: Number of answers to be converted. QA sets can contain multi-way annotations, which are converted to arrays of max_answer length. + Adjusts to maximum number of answers in the first processed datasets if not set. + Truncates or pads to max_answer length if set. :param kwargs: placeholder for passing generic parameters """ self.ph_output_type = "per_token_squad" @@ -469,12 +470,19 @@ class SquadProcessor(Processor): # Split documents into smaller passages to fit max_seq_len baskets = self._split_docs_into_passages(baskets) + # Determine max_answers if not set + max_answers = ( + self.max_answers + if self.max_answers is not None + else max(max(len(basket.raw["answers"]) for basket in baskets), 1) + ) + # Convert answers from string to token space, skip this step for inference if not return_baskets: - baskets = self._convert_answers(baskets) + baskets = self._convert_answers(baskets, max_answers) # Convert internal representation (nested baskets + samples with mixed types) to pytorch features (arrays of numbers) - baskets = self._passages_to_pytorch_features(baskets, return_baskets) + baskets = self._passages_to_pytorch_features(baskets, return_baskets, max_answers) # Convert features into pytorch dataset, this step also removes potential errors during preprocessing dataset, tensor_names, baskets = self._create_dataset(baskets) @@ -607,7 +615,7 @@ class SquadProcessor(Processor): return baskets - def _convert_answers(self, baskets: List[SampleBasket]): + def _convert_answers(self, baskets: List[SampleBasket], max_answers: int): """ Converts answers that are pure strings into the token based representation with start and end token offset. Can handle multiple answers per question document pair as is common for development/text sets @@ -617,7 +625,7 @@ class SquadProcessor(Processor): for sample in basket.samples: # type: ignore # Dealing with potentially multiple answers (e.g. Squad dev set) # Initializing a numpy array of shape (max_answers, 2), filled with -1 for missing values - label_idxs = np.full((self.max_answers, 2), fill_value=-1) + label_idxs = np.full((max_answers, 2), fill_value=-1) if error_in_answer or (len(basket.raw["answers"]) == 0): # If there are no answers we set @@ -625,6 +633,14 @@ class SquadProcessor(Processor): else: # For all other cases we use start and end token indices, that are relative to the passage for i, answer in enumerate(basket.raw["answers"]): + if i >= max_answers: + logger.warning( + "Found a sample with more answers (%d) than " + "max_answers (%d). These will be ignored.", + len(basket.raw["answers"]), + max_answers, + ) + break # Calculate start and end relative to document answer_len_c = len(answer["text"]) answer_start_c = answer["answer_start"] @@ -691,7 +707,7 @@ class SquadProcessor(Processor): return baskets - def _passages_to_pytorch_features(self, baskets: List[SampleBasket], return_baskets: bool): + def _passages_to_pytorch_features(self, baskets: List[SampleBasket], return_baskets: bool, max_answers: int): """ Convert internal representation (nested baskets + samples with mixed types) to python features (arrays of numbers). We first join question and passages into one large vector. @@ -769,7 +785,7 @@ class SquadProcessor(Processor): len(input_ids) == len(padding_mask) == len(segment_ids) == len(start_of_word) == len(span_mask) ) id_check = len(sample_id) == 3 - label_check = return_baskets or len(sample.tokenized.get("labels", [])) == self.max_answers # type: ignore + label_check = return_baskets or len(sample.tokenized.get("labels", [])) == max_answers # type: ignore # labels are set to -100 when answer cannot be found label_check2 = return_baskets or np.all(sample.tokenized["labels"] > -99) # type: ignore if len_check and id_check and label_check and label_check2: diff --git a/test/modeling/test_processor.py b/test/modeling/test_processor.py index 2ae817605..2f053fefc 100644 --- a/test/modeling/test_processor.py +++ b/test/modeling/test_processor.py @@ -1,5 +1,7 @@ +import copy import logging +import pytest from transformers import AutoTokenizer from haystack.modeling.data_handler.processor import SquadProcessor @@ -233,7 +235,7 @@ def test_batch_encoding_flatten_rename(): pass -def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None): +def test_dataset_from_dicts_qa_label_conversion(samples_path, caplog=None): if caplog: caplog.set_level(logging.CRITICAL) @@ -248,7 +250,7 @@ def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None): for model in models: tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model) - processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None) + processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None, max_answers=6) for sample_type in sample_types: dicts = processor.file_to_dicts(samples_path / "qa" / f"{sample_type}.json") @@ -296,3 +298,36 @@ def test_dataset_from_dicts_qa_labelconversion(samples_path, caplog=None): 12, 12, ], f"Processing labels for {model} has changed." + + +@pytest.mark.integration +def test_dataset_from_dicts_auto_determine_max_answers(samples_path, caplog=None): + """ + SquadProcessor should determine the number of answers for the pytorch dataset based on + the maximum number of answers for each question. Vanilla.json has one question with two answers, + so the number of answers should be two. + """ + model = "deepset/roberta-base-squad2" + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model) + processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None) + dicts = processor.file_to_dicts(samples_path / "qa" / "vanilla.json") + dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts, indices=[1]) + assert len(dataset[0][tensor_names.index("labels")]) == 2 + # check that a max_answers will be adjusted when processing a different dataset with the same SquadProcessor + dicts_more_answers = copy.deepcopy(dicts) + dicts_more_answers[0]["qas"][0]["answers"] = dicts_more_answers[0]["qas"][0]["answers"] * 3 + dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts_more_answers, indices=[1]) + assert len(dataset[0][tensor_names.index("labels")]) == 6 + + +@pytest.mark.integration +def test_dataset_from_dicts_truncate_max_answers(samples_path, caplog=None): + """ + Test that it is possible to manually set the number of answers, truncating the answers in the data. + """ + model = "deepset/roberta-base-squad2" + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model) + processor = SquadProcessor(tokenizer, max_seq_len=256, data_dir=None, max_answers=1) + dicts = processor.file_to_dicts(samples_path / "qa" / "vanilla.json") + dataset, tensor_names, problematic_sample_ids = processor.dataset_from_dicts(dicts, indices=[1]) + assert len(dataset[0][tensor_names.index("labels")]) == 1