mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 01:41:26 +00:00 
			
		
		
		
	add mps runtime (#223)
This commit is contained in:
		
							parent
							
								
									b114053378
								
							
						
					
					
						commit
						f4c8bb024c
					
				| @ -1111,6 +1111,15 @@ | |||||||
|     "from functools import partial\n", |     "from functools import partial\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", |     "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||||||
|  |     "\n", | ||||||
|  |     "# If you have a Mac with Apple Silicon chip, you can uncomment the next lines of code\n", | ||||||
|  |     "# to train the model on the Mac's GPU cores. However, as of this writing, this results in\n", | ||||||
|  |     "# larger numerical deviations from the results shown in this chapter, because Apple Silicon\n", | ||||||
|  |     "# support in PyTorch is still experimental\n", | ||||||
|  |     "\n", | ||||||
|  |     "# if torch.backends.mps.is_available():\n", | ||||||
|  |     "#     device = torch.device(\"mps\")\n", | ||||||
|  |     "\n", | ||||||
|     "print(\"Device:\", device)\n", |     "print(\"Device:\", device)\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)" |     "customized_collate_fn = partial(custom_collate_fn, device=device, allowed_max_length=1024)" | ||||||
| @ -1743,9 +1752,11 @@ | |||||||
|     "| Model              | Device                | Runtime for 2 Epochs |\n", |     "| Model              | Device                | Runtime for 2 Epochs |\n", | ||||||
|     "|--------------------|-----------------------|----------------------|\n", |     "|--------------------|-----------------------|----------------------|\n", | ||||||
|     "| gpt2-medium (355M) | CPU (M3 MacBook Air)  | 15.78 minutes        |\n", |     "| gpt2-medium (355M) | CPU (M3 MacBook Air)  | 15.78 minutes        |\n", | ||||||
|  |     "| gpt2-small (124M)  | GPU (M3 MacBook Air)  | 10.77 minutes        |\n", | ||||||
|     "| gpt2-medium (355M) | GPU (L4)              | 1.83 minutes         |\n", |     "| gpt2-medium (355M) | GPU (L4)              | 1.83 minutes         |\n", | ||||||
|     "| gpt2-medium (355M) | GPU (A100)            | 0.86 minutes         |\n", |     "| gpt2-medium (355M) | GPU (A100)            | 0.86 minutes         |\n", | ||||||
|     "| gpt2-small (124M)  | CPU (M3 MacBook Air)  | 5.74 minutes         |\n", |     "| gpt2-small (124M)  | CPU (M3 MacBook Air)  | 5.74 minutes         |\n", | ||||||
|  |     "| gpt2-small (124M)  | GPU (M3 MacBook Air)  | 3.73 minutes         |\n", | ||||||
|     "| gpt2-small (124M)  | GPU (L4)              | 0.69 minutes         |\n", |     "| gpt2-small (124M)  | GPU (L4)              | 0.69 minutes         |\n", | ||||||
|     "| gpt2-small (124M)  | GPU (A100)            | 0.39 minutes         |\n", |     "| gpt2-small (124M)  | GPU (A100)            | 0.39 minutes         |\n", | ||||||
|     "\n", |     "\n", | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian Raschka
						Sebastian Raschka