mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 12:03:36 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			133 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			133 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from __future__ import annotations
 | 
						|
 | 
						|
import functools
 | 
						|
import logging
 | 
						|
from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, sd_samplers_lcm, shared, sd_samplers_common, sd_schedulers
 | 
						|
 | 
						|
# imports for functions that previously were here and are used by other modules
 | 
						|
samples_to_image_grid = sd_samplers_common.samples_to_image_grid
 | 
						|
sample_to_image = sd_samplers_common.sample_to_image
 | 
						|
 | 
						|
all_samplers = [
 | 
						|
    *sd_samplers_kdiffusion.samplers_data_k_diffusion,
 | 
						|
    *sd_samplers_timesteps.samplers_data_timesteps,
 | 
						|
    *sd_samplers_lcm.samplers_data_lcm,
 | 
						|
]
 | 
						|
all_samplers_map = {x.name: x for x in all_samplers}
 | 
						|
 | 
						|
samplers: list[sd_samplers_common.SamplerData] = []
 | 
						|
samplers_for_img2img: list[sd_samplers_common.SamplerData] = []
 | 
						|
samplers_map = {}
 | 
						|
samplers_hidden = {}
 | 
						|
 | 
						|
 | 
						|
def find_sampler_config(name):
 | 
						|
    if name is not None:
 | 
						|
        config = all_samplers_map.get(name, None)
 | 
						|
    else:
 | 
						|
        config = all_samplers[0]
 | 
						|
 | 
						|
    return config
 | 
						|
 | 
						|
 | 
						|
def create_sampler(name, model):
 | 
						|
    config = find_sampler_config(name)
 | 
						|
 | 
						|
    assert config is not None, f'bad sampler name: {name}'
 | 
						|
 | 
						|
    if model.is_sdxl and config.options.get("no_sdxl", False):
 | 
						|
        raise Exception(f"Sampler {config.name} is not supported for SDXL")
 | 
						|
 | 
						|
    sampler = config.constructor(model)
 | 
						|
    sampler.config = config
 | 
						|
 | 
						|
    return sampler
 | 
						|
 | 
						|
 | 
						|
def set_samplers():
 | 
						|
    global samplers, samplers_for_img2img, samplers_hidden
 | 
						|
 | 
						|
    samplers_hidden = set(shared.opts.hide_samplers)
 | 
						|
    samplers = all_samplers
 | 
						|
    samplers_for_img2img = all_samplers
 | 
						|
 | 
						|
    samplers_map.clear()
 | 
						|
    for sampler in all_samplers:
 | 
						|
        samplers_map[sampler.name.lower()] = sampler.name
 | 
						|
        for alias in sampler.aliases:
 | 
						|
            samplers_map[alias.lower()] = sampler.name
 | 
						|
 | 
						|
 | 
						|
def visible_sampler_names():
 | 
						|
    return [x.name for x in samplers if x.name not in samplers_hidden]
 | 
						|
 | 
						|
 | 
						|
def visible_samplers():
 | 
						|
    return [x for x in samplers if x.name not in samplers_hidden]
 | 
						|
 | 
						|
 | 
						|
def get_sampler_from_infotext(d: dict):
 | 
						|
    return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[0]
 | 
						|
 | 
						|
 | 
						|
def get_scheduler_from_infotext(d: dict):
 | 
						|
    return get_sampler_and_scheduler(d.get("Sampler"), d.get("Schedule type"))[1]
 | 
						|
 | 
						|
 | 
						|
def get_hr_sampler_and_scheduler(d: dict):
 | 
						|
    hr_sampler = d.get("Hires sampler", "Use same sampler")
 | 
						|
    sampler = d.get("Sampler") if hr_sampler == "Use same sampler" else hr_sampler
 | 
						|
 | 
						|
    hr_scheduler = d.get("Hires schedule type", "Use same scheduler")
 | 
						|
    scheduler = d.get("Schedule type") if hr_scheduler == "Use same scheduler" else hr_scheduler
 | 
						|
 | 
						|
    sampler, scheduler = get_sampler_and_scheduler(sampler, scheduler)
 | 
						|
 | 
						|
    sampler = sampler if sampler != d.get("Sampler") else "Use same sampler"
 | 
						|
    scheduler = scheduler if scheduler != d.get("Schedule type") else "Use same scheduler"
 | 
						|
 | 
						|
    return sampler, scheduler
 | 
						|
 | 
						|
 | 
						|
def get_hr_sampler_from_infotext(d: dict):
 | 
						|
    return get_hr_sampler_and_scheduler(d)[0]
 | 
						|
 | 
						|
 | 
						|
def get_hr_scheduler_from_infotext(d: dict):
 | 
						|
    return get_hr_sampler_and_scheduler(d)[1]
 | 
						|
 | 
						|
 | 
						|
@functools.cache
 | 
						|
def get_sampler_and_scheduler(sampler_name, scheduler_name):
 | 
						|
    default_sampler = samplers[0]
 | 
						|
    found_scheduler = sd_schedulers.schedulers_map.get(scheduler_name, sd_schedulers.schedulers[0])
 | 
						|
 | 
						|
    name = sampler_name or default_sampler.name
 | 
						|
 | 
						|
    for scheduler in sd_schedulers.schedulers:
 | 
						|
        name_options = [scheduler.label, scheduler.name, *(scheduler.aliases or [])]
 | 
						|
 | 
						|
        for name_option in name_options:
 | 
						|
            if name.endswith(" " + name_option):
 | 
						|
                found_scheduler = scheduler
 | 
						|
                name = name[0:-(len(name_option) + 1)]
 | 
						|
                break
 | 
						|
 | 
						|
    sampler = all_samplers_map.get(name, default_sampler)
 | 
						|
 | 
						|
    # revert back to Automatic if it's the default scheduler for the selected sampler
 | 
						|
    if sampler.options.get('scheduler', None) == found_scheduler.name:
 | 
						|
        found_scheduler = sd_schedulers.schedulers[0]
 | 
						|
 | 
						|
    return sampler.name, found_scheduler.label
 | 
						|
 | 
						|
 | 
						|
def fix_p_invalid_sampler_and_scheduler(p):
 | 
						|
    i_sampler_name, i_scheduler = p.sampler_name, p.scheduler
 | 
						|
    p.sampler_name, p.scheduler = get_sampler_and_scheduler(p.sampler_name, p.scheduler)
 | 
						|
    if p.sampler_name != i_sampler_name or i_scheduler != p.scheduler:
 | 
						|
        logging.warning(f'Sampler Scheduler autocorrection: "{i_sampler_name}" -> "{p.sampler_name}", "{i_scheduler}" -> "{p.scheduler}"')
 | 
						|
 | 
						|
 | 
						|
set_samplers()
 |