mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
Fix: FARMReader - Consider the max number of labels/answers during training (#5197)
* first draft * improve it a bit * unit tests * PR review, improved tests * PR review, improved tests 2
This commit is contained in:
parent
f1932492f1
commit
25d5dedb46
@ -3,6 +3,7 @@ from typing import Optional, List
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, Sampler
|
||||
|
||||
from haystack.errors import ModelingError
|
||||
@ -40,18 +41,26 @@ class NamedDataLoader(DataLoader):
|
||||
else:
|
||||
_tensor_names = tensor_names
|
||||
|
||||
if type(batch[0]) == list:
|
||||
if isinstance(batch[0], list):
|
||||
batch = batch[0]
|
||||
|
||||
if len(batch[0]) != len(_tensor_names):
|
||||
raise ModelingError(
|
||||
f"Dataset contains {len(batch[0])} tensors while there are {len(_tensor_names)} tensor names supplied: {_tensor_names}"
|
||||
)
|
||||
lists_temp = [[] for _ in range(len(_tensor_names))]
|
||||
ret = dict(zip(_tensor_names, lists_temp))
|
||||
|
||||
max_num_labels = self._compute_max_number_of_labels(batch=batch, tensor_names=_tensor_names)
|
||||
|
||||
ret = {name: [] for name in tensor_names}
|
||||
for example in batch:
|
||||
for name, tensor in zip(_tensor_names, example):
|
||||
# each example may have a different number of answers/labels,
|
||||
# so we need to pad the corresponding tensors to the max number of labels
|
||||
if name == "labels" and tensor.ndim > 0:
|
||||
num_labels = tensor.size(0)
|
||||
if num_labels < max_num_labels:
|
||||
padding = (0, 0, 0, max_num_labels - num_labels)
|
||||
tensor = F.pad(tensor, padding, value=-1)
|
||||
ret[name].append(tensor)
|
||||
|
||||
for key in ret:
|
||||
@ -75,3 +84,15 @@ class NamedDataLoader(DataLoader):
|
||||
return num_batches
|
||||
else:
|
||||
return super().__len__()
|
||||
|
||||
def _compute_max_number_of_labels(self, batch, tensor_names) -> int:
|
||||
"""
|
||||
Compute the maximum number of labels in a batch.
|
||||
Each example may have a different number of labels, depending on the number of answers.
|
||||
"""
|
||||
max_num_labels = 0
|
||||
for example in batch:
|
||||
for name, tensor in zip(tensor_names, example):
|
||||
if name == "labels" and tensor.ndim > 0:
|
||||
max_num_labels = max(max_num_labels, tensor.size(0))
|
||||
return max_num_labels
|
||||
|
||||
47
test/modeling/test_dataloader.py
Normal file
47
test/modeling/test_dataloader.py
Normal file
@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
|
||||
import torch
|
||||
|
||||
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def named_dataloader():
|
||||
tensor_names = ["input_ids", "labels"]
|
||||
return NamedDataLoader(None, 1, tensor_names=tensor_names)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch():
|
||||
# batch containing tensors of different lengths
|
||||
return [
|
||||
(torch.tensor([1, 2, 3]), torch.tensor([[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]])),
|
||||
(torch.tensor([4, 5, 6]), torch.tensor([[0, 0], [-1, -1], [-1, -1]])),
|
||||
(torch.tensor([7, 8, 9]), torch.tensor([[0, 0], [-1, -1]])),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_compute_max_number_of_labels(named_dataloader, batch):
|
||||
tensor_names = ["input_ids", "labels"]
|
||||
max_num_labels = named_dataloader._compute_max_number_of_labels(batch, tensor_names)
|
||||
assert max_num_labels == 6
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_collate_fn(named_dataloader, batch):
|
||||
collated_batch = named_dataloader.collate_fn(batch)
|
||||
|
||||
expected_collated_batch = {
|
||||
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
|
||||
"labels": torch.tensor(
|
||||
[
|
||||
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
|
||||
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
|
||||
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
for key in collated_batch:
|
||||
assert torch.equal(collated_batch[key], expected_collated_batch[key])
|
||||
Loading…
x
Reference in New Issue
Block a user