mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 01:39:45 +00:00 
			
		
		
		
	 25d5dedb46
			
		
	
	
		25d5dedb46
		
			
		
	
	
	
	
		
			
			* first draft * improve it a bit * unit tests * PR review, improved tests * PR review, improved tests 2
		
			
				
	
	
		
			48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pytest
 | |
| 
 | |
| import torch
 | |
| 
 | |
| from haystack.modeling.data_handler.dataloader import NamedDataLoader
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def named_dataloader():
 | |
|     tensor_names = ["input_ids", "labels"]
 | |
|     return NamedDataLoader(None, 1, tensor_names=tensor_names)
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def batch():
 | |
|     # batch containing tensors of different lengths
 | |
|     return [
 | |
|         (torch.tensor([1, 2, 3]), torch.tensor([[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]])),
 | |
|         (torch.tensor([4, 5, 6]), torch.tensor([[0, 0], [-1, -1], [-1, -1]])),
 | |
|         (torch.tensor([7, 8, 9]), torch.tensor([[0, 0], [-1, -1]])),
 | |
|     ]
 | |
| 
 | |
| 
 | |
| @pytest.mark.unit
 | |
| def test_compute_max_number_of_labels(named_dataloader, batch):
 | |
|     tensor_names = ["input_ids", "labels"]
 | |
|     max_num_labels = named_dataloader._compute_max_number_of_labels(batch, tensor_names)
 | |
|     assert max_num_labels == 6
 | |
| 
 | |
| 
 | |
| @pytest.mark.unit
 | |
| def test_collate_fn(named_dataloader, batch):
 | |
|     collated_batch = named_dataloader.collate_fn(batch)
 | |
| 
 | |
|     expected_collated_batch = {
 | |
|         "input_ids": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
 | |
|         "labels": torch.tensor(
 | |
|             [
 | |
|                 [[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
 | |
|                 [[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
 | |
|                 [[0, 0], [-1, -1], [-1, -1], [-1, -1], [-1, -1], [-1, -1]],
 | |
|             ]
 | |
|         ),
 | |
|     }
 | |
| 
 | |
|     for key in collated_batch:
 | |
|         assert torch.equal(collated_batch[key], expected_collated_batch[key])
 |