haystack/test/modeling/test_dataloader.py

48 lines
1.5 KiB
Python
Raw Normal View History

import pytest
import torch
from haystack.modeling.data_handler.dataloader import NamedDataLoader
@pytest.fixture
def named_dataloader():
tensor_names = ["input_ids", "labels"]
return NamedDataLoader(None, 1, tensor_names=tensor_names)
@pytest.fixture
def batch():
# batch containing tensors of different lengths
return [
(torch.tensor([1, 2, 3]), torch.tensor([[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]])),
(torch.tensor([4, 5, 6]), torch.tensor([[0, 0], [-1, -1], [-1, -1]])),
(torch.tensor([7, 8, 9]), torch.tensor([[0, 0], [-1, -1]])),
]
@pytest.mark.unit
def test_compute_max_number_of_labels(named_dataloader, batch):
tensor_names = ["input_ids", "labels"]
max_num_labels = named_dataloader._compute_max_number_of_labels(batch, tensor_names)
assert max_num_labels == 6
@pytest.mark.unit
def test_collate_fn(named_dataloader, batch):
collated_batch = named_dataloader.collate_fn(batch)
expected_collated_batch = {
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
"labels": torch.tensor(
[
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
]
),
}
for key in collated_batch:
assert torch.equal(collated_batch[key], expected_collated_batch[key])