mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-31 01:54:44 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			43 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			43 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 | |
| from transformers import AutoFeatureExtractor
 | |
| from PIL import Image
 | |
| 
 | |
| import modules.shared as shared
 | |
| 
 | |
| safety_model_id = "CompVis/stable-diffusion-safety-checker"
 | |
| safety_feature_extractor = None
 | |
| safety_checker = None
 | |
| 
 | |
| def numpy_to_pil(images):
 | |
|     """
 | |
|     Convert a numpy image or a batch of images to a PIL image.
 | |
|     """
 | |
|     if images.ndim == 3:
 | |
|         images = images[None, ...]
 | |
|     images = (images * 255).round().astype("uint8")
 | |
|     pil_images = [Image.fromarray(image) for image in images]
 | |
| 
 | |
|     return pil_images
 | |
| 
 | |
| # check and replace nsfw content
 | |
| def check_safety(x_image):
 | |
|     global safety_feature_extractor, safety_checker
 | |
| 
 | |
|     if safety_feature_extractor is None:
 | |
|         safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
 | |
|         safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
 | |
| 
 | |
|     safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
 | |
|     x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
 | |
| 
 | |
|     return x_checked_image, has_nsfw_concept
 | |
| 
 | |
| 
 | |
| def censor_batch(x):
 | |
|     x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
 | |
|     x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
 | |
|     x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
 | |
| 
 | |
|     return x
 | 
