2022-09-20 00:13:12 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  os  
						 
					
						
							
								
									
										
										
										
											2023-01-22 10:17:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  re  
						 
					
						
							
								
									
										
										
										
											2023-01-11 09:10:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  shutil  
						 
					
						
							
								
									
										
										
										
											2022-09-20 00:13:12 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 11:31:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-25 19:22:12 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
									
										
										
										
											2022-09-27 10:44:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  tqdm  
						 
					
						
							
								
									
										
										
										
											2022-09-25 19:22:12 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 22:43:08 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  modules  import  shared ,  images ,  sd_models ,  sd_vae ,  sd_models_config  
						 
					
						
							
								
									
										
										
										
											2023-01-23 14:50:20 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  modules . ui_common  import  plaintext_to_html  
						 
					
						
							
								
									
										
										
										
											2022-09-29 00:59:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  gradio  as  gr  
						 
					
						
							
								
									
										
										
										
											2022-11-27 15:51:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  safetensors . torch  
						 
					
						
							
								
									
										
										
										
											2022-09-13 19:23:55 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 11:31:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-17 16:07:07 +10:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  run_pnginfo ( image ) :  
						 
					
						
							
								
									
										
										
										
											2022-09-19 13:18:16 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  image  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  ' ' ,  ' ' ,  ' ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 16:28:32 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    geninfo ,  items  =  images . read_info_from_image ( image ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    items  =  { * * { ' parameters ' :  geninfo } ,  * * items } 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-24 02:39:09 +00:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 11:31:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    info  =  ' ' 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-13 19:23:55 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    for  key ,  text  in  items . items ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-11 11:31:16 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        info  + =  f """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								< div >  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								< p > < b > { plaintext_to_html ( str ( key ) ) } < / b > < / p >  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								< p > { plaintext_to_html ( str ( text ) ) } < / p >  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								< / div >  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								""" .strip()+ " \n " 
  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  len ( info )  ==  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        message  =  " Nothing found in the image. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        info  =  f " <div><p> { message } <p></div> " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-23 22:49:21 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  ' ' ,  geninfo ,  info 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-25 19:22:12 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-11 09:10:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  create_config ( ckpt_result ,  config_source ,  a ,  b ,  c ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  config ( x ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-27 22:43:08 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        res  =  sd_models_config . find_checkpoint_config_near_filename ( x )  if  x  else  None 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  res  if  res  !=  shared . sd_default_config  else  None 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-11 09:10:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  config_source  ==  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        cfg  =  config ( a )  or  config ( b )  or  config ( c ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  config_source  ==  1 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        cfg  =  config ( b ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  config_source  ==  2 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        cfg  =  config ( c ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        cfg  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  cfg  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    filename ,  _  =  os . path . splitext ( ckpt_result ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    checkpoint_filename  =  filename  +  " .yaml " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( " Copying config: " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( "    from: " ,  cfg ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( "      to: " ,  checkpoint_filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    shutil . copyfile ( cfg ,  checkpoint_filename ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:24:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								checkpoint_dict_skip_on_merge  =  [ " cond_stage_model.transformer.text_model.embeddings.position_ids " ]  
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 12:12:09 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  to_half ( tensor ,  enable ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  enable  and  tensor . dtype  ==  torch . float : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  tensor . half ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  tensor 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-22 10:17:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  run_modelmerger ( id_task ,  primary_model_name ,  secondary_model_name ,  tertiary_model_name ,  interp_method ,  multiplier ,  save_as_half ,  custom_name ,  checkpoint_format ,  config_source ,  bake_in_vae ,  discard_weights ) :  
						 
					
						
							
								
									
										
										
										
											2023-01-03 10:21:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . state . begin ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    shared . state . job  =  ' model-merge ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 08:53:50 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  fail ( message ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        shared . state . textinfo  =  message 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        shared . state . end ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 09:25:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  [ * [ gr . update ( )  for  _  in  range ( 4 ) ] ,  message ] 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 08:53:50 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-16 18:44:39 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  weighted_sum ( theta0 ,  theta1 ,  alpha ) : 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-26 10:50:21 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  ( ( 1  -  alpha )  *  theta0 )  +  ( alpha  *  theta1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-16 18:44:39 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  get_difference ( theta1 ,  theta2 ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  theta1  -  theta2 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  add_difference ( theta0 ,  theta1_2_diff ,  alpha ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  theta0  +  ( alpha  *  theta1_2_diff ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-14 09:05:06 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:24:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  filename_weighted_sum ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        a  =  primary_model_info . model_name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        b  =  secondary_model_info . model_name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        Ma  =  round ( 1  -  multiplier ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        Mb  =  round ( multiplier ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  f " { Ma } ( { a } ) +  { Mb } ( { b } ) " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:24:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  filename_add_difference ( ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        a  =  primary_model_info . model_name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        b  =  secondary_model_info . model_name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        c  =  tertiary_model_info . model_name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        M  =  round ( multiplier ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  f " { a }  +  { M } ( { b }  -  { c } ) " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  filename_nothing ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  primary_model_info . model_name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    theta_funcs  =  { 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:24:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        " Weighted sum " :  ( filename_weighted_sum ,  None ,  weighted_sum ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        " Add difference " :  ( filename_add_difference ,  get_difference ,  add_difference ) , 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        " No interpolation " :  ( filename_nothing ,  None ,  None ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    } 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    filename_generator ,  theta_func1 ,  theta_func2  =  theta_funcs [ interp_method ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    shared . state . job_count  =  ( 1  if  theta_func1  else  0 )  +  ( 1  if  theta_func2  else  0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-18 19:13:15 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  not  primary_model_name : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 08:53:50 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  fail ( " Failed: Merging requires a primary model. " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-18 19:13:15 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 00:59:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    primary_model_info  =  sd_models . checkpoints_list [ primary_model_name ] 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-18 19:13:15 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  theta_func2  and  not  secondary_model_name : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 08:53:50 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  fail ( " Failed: Merging requires a secondary model. " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-27 10:44:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    secondary_model_info  =  sd_models . checkpoints_list [ secondary_model_name ]  if  theta_func2  else  None 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-27 10:44:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-18 19:13:15 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  theta_func1  and  not  tertiary_model_name : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 08:53:50 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  fail ( f " Failed: Interpolation method ( { interp_method } ) requires a tertiary model. " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-18 21:21:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    tertiary_model_info  =  sd_models . checkpoints_list [ tertiary_model_name ]  if  theta_func1  else  None 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-18 19:13:15 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    result_is_inpainting_model  =  False 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 06:05:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    result_is_instruct_pix2pix_model  =  False 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 01:13:36 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  theta_func2 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        shared . state . textinfo  =  f " Loading B " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Loading  { secondary_model_info . filename } ... " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        theta_1  =  sd_models . read_state_dict ( secondary_model_info . filename ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        theta_1  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-14 09:05:06 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-16 18:44:39 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  theta_func1 : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        shared . state . textinfo  =  f " Loading C " 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 01:13:36 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        print ( f " Loading  { tertiary_model_info . filename } ... " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        theta_2  =  sd_models . read_state_dict ( tertiary_model_info . filename ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        shared . state . textinfo  =  ' Merging B and C ' 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 09:25:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        shared . state . sampling_steps  =  len ( theta_1 . keys ( ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-16 18:44:39 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        for  key  in  tqdm . tqdm ( theta_1 . keys ( ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:24:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            if  key  in  checkpoint_dict_skip_on_merge : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                continue 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-16 18:44:39 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            if  ' model '  in  key : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-18 15:33:24 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                if  key  in  theta_2 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    t2  =  theta_2 . get ( key ,  torch . zeros_like ( theta_1 [ key ] ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    theta_1 [ key ]  =  theta_func1 ( theta_1 [ key ] ,  t2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                else : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-18 16:05:52 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    theta_1 [ key ]  =  torch . zeros_like ( theta_1 [ key ] ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 09:25:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            shared . state . sampling_step  + =  1 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 01:13:36 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        del  theta_2 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 09:25:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        shared . state . nextjob ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-03 10:21:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . state . textinfo  =  f " Loading  { primary_model_info . filename } ... " 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 01:13:36 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    print ( f " Loading  { primary_model_info . filename } ... " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    theta_0  =  sd_models . read_state_dict ( primary_model_info . filename ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    print ( " Merging... " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . state . textinfo  =  ' Merging A and B ' 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 09:25:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . state . sampling_steps  =  len ( theta_0 . keys ( ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-27 10:44:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    for  key  in  tqdm . tqdm ( theta_0 . keys ( ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  theta_1  and  ' model '  in  key  and  key  in  theta_1 : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 20:00:00 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:24:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            if  key  in  checkpoint_dict_skip_on_merge : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 20:00:00 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                continue 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 12:30:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            a  =  theta_0 [ key ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            b  =  theta_1 [ key ] 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-14 21:20:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 12:30:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            # this enables merging an inpainting model (A) with another one (B); 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            # where normal model would have 4 channels, for latenst space, inpainting model would 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  a . shape  !=  b . shape  and  a . shape [ 0 : 1 ]  +  a . shape [ 2 : ]  ==  b . shape [ 0 : 1 ]  +  b . shape [ 2 : ] : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                if  a . shape [ 1 ]  ==  4  and  b . shape [ 1 ]  ==  9 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    raise  RuntimeError ( " When merging inpainting model with a normal one, A must be the inpainting model. " ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 04:38:04 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                if  a . shape [ 1 ]  ==  4  and  b . shape [ 1 ]  ==  8 : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 06:05:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    raise  RuntimeError ( " When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model. " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 12:30:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 06:05:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                if  a . shape [ 1 ]  ==  8  and  b . shape [ 1 ]  ==  4 : #If we have an Instruct-Pix2Pix model... 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 03:45:16 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    theta_0 [ key ] [ : ,  0 : 4 ,  : ,  : ]  =  theta_func2 ( a [ : ,  0 : 4 ,  : ,  : ] ,  b ,  multiplier ) #Merge only the vectors the models have in common.  Otherwise we get an error due to dimension mismatch. 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 06:05:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                    result_is_instruct_pix2pix_model  =  True 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 03:45:16 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    assert  a . shape [ 1 ]  ==  9  and  b . shape [ 1 ]  ==  4 ,  f " Bad dimensions for merged layer  { key } : A= { a . shape } , B= { b . shape } " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    theta_0 [ key ] [ : ,  0 : 4 ,  : ,  : ]  =  theta_func2 ( a [ : ,  0 : 4 ,  : ,  : ] ,  b ,  multiplier ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                    result_is_inpainting_model  =  True 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 12:30:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                theta_0 [ key ]  =  theta_func2 ( a ,  b ,  multiplier ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 03:45:16 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 12:12:09 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            theta_0 [ key ]  =  to_half ( theta_0 [ key ] ,  save_as_half ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-09 19:26:52 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 09:25:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        shared . state . sampling_step  + =  1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    del  theta_1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    bake_in_vae_filename  =  sd_vae . vae_dict . get ( bake_in_vae ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  bake_in_vae_filename  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        print ( f " Baking in VAE from  { bake_in_vae_filename } " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        shared . state . textinfo  =  ' Baking in VAE ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        vae_dict  =  sd_vae . load_vae_dict ( bake_in_vae_filename ,  map_location = ' cpu ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 20:00:00 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        for  key  in  vae_dict . keys ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            theta_0_key  =  ' first_stage_model. '  +  key 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  theta_0_key  in  theta_0 : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 12:12:09 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								                theta_0 [ theta_0_key ]  =  to_half ( vae_dict [ key ] ,  save_as_half ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-14 20:00:00 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        del  vae_dict 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-27 10:44:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 12:12:09 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  save_as_half  and  not  theta_func2 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  key  in  theta_0 . keys ( ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            theta_0 [ key ]  =  to_half ( theta_0 [ key ] ,  save_as_half ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-22 10:17:12 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  discard_weights : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        regex  =  re . compile ( discard_weights ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  key  in  list ( theta_0 ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            if  re . search ( regex ,  key ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                theta_0 . pop ( key ,  None ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-30 22:57:25 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    ckpt_dir  =  shared . cmd_opts . ckpt_dir  or  sd_models . model_path 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    filename  =  filename_generator ( )  if  custom_name  ==  ' '  else  custom_name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    filename  + =  " .inpainting "  if  result_is_inpainting_model  else  " " 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-26 11:27:07 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    filename  + =  " .instruct-pix2pix "  if  result_is_instruct_pix2pix_model  else  " " 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    filename  + =  " . "  +  checkpoint_format 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-04 12:30:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-30 22:57:25 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    output_modelname  =  os . path . join ( ckpt_dir ,  filename ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 00:21:54 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 09:25:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . state . nextjob ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . state . textinfo  =  " Saving " 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-27 10:44:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    print ( f " Saving to  { output_modelname } ... " ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-11-27 15:51:29 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    _ ,  extension  =  os . path . splitext ( output_modelname ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  extension . lower ( )  ==  " .safetensors " : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        safetensors . torch . save_file ( theta_0 ,  output_modelname ,  metadata = { " format " :  " pt " } ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        torch . save ( theta_0 ,  output_modelname ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-09-27 10:44:00 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-09-29 00:59:44 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    sd_models . list_models ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-11 09:10:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    create_config ( output_modelname ,  config_source ,  primary_model_info ,  secondary_model_info ,  tertiary_model_info ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 10:39:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    print ( f " Checkpoint saved to  { output_modelname } . " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    shared . state . textinfo  =  " Checkpoint saved " 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-03 10:21:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    shared . state . end ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-19 09:25:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  [ * [ gr . Dropdown . update ( choices = sd_models . checkpoint_tiles ( ) )  for  _  in  range ( 4 ) ] ,  " Checkpoint saved to  "  +  output_modelname ]