mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-27 02:40:41 +00:00
48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
![]() |
import pytest
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from haystack.modeling.data_handler.dataloader import NamedDataLoader
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def named_dataloader():
|
||
|
tensor_names = ["input_ids", "labels"]
|
||
|
return NamedDataLoader(None, 1, tensor_names=tensor_names)
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def batch():
|
||
|
# batch containing tensors of different lengths
|
||
|
return [
|
||
|
(torch.tensor([1, 2, 3]), torch.tensor([[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]])),
|
||
|
(torch.tensor([4, 5, 6]), torch.tensor([[0, 0], [-1, -1], [-1, -1]])),
|
||
|
(torch.tensor([7, 8, 9]), torch.tensor([[0, 0], [-1, -1]])),
|
||
|
]
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_compute_max_number_of_labels(named_dataloader, batch):
|
||
|
tensor_names = ["input_ids", "labels"]
|
||
|
max_num_labels = named_dataloader._compute_max_number_of_labels(batch, tensor_names)
|
||
|
assert max_num_labels == 6
|
||
|
|
||
|
|
||
|
@pytest.mark.unit
|
||
|
def test_collate_fn(named_dataloader, batch):
|
||
|
collated_batch = named_dataloader.collate_fn(batch)
|
||
|
|
||
|
expected_collated_batch = {
|
||
|
"input_ids": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
|
||
|
"labels": torch.tensor(
|
||
|
[
|
||
|
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
|
||
|
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
|
||
|
[[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
|
||
|
]
|
||
|
),
|
||
|
}
|
||
|
|
||
|
for key in collated_batch:
|
||
|
assert torch.equal(collated_batch[key], expected_collated_batch[key])
|