2023-02-01 09:28:16 -05:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								import  sys 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 12:32:22 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								import  contextlib 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-21 21:55:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								from  functools  import  lru_cache 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 07:11:27 +02:00 
										
									 
								 
							 
							
								
							 
							
								 
							 
							
							
								import  torch 
							 
						 
					
						
							
								
									
										
										
										
											2023-08-09 10:25:35 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								from  modules  import  errors ,  shared 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-01 09:28:16 -05:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								if  sys . platform  ==  " darwin " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    from  modules  import  mac_specific 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 16:34:13 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 11:02:40 +08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								def  has_mps ( )  - >  bool : 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-01 09:28:16 -05:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  sys . platform  !=  " darwin " : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 11:02:40 +08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        return  False 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-01 09:28:16 -05:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        return  mac_specific . has_mps 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 18:48:36 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								def  get_cuda_device_string ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  shared . cmd_opts . device_id  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        return  f " cuda: { shared . cmd_opts . device_id } " 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-22 14:04:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    return  " cuda " 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-22 14:04:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								def  get_optimal_device_name ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  torch . cuda . is_available ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        return  get_cuda_device_string ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 18:48:36 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 11:02:40 +08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  has_mps ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        return  " mps " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  " cpu " 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 18:48:36 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								def  get_optimal_device ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  torch . device ( get_optimal_device_name ( ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 23:24:24 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-03 18:06:33 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								def  get_device_for ( task ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  task  in  shared . cmd_opts . use_cpu : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        return  cpu 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  get_optimal_device ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 23:24:24 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								def  torch_gc ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-08 17:13:18 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 23:24:24 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  torch . cuda . is_available ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 13:08:54 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        with  torch . cuda . device ( get_cuda_device_string ( ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-26 23:25:16 +00:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								            torch . cuda . empty_cache ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            torch . cuda . ipc_collect ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-10 21:18:34 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  has_mps ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        mac_specific . torch_mps_gc ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 16:34:13 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								def  enable_tf32 ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  torch . cuda . is_available ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-03 16:01:23 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        if  any ( torch . cuda . get_device_capability ( devid )  ==  ( 7 ,  5 )  for  devid  in  range ( 0 ,  torch . cuda . device_count ( ) ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-07 18:06:48 -08:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								            torch . backends . cudnn . benchmark  =  True 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-03 15:57:52 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 16:34:13 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								        torch . backends . cuda . matmul . allow_tf32  =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        torch . backends . cudnn . allow_tf32  =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								errors . run ( enable_tf32 ,  " Enabling TF32 " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 20:09:32 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-08-03 07:18:55 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								cpu :  torch . device  =  torch . device ( " cpu " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								device :  torch . device  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								device_interrogate :  torch . device  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								device_gfpgan :  torch . device  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								device_esrgan :  torch . device  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								device_codeformer :  torch . device  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								dtype :  torch . dtype  =  torch . float16 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								dtype_vae :  torch . dtype  =  torch . float16 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								dtype_unet :  torch . dtype  =  torch . float16 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-24 23:51:45 -05:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								unet_needs_upcast  =  False 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-12 20:09:32 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 10:19:43 -05:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								def  cond_cast_unet ( input ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  input . to ( dtype_unet )  if  unet_needs_upcast  else  input 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								def  cond_cast_float ( input ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  input . float ( )  if  unet_needs_upcast  else  input 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-08-03 00:00:23 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								nv_rng  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 16:11:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								def  autocast ( disable = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  disable : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        return  contextlib . nullcontext ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-04 12:32:22 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  dtype  ==  torch . float32  or  shared . cmd_opts . precision  ==  " full " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        return  contextlib . nullcontext ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  torch . autocast ( " cuda " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-25 02:01:57 -04:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-12 10:00:49 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								def  without_autocast ( disable = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    return  torch . autocast ( " cuda " ,  enabled = False )  if  torch . is_autocast_enabled ( )  and  not  disable  else  contextlib . nullcontext ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-16 22:59:46 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								class  NansException ( Exception ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    pass 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								def  test_for_nans ( x ,  where ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-17 11:04:56 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  shared . cmd_opts . disable_nan_check : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        return 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-16 22:59:46 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    if  not  torch . all ( torch . isnan ( x ) ) . item ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        return 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    if  where  ==  " unet " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        message  =  " A tensor with all NaNs was produced in Unet. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        if  not  shared . cmd_opts . no_half : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								            message  + =  "  This could be either because there ' s not enough precision to represent the picture, or because your video card does not support half type. Try setting the  \" Upcast cross attention layer to float32 \"  option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this. " 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-16 22:59:46 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    elif  where  ==  " vae " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        message  =  " A tensor with all NaNs was produced in VAE. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        if  not  shared . cmd_opts . no_half  and  not  shared . cmd_opts . no_half_vae : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								            message  + =  "  This could be because there ' s not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								        message  =  " A tensor with all NaNs was produced. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 13:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    message  + =  "  Use --disable-nan-check commandline argument to disable this check. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-16 22:59:46 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								    raise  NansException ( message ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-21 21:55:14 +03:00 
										
									 
								 
							 
							
								
									
										 
									 
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								@lru_cache 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								def  first_time_calculation ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    just  do  any  calculation  with  pytorch  layers  -  the  first  time  this  is  done  it  allocaltes  about  700 MB  of  memory  and 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    spends  about  2.7  seconds  doing  that ,  at  least  wih  NVidia . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    x  =  torch . zeros ( ( 1 ,  1 ) ) . to ( device ,  dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    linear  =  torch . nn . Linear ( 1 ,  1 ) . to ( device ,  dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    linear ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    x  =  torch . zeros ( ( 1 ,  1 ,  3 ,  3 ) ) . to ( device ,  dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    conv2d  =  torch . nn . Conv2d ( 1 ,  1 ,  ( 3 ,  3 ) ) . to ( device ,  dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							 
							
							
								    conv2d ( x ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-08-09 08:43:31 +03:00