mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-23 05:49:26 +00:00 
			
		
		
		
	add chapter 6 unit test
This commit is contained in:
		
							parent
							
								
									6b5bc7a1cd
								
							
						
					
					
						commit
						37c33d6fee
					
				
							
								
								
									
										3
									
								
								.github/workflows/basic-tests-linux.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/basic-tests-linux.yml
									
									
									
									
										vendored
									
									
								
							| @ -38,9 +38,10 @@ jobs: | |||||||
| 
 | 
 | ||||||
|     - name: Test Selected Python Scripts |     - name: Test Selected Python Scripts | ||||||
|       run: | |       run: | | ||||||
|  |         pytest setup/02_installing-python-libraries/tests.py | ||||||
|         pytest ch04/01_main-chapter-code/tests.py |         pytest ch04/01_main-chapter-code/tests.py | ||||||
|         pytest ch05/01_main-chapter-code/tests.py |         pytest ch05/01_main-chapter-code/tests.py | ||||||
|         pytest setup/02_installing-python-libraries/tests.py |         pytest ch06/01_main-chapter-code/gpt-class-finetune.py --test_mode | ||||||
| 
 | 
 | ||||||
|     - name: Validate Selected Jupyter Notebooks |     - name: Validate Selected Jupyter Notebooks | ||||||
|       run: | |       run: | | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.github/workflows/basic-tests-macos.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/basic-tests-macos.yml
									
									
									
									
										vendored
									
									
								
							| @ -38,9 +38,10 @@ jobs: | |||||||
| 
 | 
 | ||||||
|     - name: Test Selected Python Scripts |     - name: Test Selected Python Scripts | ||||||
|       run: | |       run: | | ||||||
|  |         pytest setup/02_installing-python-libraries/tests.py | ||||||
|         pytest ch04/01_main-chapter-code/tests.py |         pytest ch04/01_main-chapter-code/tests.py | ||||||
|         pytest ch05/01_main-chapter-code/tests.py |         pytest ch05/01_main-chapter-code/tests.py | ||||||
|         pytest setup/02_installing-python-libraries/tests.py |         pytest ch06/01_main-chapter-code/gpt-class-finetune.py --test_mode | ||||||
| 
 | 
 | ||||||
|     - name: Validate Selected Jupyter Notebooks |     - name: Validate Selected Jupyter Notebooks | ||||||
|       run: | |       run: | | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.github/workflows/basic-tests-windows.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/basic-tests-windows.yml
									
									
									
									
										vendored
									
									
								
							| @ -41,9 +41,10 @@ jobs: | |||||||
|       - name: Test Selected Python Scripts |       - name: Test Selected Python Scripts | ||||||
|         shell: bash |         shell: bash | ||||||
|         run: | |         run: | | ||||||
|  |           pytest setup/02_installing-python-libraries/tests.py | ||||||
|           pytest ch04/01_main-chapter-code/tests.py |           pytest ch04/01_main-chapter-code/tests.py | ||||||
|           pytest ch05/01_main-chapter-code/tests.py |           pytest ch05/01_main-chapter-code/tests.py | ||||||
|           pytest setup/02_installing-python-libraries/tests.py |           pytest ch06/01_main-chapter-code/gpt-class-finetune.py --test_mode | ||||||
| 
 | 
 | ||||||
|       - name: Validate Selected Jupyter Notebooks |       - name: Validate Selected Jupyter Notebooks | ||||||
|         shell: bash |         shell: bash | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -6,6 +6,8 @@ appendix-D/01_main-chapter-code/3.pdf | |||||||
| ch05/01_main-chapter-code/loss-plot.pdf | ch05/01_main-chapter-code/loss-plot.pdf | ||||||
| ch05/01_main-chapter-code/temperature-plot.pdf | ch05/01_main-chapter-code/temperature-plot.pdf | ||||||
| ch05/01_main-chapter-code/the-verdict.txt | ch05/01_main-chapter-code/the-verdict.txt | ||||||
|  | ch06/01_main-chapter-code/loss-plot.pdf | ||||||
|  | ch06/01_main-chapter-code/accuracy-plot.pdf | ||||||
| 
 | 
 | ||||||
