mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-31 10:03:40 +00:00 
			
		
		
		
	Merge pull request #14467 from akx/drop-basicsr
Drop basicsr dependency
This commit is contained in:
		
						commit
						16848f950b
					
				
							
								
								
									
										2
									
								
								.github/workflows/run_tests.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/run_tests.yaml
									
									
									
									
										vendored
									
									
								
							| @ -57,7 +57,7 @@ jobs: | |||||||
|           2>&1 | tee output.txt & |           2>&1 | tee output.txt & | ||||||
|       - name: Run tests |       - name: Run tests | ||||||
|         run: | |         run: | | ||||||
|           wait-for-it --service 127.0.0.1:7860 -t 600 |           wait-for-it --service 127.0.0.1:7860 -t 20 | ||||||
|           python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test |           python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test | ||||||
|       - name: Kill test server |       - name: Kill test server | ||||||
|         if: always() |         if: always() | ||||||
|  | |||||||
| @ -17,6 +17,28 @@ if TYPE_CHECKING: | |||||||
| logger = logging.getLogger(__name__) | logger = logging.getLogger(__name__) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor: | ||||||
|  |     """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.""" | ||||||
|  |     assert img.shape[2] == 3, "image must be RGB" | ||||||
|  |     if img.dtype == "float64": | ||||||
|  |         img = img.astype("float32") | ||||||
|  |     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||||||
|  |     return torch.from_numpy(img.transpose(2, 0, 1)).float() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray: | ||||||
|  |     """ | ||||||
|  |     Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range. | ||||||
|  |     """ | ||||||
|  |     tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) | ||||||
|  |     tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) | ||||||
|  |     assert tensor.dim() == 3, "tensor must be RGB" | ||||||
|  |     img_np = tensor.numpy().transpose(1, 2, 0) | ||||||
|  |     if img_np.shape[2] == 1:  # gray image, no RGB/BGR required | ||||||
|  |         return np.squeeze(img_np, axis=2) | ||||||
|  |     return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def create_face_helper(device) -> FaceRestoreHelper: | def create_face_helper(device) -> FaceRestoreHelper: | ||||||
|     from facexlib.detection import retinaface |     from facexlib.detection import retinaface | ||||||
|     from facexlib.utils.face_restoration_helper import FaceRestoreHelper |     from facexlib.utils.face_restoration_helper import FaceRestoreHelper | ||||||
| @ -36,14 +58,13 @@ def create_face_helper(device) -> FaceRestoreHelper: | |||||||
| def restore_with_face_helper( | def restore_with_face_helper( | ||||||
|     np_image: np.ndarray, |     np_image: np.ndarray, | ||||||
|     face_helper: FaceRestoreHelper, |     face_helper: FaceRestoreHelper, | ||||||
|     restore_face: Callable[[np.ndarray], np.ndarray], |     restore_face: Callable[[torch.Tensor], torch.Tensor], | ||||||
| ) -> np.ndarray: | ) -> np.ndarray: | ||||||
|     """ |     """ | ||||||
|     Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. |     Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. | ||||||
| 
 | 
 | ||||||
|     `restore_face` should take a cropped face image and return a restored face image. |     `restore_face` should take a cropped face image and return a restored face image. | ||||||
|     """ |     """ | ||||||
|     from basicsr.utils import img2tensor, tensor2img |  | ||||||
|     from torchvision.transforms.functional import normalize |     from torchvision.transforms.functional import normalize | ||||||
|     np_image = np_image[:, :, ::-1] |     np_image = np_image[:, :, ::-1] | ||||||
|     original_resolution = np_image.shape[0:2] |     original_resolution = np_image.shape[0:2] | ||||||
| @ -56,23 +77,19 @@ def restore_with_face_helper( | |||||||
|         face_helper.align_warp_face() |         face_helper.align_warp_face() | ||||||
|         logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) |         logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) | ||||||
|         for cropped_face in face_helper.cropped_faces: |         for cropped_face in face_helper.cropped_faces: | ||||||
|             cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) |             cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0) | ||||||
|             normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) |             normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | ||||||
|             cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) |             cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) | ||||||
| 
 | 
 | ||||||
|             try: |             try: | ||||||
|                 with torch.no_grad(): |                 with torch.no_grad(): | ||||||
|                     restored_face = tensor2img( |                     cropped_face_t = restore_face(cropped_face_t) | ||||||
|                         restore_face(cropped_face_t), |  | ||||||
|                         rgb2bgr=True, |  | ||||||
|                         min_max=(-1, 1), |  | ||||||
|                     ) |  | ||||||
|                 devices.torch_gc() |                 devices.torch_gc() | ||||||
|             except Exception: |             except Exception: | ||||||
|                 errors.report('Failed face-restoration inference', exc_info=True) |                 errors.report('Failed face-restoration inference', exc_info=True) | ||||||
|                 restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) |  | ||||||
| 
 | 
 | ||||||
