2022-10-13 23:00:38 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  collections  
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  os . path  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  sys  
						 
					
						
							
								
									
										
										
										
											2022-11-01 04:01:49 -03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  gc  
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  threading  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
									
										
										
										
											2022-10-28 05:49:39 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  re  
						 
					
						
							
								
									
										
										
										
											2022-11-27 14:46:40 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  safetensors . torch  
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								from  omegaconf  import  OmegaConf  
						 
					
						
							
								
									
										
										
										
											2022-12-08 18:14:35 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  os  import  mkdir  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  urllib  import  request  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  ldm . modules . midas  as  midas  
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  ldm . util  import  instantiate_from_config  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-27 15:47:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  modules  import  paths ,  shared ,  modelloader ,  devices ,  script_callbacks ,  sd_vae ,  sd_disable_initialization ,  errors ,  hashes ,  sd_models_config ,  sd_unet  
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  modules . sd_hijack_inpainting  import  do_inpainting_hijack  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  modules . timer  import  Timer  
						 
					
						
							
								
									
										
										
										
											2023-04-04 02:26:44 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  tomesd  
						 
					
						
							
								
									
										
										
										
											2022-09-27 11:01:13 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								model_dir  =  " Stable-diffusion "  
						 
					
						
							
								
									
										
										
										
											2023-01-25 17:15:42 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								model_path  =  os . path . abspath ( os . path . join ( paths . models_path ,  model_dir ) )  
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								checkpoints_list  =  { }  
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								checkpoint_alisases  =  { }  
						 
					
						
							
								
									
										
										
										
											2022-10-13 23:00:38 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								checkpoints_loaded  =  collections . OrderedDict ( )  
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  CheckpointInfo :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  filename ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . filename  =  filename 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        abspath  =  os . path . abspath ( filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  shared . cmd_opts . ckpt_dir  is  not  None  and  abspath . startswith ( shared . cmd_opts . ckpt_dir ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            name  =  abspath . replace ( shared . cmd_opts . ckpt_dir ,  ' ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        elif  abspath . startswith ( model_path ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            name  =  abspath . replace ( model_path ,  ' ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            name  =  os . path . basename ( filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  name . startswith ( " \\ " )  or  name . startswith ( " / " ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            name  =  name [ 1 : ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 18:58:08 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . name  =  name 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-29 10:20:19 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . name_for_extra  =  os . path . splitext ( os . path . basename ( filename ) ) [ 0 ] 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . model_name  =  os . path . splitext ( name . replace ( " / " ,  " _ " ) . replace ( " \\ " ,  " _ " ) ) [ 0 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . hash  =  model_hash ( filename ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 15:55:40 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-09 22:17:58 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . sha256  =  hashes . sha256_from_cache ( self . filename ,  f " checkpoint/ { name } " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 15:55:40 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . shorthash  =  self . sha256 [ 0 : 10 ]  if  self . sha256  else  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 18:58:08 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . title  =  name  if  self . shorthash  is  None  else  f ' { name }  [ { self . shorthash } ] ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . ids  =  [ self . hash ,  self . model_name ,  self . title ,  name ,  f ' { name }  [ { self . hash } ] ' ]  +  ( [ self . shorthash ,  self . sha256 ,  f ' { self . name }  [ { self . shorthash } ] ' ]  if  self . shorthash  else  [ ] ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-04-02 17:41:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . metadata  =  { } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        _ ,  ext  =  os . path . splitext ( self . filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  ext . lower ( )  ==  " .safetensors " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                self . metadata  =  read_metadata_from_safetensors ( filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            except  Exception  as  e : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                errors . display ( e ,  f " reading checkpoint metadata:  { filename } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  register ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        checkpoints_list [ self . title ]  =  self 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  id  in  self . ids : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            checkpoint_alisases [ id ]  =  self 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  calculate_shorthash ( self ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-09 22:17:58 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . sha256  =  hashes . sha256 ( self . filename ,  f " checkpoint/ { self . name } " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-04 11:38:56 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  self . sha256  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . shorthash  =  self . sha256 [ 0 : 10 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  self . shorthash  not  in  self . ids : 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-04 15:23:16 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            self . ids  + =  [ self . shorthash ,  self . sha256 ,  f ' { self . name }  [ { self . shorthash } ] ' ] 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-02-04 15:23:16 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        checkpoints_list . pop ( self . title ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 18:58:08 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . title  =  f ' { self . name }  [ { self . shorthash } ] ' 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-04 15:23:16 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . register ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 18:58:08 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  self . shorthash 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								try :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 09:02:23 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    from  transformers  import  logging ,  CLIPModel   # noqa: F401 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    logging . set_verbosity_error ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								except  Exception :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    pass 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 21:09:10 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  setup_model ( ) :  
						 
					
						
							
								
									
										
										
										
											2022-09-27 11:01:13 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  not  os . path . exists ( model_path ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        os . makedirs ( model_path ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 21:09:10 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-08 18:14:35 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    enable_midas_autodownload ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 19:59:36 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  checkpoint_tiles ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  convert ( name ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  int ( name )  if  name . isdigit ( )  else  name . lower ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  alphanumeric_key ( key ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  [ convert ( c )  for  c  in  re . split ( ' ([0-9]+) ' ,  key ) ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  sorted ( [ x . title  for  x  in  checkpoints_list . values ( ) ] ,  key = alphanumeric_key ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 00:59:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								def  list_models ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    checkpoints_list . clear ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    checkpoint_alisases . clear ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_ckpt  =  shared . cmd_opts . ckpt 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-19 20:49:07 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  shared . cmd_opts . no_download_sd_model  or  cmd_ckpt  !=  shared . sd_model_file  or  os . path . exists ( cmd_ckpt ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-19 20:37:40 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        model_url  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        model_url  =  " https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    model_list  =  modelloader . load_models ( model_path = model_path ,  model_url = model_url ,  command_path = shared . cmd_opts . ckpt_dir ,  ext_filter = [ " .ckpt " ,  " .safetensors " ] ,  download_name = " v1-5-pruned-emaonly.safetensors " ,  ext_blacklist = [ " .vae.ckpt " ,  " .vae.safetensors " ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    if  os . path . exists ( cmd_ckpt ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        checkpoint_info  =  CheckpointInfo ( cmd_ckpt ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        checkpoint_info . register ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        shared . opts . data [ ' sd_model_checkpoint ' ]  =  checkpoint_info . title 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    elif  cmd_ckpt  is  not  None  and  cmd_ckpt  !=  shared . default_sd_model_file : 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-27 11:01:13 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        print ( f " Checkpoint in --ckpt argument not found (Possible it was moved to  { model_path } :  { cmd_ckpt } " ,  file = sys . stderr ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-28 20:03:57 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    for  filename  in  sorted ( model_list ,  key = str . lower ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        checkpoint_info  =  CheckpointInfo ( filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        checkpoint_info . register ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 23:26:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_closet_checkpoint_match ( search_string ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    checkpoint_info  =  checkpoint_alisases . get ( search_string ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  checkpoint_info  is  not  None : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 10:25:21 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  checkpoint_info 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-30 11:42:40 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    found  =  sorted ( [ info  for  info  in  checkpoints_list . values ( )  if  search_string  in  info . title ] ,  key = lambda  x :  len ( x . title ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  found : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  found [ 0 ] 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-28 22:30:09 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  None 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-30 11:42:40 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								def  model_hash ( filename ) :  
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    """ old hash that only looks at a small part of the file and is prone to collisions """ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        with  open ( filename ,  " rb " )  as  file : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            import  hashlib 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            m  =  hashlib . sha256 ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            file . seek ( 0x100000 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            m . update ( file . read ( 0x10000 ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  m . hexdigest ( ) [ 0 : 8 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    except  FileNotFoundError : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  ' NOFILE ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  select_checkpoint ( ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-26 15:08:53 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    """ Raises `FileNotFoundError` if no checkpoints are found. """ 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    model_checkpoint  =  shared . opts . sd_model_checkpoint 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    checkpoint_info  =  checkpoint_alisases . get ( model_checkpoint ,  None ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    if  checkpoint_info  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  checkpoint_info 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  len ( checkpoints_list )  ==  0 : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-26 15:08:53 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        error_message  =  " No checkpoints found. When searching for checkpoints, looked at: " 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 21:09:10 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  shared . cmd_opts . ckpt  is  not  None : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-26 15:08:53 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            error_message  + =  f " \n  - file  { os . path . abspath ( shared . cmd_opts . ckpt ) } " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        error_message  + =  f " \n  - directory  { model_path } " 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 21:09:10 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  shared . cmd_opts . ckpt_dir  is  not  None : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-26 15:08:53 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            error_message  + =  f " \n  - directory  { os . path . abspath ( shared . cmd_opts . ckpt_dir ) } " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        error_message  + =  " Can ' t run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        raise  FileNotFoundError ( error_message ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    checkpoint_info  =  next ( iter ( checkpoints_list . values ( ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  model_checkpoint  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Checkpoint  { model_checkpoint }  not found; loading fallback  { checkpoint_info . title } " ,  file = sys . stderr ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  checkpoint_info 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-23 14:28:08 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								checkpoint_dict_replacements  =  {  
						 
					
						
							
								
									
										
										
										
											2022-10-19 08:42:22 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    ' cond_stage_model.transformer.embeddings. ' :  ' cond_stage_model.transformer.text_model.embeddings. ' , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ' cond_stage_model.transformer.encoder. ' :  ' cond_stage_model.transformer.text_model.encoder. ' , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ' cond_stage_model.transformer.final_layer_norm. ' :  ' cond_stage_model.transformer.text_model.final_layer_norm. ' , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  transform_checkpoint_dict_key ( k ) :  
						 
					
						
							
								
									
										
										
										
											2023-03-23 14:28:08 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    for  text ,  replacement  in  checkpoint_dict_replacements . items ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-19 08:42:22 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  k . startswith ( text ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            k  =  replacement  +  k [ len ( text ) : ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  k 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-09 10:23:31 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_state_dict_from_checkpoint ( pl_sd ) :  
						 
					
						
							
								
									
										
										
										
											2022-11-28 08:39:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    pl_sd  =  pl_sd . pop ( " state_dict " ,  pl_sd ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    pl_sd . pop ( " state_dict " ,  None ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-19 08:42:22 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    sd  =  { } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  k ,  v  in  pl_sd . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        new_key  =  transform_checkpoint_dict_key ( k ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  new_key  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            sd [ new_key ]  =  v 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-09 10:23:31 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-19 12:45:30 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    pl_sd . clear ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    pl_sd . update ( sd ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  pl_sd 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-09 10:23:31 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-14 09:10:26 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  read_metadata_from_safetensors ( filename ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    import  json 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    with  open ( filename ,  mode = " rb " )  as  file : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        metadata_len  =  file . read ( 8 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        metadata_len  =  int . from_bytes ( metadata_len ,  " little " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        json_start  =  file . read ( 2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        assert  metadata_len  >  2  and  json_start  in  ( b ' { " ' ,  b " { ' " ) ,  f " { filename }  is not a safetensors file " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        json_data  =  json_start  +  file . read ( metadata_len - 2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        json_obj  =  json . loads ( json_data ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        res  =  { } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  k ,  v  in  json_obj . get ( " __metadata__ " ,  { } ) . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            res [ k ]  =  v 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-14 11:22:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            if  isinstance ( v ,  str )  and  v [ 0 : 1 ]  ==  ' { ' : 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-14 09:10:26 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    res [ k ]  =  json . loads ( v ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 07:52:45 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                except  Exception : 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-14 09:10:26 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    pass 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  res 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 15:51:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  read_state_dict ( checkpoint_file ,  print_global_state = False ,  map_location = None ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    _ ,  extension  =  os . path . splitext ( checkpoint_file ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  extension . lower ( )  ==  " .safetensors " : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        device  =  map_location  or  shared . weight_load_location  or  devices . get_optimal_device_name ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-21 13:45:58 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        pl_sd  =  safetensors . torch . load_file ( checkpoint_file ,  device = device ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 15:51:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        pl_sd  =  torch . load ( checkpoint_file ,  map_location = map_location  or  shared . weight_load_location ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  print_global_state  and  " global_step "  in  pl_sd : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Global Step:  { pl_sd [ ' global_step ' ] } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    sd  =  get_state_dict_from_checkpoint ( pl_sd ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  sd 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_checkpoint_state_dict ( checkpoint_info :  CheckpointInfo ,  timer ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    sd_model_hash  =  checkpoint_info . calculate_shorthash ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    timer . record ( " calculate hash " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  checkpoint_info  in  checkpoints_loaded : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # use checkpoint cache 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Loading weights [ { sd_model_hash } ] from cache " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  checkpoints_loaded [ checkpoint_info ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( f " Loading weights [ { sd_model_hash } ] from  { checkpoint_info . filename } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    res  =  read_state_dict ( checkpoint_info . filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    timer . record ( " load weights from disk " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  res 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  load_model_weights ( model ,  checkpoint_info :  CheckpointInfo ,  state_dict ,  timer ) :  
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    sd_model_hash  =  checkpoint_info . calculate_shorthash ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer . record ( " calculate hash " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-28 16:23:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . opts . data [ " sd_model_checkpoint " ]  =  checkpoint_info . title 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 23:26:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  state_dict  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        state_dict  =  get_checkpoint_state_dict ( checkpoint_info ,  timer ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-09 04:54:21 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    model . load_state_dict ( state_dict ,  strict = False ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    del  state_dict 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    timer . record ( " apply weights to model " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  shared . opts . sd_checkpoint_cache  >  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # cache newly loaded model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        checkpoints_loaded [ checkpoint_info ]  =  model . state_dict ( ) . copy ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . cmd_opts . opt_channelslast : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        model . to ( memory_format = torch . channels_last ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        timer . record ( " apply channels_last " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  not  shared . cmd_opts . no_half : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        vae  =  model . first_stage_model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        depth_model  =  getattr ( model ,  ' depth_model ' ,  None ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  shared . cmd_opts . no_half_vae : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            model . first_stage_model  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # with --upcast-sampling, don't convert the depth model weights to float16 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  shared . cmd_opts . upcast_sampling  and  depth_model : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            model . depth_model  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-02 14:41:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        model . half ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        model . first_stage_model  =  vae 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  depth_model : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            model . depth_model  =  depth_model 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-02 14:41:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        timer . record ( " apply half() " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    devices . dtype_unet  =  model . model . diffusion_model . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    devices . unet_needs_upcast  =  shared . cmd_opts . upcast_sampling  and  devices . dtype  ==  torch . float16  and  devices . dtype_unet  ==  torch . float16 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    model . first_stage_model . to ( devices . dtype_vae ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    timer . record ( " apply dtype to VAE " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-02 14:41:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-09 04:54:21 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    # clean up cache if limit is reached 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    while  len ( checkpoints_loaded )  >  shared . opts . sd_checkpoint_cache : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        checkpoints_loaded . popitem ( last = False ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-31 16:27:27 +07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    model . sd_model_hash  =  sd_model_hash 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 09:56:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    model . sd_model_checkpoint  =  checkpoint_info . filename 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 23:26:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    model . sd_checkpoint_info  =  checkpoint_info 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 15:55:40 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . opts . data [ " sd_checkpoint_hash " ]  =  checkpoint_info . sha256 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-02 00:38:09 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    model . logvar  =  model . logvar . to ( devices . device )   # fix for training 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-13 11:11:14 +07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    sd_vae . delete_base_vae ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-03 11:10:53 +07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    sd_vae . clear_loaded_vae ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 19:56:09 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    vae_file ,  vae_source  =  sd_vae . resolve_vae ( checkpoint_info . filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    sd_vae . load_vae ( model ,  vae_file ,  vae_source ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer . record ( " load VAE " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-02 12:51:46 +07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-08 18:14:35 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  enable_midas_autodownload ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    Gives  the  ldm . modules . midas . api . load_model  function  automatic  downloading . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    When  the  512 - depth - ema  model ,  and  other  future  models  like  it ,  is  loaded , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    it  calls  midas . api . load_model  to  load  the  associated  midas  depth  model . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    This  function  applies  a  wrapper  to  download  the  model  to  the  correct 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    location  automatically . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 17:15:42 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    midas_path  =  os . path . join ( paths . models_path ,  ' midas ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-08 18:14:35 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # stable-diffusion-stability-ai hard-codes the midas model path to 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # a location that differs from where other scripts using this model look. 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # HACK: Overriding the path here. 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  k ,  v  in  midas . api . ISL_PATHS . items ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        file_name  =  os . path . basename ( v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        midas . api . ISL_PATHS [ k ]  =  os . path . join ( midas_path ,  file_name ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    midas_urls  =  { 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " dpt_large " :  " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt " , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " dpt_hybrid " :  " https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt " , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " midas_v21 " :  " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt " , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " midas_v21_small " :  " https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt " , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    midas . api . load_model_inner  =  midas . api . load_model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  load_model_wrapper ( model_type ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        path  =  midas . api . ISL_PATHS [ model_type ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  not  os . path . exists ( path ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  not  os . path . exists ( midas_path ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                mkdir ( midas_path ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-08 18:14:35 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            print ( f " Downloading midas model weights for  { model_type }  to  { path } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            request . urlretrieve ( midas_urls [ model_type ] ,  path ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            print ( f " { model_type }  downloaded " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  midas . api . load_model_inner ( model_type ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    midas . api . load_model  =  load_model_wrapper 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  repair_config ( sd_config ) :  
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  not  hasattr ( sd_config . model . params ,  " use_ema " ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_config . model . params . use_ema  =  False 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  shared . cmd_opts . no_half : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_config . model . params . unet_config . params . use_fp16  =  False 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  shared . cmd_opts . upcast_sampling : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_config . model . params . unet_config . params . use_fp16  =  True 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-26 16:55:29 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  getattr ( sd_config . model . params . first_stage_config . params . ddconfig ,  " attn_type " ,  None )  ==  " vanilla-xformers "  and  not  shared . xformers_available : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_config . model . params . first_stage_config . params . ddconfig . attn_type  =  " vanilla " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-24 22:48:16 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    # For UnCLIP-L, override the hardcoded karlo directory 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  hasattr ( sd_config . model . params ,  " noise_aug_config " )  and  hasattr ( sd_config . model . params . noise_aug_config . params ,  " clip_stats_path " ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        karlo_path  =  os . path . join ( paths . models_path ,  ' karlo ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_config . model . params . noise_aug_config . params . clip_stats_path  =  sd_config . model . params . noise_aug_config . params . clip_stats_path . replace ( " checkpoints/karlo_models " ,  karlo_path ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-02-05 11:20:47 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								sd1_clip_weight  =  ' cond_stage_model.transformer.text_model.embeddings.token_embedding.weight '  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								sd2_clip_weight  =  ' cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight '  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdModelData :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . sd_model  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 15:47:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . was_loaded_at_least_once  =  False 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        self . lock  =  threading . Lock ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  get_sd_model ( self ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 15:47:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  self . was_loaded_at_least_once : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  self . sd_model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  self . sd_model  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            with  self . lock : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 15:47:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                if  self . sd_model  is  not  None  or  self . was_loaded_at_least_once : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-14 13:27:50 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    return  self . sd_model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    load_model ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                except  Exception  as  e : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-26 15:15:59 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    errors . display ( e ,  " loading stable diffusion model " ,  full_traceback = True ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    print ( " " ,  file = sys . stderr ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    print ( " Stable diffusion model failed to load " ,  file = sys . stderr ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    self . sd_model  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  self . sd_model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  set_sd_model ( self ,  v ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . sd_model  =  v 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								model_data  =  SdModelData ( )  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  load_model ( checkpoint_info = None ,  already_loaded_state_dict = None ) :  
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    from  modules  import  lowvram ,  sd_hijack 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-20 16:01:27 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    checkpoint_info  =  checkpoint_info  or  select_checkpoint ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 23:26:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  model_data . sd_model : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_hijack . model_hijack . undo_hijack ( model_data . sd_model ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        model_data . sd_model  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-01 04:01:49 -03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        gc . collect ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        devices . torch_gc ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    do_inpainting_hijack ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 08:53:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer  =  Timer ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-11 10:19:46 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  already_loaded_state_dict  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        state_dict  =  already_loaded_state_dict 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        state_dict  =  get_checkpoint_state_dict ( checkpoint_info ,  timer ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-01 04:01:49 -03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    checkpoint_config  =  sd_models_config . find_checkpoint_config ( state_dict ,  checkpoint_info ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-05 11:20:47 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    clip_is_included_into_sd  =  sd1_clip_weight  in  state_dict  or  sd2_clip_weight  in  state_dict 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-26 13:28:44 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer . record ( " find config " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    sd_config  =  OmegaConf . load ( checkpoint_config ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    repair_config ( sd_config ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    timer . record ( " load config " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( f " Creating model from config:  { checkpoint_config } " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-11 18:54:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    sd_model  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 17:46:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    try : 
							 
						 
					
						
							
								
									
										
										
										
											2023-02-05 11:20:47 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        with  sd_disable_initialization . DisableInitialization ( disable_clip = clip_is_included_into_sd ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 17:46:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            sd_model  =  instantiate_from_config ( sd_config . model ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 07:52:45 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    except  Exception : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-11 10:24:56 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        pass 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  sd_model  is  None : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 17:46:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        print ( ' Failed to create model quickly; will retry using slow method. ' ,  file = sys . stderr ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 14:08:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sd_model  =  instantiate_from_config ( sd_config . model ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    sd_model . used_config  =  checkpoint_config 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer . record ( " create model " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    load_model_weights ( sd_model ,  checkpoint_info ,  state_dict ,  timer ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    if  shared . cmd_opts . lowvram  or  shared . cmd_opts . medvram : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        lowvram . setup_for_low_vram ( sd_model ,  shared . cmd_opts . medvram ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_model . to ( shared . device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer . record ( " move model to device " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    sd_hijack . model_hijack . hijack ( sd_model ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer . record ( " hijack " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    sd_model . eval ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    model_data . sd_model  =  sd_model 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 15:47:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    model_data . was_loaded_at_least_once  =  True 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-22 12:23:45 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-03 18:39:14 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    sd_hijack . model_hijack . embedding_db . load_textual_inversion_embeddings ( force_reload = True )   # Reload embeddings after model load as they may or may not fit the model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer . record ( " load textual inversion embeddings " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-22 20:15:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    script_callbacks . model_loaded_callback ( sd_model ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer . record ( " scripts callbacks " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-22 00:13:53 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    with  devices . autocast ( ) ,  torch . no_grad ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_model . cond_stage_model_empty_prompt  =  sd_model . cond_stage_model ( [ " " ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    timer . record ( " calculate empty prompt " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    print ( f " Model loaded in  { timer . summary ( ) } . " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-31 11:27:02 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    return  sd_model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-02 12:51:46 +07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  reload_model_weights ( sd_model = None ,  info = None ) :  
						 
					
						
							
								
									
										
										
										
											2022-09-29 15:40:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    from  modules  import  lowvram ,  devices ,  sd_hijack 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 13:49:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    checkpoint_info  =  info  or  select_checkpoint ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-01 04:01:49 -03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  not  sd_model : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sd_model  =  model_data . sd_model 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  sd_model  is  None :   # previous model load failed 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 18:34:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        current_checkpoint_info  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        current_checkpoint_info  =  sd_model . sd_checkpoint_info 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  sd_model . sd_model_checkpoint  ==  checkpoint_info . filename : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-27 15:47:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sd_unet . apply_unet ( " None " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:54:19 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  shared . cmd_opts . lowvram  or  shared . cmd_opts . medvram : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            lowvram . send_everything_to_cpu ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            sd_model . to ( devices . cpu ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:54:19 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sd_hijack . model_hijack . undo_hijack ( sd_model ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 15:40:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-10 16:51:04 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer  =  Timer ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    state_dict  =  get_checkpoint_state_dict ( checkpoint_info ,  timer ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    checkpoint_config  =  sd_models_config . find_checkpoint_config ( state_dict ,  checkpoint_info ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    timer . record ( " find config " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  sd_model  is  None  or  checkpoint_config  !=  sd_model . used_config : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  sd_model 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-09 07:56:19 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        load_model ( checkpoint_info ,  already_loaded_state_dict = state_dict ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  model_data . sd_model 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    try : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        load_model_weights ( sd_model ,  checkpoint_info ,  state_dict ,  timer ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 07:52:45 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    except  Exception : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        print ( " Failed to load checkpoint, restoring previous " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        load_model_weights ( sd_model ,  current_checkpoint_info ,  None ,  timer ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        raise 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    finally : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_hijack . model_hijack . hijack ( sd_model ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        timer . record ( " hijack " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        script_callbacks . model_loaded_callback ( sd_model ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        timer . record ( " script callbacks " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  not  shared . cmd_opts . lowvram  and  not  shared . cmd_opts . medvram : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            sd_model . to ( devices . device ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            timer . record ( " move model to device " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 11:28:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    print ( f " Weights loaded in  { timer . summary ( ) } . " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-04 12:35:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 12:05:04 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    return  sd_model 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-09 07:56:19 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-09 07:56:19 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  unload_model_weights ( sd_model = None ,  info = None ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-10 08:43:42 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    from  modules  import  devices ,  sd_hijack 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-09 07:56:19 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    timer  =  Timer ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-02 09:08:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  model_data . sd_model : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        model_data . sd_model . to ( devices . cpu ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sd_hijack . model_hijack . undo_hijack ( model_data . sd_model ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        model_data . sd_model  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-09 07:56:19 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sd_model  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        gc . collect ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        devices . torch_gc ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        torch . cuda . empty_cache ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( f " Unloaded weights  { timer . summary ( ) } . " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-04-04 02:26:44 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  sd_model 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-17 20:22:38 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  apply_token_merging ( sd_model ,  token_merging_ratio ) :  
						 
					
						
							
								
									
										
										
										
											2023-04-04 02:26:44 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    Applies  speed  and  memory  optimizations  from  tomesd . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-17 20:22:38 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    current_token_merging_ratio  =  getattr ( sd_model ,  ' applied_token_merged_ratio ' ,  0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  current_token_merging_ratio  ==  token_merging_ratio : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  current_token_merging_ratio  >  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        tomesd . remove_patch ( sd_model ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  token_merging_ratio  >  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        tomesd . apply_patch ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            sd_model , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ratio = token_merging_ratio , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            use_rand = False ,   # can cause issues with some samplers 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            merge_attn = True , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            merge_crossattn = False , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            merge_mlp = False 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    sd_model . applied_token_merged_ratio  =  token_merging_ratio