| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  | import collections | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | import os.path | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | from collections import namedtuple | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | from omegaconf import OmegaConf | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from ldm.util import instantiate_from_config | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-22 12:23:45 +03:00
										 |  |  | from modules import shared, modelloader, devices, script_callbacks | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  | from modules.paths import models_path | 
					
						
							| 
									
										
										
										
											2022-10-19 13:47:45 -07:00
										 |  |  | from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | model_dir = "Stable-diffusion" | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  | model_path = os.path.abspath(os.path.join(models_path, model_dir)) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  | CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | checkpoints_list = {} | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  | checkpoints_loaded = collections.OrderedDict() | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-16 17:53:56 +02:00
										 |  |  |     from transformers import logging, CLIPModel | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     logging.set_verbosity_error() | 
					
						
							|  |  |  | except Exception: | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 21:09:10 +03:00
										 |  |  | def setup_model(): | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  |     if not os.path.exists(model_path): | 
					
						
							|  |  |  |         os.makedirs(model_path) | 
					
						
							| 
									
										
										
										
											2022-10-02 21:09:10 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 19:59:36 -05:00
										 |  |  |     list_models() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 00:59:44 +03:00
										 |  |  | def checkpoint_tiles(): | 
					
						
							|  |  |  |     return sorted([x.title for x in checkpoints_list.values()]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | def list_models(): | 
					
						
							|  |  |  |     checkpoints_list.clear() | 
					
						
							| 
									
										
										
										
											2022-10-02 21:09:10 +03:00
										 |  |  |     model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |     def modeltitle(path, shorthash): | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |         abspath = os.path.abspath(path) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 21:22:20 +03:00
										 |  |  |         if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): | 
					
						
							|  |  |  |             name = abspath.replace(shared.cmd_opts.ckpt_dir, '') | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         elif abspath.startswith(model_path): | 
					
						
							|  |  |  |             name = abspath.replace(model_path, '') | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |         else: | 
					
						
							|  |  |  |             name = os.path.basename(path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if name.startswith("\\") or name.startswith("/"): | 
					
						
							|  |  |  |             name = name[1:] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 00:59:44 +03:00
										 |  |  |         shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         return f'{name} [{shorthash}]', shortname | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     cmd_ckpt = shared.cmd_opts.ckpt | 
					
						
							|  |  |  |     if os.path.exists(cmd_ckpt): | 
					
						
							|  |  |  |         h = model_hash(cmd_ckpt) | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         title, short_model_name = modeltitle(cmd_ckpt, h) | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  |         checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) | 
					
						
							| 
									
										
										
										
											2022-10-02 17:24:50 +03:00
										 |  |  |         shared.opts.data['sd_model_checkpoint'] = title | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  |         print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) | 
					
						
							|  |  |  |     for filename in model_list: | 
					
						
							|  |  |  |         h = model_hash(filename) | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         title, short_model_name = modeltitle(filename, h) | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         basename, _ = os.path.splitext(filename) | 
					
						
							|  |  |  |         config = basename + ".yaml" | 
					
						
							|  |  |  |         if not os.path.exists(config): | 
					
						
							|  |  |  |             config = shared.cmd_opts.config | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-28 22:30:09 +01:00
										 |  |  | def get_closet_checkpoint_match(searchString): | 
					
						
							| 
									
										
										
										
											2022-09-29 19:08:03 +01:00
										 |  |  |     applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |     if len(applicable) > 0: | 
					
						
							| 
									
										
										
										
											2022-09-28 22:30:09 +01:00
										 |  |  |         return applicable[0] | 
					
						
							|  |  |  |     return None | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | def model_hash(filename): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         with open(filename, "rb") as file: | 
					
						
							|  |  |  |             import hashlib | 
					
						
							|  |  |  |             m = hashlib.sha256() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             file.seek(0x100000) | 
					
						
							|  |  |  |             m.update(file.read(0x10000)) | 
					
						
							|  |  |  |             return m.hexdigest()[0:8] | 
					
						
							|  |  |  |     except FileNotFoundError: | 
					
						
							|  |  |  |         return 'NOFILE' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def select_checkpoint(): | 
					
						
							|  |  |  |     model_checkpoint = shared.opts.sd_model_checkpoint | 
					
						
							|  |  |  |     checkpoint_info = checkpoints_list.get(model_checkpoint, None) | 
					
						
							|  |  |  |     if checkpoint_info is not None: | 
					
						
							|  |  |  |         return checkpoint_info | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if len(checkpoints_list) == 0: | 
					
						
							| 
									
										
										
										
											2022-09-18 23:52:01 +03:00
										 |  |  |         print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) | 
					
						
							| 
									
										
										
										
											2022-10-02 21:09:10 +03:00
										 |  |  |         if shared.cmd_opts.ckpt is not None: | 
					
						
							|  |  |  |             print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) | 
					
						
							|  |  |  |         print(f" - directory {model_path}", file=sys.stderr) | 
					
						
							|  |  |  |         if shared.cmd_opts.ckpt_dir is not None: | 
					
						
							|  |  |  |             print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) | 
					
						
							| 
									
										
										
										
											2022-09-18 23:52:01 +03:00
										 |  |  |         print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) | 
					
						
							|  |  |  |         exit(1) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     checkpoint_info = next(iter(checkpoints_list.values())) | 
					
						
							|  |  |  |     if model_checkpoint is not None: | 
					
						
							|  |  |  |         print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return checkpoint_info | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-19 08:42:22 +03:00
										 |  |  | chckpoint_dict_replacements = { | 
					
						
							|  |  |  |     'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', | 
					
						
							|  |  |  |     'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', | 
					
						
							|  |  |  |     'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def transform_checkpoint_dict_key(k): | 
					
						
							|  |  |  |     for text, replacement in chckpoint_dict_replacements.items(): | 
					
						
							|  |  |  |         if k.startswith(text): | 
					
						
							|  |  |  |             k = replacement + k[len(text):] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return k | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-09 10:23:31 +03:00
										 |  |  | def get_state_dict_from_checkpoint(pl_sd): | 
					
						
							|  |  |  |     if "state_dict" in pl_sd: | 
					
						
							| 
									
										
										
										
											2022-10-19 08:42:22 +03:00
										 |  |  |         pl_sd = pl_sd["state_dict"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sd = {} | 
					
						
							|  |  |  |     for k, v in pl_sd.items(): | 
					
						
							|  |  |  |         new_key = transform_checkpoint_dict_key(k) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if new_key is not None: | 
					
						
							|  |  |  |             sd[new_key] = v | 
					
						
							| 
									
										
										
										
											2022-10-09 10:23:31 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-19 12:45:30 +03:00
										 |  |  |     pl_sd.clear() | 
					
						
							|  |  |  |     pl_sd.update(sd) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return pl_sd | 
					
						
							| 
									
										
										
										
											2022-10-09 10:23:31 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-21 17:35:51 +03:00
										 |  |  | vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  | def load_model_weights(model, checkpoint_info): | 
					
						
							|  |  |  |     checkpoint_file = checkpoint_info.filename | 
					
						
							|  |  |  |     sd_model_hash = checkpoint_info.hash | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |     if checkpoint_info not in checkpoints_loaded: | 
					
						
							|  |  |  |         print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-15 10:35:18 +03:00
										 |  |  |         pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         if "global_step" in pl_sd: | 
					
						
							|  |  |  |             print(f"Global Step: {pl_sd['global_step']}") | 
					
						
							| 
									
										
										
										
											2022-10-09 10:23:31 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         sd = get_state_dict_from_checkpoint(pl_sd) | 
					
						
							| 
									
										
										
										
											2022-10-19 08:42:22 +03:00
										 |  |  |         missing, extra = model.load_state_dict(sd, strict=False) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         if shared.cmd_opts.opt_channelslast: | 
					
						
							|  |  |  |             model.to(memory_format=torch.channels_last) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         if not shared.cmd_opts.no_half: | 
					
						
							|  |  |  |             model.half() | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 | 
					
						
							|  |  |  |         devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: | 
					
						
							|  |  |  |             vae_file = shared.cmd_opts.vae_path | 
					
						
							| 
									
										
										
										
											2022-10-10 20:46:55 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         if os.path.exists(vae_file): | 
					
						
							|  |  |  |             print(f"Loading VAE weights from: {vae_file}") | 
					
						
							| 
									
										
										
										
											2022-10-15 10:35:18 +03:00
										 |  |  |             vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) | 
					
						
							| 
									
										
										
										
											2022-10-21 17:35:51 +03:00
										 |  |  |             vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |             model.first_stage_model.load_state_dict(vae_dict) | 
					
						
							| 
									
										
										
										
											2022-10-07 10:40:22 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         model.first_stage_model.to(devices.dtype_vae) | 
					
						
							| 
									
										
										
										
											2022-10-07 10:40:22 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         checkpoints_loaded[checkpoint_info] = model.state_dict().copy() | 
					
						
							|  |  |  |         while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: | 
					
						
							|  |  |  |             checkpoints_loaded.popitem(last=False)  # LRU | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         print(f"Loading weights [{sd_model_hash}] from cache") | 
					
						
							|  |  |  |         checkpoints_loaded.move_to_end(checkpoint_info) | 
					
						
							|  |  |  |         model.load_state_dict(checkpoints_loaded[checkpoint_info]) | 
					
						
							| 
									
										
										
										
											2022-10-10 16:11:14 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     model.sd_model_hash = sd_model_hash | 
					
						
							| 
									
										
										
										
											2022-10-08 15:12:24 -04:00
										 |  |  |     model.sd_model_checkpoint = checkpoint_file | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  |     model.sd_checkpoint_info = checkpoint_info | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 16:01:27 -07:00
										 |  |  | def load_model(checkpoint_info=None): | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     from modules import lowvram, sd_hijack | 
					
						
							| 
									
										
										
										
											2022-10-20 16:01:27 -07:00
										 |  |  |     checkpoint_info = checkpoint_info or select_checkpoint() | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  |     if checkpoint_info.config != shared.cmd_opts.config: | 
					
						
							| 
									
										
										
										
											2022-10-09 10:31:47 +03:00
										 |  |  |         print(f"Loading config from: {checkpoint_info.config}") | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     sd_config = OmegaConf.load(checkpoint_info.config) | 
					
						
							| 
									
										
										
										
											2022-10-19 13:47:45 -07:00
										 |  |  |      | 
					
						
							|  |  |  |     if should_hijack_inpainting(checkpoint_info): | 
					
						
							|  |  |  |         # Hardcoded config for now... | 
					
						
							|  |  |  |         sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" | 
					
						
							|  |  |  |         sd_config.model.params.use_ema = False | 
					
						
							|  |  |  |         sd_config.model.params.conditioning_key = "hybrid" | 
					
						
							|  |  |  |         sd_config.model.params.unet_config.params.in_channels = 9 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Create a "fake" config with a different name so that we know to unload it when switching models. | 
					
						
							|  |  |  |         checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 13:28:43 -07:00
										 |  |  |     do_inpainting_hijack() | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     sd_model = instantiate_from_config(sd_config.model) | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  |     load_model_weights(sd_model, checkpoint_info) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: | 
					
						
							|  |  |  |         lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         sd_model.to(shared.device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sd_hijack.model_hijack.hijack(sd_model) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sd_model.eval() | 
					
						
							| 
									
										
										
										
											2022-10-22 12:23:45 +03:00
										 |  |  |     shared.sd_model = sd_model | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-22 20:15:12 +03:00
										 |  |  |     script_callbacks.model_loaded_callback(sd_model) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     print(f"Model loaded.") | 
					
						
							|  |  |  |     return sd_model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 13:49:36 +03:00
										 |  |  | def reload_model_weights(sd_model, info=None): | 
					
						
							| 
									
										
										
										
											2022-09-29 15:40:28 +03:00
										 |  |  |     from modules import lowvram, devices, sd_hijack | 
					
						
							| 
									
										
										
										
											2022-09-17 13:49:36 +03:00
										 |  |  |     checkpoint_info = info or select_checkpoint() | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 15:12:24 -04:00
										 |  |  |     if sd_model.sd_model_checkpoint == checkpoint_info.filename: | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-19 13:47:45 -07:00
										 |  |  |     if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): | 
					
						
							| 
									
										
										
										
											2022-10-13 23:00:38 -06:00
										 |  |  |         checkpoints_loaded.clear() | 
					
						
							| 
									
										
										
										
											2022-10-22 12:23:45 +03:00
										 |  |  |         load_model(checkpoint_info) | 
					
						
							| 
									
										
										
										
											2022-10-09 13:23:30 +03:00
										 |  |  |         return shared.sd_model | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: | 
					
						
							|  |  |  |         lowvram.send_everything_to_cpu() | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         sd_model.to(devices.cpu) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 15:40:28 +03:00
										 |  |  |     sd_hijack.model_hijack.undo_hijack(sd_model) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-08 23:26:48 +03:00
										 |  |  |     load_model_weights(sd_model, checkpoint_info) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 15:40:28 +03:00
										 |  |  |     sd_hijack.model_hijack.hijack(sd_model) | 
					
						
							| 
									
										
										
										
											2022-10-22 12:59:21 -04:00
										 |  |  |     script_callbacks.model_loaded_callback(sd_model) | 
					
						
							| 
									
										
										
										
											2022-09-29 15:40:28 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: | 
					
						
							|  |  |  |         sd_model.to(devices.device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print(f"Weights loaded.") | 
					
						
							|  |  |  |     return sd_model |