mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-25 23:24:45 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			193 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			193 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # this code is adapted from the script contributed by anon from /h/
 | |
| 
 | |
| import io
 | |
| import pickle
 | |
| import collections
 | |
| import sys
 | |
| import traceback
 | |
| 
 | |
| import torch
 | |
| import numpy
 | |
| import _codecs
 | |
| import zipfile
 | |
| import re
 | |
| 
 | |
| 
 | |
| # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage
 | |
| TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage
 | |
| 
 | |
| 
 | |
| def encode(*args):
 | |
|     out = _codecs.encode(*args)
 | |
|     return out
 | |
| 
 | |
| 
 | |
| class RestrictedUnpickler(pickle.Unpickler):
 | |
|     extra_handler = None
 | |
| 
 | |
|     def persistent_load(self, saved_id):
 | |
|         assert saved_id[0] == 'storage'
 | |
|         return TypedStorage()
 | |
| 
 | |
|     def find_class(self, module, name):
 | |
|         if self.extra_handler is not None:
 | |
|             res = self.extra_handler(module, name)
 | |
|             if res is not None:
 | |
|                 return res
 | |
| 
 | |
|         if module == 'collections' and name == 'OrderedDict':
 | |
|             return getattr(collections, name)
 | |
|         if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
 | |
|             return getattr(torch._utils, name)
 | |
|         if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
 | |
|             return getattr(torch, name)
 | |
|         if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
 | |
|             return getattr(torch.nn.modules.container, name)
 | |
|         if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
 | |
|             return getattr(numpy.core.multiarray, name)
 | |
|         if module == 'numpy' and name in ['dtype', 'ndarray']:
 | |
|             return getattr(numpy, name)
 | |
|         if module == '_codecs' and name == 'encode':
 | |
|             return encode
 | |
|         if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
 | |
|             import pytorch_lightning.callbacks
 | |
|             return pytorch_lightning.callbacks.model_checkpoint
 | |
|         if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
 | |
|             import pytorch_lightning.callbacks.model_checkpoint
 | |
|             return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
 | |
|         if module == "__builtin__" and name == 'set':
 | |
|             return set
 | |
| 
 | |
|         # Forbid everything else.
 | |
|         raise Exception(f"global '{module}/{name}' is forbidden")
 | |
| 
 | |
| 
 | |
| # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>'
 | |
| allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$")
 | |
| data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$")
 | |
| 
 | |
| def check_zip_filenames(filename, names):
 | |
|     for name in names:
 | |
|         if allowed_zip_names_re.match(name):
 | |
|             continue
 | |
| 
 | |
|         raise Exception(f"bad file inside {filename}: {name}")
 | |
| 
 | |
| 
 | |
| def check_pt(filename, extra_handler):
 | |
|     try:
 | |
| 
 | |
|         # new pytorch format is a zip file
 | |
|         with zipfile.ZipFile(filename) as z:
 | |
|             check_zip_filenames(filename, z.namelist())
 | |
| 
 | |
|             # find filename of data.pkl in zip file: '<directory name>/data.pkl'
 | |
|             data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)]
 | |
|             if len(data_pkl_filenames) == 0:
 | |
|                 raise Exception(f"data.pkl not found in {filename}")
 | |
|             if len(data_pkl_filenames) > 1:
 | |
|                 raise Exception(f"Multiple data.pkl found in {filename}")
 | |
|             with z.open(data_pkl_filenames[0]) as file:
 | |
|                 unpickler = RestrictedUnpickler(file)
 | |
|                 unpickler.extra_handler = extra_handler
 | |
|                 unpickler.load()
 | |
| 
 | |
|     except zipfile.BadZipfile:
 | |
| 
 | |
|         # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
 | |
|         with open(filename, "rb") as file:
 | |
|             unpickler = RestrictedUnpickler(file)
 | |
|             unpickler.extra_handler = extra_handler
 | |
|             for i in range(5):
 | |
|                 unpickler.load()
 | |
| 
 | |
| 
 | |
| def load(filename, *args, **kwargs):
 | |
|     return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs)
 | |
| 
 | |
| 
 | |
| def load_with_extra(filename, extra_handler=None, *args, **kwargs):
 | |
|     """
 | |
|     this function is intended to be used by extensions that want to load models with
 | |
|     some extra classes in them that the usual unpickler would find suspicious.
 | |
| 
 | |
|     Use the extra_handler argument to specify a function that takes module and field name as text,
 | |
|     and returns that field's value:
 | |
| 
 | |
|     ```python
 | |
|     def extra(module, name):
 | |
|         if module == 'collections' and name == 'OrderedDict':
 | |
|             return collections.OrderedDict
 | |
| 
 | |
|         return None
 | |
| 
 | |
|     safe.load_with_extra('model.pt', extra_handler=extra)
 | |
|     ```
 | |
| 
 | |
|     The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
 | |
|     definitely unsafe.
 | |
|     """
 | |
| 
 | |
|     from modules import shared
 | |
| 
 | |
|     try:
 | |
|         if not shared.cmd_opts.disable_safe_unpickle:
 | |
|             check_pt(filename, extra_handler)
 | |
| 
 | |
|     except pickle.UnpicklingError:
 | |
|         print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
 | |
|         print(traceback.format_exc(), file=sys.stderr)
 | |
|         print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
 | |
|         print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
 | |
|         return None
 | |
| 
 | |
|     except Exception:
 | |
|         print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
 | |
|         print(traceback.format_exc(), file=sys.stderr)
 | |
|         print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
 | |
|         print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
 | |
|         return None
 | |
| 
 | |
|     return unsafe_torch_load(filename, *args, **kwargs)
 | |
| 
 | |
| 
 | |
| class Extra:
 | |
|     """
 | |
|     A class for temporarily setting the global handler for when you can't explicitly call load_with_extra
 | |
|     (because it's not your code making the torch.load call). The intended use is like this:
 | |
| 
 | |
| ```
 | |
| import torch
 | |
| from modules import safe
 | |
| 
 | |
| def handler(module, name):
 | |
|     if module == 'torch' and name in ['float64', 'float16']:
 | |
|         return getattr(torch, name)
 | |
| 
 | |
|     return None
 | |
| 
 | |
| with safe.Extra(handler):
 | |
|     x = torch.load('model.pt')
 | |
| ```
 | |
|     """
 | |
| 
 | |
|     def __init__(self, handler):
 | |
|         self.handler = handler
 | |
| 
 | |
|     def __enter__(self):
 | |
|         global global_extra_handler
 | |
| 
 | |
|         assert global_extra_handler is None, 'already inside an Extra() block'
 | |
|         global_extra_handler = self.handler
 | |
| 
 | |
|     def __exit__(self, exc_type, exc_val, exc_tb):
 | |
|         global global_extra_handler
 | |
| 
 | |
|         global_extra_handler = None
 | |
| 
 | |
| 
 | |
| unsafe_torch_load = torch.load
 | |
| torch.load = load
 | |
| global_extra_handler = None
 | |
| 
 | 
