mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-31 01:54:44 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			198 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			198 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| Supports saving and restoring webui and extensions from a known working set of commits
 | |
| """
 | |
| 
 | |
| import os
 | |
| import json
 | |
| import time
 | |
| import tqdm
 | |
| 
 | |
| from datetime import datetime
 | |
| from collections import OrderedDict
 | |
| import git
 | |
| 
 | |
| from modules import shared, extensions, errors
 | |
| from modules.paths_internal import script_path, config_states_dir
 | |
| 
 | |
| 
 | |
| all_config_states = OrderedDict()
 | |
| 
 | |
| 
 | |
| def list_config_states():
 | |
|     global all_config_states
 | |
| 
 | |
|     all_config_states.clear()
 | |
|     os.makedirs(config_states_dir, exist_ok=True)
 | |
| 
 | |
|     config_states = []
 | |
|     for filename in os.listdir(config_states_dir):
 | |
|         if filename.endswith(".json"):
 | |
|             path = os.path.join(config_states_dir, filename)
 | |
|             with open(path, "r", encoding="utf-8") as f:
 | |
|                 j = json.load(f)
 | |
|                 j["filepath"] = path
 | |
|                 config_states.append(j)
 | |
| 
 | |
|     config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)
 | |
| 
 | |
|     for cs in config_states:
 | |
|         timestamp = time.asctime(time.gmtime(cs["created_at"]))
 | |
|         name = cs.get("name", "Config")
 | |
|         full_name = f"{name}: {timestamp}"
 | |
|         all_config_states[full_name] = cs
 | |
| 
 | |
|     return all_config_states
 | |
| 
 | |
| 
 | |
| def get_webui_config():
 | |
|     webui_repo = None
 | |
| 
 | |
|     try:
 | |
|         if os.path.exists(os.path.join(script_path, ".git")):
 | |
|             webui_repo = git.Repo(script_path)
 | |
|     except Exception:
 | |
|         errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
 | |
| 
 | |
|     webui_remote = None
 | |
|     webui_commit_hash = None
 | |
|     webui_commit_date = None
 | |
|     webui_branch = None
 | |
|     if webui_repo and not webui_repo.bare:
 | |
|         try:
 | |
|             webui_remote = next(webui_repo.remote().urls, None)
 | |
|             head = webui_repo.head.commit
 | |
|             webui_commit_date = webui_repo.head.commit.committed_date
 | |
|             webui_commit_hash = head.hexsha
 | |
|             webui_branch = webui_repo.active_branch.name
 | |
| 
 | |
|         except Exception:
 | |
|             webui_remote = None
 | |
| 
 | |
|     return {
 | |
|         "remote": webui_remote,
 | |
|         "commit_hash": webui_commit_hash,
 | |
|         "commit_date": webui_commit_date,
 | |
|         "branch": webui_branch,
 | |
|     }
 | |
| 
 | |
| 
 | |
| def get_extension_config():
 | |
|     ext_config = {}
 | |
| 
 | |
|     for ext in extensions.extensions:
 | |
|         ext.read_info_from_repo()
 | |
| 
 | |
|         entry = {
 | |
|             "name": ext.name,
 | |
|             "path": ext.path,
 | |
|             "enabled": ext.enabled,
 | |
|             "is_builtin": ext.is_builtin,
 | |
|             "remote": ext.remote,
 | |
|             "commit_hash": ext.commit_hash,
 | |
|             "commit_date": ext.commit_date,
 | |
|             "branch": ext.branch,
 | |
|             "have_info_from_repo": ext.have_info_from_repo
 | |
|         }
 | |
| 
 | |
|         ext_config[ext.name] = entry
 | |
| 
 | |
|     return ext_config
 | |
| 
 | |
| 
 | |
| def get_config():
 | |
|     creation_time = datetime.now().timestamp()
 | |
|     webui_config = get_webui_config()
 | |
|     ext_config = get_extension_config()
 | |
| 
 | |
|     return {
 | |
|         "created_at": creation_time,
 | |
|         "webui": webui_config,
 | |
|         "extensions": ext_config
 | |
|     }
 | |
| 
 | |
| 
 | |
| def restore_webui_config(config):
 | |
|     print("* Restoring webui state...")
 | |
| 
 | |
|     if "webui" not in config:
 | |
|         print("Error: No webui data saved to config")
 | |
|         return
 | |
| 
 | |
|     webui_config = config["webui"]
 | |
| 
 | |
|     if "commit_hash" not in webui_config:
 | |
|         print("Error: No commit saved to webui config")
 | |
|         return
 | |
| 
 | |
|     webui_commit_hash = webui_config.get("commit_hash", None)
 | |
|     webui_repo = None
 | |
| 
 | |
|     try:
 | |
|         if os.path.exists(os.path.join(script_path, ".git")):
 | |
|             webui_repo = git.Repo(script_path)
 | |
|     except Exception:
 | |
|         errors.report(f"Error reading webui git info from {script_path}", exc_info=True)
 | |
|         return
 | |
| 
 | |
|     try:
 | |
|         webui_repo.git.fetch(all=True)
 | |
|         webui_repo.git.reset(webui_commit_hash, hard=True)
 | |
|         print(f"* Restored webui to commit {webui_commit_hash}.")
 | |
|     except Exception:
 | |
|         errors.report(f"Error restoring webui to commit{webui_commit_hash}")
 | |
| 
 | |
| 
 | |
| def restore_extension_config(config):
 | |
|     print("* Restoring extension state...")
 | |
| 
 | |
|     if "extensions" not in config:
 | |
|         print("Error: No extension data saved to config")
 | |
|         return
 | |
| 
 | |
|     ext_config = config["extensions"]
 | |
| 
 | |
|     results = []
 | |
|     disabled = []
 | |
| 
 | |
|     for ext in tqdm.tqdm(extensions.extensions):
 | |
|         if ext.is_builtin:
 | |
|             continue
 | |
| 
 | |
|         ext.read_info_from_repo()
 | |
|         current_commit = ext.commit_hash
 | |
| 
 | |
|         if ext.name not in ext_config:
 | |
|             ext.disabled = True
 | |
|             disabled.append(ext.name)
 | |
|             results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled"))
 | |
|             continue
 | |
| 
 | |
|         entry = ext_config[ext.name]
 | |
| 
 | |
|         if "commit_hash" in entry and entry["commit_hash"]:
 | |
|             try:
 | |
|                 ext.fetch_and_reset_hard(entry["commit_hash"])
 | |
|                 ext.read_info_from_repo()
 | |
|                 if current_commit != entry["commit_hash"]:
 | |
|                     results.append((ext, current_commit[:8], True, entry["commit_hash"][:8]))
 | |
|             except Exception as ex:
 | |
|                 results.append((ext, current_commit[:8], False, ex))
 | |
|         else:
 | |
|             results.append((ext, current_commit[:8], False, "No commit hash found in config"))
 | |
| 
 | |
|         if not entry.get("enabled", False):
 | |
|             ext.disabled = True
 | |
|             disabled.append(ext.name)
 | |
|         else:
 | |
|             ext.disabled = False
 | |
| 
 | |
|     shared.opts.disabled_extensions = disabled
 | |
|     shared.opts.save(shared.config_filename)
 | |
| 
 | |
|     print("* Finished restoring extensions. Results:")
 | |
|     for ext, prev_commit, success, result in results:
 | |
|         if success:
 | |
|             print(f"  + {ext.name}: {prev_commit} -> {result}")
 | |
|         else:
 | |
|             print(f"  ! {ext.name}: FAILURE ({result})")
 | 