|             restored_face = restored_face.astype('uint8') |             restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1)) | ||||||
|  |             restored_face = (restored_face * 255.0).astype('uint8') | ||||||
|             face_helper.add_restored_face(restored_face) |             face_helper.add_restored_face(restored_face) | ||||||
| 
 | 
 | ||||||
|         logger.debug("Merging restored faces into image") |         logger.debug("Merging restored faces into image") | ||||||
| @ -126,7 +143,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration): | |||||||
|     def restore_with_helper( |     def restore_with_helper( | ||||||
|         self, |         self, | ||||||
|         np_image: np.ndarray, |         np_image: np.ndarray, | ||||||
|         restore_face: Callable[[np.ndarray], np.ndarray], |         restore_face: Callable[[torch.Tensor], torch.Tensor], | ||||||
|     ) -> np.ndarray: |     ) -> np.ndarray: | ||||||
|         try: |         try: | ||||||
|             if self.net is None: |             if self.net is None: | ||||||
|  | |||||||
| @ -11,7 +11,6 @@ import safetensors.torch | |||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| from PIL import Image, PngImagePlugin | from PIL import Image, PngImagePlugin | ||||||
| from torch.utils.tensorboard import SummaryWriter |  | ||||||
| 
 | 
 | ||||||
| from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes | from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes | ||||||
| import modules.textual_inversion.dataset | import modules.textual_inversion.dataset | ||||||
| @ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): | |||||||
|         }) |         }) | ||||||
| 
 | 
 | ||||||
| def tensorboard_setup(log_directory): | def tensorboard_setup(log_directory): | ||||||
|  |     from torch.utils.tensorboard import SummaryWriter | ||||||
|     os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) |     os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) | ||||||
|     return SummaryWriter( |     return SummaryWriter( | ||||||
|             log_dir=os.path.join(log_directory, "tensorboard"), |             log_dir=os.path.join(log_directory, "tensorboard"), | ||||||
| @ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st | |||||||
|     shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." |     shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." | ||||||
|     old_parallel_processing_allowed = shared.parallel_processing_allowed |     old_parallel_processing_allowed = shared.parallel_processing_allowed | ||||||
| 
 | 
 | ||||||
|  |     tensorboard_writer = None | ||||||
|     if shared.opts.training_enable_tensorboard: |     if shared.opts.training_enable_tensorboard: | ||||||
|         tensorboard_writer = tensorboard_setup(log_directory) |         try: | ||||||
|  |             tensorboard_writer = tensorboard_setup(log_directory) | ||||||
|  |         except ImportError: | ||||||
|  |             errors.report("Error initializing tensorboard", exc_info=True) | ||||||
| 
 | 
 | ||||||
|     pin_memory = shared.opts.pin_memory |     pin_memory = shared.opts.pin_memory | ||||||
| 
 | 
 | ||||||
| @ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st | |||||||
|                         last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) |                         last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) | ||||||
|                         last_saved_image += f", prompt: {preview_text}" |                         last_saved_image += f", prompt: {preview_text}" | ||||||
| 
 | 
 | ||||||
|                         if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: |                         if tensorboard_writer and shared.opts.training_tensorboard_save_images: | ||||||
|                             tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) |                             tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) | ||||||
| 
 | 
 | ||||||
|                     if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: |                     if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: | ||||||
|  | |||||||
| @ -2,7 +2,6 @@ GitPython | |||||||
| Pillow | Pillow | ||||||
| accelerate | accelerate | ||||||
| 
 | 
 | ||||||
| basicsr |  | ||||||
| blendmodes | blendmodes | ||||||
| clean-fid | clean-fid | ||||||
| einops | einops | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| GitPython==3.1.32 | GitPython==3.1.32 | ||||||
| Pillow==9.5.0 | Pillow==9.5.0 | ||||||
| accelerate==0.21.0 | accelerate==0.21.0 | ||||||
| basicsr==1.4.2 |  | ||||||
| blendmodes==2022 | blendmodes==2022 | ||||||
| clean-fid==0.1.35 | clean-fid==0.1.35 | ||||||
| einops==0.4.1 | einops==0.4.1 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 AUTOMATIC1111
						AUTOMATIC1111