| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  | import torch | 
					
						
							| 
									
										
										
										
											2023-08-22 18:49:08 +03:00
										 |  |  | from modules import devices, shared | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | module_in_gpu = None | 
					
						
							|  |  |  | cpu = torch.device("cpu") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-12 11:55:27 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | def send_everything_to_cpu(): | 
					
						
							|  |  |  |     global module_in_gpu | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if module_in_gpu is not None: | 
					
						
							|  |  |  |         module_in_gpu.to(cpu) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     module_in_gpu = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-22 18:49:08 +03:00
										 |  |  | def is_needed(sd_model): | 
					
						
							|  |  |  |     return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def apply(sd_model): | 
					
						
							|  |  |  |     enable = is_needed(sd_model) | 
					
						
							|  |  |  |     shared.parallel_processing_allowed = not enable | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if enable: | 
					
						
							|  |  |  |         setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         sd_model.lowvram = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  | def setup_for_low_vram(sd_model, use_medvram): | 
					
						
							| 
									
										
										
										
											2023-08-01 00:24:48 +03:00
										 |  |  |     if getattr(sd_model, 'lowvram', False): | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-04 13:07:22 +03:00
										 |  |  |     sd_model.lowvram = True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  |     parents = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def send_me_to_gpu(module, _): | 
					
						
							|  |  |  |         """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
 | 
					
						
							|  |  |  |         we add this as forward_pre_hook to a lot of modules and this way all but one of them will | 
					
						
							|  |  |  |         be in CPU | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         global module_in_gpu | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         module = parents.get(module, module) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if module_in_gpu == module: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if module_in_gpu is not None: | 
					
						
							|  |  |  |             module_in_gpu.to(cpu) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-22 14:04:14 +03:00
										 |  |  |         module.to(devices.device) | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  |         module_in_gpu = module | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # see below for register_forward_pre_hook; | 
					
						
							|  |  |  |     # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is | 
					
						
							|  |  |  |     # useless here, and we just replace those methods | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-01 04:01:49 -03:00
										 |  |  |     first_stage_model = sd_model.first_stage_model | 
					
						
							|  |  |  |     first_stage_model_encode = sd_model.first_stage_model.encode | 
					
						
							|  |  |  |     first_stage_model_decode = sd_model.first_stage_model.decode | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def first_stage_model_encode_wrap(x): | 
					
						
							|  |  |  |         send_me_to_gpu(first_stage_model, None) | 
					
						
							|  |  |  |         return first_stage_model_encode(x) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def first_stage_model_decode_wrap(z): | 
					
						
							|  |  |  |         send_me_to_gpu(first_stage_model, None) | 
					
						
							|  |  |  |         return first_stage_model_decode(z) | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  |     to_remain_in_cpu = [ | 
					
						
							|  |  |  |         (sd_model, 'first_stage_model'), | 
					
						
							|  |  |  |         (sd_model, 'depth_model'), | 
					
						
							|  |  |  |         (sd_model, 'embedder'), | 
					
						
							|  |  |  |         (sd_model, 'model'), | 
					
						
							|  |  |  |         (sd_model, 'embedder'), | 
					
						
							|  |  |  |     ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     is_sdxl = hasattr(sd_model, 'conditioner') | 
					
						
							|  |  |  |     is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if is_sdxl: | 
					
						
							|  |  |  |         to_remain_in_cpu.append((sd_model, 'conditioner')) | 
					
						
							|  |  |  |     elif is_sd2: | 
					
						
							|  |  |  |         to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer')) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model | 
					
						
							|  |  |  |     stored = [] | 
					
						
							|  |  |  |     for obj, field in to_remain_in_cpu: | 
					
						
							|  |  |  |         module = getattr(obj, field, None) | 
					
						
							|  |  |  |         stored.append(module) | 
					
						
							|  |  |  |         setattr(obj, field, None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # send the model to GPU. | 
					
						
							| 
									
										
										
										
											2022-10-22 14:04:14 +03:00
										 |  |  |     sd_model.to(devices.device) | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # put modules back. the modules will be in CPU. | 
					
						
							|  |  |  |     for (obj, field), module in zip(to_remain_in_cpu, stored): | 
					
						
							|  |  |  |         setattr(obj, field, module) | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-10 11:02:47 -05:00
										 |  |  |     # register hooks for those the first three models | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  |     if is_sdxl: | 
					
						
							|  |  |  |         sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							|  |  |  |     elif is_sd2: | 
					
						
							|  |  |  |         sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							| 
									
										
										
										
											2023-07-24 11:57:59 +03:00
										 |  |  |         sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							|  |  |  |         parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model | 
					
						
							|  |  |  |         parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  |     else: | 
					
						
							|  |  |  |         sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							| 
									
										
										
										
											2023-07-24 11:57:59 +03:00
										 |  |  |         parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model | 
					
						
							| 
									
										
										
										
											2023-07-12 23:52:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  |     sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							| 
									
										
										
										
											2022-11-01 04:01:49 -03:00
										 |  |  |     sd_model.first_stage_model.encode = first_stage_model_encode_wrap | 
					
						
							|  |  |  |     sd_model.first_stage_model.decode = first_stage_model_decode_wrap | 
					
						
							| 
									
										
										
										
											2022-12-10 11:02:47 -05:00
										 |  |  |     if sd_model.depth_model: | 
					
						
							|  |  |  |         sd_model.depth_model.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							| 
									
										
										
										
											2023-03-24 22:48:16 -04:00
										 |  |  |     if sd_model.embedder: | 
					
						
							|  |  |  |         sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							| 
									
										
										
										
											2023-07-14 09:56:01 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  |     if use_medvram: | 
					
						
							|  |  |  |         sd_model.model.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         diff_model = sd_model.model.diffusion_model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # the third remaining model is still too big for 4 GB, so we also do the same for its submodules | 
					
						
							|  |  |  |         # so that only one of them is in GPU at a time | 
					
						
							|  |  |  |         stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed | 
					
						
							|  |  |  |         diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None | 
					
						
							| 
									
										
										
										
											2022-10-22 14:04:14 +03:00
										 |  |  |         sd_model.model.to(devices.device) | 
					
						
							| 
									
										
										
										
											2022-09-03 12:08:45 +03:00
										 |  |  |         diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # install hooks for bits of third model | 
					
						
							|  |  |  |         diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							|  |  |  |         for block in diff_model.input_blocks: | 
					
						
							|  |  |  |             block.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							|  |  |  |         diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							|  |  |  |         for block in diff_model.output_blocks: | 
					
						
							|  |  |  |             block.register_forward_pre_hook(send_me_to_gpu) | 
					
						
							| 
									
										
										
										
											2023-06-04 13:07:22 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def is_enabled(sd_model): | 
					
						
							| 
									
										
										
										
											2023-08-22 18:49:08 +03:00
										 |  |  |     return sd_model.lowvram |