mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-31 10:03:40 +00:00 
			
		
		
		
	Merge pull request #14467 from akx/drop-basicsr
Drop basicsr dependency
This commit is contained in:
		
						commit
						16848f950b
					
				
							
								
								
									
										2
									
								
								.github/workflows/run_tests.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/run_tests.yaml
									
									
									
									
										vendored
									
									
								
							| @ -57,7 +57,7 @@ jobs: | ||||
|           2>&1 | tee output.txt & | ||||
|       - name: Run tests | ||||
|         run: | | ||||
|           wait-for-it --service 127.0.0.1:7860 -t 600 | ||||
|           wait-for-it --service 127.0.0.1:7860 -t 20 | ||||
|           python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test | ||||
|       - name: Kill test server | ||||
|         if: always() | ||||
|  | ||||
| @ -17,6 +17,28 @@ if TYPE_CHECKING: | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def bgr_image_to_rgb_tensor(img: np.ndarray) -> torch.Tensor: | ||||
|     """Convert a BGR NumPy image in [0..1] range to a PyTorch RGB float32 tensor.""" | ||||
|     assert img.shape[2] == 3, "image must be RGB" | ||||
|     if img.dtype == "float64": | ||||
|         img = img.astype("float32") | ||||
|     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||||
|     return torch.from_numpy(img.transpose(2, 0, 1)).float() | ||||
| 
 | ||||
| 
 | ||||
| def rgb_tensor_to_bgr_image(tensor: torch.Tensor, *, min_max=(0.0, 1.0)) -> np.ndarray: | ||||
|     """ | ||||
|     Convert a PyTorch RGB tensor in range `min_max` to a BGR NumPy image in [0..1] range. | ||||
|     """ | ||||
|     tensor = tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) | ||||
|     tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) | ||||
|     assert tensor.dim() == 3, "tensor must be RGB" | ||||
|     img_np = tensor.numpy().transpose(1, 2, 0) | ||||
|     if img_np.shape[2] == 1:  # gray image, no RGB/BGR required | ||||
|         return np.squeeze(img_np, axis=2) | ||||
|     return cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) | ||||
| 
 | ||||
| 
 | ||||
| def create_face_helper(device) -> FaceRestoreHelper: | ||||
|     from facexlib.detection import retinaface | ||||
|     from facexlib.utils.face_restoration_helper import FaceRestoreHelper | ||||
| @ -36,14 +58,13 @@ def create_face_helper(device) -> FaceRestoreHelper: | ||||
| def restore_with_face_helper( | ||||
|     np_image: np.ndarray, | ||||
|     face_helper: FaceRestoreHelper, | ||||
|     restore_face: Callable[[np.ndarray], np.ndarray], | ||||
|     restore_face: Callable[[torch.Tensor], torch.Tensor], | ||||
| ) -> np.ndarray: | ||||
|     """ | ||||
|     Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image. | ||||
| 
 | ||||
|     `restore_face` should take a cropped face image and return a restored face image. | ||||
|     """ | ||||
|     from basicsr.utils import img2tensor, tensor2img | ||||
|     from torchvision.transforms.functional import normalize | ||||
|     np_image = np_image[:, :, ::-1] | ||||
|     original_resolution = np_image.shape[0:2] | ||||
| @ -56,23 +77,19 @@ def restore_with_face_helper( | ||||
|         face_helper.align_warp_face() | ||||
|         logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces)) | ||||
|         for cropped_face in face_helper.cropped_faces: | ||||
|             cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) | ||||
|             cropped_face_t = bgr_image_to_rgb_tensor(cropped_face / 255.0) | ||||
|             normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) | ||||
|             cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) | ||||
| 
 | ||||
|             try: | ||||
|                 with torch.no_grad(): | ||||
|                     restored_face = tensor2img( | ||||
|                         restore_face(cropped_face_t), | ||||
|                         rgb2bgr=True, | ||||
|                         min_max=(-1, 1), | ||||
|                     ) | ||||
|                     cropped_face_t = restore_face(cropped_face_t) | ||||
|                 devices.torch_gc() | ||||
|             except Exception: | ||||
|                 errors.report('Failed face-restoration inference', exc_info=True) | ||||
|                 restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) | ||||
| 
 | ||||
|             restored_face = restored_face.astype('uint8') | ||||
|             restored_face = rgb_tensor_to_bgr_image(cropped_face_t, min_max=(-1, 1)) | ||||
|             restored_face = (restored_face * 255.0).astype('uint8') | ||||
|             face_helper.add_restored_face(restored_face) | ||||
| 
 | ||||
|         logger.debug("Merging restored faces into image") | ||||
| @ -126,7 +143,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration): | ||||
|     def restore_with_helper( | ||||
|         self, | ||||
|         np_image: np.ndarray, | ||||
|         restore_face: Callable[[np.ndarray], np.ndarray], | ||||
|         restore_face: Callable[[torch.Tensor], torch.Tensor], | ||||
|     ) -> np.ndarray: | ||||
|         try: | ||||
|             if self.net is None: | ||||
|  | ||||
| @ -11,7 +11,6 @@ import safetensors.torch | ||||
| 
 | ||||
| import numpy as np | ||||
| from PIL import Image, PngImagePlugin | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
| 
 | ||||
| from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes | ||||
| import modules.textual_inversion.dataset | ||||
| @ -344,6 +343,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): | ||||
|         }) | ||||
| 
 | ||||
| def tensorboard_setup(log_directory): | ||||
|     from torch.utils.tensorboard import SummaryWriter | ||||
|     os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) | ||||
|     return SummaryWriter( | ||||
|             log_dir=os.path.join(log_directory, "tensorboard"), | ||||
| @ -448,8 +448,12 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st | ||||
|     shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." | ||||
|     old_parallel_processing_allowed = shared.parallel_processing_allowed | ||||
| 
 | ||||
|     tensorboard_writer = None | ||||
|     if shared.opts.training_enable_tensorboard: | ||||
|         tensorboard_writer = tensorboard_setup(log_directory) | ||||
|         try: | ||||
|             tensorboard_writer = tensorboard_setup(log_directory) | ||||
|         except ImportError: | ||||
|             errors.report("Error initializing tensorboard", exc_info=True) | ||||
| 
 | ||||
|     pin_memory = shared.opts.pin_memory | ||||
| 
 | ||||
| @ -622,7 +626,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st | ||||
|                         last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) | ||||
|                         last_saved_image += f", prompt: {preview_text}" | ||||
| 
 | ||||
|                         if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: | ||||
|                         if tensorboard_writer and shared.opts.training_tensorboard_save_images: | ||||
|                             tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) | ||||
| 
 | ||||
|                     if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: | ||||
|  | ||||
| @ -2,7 +2,6 @@ GitPython | ||||
| Pillow | ||||
| accelerate | ||||
| 
 | ||||
| basicsr | ||||
| blendmodes | ||||
| clean-fid | ||||
| einops | ||||
|  | ||||
| @ -1,7 +1,6 @@ | ||||
| GitPython==3.1.32 | ||||
| Pillow==9.5.0 | ||||
| accelerate==0.21.0 | ||||
| basicsr==1.4.2 | ||||
| blendmodes==2022 | ||||
| clean-fid==0.1.35 | ||||
| einops==0.4.1 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 AUTOMATIC1111
						AUTOMATIC1111