| 
									
										
										
										
											2022-12-10 18:57:18 +00:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | import gc | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import torchvision | 
					
						
							|  |  |  | from PIL import Image | 
					
						
							|  |  |  | from einops import rearrange, repeat | 
					
						
							|  |  |  | from omegaconf import OmegaConf | 
					
						
							| 
									
										
										
										
											2022-12-10 18:57:18 +00:00
										 |  |  | import safetensors.torch | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | from ldm.models.diffusion.ddim import DDIMSampler | 
					
						
							|  |  |  | from ldm.util import instantiate_from_config, ismap | 
					
						
							| 
									
										
										
										
											2022-12-10 13:54:29 +00:00
										 |  |  | from modules import shared, sd_hijack | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-10 13:54:29 +00:00
										 |  |  | cached_ldsr_model: torch.nn.Module = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  | # Create LDSR Class | 
					
						
							|  |  |  | class LDSR: | 
					
						
							|  |  |  |     def load_model_from_config(self, half_attention): | 
					
						
							| 
									
										
										
										
											2022-12-10 13:54:29 +00:00
										 |  |  |         global cached_ldsr_model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if shared.opts.ldsr_cached and cached_ldsr_model is not None: | 
					
						
							| 
									
										
										
										
											2022-12-24 21:35:29 +02:00
										 |  |  |             print("Loading model from cache") | 
					
						
							| 
									
										
										
										
											2022-12-10 13:54:29 +00:00
										 |  |  |             model: torch.nn.Module = cached_ldsr_model | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             print(f"Loading model from {self.modelPath}") | 
					
						
							| 
									
										
										
										
											2022-12-10 18:57:18 +00:00
										 |  |  |             _, extension = os.path.splitext(self.modelPath) | 
					
						
							|  |  |  |             if extension.lower() == ".safetensors": | 
					
						
							|  |  |  |                 pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu") | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 pl_sd = torch.load(self.modelPath, map_location="cpu") | 
					
						
							|  |  |  |             sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd | 
					
						
							| 
									
										
										
										
											2022-12-10 13:54:29 +00:00
										 |  |  |             config = OmegaConf.load(self.yamlPath) | 
					
						
							|  |  |  |             config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" | 
					
						
							|  |  |  |             model: torch.nn.Module = instantiate_from_config(config.model) | 
					
						
							|  |  |  |             model.load_state_dict(sd, strict=False) | 
					
						
							|  |  |  |             model = model.to(shared.device) | 
					
						
							|  |  |  |             if half_attention: | 
					
						
							|  |  |  |                 model = model.half() | 
					
						
							|  |  |  |             if shared.cmd_opts.opt_channelslast: | 
					
						
							|  |  |  |                 model = model.to(memory_format=torch.channels_last) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             sd_hijack.model_hijack.hijack(model) # apply optimization | 
					
						
							|  |  |  |             model.eval() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if shared.opts.ldsr_cached: | 
					
						
							|  |  |  |                 cached_ldsr_model = model | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |         return {"model": model} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, model_path, yaml_path): | 
					
						
							|  |  |  |         self.modelPath = model_path | 
					
						
							|  |  |  |         self.yamlPath = yaml_path | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @staticmethod | 
					
						
							|  |  |  |     def run(model, selected_path, custom_steps, eta): | 
					
						
							|  |  |  |         example = get_cond(selected_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         n_runs = 1 | 
					
						
							|  |  |  |         guider = None | 
					
						
							|  |  |  |         ckwargs = None | 
					
						
							|  |  |  |         ddim_use_x0_pred = False | 
					
						
							|  |  |  |         temperature = 1. | 
					
						
							|  |  |  |         eta = eta | 
					
						
							|  |  |  |         custom_shape = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         height, width = example["image"].shape[1:3] | 
					
						
							|  |  |  |         split_input = height >= 128 and width >= 128 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if split_input: | 
					
						
							|  |  |  |             ks = 128 | 
					
						
							|  |  |  |             stride = 64 | 
					
						
							|  |  |  |             vqf = 4  # | 
					
						
							|  |  |  |             model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), | 
					
						
							|  |  |  |                                         "vqf": vqf, | 
					
						
							|  |  |  |                                         "patch_distributed_vq": True, | 
					
						
							|  |  |  |                                         "tie_braker": False, | 
					
						
							|  |  |  |                                         "clip_max_weight": 0.5, | 
					
						
							|  |  |  |                                         "clip_min_weight": 0.01, | 
					
						
							|  |  |  |                                         "clip_max_tie_weight": 0.5, | 
					
						
							|  |  |  |                                         "clip_min_tie_weight": 0.01} | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             if hasattr(model, "split_input_params"): | 
					
						
							|  |  |  |                 delattr(model, "split_input_params") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         x_t = None | 
					
						
							|  |  |  |         logs = None | 
					
						
							|  |  |  |         for n in range(n_runs): | 
					
						
							|  |  |  |             if custom_shape is not None: | 
					
						
							|  |  |  |                 x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) | 
					
						
							|  |  |  |                 x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             logs = make_convolutional_sample(example, model, | 
					
						
							|  |  |  |                                              custom_steps=custom_steps, | 
					
						
							|  |  |  |                                              eta=eta, quantize_x0=False, | 
					
						
							|  |  |  |                                              custom_shape=custom_shape, | 
					
						
							|  |  |  |                                              temperature=temperature, noise_dropout=0., | 
					
						
							|  |  |  |                                              corrector=guider, corrector_kwargs=ckwargs, x_T=x_t, | 
					
						
							|  |  |  |                                              ddim_use_x0_pred=ddim_use_x0_pred | 
					
						
							|  |  |  |                                              ) | 
					
						
							|  |  |  |         return logs | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def super_resolution(self, image, steps=100, target_scale=2, half_attention=False): | 
					
						
							|  |  |  |         model = self.load_model_from_config(half_attention) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Run settings | 
					
						
							|  |  |  |         diffusion_steps = int(steps) | 
					
						
							|  |  |  |         eta = 1.0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         down_sample_method = 'Lanczos' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         gc.collect() | 
					
						
							| 
									
										
										
										
											2022-12-10 14:07:27 +00:00
										 |  |  |         if torch.cuda.is_available: | 
					
						
							|  |  |  |             torch.cuda.empty_cache() | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         im_og = image | 
					
						
							|  |  |  |         width_og, height_og = im_og.size | 
					
						
							|  |  |  |         # If we can adjust the max upscale size, then the 4 below should be our variable | 
					
						
							|  |  |  |         down_sample_rate = target_scale / 4 | 
					
						
							| 
									
										
										
										
											2022-09-29 19:59:53 -05:00
										 |  |  |         wd = width_og * down_sample_rate | 
					
						
							|  |  |  |         hd = height_og * down_sample_rate | 
					
						
							| 
									
										
										
										
											2022-11-06 11:08:20 +08:00
										 |  |  |         width_downsampled_pre = int(np.ceil(wd)) | 
					
						
							|  |  |  |         height_downsampled_pre = int(np.ceil(hd)) | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if down_sample_rate != 1: | 
					
						
							|  |  |  |             print( | 
					
						
							|  |  |  |                 f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') | 
					
						
							|  |  |  |             im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2022-09-30 08:55:04 -05:00
										 |  |  |             print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") | 
					
						
							| 
									
										
										
										
											2022-11-06 11:08:20 +08:00
										 |  |  |          | 
					
						
							|  |  |  |         # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts | 
					
						
							|  |  |  |         pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size | 
					
						
							|  |  |  |         im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) | 
					
						
							|  |  |  |          | 
					
						
							|  |  |  |         logs = self.run(model["model"], im_padded, diffusion_steps, eta) | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | 
 | 
					
						
							|  |  |  |         sample = logs["sample"] | 
					
						
							|  |  |  |         sample = sample.detach().cpu() | 
					
						
							|  |  |  |         sample = torch.clamp(sample, -1., 1.) | 
					
						
							|  |  |  |         sample = (sample + 1.) / 2. * 255 | 
					
						
							|  |  |  |         sample = sample.numpy().astype(np.uint8) | 
					
						
							|  |  |  |         sample = np.transpose(sample, (0, 2, 3, 1)) | 
					
						
							|  |  |  |         a = Image.fromarray(sample[0]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-06 11:08:20 +08:00
										 |  |  |         # remove padding | 
					
						
							|  |  |  |         a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |         del model | 
					
						
							|  |  |  |         gc.collect() | 
					
						
							| 
									
										
										
										
											2022-12-10 14:07:27 +00:00
										 |  |  |         if torch.cuda.is_available: | 
					
						
							|  |  |  |             torch.cuda.empty_cache() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |         return a | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_cond(selected_path): | 
					
						
							|  |  |  |     example = dict() | 
					
						
							|  |  |  |     up_f = 4 | 
					
						
							|  |  |  |     c = selected_path.convert('RGB') | 
					
						
							|  |  |  |     c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) | 
					
						
							|  |  |  |     c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], | 
					
						
							|  |  |  |                                                     antialias=True) | 
					
						
							|  |  |  |     c_up = rearrange(c_up, '1 c h w -> 1 h w c') | 
					
						
							|  |  |  |     c = rearrange(c, '1 c h w -> 1 h w c') | 
					
						
							|  |  |  |     c = 2. * c - 1. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-10 14:07:27 +00:00
										 |  |  |     c = c.to(shared.device) | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |     example["LR_image"] = c | 
					
						
							|  |  |  |     example["image"] = c_up | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return example | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @torch.no_grad() | 
					
						
							|  |  |  | def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, | 
					
						
							|  |  |  |                     mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None, | 
					
						
							|  |  |  |                     corrector_kwargs=None, x_t=None | 
					
						
							|  |  |  |                     ): | 
					
						
							|  |  |  |     ddim = DDIMSampler(model) | 
					
						
							|  |  |  |     bs = shape[0] | 
					
						
							|  |  |  |     shape = shape[1:] | 
					
						
							|  |  |  |     print(f"Sampling with eta = {eta}; steps: {steps}") | 
					
						
							|  |  |  |     samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, | 
					
						
							|  |  |  |                                          normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, | 
					
						
							|  |  |  |                                          mask=mask, x0=x0, temperature=temperature, verbose=False, | 
					
						
							|  |  |  |                                          score_corrector=score_corrector, | 
					
						
							|  |  |  |                                          corrector_kwargs=corrector_kwargs, x_t=x_t) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return samples, intermediates | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @torch.no_grad() | 
					
						
							|  |  |  | def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, | 
					
						
							|  |  |  |                               corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): | 
					
						
							|  |  |  |     log = dict() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, | 
					
						
							|  |  |  |                                         return_first_stage_outputs=True, | 
					
						
							|  |  |  |                                         force_c_encode=not (hasattr(model, 'split_input_params') | 
					
						
							|  |  |  |                                                             and model.cond_stage_key == 'coordinates_bbox'), | 
					
						
							|  |  |  |                                         return_original_cond=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if custom_shape is not None: | 
					
						
							|  |  |  |         z = torch.randn(custom_shape) | 
					
						
							|  |  |  |         print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     z0 = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log["input"] = x | 
					
						
							|  |  |  |     log["reconstruction"] = xrec | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if ismap(xc): | 
					
						
							|  |  |  |         log["original_conditioning"] = model.to_rgb(xc) | 
					
						
							|  |  |  |         if hasattr(model, 'cond_stage_key'): | 
					
						
							|  |  |  |             log[model.cond_stage_key] = model.to_rgb(xc) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) | 
					
						
							|  |  |  |         if model.cond_stage_model: | 
					
						
							|  |  |  |             log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) | 
					
						
							|  |  |  |             if model.cond_stage_key == 'class_label': | 
					
						
							|  |  |  |                 log[model.cond_stage_key] = xc[model.cond_stage_key] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with model.ema_scope("Plotting"): | 
					
						
							|  |  |  |         t0 = time.time() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, | 
					
						
							|  |  |  |                                                 eta=eta, | 
					
						
							|  |  |  |                                                 quantize_x0=quantize_x0, mask=None, x0=z0, | 
					
						
							|  |  |  |                                                 temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs, | 
					
						
							|  |  |  |                                                 x_t=x_T) | 
					
						
							|  |  |  |         t1 = time.time() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if ddim_use_x0_pred: | 
					
						
							|  |  |  |             sample = intermediates['pred_x0'][-1] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     x_sample = model.decode_first_stage(sample) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) | 
					
						
							|  |  |  |         log["sample_noquant"] = x_sample_noquant | 
					
						
							|  |  |  |         log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) | 
					
						
							|  |  |  |     except: | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     log["sample"] = x_sample | 
					
						
							|  |  |  |     log["time"] = t1 - t0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return log |