mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 03:55:05 +00:00 
			
		
		
		
	make existing script loading and new preload code use same code for loading modules
limit extension preload scripts to just one file named preload.py
This commit is contained in:
		
							parent
							
								
									e5690d0bf2
								
							
						
					
					
						commit
						a1a376331c
					
				@ -1,7 +1,6 @@
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
from importlib.machinery import SourceFileLoader
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
import git
 | 
					import git
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -85,23 +84,3 @@ def list_extensions():
 | 
				
			|||||||
        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
 | 
					        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
 | 
				
			||||||
        extensions.append(extension)
 | 
					        extensions.append(extension)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
def preload_extensions(parser):
 | 
					 | 
				
			||||||
    if not os.path.isdir(extensions_dir):
 | 
					 | 
				
			||||||
        return
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for dirname in sorted(os.listdir(extensions_dir)):
 | 
					 | 
				
			||||||
        path = os.path.join(extensions_dir, dirname)
 | 
					 | 
				
			||||||
        if not os.path.isdir(path):
 | 
					 | 
				
			||||||
            continue
 | 
					 | 
				
			||||||
        for file in os.listdir(path):
 | 
					 | 
				
			||||||
            if "preload.py" in file:
 | 
					 | 
				
			||||||
                full_file = os.path.join(path, file)
 | 
					 | 
				
			||||||
                print(f"Got preload file: {full_file}")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    ext = SourceFileLoader("preload", full_file).load_module()
 | 
					 | 
				
			||||||
                    parser = ext.preload(parser)
 | 
					 | 
				
			||||||
                except Exception as e:
 | 
					 | 
				
			||||||
                    print(f"Exception preloading script: {e}")
 | 
					 | 
				
			||||||
    return parser
 | 
					 | 
				
			||||||
							
								
								
									
										34
									
								
								modules/script_loading.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								modules/script_loading.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,34 @@
 | 
				
			|||||||
 | 
					import os
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					import traceback
 | 
				
			||||||
 | 
					from types import ModuleType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_module(path):
 | 
				
			||||||
 | 
					    with open(path, "r", encoding="utf8") as file:
 | 
				
			||||||
 | 
					        text = file.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    compiled = compile(text, path, 'exec')
 | 
				
			||||||
 | 
					    module = ModuleType(os.path.basename(path))
 | 
				
			||||||
 | 
					    exec(compiled, module.__dict__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def preload_extensions(extensions_dir, parser):
 | 
				
			||||||
 | 
					    if not os.path.isdir(extensions_dir):
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for dirname in sorted(os.listdir(extensions_dir)):
 | 
				
			||||||
 | 
					        preload_script = os.path.join(extensions_dir, dirname, "preload.py")
 | 
				
			||||||
 | 
					        if not os.path.isfile(preload_script):
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            module = load_module(preload_script)
 | 
				
			||||||
 | 
					            if hasattr(module, 'preload'):
 | 
				
			||||||
 | 
					                module.preload(parser)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except Exception:
 | 
				
			||||||
 | 
					            print(f"Error running preload() for {preload_script}", file=sys.stderr)
 | 
				
			||||||
 | 
					            print(traceback.format_exc(), file=sys.stderr)
 | 
				
			||||||
@ -6,7 +6,7 @@ from collections import namedtuple
 | 
				
			|||||||
import gradio as gr
 | 
					import gradio as gr
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from modules.processing import StableDiffusionProcessing
 | 
					from modules.processing import StableDiffusionProcessing
 | 
				
			||||||
from modules import shared, paths, script_callbacks, extensions
 | 
					from modules import shared, paths, script_callbacks, extensions, script_loading
 | 
				
			||||||
 | 
					
 | 
				
			||||||
AlwaysVisible = object()
 | 
					AlwaysVisible = object()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -161,13 +161,7 @@ def load_scripts():
 | 
				
			|||||||
                sys.path = [scriptfile.basedir] + sys.path
 | 
					                sys.path = [scriptfile.basedir] + sys.path
 | 
				
			||||||
            current_basedir = scriptfile.basedir
 | 
					            current_basedir = scriptfile.basedir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            with open(scriptfile.path, "r", encoding="utf8") as file:
 | 
					            module = script_loading.load_module(scriptfile.path)
 | 
				
			||||||
                text = file.read()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            from types import ModuleType
 | 
					 | 
				
			||||||
            compiled = compile(text, scriptfile.path, 'exec')
 | 
					 | 
				
			||||||
            module = ModuleType(scriptfile.filename)
 | 
					 | 
				
			||||||
            exec(compiled, module.__dict__)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for key, script_class in module.__dict__.items():
 | 
					            for key, script_class in module.__dict__.items():
 | 
				
			||||||
                if type(script_class) == type and issubclass(script_class, Script):
 | 
					                if type(script_class) == type and issubclass(script_class, Script):
 | 
				
			||||||
@ -328,27 +322,21 @@ class ScriptRunner:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def reload_sources(self, cache):
 | 
					    def reload_sources(self, cache):
 | 
				
			||||||
        for si, script in list(enumerate(self.scripts)):
 | 
					        for si, script in list(enumerate(self.scripts)):
 | 
				
			||||||
            with open(script.filename, "r", encoding="utf8") as file:
 | 
					            args_from = script.args_from
 | 
				
			||||||
                args_from = script.args_from
 | 
					            args_to = script.args_to
 | 
				
			||||||
                args_to = script.args_to
 | 
					            filename = script.filename
 | 
				
			||||||
                filename = script.filename
 | 
					 | 
				
			||||||
                text = file.read()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                from types import ModuleType
 | 
					            module = cache.get(filename, None)
 | 
				
			||||||
 | 
					            if module is None:
 | 
				
			||||||
 | 
					                module = script_loading.load_module(script.filename)
 | 
				
			||||||
 | 
					                cache[filename] = module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                module = cache.get(filename, None)
 | 
					            for key, script_class in module.__dict__.items():
 | 
				
			||||||
                if module is None:
 | 
					                if type(script_class) == type and issubclass(script_class, Script):
 | 
				
			||||||
                    compiled = compile(text, filename, 'exec')
 | 
					                    self.scripts[si] = script_class()
 | 
				
			||||||
                    module = ModuleType(script.filename)
 | 
					                    self.scripts[si].filename = filename
 | 
				
			||||||
                    exec(compiled, module.__dict__)
 | 
					                    self.scripts[si].args_from = args_from
 | 
				
			||||||
                    cache[filename] = module
 | 
					                    self.scripts[si].args_to = args_to
 | 
				
			||||||
 | 
					 | 
				
			||||||
                for key, script_class in module.__dict__.items():
 | 
					 | 
				
			||||||
                    if type(script_class) == type and issubclass(script_class, Script):
 | 
					 | 
				
			||||||
                        self.scripts[si] = script_class()
 | 
					 | 
				
			||||||
                        self.scripts[si].filename = filename
 | 
					 | 
				
			||||||
                        self.scripts[si].args_from = args_from
 | 
					 | 
				
			||||||
                        self.scripts[si].args_to = args_to
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
scripts_txt2img = ScriptRunner()
 | 
					scripts_txt2img = ScriptRunner()
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,6 @@ import datetime
 | 
				
			|||||||
import json
 | 
					import json
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from collections import OrderedDict
 | 
					 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import gradio as gr
 | 
					import gradio as gr
 | 
				
			||||||
@ -15,7 +14,7 @@ import modules.memmon
 | 
				
			|||||||
import modules.sd_models
 | 
					import modules.sd_models
 | 
				
			||||||
import modules.styles
 | 
					import modules.styles
 | 
				
			||||||
import modules.devices as devices
 | 
					import modules.devices as devices
 | 
				
			||||||
from modules import sd_samplers, sd_models, localization, sd_vae, extensions
 | 
					from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
 | 
				
			||||||
from modules.hypernetworks import hypernetwork
 | 
					from modules.hypernetworks import hypernetwork
 | 
				
			||||||
from modules.paths import models_path, script_path, sd_path
 | 
					from modules.paths import models_path, script_path, sd_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
 | 
				
			|||||||
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
 | 
					parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
 | 
				
			||||||
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
 | 
					parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
extensions.preload_extensions(parser)
 | 
					script_loading.preload_extensions(extensions.extensions_dir, parser)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cmd_opts = parser.parse_args()
 | 
					cmd_opts = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user