From 25d5dedb4662b6ae2ed12d8b49e4fb0e29741aa2 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Mon, 26 Jun 2023 10:14:21 +0200 Subject: [PATCH] Fix: `FARMReader` - Consider the max number of labels/answers during training (#5197) * first draft * improve it a bit * unit tests * PR review, improved tests * PR review, improved tests 2 --- haystack/modeling/data_handler/dataloader.py | 27 +++++++++-- test/modeling/test_dataloader.py | 47 ++++++++++++++++++++ 2 files changed, 71 insertions(+), 3 deletions(-) create mode 100644 test/modeling/test_dataloader.py diff --git a/haystack/modeling/data_handler/dataloader.py b/haystack/modeling/data_handler/dataloader.py index bae318b54..52e95bc6d 100644 --- a/haystack/modeling/data_handler/dataloader.py +++ b/haystack/modeling/data_handler/dataloader.py @@ -3,6 +3,7 @@ from typing import Optional, List from math import ceil import torch +import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, Sampler from haystack.errors import ModelingError @@ -40,18 +41,26 @@ class NamedDataLoader(DataLoader): else: _tensor_names = tensor_names - if type(batch[0]) == list: + if isinstance(batch[0], list): batch = batch[0] if len(batch[0]) != len(_tensor_names): raise ModelingError( f"Dataset contains {len(batch[0])} tensors while there are {len(_tensor_names)} tensor names supplied: {_tensor_names}" ) - lists_temp = [[] for _ in range(len(_tensor_names))] - ret = dict(zip(_tensor_names, lists_temp)) + max_num_labels = self._compute_max_number_of_labels(batch=batch, tensor_names=_tensor_names) + + ret = {name: [] for name in tensor_names} for example in batch: for name, tensor in zip(_tensor_names, example): + # each example may have a different number of answers/labels, + # so we need to pad the corresponding tensors to the max number of labels + if name == "labels" and tensor.ndim > 0: + num_labels = tensor.size(0) + if num_labels < max_num_labels: + padding = (0, 0, 0, max_num_labels - num_labels) + tensor = F.pad(tensor, padding, value=-1) ret[name].append(tensor) for key in ret: @@ -75,3 +84,15 @@ class NamedDataLoader(DataLoader): return num_batches else: return super().__len__() + + def _compute_max_number_of_labels(self, batch, tensor_names) -> int: + """ + Compute the maximum number of labels in a batch. + Each example may have a different number of labels, depending on the number of answers. + """ + max_num_labels = 0 + for example in batch: + for name, tensor in zip(tensor_names, example): + if name == "labels" and tensor.ndim > 0: + max_num_labels = max(max_num_labels, tensor.size(0)) + return max_num_labels diff --git a/test/modeling/test_dataloader.py b/test/modeling/test_dataloader.py new file mode 100644 index 000000000..fc2b95b3a --- /dev/null +++ b/test/modeling/test_dataloader.py @@ -0,0 +1,47 @@ +import pytest + +import torch + +from haystack.modeling.data_handler.dataloader import NamedDataLoader + + +@pytest.fixture +def named_dataloader(): + tensor_names = ["input_ids", "labels"] + return NamedDataLoader(None, 1, tensor_names=tensor_names) + + +@pytest.fixture +def batch(): + # batch containing tensors of different lengths + return [ + (torch.tensor([1, 2, 3]), torch.tensor([[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]])), + (torch.tensor([4, 5, 6]), torch.tensor([[0, 0], [-1, -1], [-1, -1]])), + (torch.tensor([7, 8, 9]), torch.tensor([[0, 0], [-1, -1]])), + ] + + +@pytest.mark.unit +def test_compute_max_number_of_labels(named_dataloader, batch): + tensor_names = ["input_ids", "labels"] + max_num_labels = named_dataloader._compute_max_number_of_labels(batch, tensor_names) + assert max_num_labels == 6 + + +@pytest.mark.unit +def test_collate_fn(named_dataloader, batch): + collated_batch = named_dataloader.collate_fn(batch) + + expected_collated_batch = { + "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + "labels": torch.tensor( + [ + [[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]], + [[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]], + [[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]], + ] + ), + } + + for key in collated_batch: + assert torch.equal(collated_batch[key], expected_collated_batch[key])