mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 11:50:14 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			41 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			41 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
 | 
						|
# Source for "Build a Large Language Model From Scratch"
 | 
						|
#   - https://www.manning.com/books/build-a-large-language-model-from-scratch
 | 
						|
# Code: https://github.com/rasbt/LLMs-from-scratch
 | 
						|
 | 
						|
# File for internal use (unit tests)
 | 
						|
 | 
						|
from gpt import main
 | 
						|
 | 
						|
expected = """
 | 
						|
==================================================
 | 
						|
                      IN
 | 
						|
==================================================
 | 
						|
 | 
						|
Input text: Hello, I am
 | 
						|
Encoded input text: [15496, 11, 314, 716]
 | 
						|
encoded_tensor.shape: torch.Size([1, 4])
 | 
						|
 | 
						|
 | 
						|
==================================================
 | 
						|
                      OUT
 | 
						|
==================================================
 | 
						|
 | 
						|
Output: tensor([[15496,    11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267,
 | 
						|
         49706, 43231, 47062, 34657]])
 | 
						|
Output length: 14
 | 
						|
Output text: Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous
 | 
						|
"""
 | 
						|
 | 
						|
 | 
						|
def test_main(capsys):
 | 
						|
    main()
 | 
						|
    captured = capsys.readouterr()
 | 
						|
 | 
						|
    # Normalize line endings and strip trailing whitespace from each line
 | 
						|
    normalized_expected = '\n'.join(line.rstrip() for line in expected.splitlines())
 | 
						|
    normalized_output = '\n'.join(line.rstrip() for line in captured.out.splitlines())
 | 
						|
 | 
						|
    # Compare normalized strings
 | 
						|
    assert normalized_output == normalized_expected
 |