| 
									
										
										
										
											2024-03-19 09:26:26 -05:00
										 |  |  | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). | 
					
						
							|  |  |  | # Source for "Build a Large Language Model From Scratch" | 
					
						
							|  |  |  | #   - https://www.manning.com/books/build-a-large-language-model-from-scratch | 
					
						
							|  |  |  | # Code: https://github.com/rasbt/LLMs-from-scratch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # File for internal use (unit tests) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-18 11:58:37 -05:00
										 |  |  | from gpt import main | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | expected = """
 | 
					
						
							|  |  |  | ================================================== | 
					
						
							|  |  |  |                       IN | 
					
						
							|  |  |  | ================================================== | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Input text: Hello, I am | 
					
						
							|  |  |  | Encoded input text: [15496, 11, 314, 716] | 
					
						
							|  |  |  | encoded_tensor.shape: torch.Size([1, 4]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ================================================== | 
					
						
							|  |  |  |                       OUT | 
					
						
							|  |  |  | ================================================== | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Output: tensor([[15496,    11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267, | 
					
						
							|  |  |  |          49706, 43231, 47062, 34657]]) | 
					
						
							|  |  |  | Output length: 14 | 
					
						
							|  |  |  | Output text: Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous | 
					
						
							|  |  |  | """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def test_main(capsys): | 
					
						
							|  |  |  |     main() | 
					
						
							|  |  |  |     captured = capsys.readouterr() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Normalize line endings and strip trailing whitespace from each line | 
					
						
							|  |  |  |     normalized_expected = '\n'.join(line.rstrip() for line in expected.splitlines()) | 
					
						
							|  |  |  |     normalized_output = '\n'.join(line.rstrip() for line in captured.out.splitlines()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Compare normalized strings | 
					
						
							|  |  |  |     assert normalized_output == normalized_expected |