| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  | import os | 
					
						
							|  |  |  | import tempfile | 
					
						
							|  |  |  | from collections import namedtuple | 
					
						
							| 
									
										
										
										
											2023-01-03 14:18:48 +03:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-27 19:06:49 +03:00
										 |  |  | import gradio.components | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | from PIL import PngImagePlugin | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from modules import shared | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | Savedfile = namedtuple("Savedfile", ["name"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-03 14:18:48 +03:00
										 |  |  | def register_tmp_file(gradio, filename): | 
					
						
							|  |  |  |     if hasattr(gradio, 'temp_file_sets'):  # gradio 3.15 | 
					
						
							|  |  |  |         gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if hasattr(gradio, 'temp_dirs'):  # gradio 3.9 | 
					
						
							|  |  |  |         gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def check_tmp_file(gradio, filename): | 
					
						
							|  |  |  |     if hasattr(gradio, 'temp_file_sets'): | 
					
						
							| 
									
										
										
										
											2023-05-10 11:05:02 +03:00
										 |  |  |         return any(filename in fileset for fileset in gradio.temp_file_sets) | 
					
						
							| 
									
										
										
										
											2023-01-03 14:18:48 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if hasattr(gradio, 'temp_dirs'): | 
					
						
							|  |  |  |         return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-04 14:20:23 +08:00
										 |  |  | def save_pil_to_file(self, pil_image, dir=None, format="png"): | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  |     already_saved_as = getattr(pil_image, 'already_saved_as', None) | 
					
						
							| 
									
										
										
										
											2022-11-27 23:14:13 +03:00
										 |  |  |     if already_saved_as and os.path.isfile(already_saved_as): | 
					
						
							| 
									
										
										
										
											2023-05-04 15:55:57 +08:00
										 |  |  |         register_tmp_file(shared.demo, already_saved_as) | 
					
						
							| 
									
										
										
										
											2023-05-27 19:06:49 +03:00
										 |  |  |         filename = already_saved_as | 
					
						
							| 
									
										
										
										
											2023-01-01 11:08:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-27 19:06:49 +03:00
										 |  |  |         if not shared.opts.save_images_add_number: | 
					
						
							|  |  |  |             filename += f'?{os.path.getmtime(already_saved_as)}' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return filename | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if shared.opts.temp_dir != "": | 
					
						
							|  |  |  |         dir = shared.opts.temp_dir | 
					
						
							| 
									
										
										
										
											2023-08-21 17:36:17 -04:00
										 |  |  |     else: | 
					
						
							|  |  |  |         os.makedirs(dir, exist_ok=True) | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     use_metadata = False | 
					
						
							|  |  |  |     metadata = PngImagePlugin.PngInfo() | 
					
						
							|  |  |  |     for key, value in pil_image.info.items(): | 
					
						
							|  |  |  |         if isinstance(key, str) and isinstance(value, str): | 
					
						
							|  |  |  |             metadata.add_text(key, value) | 
					
						
							|  |  |  |             use_metadata = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir) | 
					
						
							|  |  |  |     pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) | 
					
						
							| 
									
										
										
										
											2023-05-27 19:06:49 +03:00
										 |  |  |     return file_obj.name | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-09 18:11:13 +03:00
										 |  |  | def install_ui_tempdir_override(): | 
					
						
							|  |  |  |     """override save to file function so that it also writes PNG info""" | 
					
						
							|  |  |  |     gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def on_tmpdir_changed(): | 
					
						
							|  |  |  |     if shared.opts.temp_dir == "" or shared.demo is None: | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     os.makedirs(shared.opts.temp_dir, exist_ok=True) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-03 14:18:48 +03:00
										 |  |  |     register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def cleanup_tmpdr(): | 
					
						
							|  |  |  |     temp_dir = shared.opts.temp_dir | 
					
						
							|  |  |  |     if temp_dir == "" or not os.path.isdir(temp_dir): | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 11:37:18 +03:00
										 |  |  |     for root, _, files in os.walk(temp_dir, topdown=False): | 
					
						
							| 
									
										
										
										
											2022-11-27 11:52:53 +03:00
										 |  |  |         for name in files: | 
					
						
							|  |  |  |             _, extension = os.path.splitext(name) | 
					
						
							|  |  |  |             if extension != ".png": | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             filename = os.path.join(root, name) | 
					
						
							|  |  |  |             os.remove(filename) |