mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-03 19:44:27 +00:00 
			
		
		
		
	xy_grid: Refactor confirm functions
This commit is contained in:
		
							parent
							
								
									7dba1c07cb
								
							
						
					
					
						commit
						2fffd4bddc
					
				@ -77,12 +77,26 @@ def apply_sampler(p, x, xs):
 | 
			
		||||
    p.sampler_index = sampler_index
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def confirm_samplers(p, xs):
 | 
			
		||||
    samplers_dict = build_samplers_dict(p)
 | 
			
		||||
    for x in xs:
 | 
			
		||||
        if x.lower() not in samplers_dict.keys():
 | 
			
		||||
            raise RuntimeError(f"Unknown sampler: {x}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_checkpoint(p, x, xs):
 | 
			
		||||
    info = modules.sd_models.get_closet_checkpoint_match(x)
 | 
			
		||||
    assert info is not None, f'Checkpoint for {x} not found'
 | 
			
		||||
    if info is None:
 | 
			
		||||
        raise RuntimeError(f"Unknown checkpoint: {x}")
 | 
			
		||||
    modules.sd_models.reload_model_weights(shared.sd_model, info)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def confirm_checkpoints(p, xs):
 | 
			
		||||
    for x in xs:
 | 
			
		||||
        if modules.sd_models.get_closet_checkpoint_match(x) is None:
 | 
			
		||||
            raise RuntimeError(f"Unknown checkpoint: {x}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_hypernetwork(p, x, xs):
 | 
			
		||||
    if x.lower() in ["", "none"]:
 | 
			
		||||
        name = None
 | 
			
		||||
@ -93,7 +107,7 @@ def apply_hypernetwork(p, x, xs):
 | 
			
		||||
    hypernetwork.load_hypernetwork(name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def confirm_hypernetworks(xs):
 | 
			
		||||
def confirm_hypernetworks(p, xs):
 | 
			
		||||
    for x in xs:
 | 
			
		||||
        if x.lower() in ["", "none"]:
 | 
			
		||||
            continue
 | 
			
		||||
@ -135,29 +149,29 @@ def str_permutations(x):
 | 
			
		||||
    return x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
 | 
			
		||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
 | 
			
		||||
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
 | 
			
		||||
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
axis_options = [
 | 
			
		||||
    AxisOption("Nothing", str, do_nothing, format_nothing),
 | 
			
		||||
    AxisOption("Seed", int, apply_field("seed"), format_value_add_label),
 | 
			
		||||
    AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label),
 | 
			
		||||
    AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label),
 | 
			
		||||
    AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
 | 
			
		||||
    AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
 | 
			
		||||
    AxisOption("Prompt S/R", str, apply_prompt, format_value),
 | 
			
		||||
    AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
 | 
			
		||||
    AxisOption("Sampler", str, apply_sampler, format_value),
 | 
			
		||||
    AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
 | 
			
		||||
    AxisOption("Hypernetwork", str, apply_hypernetwork, format_value),
 | 
			
		||||
    AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
 | 
			
		||||
    AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
 | 
			
		||||
    AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
 | 
			
		||||
    AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
 | 
			
		||||
    AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
 | 
			
		||||
    AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label),
 | 
			
		||||
    AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),  # as it is now all AxisOptionImg2Img items must go after AxisOption ones
 | 
			
		||||
    AxisOption("Nothing", str, do_nothing, format_nothing, None),
 | 
			
		||||
    AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
 | 
			
		||||
    AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
 | 
			
		||||
    AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
 | 
			
		||||
    AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
 | 
			
		||||
    AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
 | 
			
		||||
    AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
 | 
			
		||||
    AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
 | 
			
		||||
    AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),  # as it is now all AxisOptionImg2Img items must go after AxisOption ones
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -283,19 +297,10 @@ class Script(scripts.Script):
 | 
			
		||||
                valslist = list(permutations(valslist))
 | 
			
		||||
 | 
			
		||||
            valslist = [opt.type(x) for x in valslist]
 | 
			
		||||
            
 | 
			
		||||
 | 
			
		||||
            # Confirm options are valid before starting
 | 
			
		||||
            if opt.label == "Sampler":
 | 
			
		||||
                samplers_dict = build_samplers_dict(p)
 | 
			
		||||
                for sampler_val in valslist:
 | 
			
		||||
                    if sampler_val.lower() not in samplers_dict.keys():
 | 
			
		||||
                        raise RuntimeError(f"Unknown sampler: {sampler_val}")
 | 
			
		||||
            elif opt.label == "Checkpoint name":
 | 
			
		||||
                for ckpt_val in valslist:
 | 
			
		||||
                    if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
 | 
			
		||||
                        raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
 | 
			
		||||
            elif opt.label == "Hypernetwork":
 | 
			
		||||
                confirm_hypernetworks(valslist)
 | 
			
		||||
            if opt.confirm:
 | 
			
		||||
                opt.confirm(p, valslist)
 | 
			
		||||
 | 
			
		||||
            return valslist
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user