| 
									
										
										
										
											2023-12-25 14:43:51 +02:00
										 |  |  | from modules import modelloader, devices, errors | 
					
						
							| 
									
										
										
										
											2022-09-26 09:29:50 -05:00
										 |  |  | from modules.shared import opts | 
					
						
							| 
									
										
										
										
											2023-05-29 10:38:51 +03:00
										 |  |  | from modules.upscaler import Upscaler, UpscalerData | 
					
						
							| 
									
										
										
										
											2023-12-27 11:04:33 +02:00
										 |  |  | from modules.upscaler_utils import upscale_with_model | 
					
						
							| 
									
										
										
										
											2022-09-04 18:54:12 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | class UpscalerESRGAN(Upscaler): | 
					
						
							|  |  |  |     def __init__(self, dirname): | 
					
						
							|  |  |  |         self.name = "ESRGAN" | 
					
						
							| 
									
										
										
										
											2022-10-02 12:58:17 -05:00
										 |  |  |         self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth" | 
					
						
							|  |  |  |         self.model_name = "ESRGAN_4x" | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |         self.scalers = [] | 
					
						
							|  |  |  |         self.user_path = dirname | 
					
						
							|  |  |  |         super().__init__() | 
					
						
							|  |  |  |         model_paths = self.find_models(ext_filter=[".pt", ".pth"]) | 
					
						
							|  |  |  |         scalers = [] | 
					
						
							|  |  |  |         if len(model_paths) == 0: | 
					
						
							|  |  |  |             scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) | 
					
						
							|  |  |  |             scalers.append(scaler_data) | 
					
						
							|  |  |  |         for file in model_paths: | 
					
						
							| 
									
										
										
										
											2023-05-29 09:41:36 +03:00
										 |  |  |             if file.startswith("http"): | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |                 name = self.model_name | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 name = modelloader.friendly_name(file) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             scaler_data = UpscalerData(name, file, self, 4) | 
					
						
							|  |  |  |             self.scalers.append(scaler_data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def do_upscale(self, img, selected_model): | 
					
						
							| 
									
										
										
										
											2023-05-29 10:38:51 +03:00
										 |  |  |         try: | 
					
						
							|  |  |  |             model = self.load_model(selected_model) | 
					
						
							| 
									
										
										
										
											2023-12-25 14:43:51 +02:00
										 |  |  |         except Exception: | 
					
						
							|  |  |  |             errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True) | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |             return img | 
					
						
							| 
									
										
										
										
											2022-10-04 04:24:35 -04:00
										 |  |  |         model.to(devices.device_esrgan) | 
					
						
							| 
									
										
										
										
											2023-12-25 14:43:51 +02:00
										 |  |  |         return esrgan_upscale(model, img) | 
					
						
							| 
									
										
										
										
											2022-09-08 15:49:47 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |     def load_model(self, path: str): | 
					
						
							| 
									
										
										
										
											2023-05-29 09:41:36 +03:00
										 |  |  |         if path.startswith("http"): | 
					
						
							| 
									
										
										
										
											2023-05-29 09:45:07 +03:00
										 |  |  |             # TODO: this doesn't use `path` at all? | 
					
						
							| 
									
										
										
										
											2023-05-29 09:34:26 +03:00
										 |  |  |             filename = modelloader.load_file_from_url( | 
					
						
							| 
									
										
										
										
											2023-05-09 22:17:58 +03:00
										 |  |  |                 url=self.model_url, | 
					
						
							| 
									
										
										
										
											2023-05-19 09:09:00 +03:00
										 |  |  |                 model_dir=self.model_download_path, | 
					
						
							| 
									
										
										
										
											2023-05-09 22:17:58 +03:00
										 |  |  |                 file_name=f"{self.model_name}.pth", | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2022-09-04 18:54:12 +03:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  |             filename = path | 
					
						
							| 
									
										
										
										
											2022-09-30 11:42:40 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-25 14:43:51 +02:00
										 |  |  |         return modelloader.load_spandrel_model( | 
					
						
							|  |  |  |             filename, | 
					
						
							|  |  |  |             device=('cpu' if devices.device_esrgan.type == 'mps' else None), | 
					
						
							| 
									
										
										
										
											2023-12-30 16:37:03 +02:00
										 |  |  |             expected_architecture='ESRGAN', | 
					
						
							| 
									
										
										
										
											2023-12-25 14:43:51 +02:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2022-09-29 17:46:23 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 18:54:12 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | def esrgan_upscale(model, img): | 
					
						
							| 
									
										
										
										
											2023-12-27 11:04:33 +02:00
										 |  |  |     return upscale_with_model( | 
					
						
							|  |  |  |         model, | 
					
						
							|  |  |  |         img, | 
					
						
							|  |  |  |         tile_size=opts.ESRGAN_tile, | 
					
						
							|  |  |  |         tile_overlap=opts.ESRGAN_tile_overlap, | 
					
						
							|  |  |  |     ) |