| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  | # this code is adapted from the script contributed by anon from /h/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import io | 
					
						
							|  |  |  | import pickle | 
					
						
							|  |  |  | import collections | 
					
						
							|  |  |  | import sys | 
					
						
							|  |  |  | import traceback | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import numpy | 
					
						
							|  |  |  | import _codecs | 
					
						
							|  |  |  | import zipfile | 
					
						
							| 
									
										
										
										
											2022-10-11 17:03:00 +03:00
										 |  |  | import re | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-10 00:38:55 -04:00
										 |  |  | # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage | 
					
						
							|  |  |  | TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  | def encode(*args): | 
					
						
							|  |  |  |     out = _codecs.encode(*args) | 
					
						
							|  |  |  |     return out | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class RestrictedUnpickler(pickle.Unpickler): | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  |     extra_handler = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |     def persistent_load(self, saved_id): | 
					
						
							|  |  |  |         assert saved_id[0] == 'storage' | 
					
						
							| 
									
										
										
										
											2022-10-10 00:38:55 -04:00
										 |  |  |         return TypedStorage() | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def find_class(self, module, name): | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  |         if self.extra_handler is not None: | 
					
						
							|  |  |  |             res = self.extra_handler(module, name) | 
					
						
							|  |  |  |             if res is not None: | 
					
						
							|  |  |  |                 return res | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |         if module == 'collections' and name == 'OrderedDict': | 
					
						
							|  |  |  |             return getattr(collections, name) | 
					
						
							| 
									
										
										
										
											2022-12-17 03:24:54 -05:00
										 |  |  |         if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']: | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |             return getattr(torch._utils, name) | 
					
						
							| 
									
										
										
										
											2022-12-17 03:24:54 -05:00
										 |  |  |         if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']: | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |             return getattr(torch, name) | 
					
						
							|  |  |  |         if module == 'torch.nn.modules.container' and name in ['ParameterDict']: | 
					
						
							|  |  |  |             return getattr(torch.nn.modules.container, name) | 
					
						
							| 
									
										
										
										
											2022-12-17 03:24:54 -05:00
										 |  |  |         if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']: | 
					
						
							|  |  |  |             return getattr(numpy.core.multiarray, name) | 
					
						
							|  |  |  |         if module == 'numpy' and name in ['dtype', 'ndarray']: | 
					
						
							|  |  |  |             return getattr(numpy, name) | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |         if module == '_codecs' and name == 'encode': | 
					
						
							|  |  |  |             return encode | 
					
						
							|  |  |  |         if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': | 
					
						
							|  |  |  |             import pytorch_lightning.callbacks | 
					
						
							|  |  |  |             return pytorch_lightning.callbacks.model_checkpoint | 
					
						
							|  |  |  |         if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': | 
					
						
							|  |  |  |             import pytorch_lightning.callbacks.model_checkpoint | 
					
						
							|  |  |  |             return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint | 
					
						
							|  |  |  |         if module == "__builtin__" and name == 'set': | 
					
						
							|  |  |  |             return set | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Forbid everything else. | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  |         raise Exception(f"global '{module}/{name}' is forbidden") | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-02 11:12:13 +01:00
										 |  |  | # Regular expression that accepts 'dirname/version', 'dirname/data.pkl', and 'dirname/data/<number>' | 
					
						
							|  |  |  | allowed_zip_names_re = re.compile(r"^([^/]+)/((data/\d+)|version|(data\.pkl))$") | 
					
						
							|  |  |  | data_pkl_re = re.compile(r"^([^/]+)/data\.pkl$") | 
					
						
							| 
									
										
										
										
											2022-10-11 17:03:00 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | def check_zip_filenames(filename, names): | 
					
						
							|  |  |  |     for name in names: | 
					
						
							|  |  |  |         if allowed_zip_names_re.match(name): | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raise Exception(f"bad file inside {filename}: {name}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  | def check_pt(filename, extra_handler): | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |     try: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # new pytorch format is a zip file | 
					
						
							|  |  |  |         with zipfile.ZipFile(filename) as z: | 
					
						
							| 
									
										
										
										
											2022-10-11 17:03:00 +03:00
										 |  |  |             check_zip_filenames(filename, z.namelist()) | 
					
						
							| 
									
										
										
										
											2022-12-14 21:01:32 -05:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-02 11:12:13 +01:00
										 |  |  |             # find filename of data.pkl in zip file: '<directory name>/data.pkl' | 
					
						
							|  |  |  |             data_pkl_filenames = [f for f in z.namelist() if data_pkl_re.match(f)] | 
					
						
							|  |  |  |             if len(data_pkl_filenames) == 0: | 
					
						
							|  |  |  |                 raise Exception(f"data.pkl not found in {filename}") | 
					
						
							|  |  |  |             if len(data_pkl_filenames) > 1: | 
					
						
							|  |  |  |                 raise Exception(f"Multiple data.pkl found in {filename}") | 
					
						
							|  |  |  |             with z.open(data_pkl_filenames[0]) as file: | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |                 unpickler = RestrictedUnpickler(file) | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  |                 unpickler.extra_handler = extra_handler | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |                 unpickler.load() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     except zipfile.BadZipfile: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle | 
					
						
							|  |  |  |         with open(filename, "rb") as file: | 
					
						
							|  |  |  |             unpickler = RestrictedUnpickler(file) | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  |             unpickler.extra_handler = extra_handler | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |             for i in range(5): | 
					
						
							|  |  |  |                 unpickler.load() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def load(filename, *args, **kwargs): | 
					
						
							| 
									
										
										
										
											2022-12-25 09:03:56 +03:00
										 |  |  |     return load_with_extra(filename, extra_handler=global_extra_handler, *args, **kwargs) | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def load_with_extra(filename, extra_handler=None, *args, **kwargs): | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2022-12-14 21:01:32 -05:00
										 |  |  |     this function is intended to be used by extensions that want to load models with | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  |     some extra classes in them that the usual unpickler would find suspicious. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     Use the extra_handler argument to specify a function that takes module and field name as text, | 
					
						
							|  |  |  |     and returns that field's value: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ```python | 
					
						
							|  |  |  |     def extra(module, name): | 
					
						
							|  |  |  |         if module == 'collections' and name == 'OrderedDict': | 
					
						
							|  |  |  |             return collections.OrderedDict | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     safe.load_with_extra('model.pt', extra_handler=extra) | 
					
						
							|  |  |  |     ``` | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is | 
					
						
							|  |  |  |     definitely unsafe. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |     from modules import shared | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         if not shared.cmd_opts.disable_safe_unpickle: | 
					
						
							| 
									
										
										
										
											2022-11-06 11:20:23 +03:00
										 |  |  |             check_pt(filename, extra_handler) | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-14 16:37:32 +03:00
										 |  |  |     except pickle.UnpicklingError: | 
					
						
							|  |  |  |         print(f"Error verifying pickled file from {filename}:", file=sys.stderr) | 
					
						
							|  |  |  |         print(traceback.format_exc(), file=sys.stderr) | 
					
						
							| 
									
										
										
										
											2022-12-24 21:35:29 +02:00
										 |  |  |         print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) | 
					
						
							|  |  |  |         print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) | 
					
						
							| 
									
										
										
										
											2022-10-14 16:37:32 +03:00
										 |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |     except Exception: | 
					
						
							|  |  |  |         print(f"Error verifying pickled file from {filename}:", file=sys.stderr) | 
					
						
							|  |  |  |         print(traceback.format_exc(), file=sys.stderr) | 
					
						
							| 
									
										
										
										
											2022-12-24 21:35:29 +02:00
										 |  |  |         print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) | 
					
						
							|  |  |  |         print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return unsafe_torch_load(filename, *args, **kwargs) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-25 09:03:56 +03:00
										 |  |  | class Extra: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     A class for temporarily setting the global handler for when you can't explicitly call load_with_extra | 
					
						
							|  |  |  |     (because it's not your code making the torch.load call). The intended use is like this: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | from modules import safe | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def handler(module, name): | 
					
						
							|  |  |  |     if module == 'torch' and name in ['float64', 'float16']: | 
					
						
							|  |  |  |         return getattr(torch, name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | with safe.Extra(handler): | 
					
						
							|  |  |  |     x = torch.load('model.pt') | 
					
						
							|  |  |  | ``` | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, handler): | 
					
						
							|  |  |  |         self.handler = handler | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __enter__(self): | 
					
						
							|  |  |  |         global global_extra_handler | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert global_extra_handler is None, 'already inside an Extra() block' | 
					
						
							|  |  |  |         global_extra_handler = self.handler | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __exit__(self, exc_type, exc_val, exc_tb): | 
					
						
							|  |  |  |         global global_extra_handler | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         global_extra_handler = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-09 17:58:43 +03:00
										 |  |  | unsafe_torch_load = torch.load | 
					
						
							|  |  |  | torch.load = load | 
					
						
							| 
									
										
										
										
											2022-12-25 09:03:56 +03:00
										 |  |  | global_extra_handler = None | 
					
						
							|  |  |  | 
 |