| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  | from __future__ import annotations | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import sgm.models.diffusion | 
					
						
							|  |  |  | import sgm.modules.diffusionmodules.denoiser_scaling | 
					
						
							|  |  |  | import sgm.modules.diffusionmodules.discretizer | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  | from modules import devices, shared, prompt_parser | 
					
						
							| 
									
										
										
										
											2023-12-31 22:38:30 +03:00
										 |  |  | from modules import torch_utils | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  | def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  |     for embedder in self.conditioner.embedders: | 
					
						
							|  |  |  |         embedder.ucg_rate = 0.0 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-09 17:12:54 +08:00
										 |  |  |     width = getattr(batch, 'width', 1024) or 1024 | 
					
						
							|  |  |  |     height = getattr(batch, 'height', 1024) or 1024 | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  |     is_negative_prompt = getattr(batch, 'is_negative_prompt', False) | 
					
						
							|  |  |  |     aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     devices_args = dict(device=devices.device, dtype=devices.dtype) | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     sdxl_conds = { | 
					
						
							|  |  |  |         "txt": batch, | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  |         "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), | 
					
						
							|  |  |  |         "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), | 
					
						
							|  |  |  |         "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), | 
					
						
							|  |  |  |         "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  |     force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) | 
					
						
							| 
									
										
										
										
											2023-07-13 11:35:52 +03:00
										 |  |  |     c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     return c | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): | 
					
						
							| 
									
										
										
										
											2023-12-21 20:15:51 +08:00
										 |  |  |     sd = self.model.state_dict() | 
					
						
							|  |  |  |     diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None) | 
					
						
							| 
									
										
										
										
											2023-12-27 10:20:56 +08:00
										 |  |  |     if diffusion_model_input is not None: | 
					
						
							|  |  |  |         if diffusion_model_input.shape[1] == 9: | 
					
						
							|  |  |  |             x = torch.cat([x] + cond['c_concat'], dim=1) | 
					
						
							| 
									
										
										
										
											2023-12-21 20:15:51 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  |     return self.model(x, t, cond) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 16:18:39 +03:00
										 |  |  | def get_first_stage_encoding(self, x):  # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility | 
					
						
							|  |  |  |     return x | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning | 
					
						
							|  |  |  | sgm.models.diffusion.DiffusionEngine.apply_model = apply_model | 
					
						
							|  |  |  | sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): | 
					
						
							|  |  |  |     res = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: | 
					
						
							|  |  |  |         encoded = embedder.encode_embedding_init_text(init_text, nvpt) | 
					
						
							|  |  |  |         res.append(encoded) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return torch.cat(res, dim=1) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-29 15:15:06 +03:00
										 |  |  | def tokenize(self: sgm.modules.GeneralConditioner, texts): | 
					
						
							|  |  |  |     for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: | 
					
						
							|  |  |  |         return embedder.tokenize(texts) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     raise AssertionError('no tokenizer available') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  | def process_texts(self, texts): | 
					
						
							|  |  |  |     for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: | 
					
						
							|  |  |  |         return embedder.process_texts(texts) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def get_target_prompt_token_count(self, token_count): | 
					
						
							|  |  |  |     for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: | 
					
						
							|  |  |  |         return embedder.get_target_prompt_token_count(token_count) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist | 
					
						
							|  |  |  | sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text | 
					
						
							| 
									
										
										
										
											2023-07-29 15:15:06 +03:00
										 |  |  | sgm.modules.GeneralConditioner.tokenize = tokenize | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  | sgm.modules.GeneralConditioner.process_texts = process_texts | 
					
						
							|  |  |  | sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  | def extend_sdxl(model): | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  |     """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-31 22:38:30 +03:00
										 |  |  |     dtype = torch_utils.get_param(model.model.diffusion_model).dtype | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  |     model.model.diffusion_model.dtype = dtype | 
					
						
							|  |  |  |     model.model.conditioning_key = 'crossattn' | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  |     model.cond_stage_key = 'txt' | 
					
						
							|  |  |  |     # model.cond_stage_model will be set in sd_hijack | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() | 
					
						
							| 
									
										
										
										
											2023-10-25 12:54:28 +08:00
										 |  |  |     model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-14 09:16:01 +03:00
										 |  |  |     model.conditioner.wrapped = torch.nn.Module() | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-11 21:16:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-01 00:24:48 +03:00
										 |  |  | sgm.modules.attention.print = shared.ldm_print | 
					
						
							|  |  |  | sgm.modules.diffusionmodules.model.print = shared.ldm_print | 
					
						
							|  |  |  | sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print | 
					
						
							|  |  |  | sgm.modules.encoders.modules.print = shared.ldm_print | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 09:30:33 +03:00
										 |  |  | # this gets the code to load the vanilla attention that we override | 
					
						
							|  |  |  | sgm.modules.attention.SDP_IS_AVAILABLE = True | 
					
						
							| 
									
										
										
										
											2023-07-13 09:38:54 +03:00
										 |  |  | sgm.modules.attention.XFORMERS_IS_AVAILABLE = False |