mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-30 17:29:59 +00:00 
			
		
		
		
	Note about ch05 mps support (#324)
This commit is contained in:
		
							parent
							
								
									1a962f3983
								
							
						
					
					
						commit
						0991c1ff24
					
				| @ -1142,6 +1142,20 @@ | |||||||
|    ], |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|     "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", |     "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||||||
|  |     "\n", | ||||||
|  |     "# Note:\n", | ||||||
|  |     "# Uncommenting the following lines will allow the code to run on Apple Silicon chips, if applicable,\n", | ||||||
|  |     "# which is approximately 2x faster than on an Apple CPU (as measured on an M3 MacBook Air).\n", | ||||||
|  |     "# However, the resulting loss values may be slightly different.\n", | ||||||
|  |     "\n", | ||||||
|  |     "#if torch.cuda.is_available():\n", | ||||||
|  |     "#    device = torch.device(\"cuda\")\n", | ||||||
|  |     "#elif torch.backends.mps.is_available():\n", | ||||||
|  |     "#    device = torch.device(\"mps\")\n", | ||||||
|  |     "#else:\n", | ||||||
|  |     "#    device = torch.device(\"cpu\")\n", | ||||||
|  |     "\n", | ||||||
|  |     "\n", | ||||||
|     "model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n", |     "model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "\n", |     "\n", | ||||||
| @ -1308,6 +1322,11 @@ | |||||||
|     } |     } | ||||||
|    ], |    ], | ||||||
|    "source": [ |    "source": [ | ||||||
|  |     "# Note:\n", | ||||||
|  |     "# Uncomment the following code to calculate the execution time\n", | ||||||
|  |     "# import time\n", | ||||||
|  |     "# start_time = time.time()\n", | ||||||
|  |     "\n", | ||||||
|     "torch.manual_seed(123)\n", |     "torch.manual_seed(123)\n", | ||||||
|     "model = GPTModel(GPT_CONFIG_124M)\n", |     "model = GPTModel(GPT_CONFIG_124M)\n", | ||||||
|     "model.to(device)\n", |     "model.to(device)\n", | ||||||
| @ -1318,7 +1337,13 @@ | |||||||
|     "    model, train_loader, val_loader, optimizer, device,\n", |     "    model, train_loader, val_loader, optimizer, device,\n", | ||||||
|     "    num_epochs=num_epochs, eval_freq=5, eval_iter=5,\n", |     "    num_epochs=num_epochs, eval_freq=5, eval_iter=5,\n", | ||||||
|     "    start_context=\"Every effort moves you\", tokenizer=tokenizer\n", |     "    start_context=\"Every effort moves you\", tokenizer=tokenizer\n", | ||||||
|     ")" |     ")\n", | ||||||
|  |     "\n", | ||||||
|  |     "# Note:\n", | ||||||
|  |     "# Uncomment the following code to show the execution time\n", | ||||||
|  |     "# end_time = time.time()\n", | ||||||
|  |     "# execution_time_minutes = (end_time - start_time) / 60\n", | ||||||
|  |     "# print(f\"Training completed in {execution_time_minutes:.2f} minutes.\")" | ||||||
|    ] |    ] | ||||||
|   }, |   }, | ||||||
|   { |   { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian Raschka
						Sebastian Raschka