mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 03:55:05 +00:00 
			
		
		
		
	
		
			
	
	
		
			148 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			148 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import re
							 | 
						||
| 
								 | 
							
								from collections import defaultdict
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from modules import errors
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								extra_network_registry = {}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def initialize():
							 | 
						||
| 
								 | 
							
								    extra_network_registry.clear()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def register_extra_network(extra_network):
							 | 
						||
| 
								 | 
							
								    extra_network_registry[extra_network.name] = extra_network
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class ExtraNetworkParams:
							 | 
						||
| 
								 | 
							
								    def __init__(self, items=None):
							 | 
						||
| 
								 | 
							
								        self.items = items or []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class ExtraNetwork:
							 | 
						||
| 
								 | 
							
								    def __init__(self, name):
							 | 
						||
| 
								 | 
							
								        self.name = name
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def activate(self, p, params_list):
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Called by processing on every run. Whatever the extra network is meant to do should be activated here.
							 | 
						||
| 
								 | 
							
								        Passes arguments related to this extra network in params_list.
							 | 
						||
| 
								 | 
							
								        User passes arguments by specifying this in his prompt:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        <name:arg1:arg2:arg3>
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments
							 | 
						||
| 
								 | 
							
								        separated by colon.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list -
							 | 
						||
| 
								 | 
							
								        in this case, all effects of this extra networks should be disabled.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        Can be called multiple times before deactivate() - each new call should override the previous call completely.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        > "1girl, <hypernet:agm:1.1> <extrasupernet:master:12:13:14> <hypernet:ray>"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        params_list will be:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        [
							 | 
						||
| 
								 | 
							
								            ExtraNetworkParams(items=["agm", "1.1"]),
							 | 
						||
| 
								 | 
							
								            ExtraNetworkParams(items=["ray"])
							 | 
						||
| 
								 | 
							
								        ]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def deactivate(self, p):
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								        Called at the end of processing for housekeeping. No need to do anything here.
							 | 
						||
| 
								 | 
							
								        """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        raise NotImplementedError
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def activate(p, extra_network_data):
							 | 
						||
| 
								 | 
							
								    """call activate for extra networks in extra_network_data in specified order, then call
							 | 
						||
| 
								 | 
							
								    activate for all remaining registered networks with an empty argument list"""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for extra_network_name, extra_network_args in extra_network_data.items():
							 | 
						||
| 
								 | 
							
								        extra_network = extra_network_registry.get(extra_network_name, None)
							 | 
						||
| 
								 | 
							
								        if extra_network is None:
							 | 
						||
| 
								 | 
							
								            print(f"Skipping unknown extra network: {extra_network_name}")
							 | 
						||
| 
								 | 
							
								            continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            extra_network.activate(p, extra_network_args)
							 | 
						||
| 
								 | 
							
								        except Exception as e:
							 | 
						||
| 
								 | 
							
								            errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for extra_network_name, extra_network in extra_network_registry.items():
							 | 
						||
| 
								 | 
							
								        args = extra_network_data.get(extra_network_name, None)
							 | 
						||
| 
								 | 
							
								        if args is not None:
							 | 
						||
| 
								 | 
							
								            continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            extra_network.activate(p, [])
							 | 
						||
| 
								 | 
							
								        except Exception as e:
							 | 
						||
| 
								 | 
							
								            errors.display(e, f"activating extra network {extra_network_name}")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def deactivate(p, extra_network_data):
							 | 
						||
| 
								 | 
							
								    """call deactivate for extra networks in extra_network_data in specified order, then call
							 | 
						||
| 
								 | 
							
								    deactivate for all remaining registered networks"""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for extra_network_name, extra_network_args in extra_network_data.items():
							 | 
						||
| 
								 | 
							
								        extra_network = extra_network_registry.get(extra_network_name, None)
							 | 
						||
| 
								 | 
							
								        if extra_network is None:
							 | 
						||
| 
								 | 
							
								            continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            extra_network.deactivate(p)
							 | 
						||
| 
								 | 
							
								        except Exception as e:
							 | 
						||
| 
								 | 
							
								            errors.display(e, f"deactivating extra network {extra_network_name}")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for extra_network_name, extra_network in extra_network_registry.items():
							 | 
						||
| 
								 | 
							
								        args = extra_network_data.get(extra_network_name, None)
							 | 
						||
| 
								 | 
							
								        if args is not None:
							 | 
						||
| 
								 | 
							
								            continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        try:
							 | 
						||
| 
								 | 
							
								            extra_network.deactivate(p)
							 | 
						||
| 
								 | 
							
								        except Exception as e:
							 | 
						||
| 
								 | 
							
								            errors.display(e, f"deactivating unmentioned extra network {extra_network_name}")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								re_extra_net = re.compile(r"<(\w+):([^>]+)>")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def parse_prompt(prompt):
							 | 
						||
| 
								 | 
							
								    res = defaultdict(list)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def found(m):
							 | 
						||
| 
								 | 
							
								        name = m.group(1)
							 | 
						||
| 
								 | 
							
								        args = m.group(2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        res[name].append(ExtraNetworkParams(items=args.split(":")))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return ""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    prompt = re.sub(re_extra_net, found, prompt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return prompt, res
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def parse_prompts(prompts):
							 | 
						||
| 
								 | 
							
								    res = []
							 | 
						||
| 
								 | 
							
								    extra_data = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for prompt in prompts:
							 | 
						||
| 
								 | 
							
								        updated_prompt, parsed_extra_data = parse_prompt(prompt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if extra_data is None:
							 | 
						||
| 
								 | 
							
								            extra_data = parsed_extra_data
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        res.append(updated_prompt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return res, extra_data
							 | 
						||
| 
								 | 
							
								
							 |