mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 12:03:36 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			101 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			101 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import tempfile
 | 
						|
from collections import namedtuple
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
import gradio.components
 | 
						|
 | 
						|
from PIL import PngImagePlugin
 | 
						|
 | 
						|
from modules import shared
 | 
						|
 | 
						|
 | 
						|
Savedfile = namedtuple("Savedfile", ["name"])
 | 
						|
 | 
						|
 | 
						|
def register_tmp_file(gradio, filename):
 | 
						|
    if hasattr(gradio, 'temp_file_sets'):  # gradio 3.15
 | 
						|
        gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)}
 | 
						|
 | 
						|
    if hasattr(gradio, 'temp_dirs'):  # gradio 3.9
 | 
						|
        gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))}
 | 
						|
 | 
						|
 | 
						|
def check_tmp_file(gradio, filename):
 | 
						|
    if hasattr(gradio, 'temp_file_sets'):
 | 
						|
        return any(filename in fileset for fileset in gradio.temp_file_sets)
 | 
						|
 | 
						|
    if hasattr(gradio, 'temp_dirs'):
 | 
						|
        return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs)
 | 
						|
 | 
						|
    return False
 | 
						|
 | 
						|
 | 
						|
def save_pil_to_file(self, pil_image, dir=None, format="png"):
 | 
						|
    already_saved_as = getattr(pil_image, 'already_saved_as', None)
 | 
						|
    if already_saved_as and os.path.isfile(already_saved_as):
 | 
						|
        register_tmp_file(shared.demo, already_saved_as)
 | 
						|
        filename_with_mtime = f'{already_saved_as}?{os.path.getmtime(already_saved_as)}'
 | 
						|
        register_tmp_file(shared.demo, filename_with_mtime)
 | 
						|
        return filename_with_mtime
 | 
						|
 | 
						|
    if shared.opts.temp_dir != "":
 | 
						|
        dir = shared.opts.temp_dir
 | 
						|
    else:
 | 
						|
        os.makedirs(dir, exist_ok=True)
 | 
						|
 | 
						|
    use_metadata = False
 | 
						|
    metadata = PngImagePlugin.PngInfo()
 | 
						|
    for key, value in pil_image.info.items():
 | 
						|
        if isinstance(key, str) and isinstance(value, str):
 | 
						|
            metadata.add_text(key, value)
 | 
						|
            use_metadata = True
 | 
						|
 | 
						|
    file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
 | 
						|
    pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
 | 
						|
    return file_obj.name
 | 
						|
 | 
						|
 | 
						|
def install_ui_tempdir_override():
 | 
						|
    """override save to file function so that it also writes PNG info"""
 | 
						|
    gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
 | 
						|
 | 
						|
 | 
						|
def on_tmpdir_changed():
 | 
						|
    if shared.opts.temp_dir == "" or shared.demo is None:
 | 
						|
        return
 | 
						|
 | 
						|
    os.makedirs(shared.opts.temp_dir, exist_ok=True)
 | 
						|
 | 
						|
    register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x"))
 | 
						|
 | 
						|
 | 
						|
def cleanup_tmpdr():
 | 
						|
    temp_dir = shared.opts.temp_dir
 | 
						|
    if temp_dir == "" or not os.path.isdir(temp_dir):
 | 
						|
        return
 | 
						|
 | 
						|
    for root, _, files in os.walk(temp_dir, topdown=False):
 | 
						|
        for name in files:
 | 
						|
            _, extension = os.path.splitext(name)
 | 
						|
            if extension != ".png":
 | 
						|
                continue
 | 
						|
 | 
						|
            filename = os.path.join(root, name)
 | 
						|
            os.remove(filename)
 | 
						|
 | 
						|
 | 
						|
def is_gradio_temp_path(path):
 | 
						|
    """
 | 
						|
    Check if the path is a temp dir used by gradio
 | 
						|
    """
 | 
						|
    path = Path(path)
 | 
						|
    if shared.opts.temp_dir and path.is_relative_to(shared.opts.temp_dir):
 | 
						|
        return True
 | 
						|
    if gradio_temp_dir := os.environ.get("GRADIO_TEMP_DIR"):
 | 
						|
        if path.is_relative_to(gradio_temp_dir):
 | 
						|
            return True
 | 
						|
    if path.is_relative_to(Path(tempfile.gettempdir()) / "gradio"):
 | 
						|
        return True
 | 
						|
    return False
 |