mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 03:55:05 +00:00 
			
		
		
		
	init job and add info to model merge
This commit is contained in:
		
							parent
							
								
									e9fb9bb0c2
								
							
						
					
					
						commit
						1d9dc48efd
					
				@ -242,6 +242,9 @@ def run_pnginfo(image):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
 | 
			
		||||
    shared.state.begin()
 | 
			
		||||
    shared.state.job = 'model-merge'
 | 
			
		||||
 | 
			
		||||
    def weighted_sum(theta0, theta1, alpha):
 | 
			
		||||
        return ((1 - alpha) * theta0) + (alpha * theta1)
 | 
			
		||||
 | 
			
		||||
@ -263,8 +266,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
 | 
			
		||||
    theta_func1, theta_func2 = theta_funcs[interp_method]
 | 
			
		||||
 | 
			
		||||
    if theta_func1 and not tertiary_model_info:
 | 
			
		||||
        shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
 | 
			
		||||
        shared.state.end()
 | 
			
		||||
        return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
 | 
			
		||||
 | 
			
		||||
    shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
 | 
			
		||||
    print(f"Loading {secondary_model_info.filename}...")
 | 
			
		||||
    theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
 | 
			
		||||
 | 
			
		||||
@ -281,6 +287,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
 | 
			
		||||
                    theta_1[key] = torch.zeros_like(theta_1[key])
 | 
			
		||||
        del theta_2
 | 
			
		||||
 | 
			
		||||
    shared.state.textinfo = f"Loading {primary_model_info.filename}..."
 | 
			
		||||
    print(f"Loading {primary_model_info.filename}...")
 | 
			
		||||
    theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
 | 
			
		||||
 | 
			
		||||
@ -291,6 +298,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
 | 
			
		||||
            a = theta_0[key]
 | 
			
		||||
            b = theta_1[key]
 | 
			
		||||
 | 
			
		||||
            shared.state.textinfo = f'Merging layer {key}'
 | 
			
		||||
            # this enables merging an inpainting model (A) with another one (B);
 | 
			
		||||
            # where normal model would have 4 channels, for latenst space, inpainting model would
 | 
			
		||||
            # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
 | 
			
		||||
@ -303,8 +311,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
 | 
			
		||||
                theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
 | 
			
		||||
                result_is_inpainting_model = True
 | 
			
		||||
            else:
 | 
			
		||||
                assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'
 | 
			
		||||
 | 
			
		||||
                theta_0[key] = theta_func2(a, b, multiplier)
 | 
			
		||||
 | 
			
		||||
            if save_as_half:
 | 
			
		||||
@ -332,6 +338,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
 | 
			
		||||
 | 
			
		||||
    output_modelname = os.path.join(ckpt_dir, filename)
 | 
			
		||||
 | 
			
		||||
    shared.state.textinfo = f"Saving to {output_modelname}..."
 | 
			
		||||
    print(f"Saving to {output_modelname}...")
 | 
			
		||||
 | 
			
		||||
    _, extension = os.path.splitext(output_modelname)
 | 
			
		||||
@ -343,4 +350,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
 | 
			
		||||
    sd_models.list_models()
 | 
			
		||||
 | 
			
		||||
    print("Checkpoint saved.")
 | 
			
		||||
    shared.state.textinfo = "Checkpoint saved to " + output_modelname
 | 
			
		||||
    shared.state.end()
 | 
			
		||||
 | 
			
		||||
    return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user