| # Checkpoint files | # Checkpoint files | ||||||
| ch05/01_main-chapter-code/gpt2/ | ch05/01_main-chapter-code/gpt2/ | ||||||
|  | |||||||
| @ -226,11 +226,24 @@ def plot_values(epochs_seen, examples_seen, train_values, val_values, label="los | |||||||
| 
 | 
 | ||||||
|     fig.tight_layout()  # Adjust layout to make room |     fig.tight_layout()  # Adjust layout to make room | ||||||
|     plt.savefig(f"{label}-plot.pdf") |     plt.savefig(f"{label}-plot.pdf") | ||||||
|     plt.show() |     #plt.show() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
| 
 | 
 | ||||||
|  |     import argparse | ||||||
|  | 
 | ||||||
|  |     parser = argparse.ArgumentParser( | ||||||
|  |         description="Finetune a GPT model for classification" | ||||||
|  |     ) | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--test_mode", | ||||||
|  |         action="store_true", | ||||||
|  |         help=("This flag runs the model in test mode for internal testing purposes. " | ||||||
|  |               "Otherwise, it runs the model as it is used in the chapter (recommended).") | ||||||
|  |     ) | ||||||
|  |     args = parser.parse_args() | ||||||
|  | 
 | ||||||
|     ######################################## |     ######################################## | ||||||
|     # Download and prepare dataset |     # Download and prepare dataset | ||||||
|     ######################################## |     ######################################## | ||||||
| @ -304,6 +317,25 @@ if __name__ == "__main__": | |||||||
|     # Load pretrained model |     # Load pretrained model | ||||||
|     ######################################## |     ######################################## | ||||||
| 
 | 
 | ||||||
|  |     # Small GPT model for testing purposes | ||||||
|  |     if args.test_mode: | ||||||
|  |         BASE_CONFIG = { | ||||||
|  |             "vocab_size": 50257, | ||||||
|  |             "context_length": 120, | ||||||
|  |             "drop_rate": 0.0, | ||||||
|  |             "qkv_bias": False, | ||||||
|  |             "emb_dim": 12, | ||||||
|  |             "n_layers": 1, | ||||||
|  |             "n_heads": 2 | ||||||
|  |         } | ||||||
|  |         model = GPTModel(BASE_CONFIG) | ||||||
|  |         model.eval() | ||||||
|  | 
 | ||||||
|  |         device = "cpu" | ||||||
|  |         model.to(device) | ||||||
|  | 
 | ||||||
|  |     # Code as it is used in the main chapter | ||||||
|  |     else: | ||||||
|         CHOOSE_MODEL = "gpt2-small (124M)" |         CHOOSE_MODEL = "gpt2-small (124M)" | ||||||
|         INPUT_PROMPT = "Every effort moves" |         INPUT_PROMPT = "Every effort moves" | ||||||
| 
 | 
 | ||||||
| @ -375,7 +407,12 @@ if __name__ == "__main__": | |||||||
|     # Plot results |     # Plot results | ||||||
|     ######################################## |     ######################################## | ||||||
| 
 | 
 | ||||||
|  |     # loss plot | ||||||
|     epochs_tensor = torch.linspace(0, num_epochs, len(train_losses)) |     epochs_tensor = torch.linspace(0, num_epochs, len(train_losses)) | ||||||
|     examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses)) |     examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses)) | ||||||
| 
 |  | ||||||
|     plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses) |     plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses) | ||||||
|  | 
 | ||||||
|  |     # accuracy plot | ||||||
|  |     epochs_tensor = torch.linspace(0, num_epochs, len(train_accs)) | ||||||
|  |     examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs)) | ||||||
|  |     plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label="accuracy") | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 rasbt
						rasbt