| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | import glob | 
					
						
							|  |  |  | import os.path | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | from collections import namedtuple | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | from omegaconf import OmegaConf | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from ldm.util import instantiate_from_config | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  | from modules import shared, modelloader | 
					
						
							|  |  |  | from modules.paths import models_path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | model_dir = "Stable-diffusion" | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  | model_path = os.path.abspath(os.path.join(models_path, model_dir)) | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  | model_name = "sd-v1-4.ckpt" | 
					
						
							|  |  |  | model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" | 
					
						
							| 
									
										
										
										
											2022-09-30 12:15:29 +03:00
										 |  |  | user_dir = None | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-27 21:08:07 -04:00
										 |  |  | CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | checkpoints_list = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | try: | 
					
						
							|  |  |  |     # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     from transformers import logging | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     logging.set_verbosity_error() | 
					
						
							|  |  |  | except Exception: | 
					
						
							|  |  |  |     pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  | def setup_model(dirname): | 
					
						
							| 
									
										
										
										
											2022-09-29 19:59:36 -05:00
										 |  |  |     global user_dir | 
					
						
							|  |  |  |     user_dir = dirname | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  |     if not os.path.exists(model_path): | 
					
						
							|  |  |  |         os.makedirs(model_path) | 
					
						
							|  |  |  |     checkpoints_list.clear() | 
					
						
							| 
									
										
										
										
											2022-09-29 19:59:36 -05:00
										 |  |  |     list_models() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 00:59:44 +03:00
										 |  |  | def checkpoint_tiles(): | 
					
						
							|  |  |  |     return sorted([x.title for x in checkpoints_list.values()]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | def list_models(): | 
					
						
							|  |  |  |     checkpoints_list.clear() | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |     model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |     def modeltitle(path, shorthash): | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |         abspath = os.path.abspath(path) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         if user_dir is not None and abspath.startswith(user_dir): | 
					
						
							|  |  |  |             name = abspath.replace(user_dir, '') | 
					
						
							|  |  |  |         elif abspath.startswith(model_path): | 
					
						
							|  |  |  |             name = abspath.replace(model_path, '') | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |         else: | 
					
						
							|  |  |  |             name = os.path.basename(path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if name.startswith("\\") or name.startswith("/"): | 
					
						
							|  |  |  |             name = name[1:] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 00:59:44 +03:00
										 |  |  |         shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         return f'{name} [{shorthash}]', shortname | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     cmd_ckpt = shared.cmd_opts.ckpt | 
					
						
							|  |  |  |     if os.path.exists(cmd_ckpt): | 
					
						
							|  |  |  |         h = model_hash(cmd_ckpt) | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         title, short_model_name = modeltitle(cmd_ckpt, h) | 
					
						
							|  |  |  |         checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: | 
					
						
							| 
									
										
										
										
											2022-09-27 11:01:13 -05:00
										 |  |  |         print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) | 
					
						
							|  |  |  |     for filename in model_list: | 
					
						
							|  |  |  |         h = model_hash(filename) | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         title, short_model_name = modeltitle(filename, h) | 
					
						
							|  |  |  |         checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-28 22:30:09 +01:00
										 |  |  | def get_closet_checkpoint_match(searchString): | 
					
						
							| 
									
										
										
										
											2022-09-29 19:08:03 +01:00
										 |  |  |     applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |     if len(applicable) > 0: | 
					
						
							| 
									
										
										
										
											2022-09-28 22:30:09 +01:00
										 |  |  |         return applicable[0] | 
					
						
							|  |  |  |     return None | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | def model_hash(filename): | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         with open(filename, "rb") as file: | 
					
						
							|  |  |  |             import hashlib | 
					
						
							|  |  |  |             m = hashlib.sha256() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             file.seek(0x100000) | 
					
						
							|  |  |  |             m.update(file.read(0x10000)) | 
					
						
							|  |  |  |             return m.hexdigest()[0:8] | 
					
						
							|  |  |  |     except FileNotFoundError: | 
					
						
							|  |  |  |         return 'NOFILE' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def select_checkpoint(): | 
					
						
							|  |  |  |     model_checkpoint = shared.opts.sd_model_checkpoint | 
					
						
							|  |  |  |     checkpoint_info = checkpoints_list.get(model_checkpoint, None) | 
					
						
							|  |  |  |     if checkpoint_info is not None: | 
					
						
							|  |  |  |         return checkpoint_info | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if len(checkpoints_list) == 0: | 
					
						
							| 
									
										
										
										
											2022-09-18 23:52:01 +03:00
										 |  |  |         print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) | 
					
						
							|  |  |  |         print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  |         print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) | 
					
						
							| 
									
										
										
										
											2022-09-18 23:52:01 +03:00
										 |  |  |         print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) | 
					
						
							|  |  |  |         exit(1) | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     checkpoint_info = next(iter(checkpoints_list.values())) | 
					
						
							|  |  |  |     if model_checkpoint is not None: | 
					
						
							|  |  |  |         print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return checkpoint_info | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def load_model_weights(model, checkpoint_file, sd_model_hash): | 
					
						
							|  |  |  |     print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     pl_sd = torch.load(checkpoint_file, map_location="cpu") | 
					
						
							|  |  |  |     if "global_step" in pl_sd: | 
					
						
							|  |  |  |         print(f"Global Step: {pl_sd['global_step']}") | 
					
						
							|  |  |  |     sd = pl_sd["state_dict"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model.load_state_dict(sd, strict=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if shared.cmd_opts.opt_channelslast: | 
					
						
							|  |  |  |         model.to(memory_format=torch.channels_last) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if not shared.cmd_opts.no_half: | 
					
						
							|  |  |  |         model.half() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     model.sd_model_hash = sd_model_hash | 
					
						
							|  |  |  |     model.sd_model_checkpint = checkpoint_file | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def load_model(): | 
					
						
							|  |  |  |     from modules import lowvram, sd_hijack | 
					
						
							|  |  |  |     checkpoint_info = select_checkpoint() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sd_config = OmegaConf.load(shared.cmd_opts.config) | 
					
						
							|  |  |  |     sd_model = instantiate_from_config(sd_config.model) | 
					
						
							|  |  |  |     load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: | 
					
						
							|  |  |  |         lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         sd_model.to(shared.device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sd_hijack.model_hijack.hijack(sd_model) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     sd_model.eval() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print(f"Model loaded.") | 
					
						
							|  |  |  |     return sd_model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 13:49:36 +03:00
										 |  |  | def reload_model_weights(sd_model, info=None): | 
					
						
							| 
									
										
										
										
											2022-09-29 15:40:28 +03:00
										 |  |  |     from modules import lowvram, devices, sd_hijack | 
					
						
							| 
									
										
										
										
											2022-09-17 13:49:36 +03:00
										 |  |  |     checkpoint_info = info or select_checkpoint() | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if sd_model.sd_model_checkpint == checkpoint_info.filename: | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: | 
					
						
							|  |  |  |         lowvram.send_everything_to_cpu() | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         sd_model.to(devices.cpu) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 15:40:28 +03:00
										 |  |  |     sd_hijack.model_hijack.undo_hijack(sd_model) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 15:40:28 +03:00
										 |  |  |     sd_hijack.model_hijack.hijack(sd_model) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-17 12:05:04 +03:00
										 |  |  |     if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: | 
					
						
							|  |  |  |         sd_model.to(devices.device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print(f"Weights loaded.") | 
					
						
							|  |  |  |     return sd_model |