2023-02-01 09:28:16 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  sys  
						 
					
						
							
								
									
										
										
										
											2022-10-04 12:32:22 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  contextlib  
						 
					
						
							
								
									
										
										
										
											2023-05-21 21:55:14 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  functools  import  lru_cache  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 07:11:27 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
									
										
										
										
											2022-09-12 16:34:13 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  modules  import  errors  
						 
					
						
							
								
									
										
										
										
											2023-02-01 09:28:16 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								if  sys . platform  ==  " darwin " :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    from  modules  import  mac_specific 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 16:34:13 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 11:02:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  has_mps ( )  - >  bool :  
						 
					
						
							
								
									
										
										
										
											2023-02-01 09:28:16 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  sys . platform  !=  " darwin " : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 11:02:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  False 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-01 09:28:16 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  mac_specific . has_mps 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 18:48:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-22 00:11:07 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  extract_device_id ( args ,  name ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  x  in  range ( len ( args ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  name  in  args [ x ] : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  args [ x  +  1 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-22 00:11:07 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  None 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 18:48:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_cuda_device_string ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    from  modules  import  shared 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . cmd_opts . device_id  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  f " cuda: { shared . cmd_opts . device_id } " 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-22 14:04:14 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  " cuda " 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-22 14:04:14 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_optimal_device_name ( ) :  
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  torch . cuda . is_available ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  get_cuda_device_string ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 18:48:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 11:02:40 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  has_mps ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  " mps " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  " cpu " 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 18:48:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  get_optimal_device ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  torch . device ( get_optimal_device_name ( ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 23:24:24 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-03 18:06:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_device_for ( task ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    from  modules  import  shared 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  task  in  shared . cmd_opts . use_cpu : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  cpu 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  get_optimal_device ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 23:24:24 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  torch_gc ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  torch . cuda . is_available ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        with  torch . cuda . device ( get_cuda_device_string ( ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-26 23:25:16 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            torch . cuda . empty_cache ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            torch . cuda . ipc_collect ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 16:34:13 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  enable_tf32 ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  torch . cuda . is_available ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-03 16:01:23 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  any ( torch . cuda . get_device_capability ( devid )  ==  ( 7 ,  5 )  for  devid  in  range ( 0 ,  torch . cuda . device_count ( ) ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-07 18:06:48 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            torch . backends . cudnn . benchmark  =  True 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-03 15:57:52 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 16:34:13 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        torch . backends . cuda . matmul . allow_tf32  =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        torch . backends . cudnn . allow_tf32  =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-06 17:05:51 -08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 16:34:13 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								errors . run ( enable_tf32 ,  " Enabling TF32 " )  
						 
					
						
							
								
									
										
										
										
											2022-09-12 20:09:32 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								cpu  =  torch . device ( " cpu " )  
						 
					
						
							
								
									
										
										
										
											2022-12-03 18:06:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								device  =  device_interrogate  =  device_gfpgan  =  device_esrgan  =  device_codeformer  =  None  
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								dtype  =  torch . float16  
						 
					
						
							
								
									
										
										
										
											2022-10-10 16:11:14 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								dtype_vae  =  torch . float16  
						 
					
						
							
								
									
										
										
										
											2023-01-24 23:51:45 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								dtype_unet  =  torch . float16  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								unet_needs_upcast  =  False  
						 
					
						
							
								
									
										
										
										
											2022-09-12 20:09:32 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 10:19:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  cond_cast_unet ( input ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  input . to ( dtype_unet )  if  unet_needs_upcast  else  input 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  cond_cast_float ( input ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  input . float ( )  if  unet_needs_upcast  else  input 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 20:09:32 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  randn ( seed ,  shape ) :  
						 
					
						
							
								
									
										
										
										
											2023-04-18 23:18:58 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    from  modules . shared  import  opts 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 20:09:32 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    torch . manual_seed ( seed ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-04-29 11:29:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  opts . randn_source  ==  " CPU "  or  device . type  ==  ' mps ' : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-30 08:02:39 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  torch . randn ( shape ,  device = cpu ) . to ( device ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 20:09:32 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  torch . randn ( shape ,  device = device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-13 21:49:58 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  randn_without_seed ( shape ) :  
						 
					
						
							
								
									
										
										
										
											2023-04-18 23:18:58 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    from  modules . shared  import  opts 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-04-29 11:29:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  opts . randn_source  ==  " CPU "  or  device . type  ==  ' mps ' : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-30 08:02:39 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  torch . randn ( shape ,  device = cpu ) . to ( device ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-13 21:49:58 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  torch . randn ( shape ,  device = device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 12:32:22 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 16:11:14 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  autocast ( disable = False ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-04 12:32:22 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    from  modules  import  shared 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 16:11:14 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  disable : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  contextlib . nullcontext ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 12:32:22 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  dtype  ==  torch . float32  or  shared . cmd_opts . precision  ==  " full " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  contextlib . nullcontext ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  torch . autocast ( " cuda " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-25 02:01:57 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  without_autocast ( disable = False ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  torch . autocast ( " cuda " ,  enabled = False )  if  torch . is_autocast_enabled ( )  and  not  disable  else  contextlib . nullcontext ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-16 22:59:46 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								class  NansException ( Exception ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    pass 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  test_for_nans ( x ,  where ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    from  modules  import  shared 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-17 11:04:56 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  shared . cmd_opts . disable_nan_check : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-16 22:59:46 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  not  torch . all ( torch . isnan ( x ) ) . item ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  where  ==  " unet " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        message  =  " A tensor with all NaNs was produced in Unet. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  not  shared . cmd_opts . no_half : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            message  + =  "  This could be either because there ' s not enough precision to represent the picture, or because your video card does not support half type. Try setting the  \" Upcast cross attention layer to float32 \"  option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this. " 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-16 22:59:46 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  where  ==  " vae " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        message  =  " A tensor with all NaNs was produced in VAE. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  not  shared . cmd_opts . no_half  and  not  shared . cmd_opts . no_half_vae : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            message  + =  "  This could be because there ' s not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        message  =  " A tensor with all NaNs was produced. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 13:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    message  + =  "  Use --disable-nan-check commandline argument to disable this check. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-16 22:59:46 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    raise  NansException ( message ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-21 21:55:14 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								@lru_cache  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  first_time_calculation ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    just  do  any  calculation  with  pytorch  layers  -  the  first  time  this  is  done  it  allocaltes  about  700 MB  of  memory  and 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    spends  about  2.7  seconds  doing  that ,  at  least  wih  NVidia . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    x  =  torch . zeros ( ( 1 ,  1 ) ) . to ( device ,  dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    linear  =  torch . nn . Linear ( 1 ,  1 ) . to ( device ,  dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    linear ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    x  =  torch . zeros ( ( 1 ,  1 ,  3 ,  3 ) ) . to ( device ,  dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    conv2d  =  torch . nn . Conv2d ( 1 ,  1 ,  ( 3 ,  3 ) ) . to ( device ,  dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    conv2d ( x )