| 
									
										
										
										
											2022-10-22 00:11:07 +02:00
										 |  |  | import sys, os, shlex | 
					
						
							| 
									
										
										
										
											2022-10-04 12:32:22 +03:00
										 |  |  | import contextlib | 
					
						
							| 
									
										
										
										
											2022-09-11 07:11:27 +02:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2022-09-12 16:34:13 +03:00
										 |  |  | from modules import errors | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-04 04:24:35 -04:00
										 |  |  | # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility | 
					
						
							| 
									
										
										
										
											2022-09-11 07:11:27 +02:00
										 |  |  | has_mps = getattr(torch, 'has_mps', False) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-11 18:48:36 +03:00
										 |  |  | cpu = torch.device("cpu") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-22 00:11:07 +02:00
										 |  |  | def extract_device_id(args, name): | 
					
						
							|  |  |  |     for x in range(len(args)): | 
					
						
							|  |  |  |         if name in args[x]: return args[x+1] | 
					
						
							|  |  |  |     return None | 
					
						
							| 
									
										
										
										
											2022-09-11 18:48:36 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-11 07:11:27 +02:00
										 |  |  | def get_optimal_device(): | 
					
						
							| 
									
										
										
										
											2022-09-11 18:48:36 +03:00
										 |  |  |     if torch.cuda.is_available(): | 
					
						
							| 
									
										
										
										
											2022-10-22 14:04:14 +03:00
										 |  |  |         from modules import shared | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         device_id = shared.cmd_opts.device_id | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-22 00:11:07 +02:00
										 |  |  |         if device_id is not None: | 
					
						
							|  |  |  |             cuda_device = f"cuda:{device_id}" | 
					
						
							|  |  |  |             return torch.device(cuda_device) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             return torch.device("cuda") | 
					
						
							| 
									
										
										
										
											2022-09-11 18:48:36 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if has_mps: | 
					
						
							|  |  |  |         return torch.device("mps") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return cpu | 
					
						
							| 
									
										
										
										
											2022-09-11 23:24:24 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def torch_gc(): | 
					
						
							|  |  |  |     if torch.cuda.is_available(): | 
					
						
							|  |  |  |         torch.cuda.empty_cache() | 
					
						
							|  |  |  |         torch.cuda.ipc_collect() | 
					
						
							| 
									
										
										
										
											2022-09-12 16:34:13 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def enable_tf32(): | 
					
						
							|  |  |  |     if torch.cuda.is_available(): | 
					
						
							|  |  |  |         torch.backends.cuda.matmul.allow_tf32 = True | 
					
						
							|  |  |  |         torch.backends.cudnn.allow_tf32 = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | errors.run(enable_tf32, "Enabling TF32") | 
					
						
							| 
									
										
										
										
											2022-09-12 20:09:32 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-24 23:04:50 -04:00
										 |  |  | device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | dtype = torch.float16 | 
					
						
							| 
									
										
										
										
											2022-10-10 16:11:14 +03:00
										 |  |  | dtype_vae = torch.float16 | 
					
						
							| 
									
										
										
										
											2022-09-12 20:09:32 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | def randn(seed, shape): | 
					
						
							|  |  |  |     # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. | 
					
						
							|  |  |  |     if device.type == 'mps': | 
					
						
							|  |  |  |         generator = torch.Generator(device=cpu) | 
					
						
							|  |  |  |         generator.manual_seed(seed) | 
					
						
							|  |  |  |         noise = torch.randn(shape, generator=generator, device=cpu).to(device) | 
					
						
							|  |  |  |         return noise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     torch.manual_seed(seed) | 
					
						
							|  |  |  |     return torch.randn(shape, device=device) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-13 21:49:58 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | def randn_without_seed(shape): | 
					
						
							|  |  |  |     # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. | 
					
						
							|  |  |  |     if device.type == 'mps': | 
					
						
							|  |  |  |         generator = torch.Generator(device=cpu) | 
					
						
							|  |  |  |         noise = torch.randn(shape, generator=generator, device=cpu).to(device) | 
					
						
							|  |  |  |         return noise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return torch.randn(shape, device=device) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-04 12:32:22 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-10 16:11:14 +03:00
										 |  |  | def autocast(disable=False): | 
					
						
							| 
									
										
										
										
											2022-10-04 12:32:22 +03:00
										 |  |  |     from modules import shared | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-10 16:11:14 +03:00
										 |  |  |     if disable: | 
					
						
							|  |  |  |         return contextlib.nullcontext() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-04 12:32:22 +03:00
										 |  |  |     if dtype == torch.float32 or shared.cmd_opts.precision == "full": | 
					
						
							|  |  |  |         return contextlib.nullcontext() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return torch.autocast("cuda") | 
					
						
							| 
									
										
										
										
											2022-10-25 02:01:57 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  | # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 | 
					
						
							|  |  |  | def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor | 
					
						
							|  |  |  | def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) |