| 
									
										
										
										
											2023-01-28 22:52:27 +03:00
										 |  |  | import html | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from modules import shared, ui_extra_networks, sd_models | 
					
						
							| 
									
										
										
										
											2023-08-04 22:05:40 +03:00
										 |  |  | from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor | 
					
						
							| 
									
										
										
										
											2023-01-28 22:52:27 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): | 
					
						
							|  |  |  |     def __init__(self): | 
					
						
							|  |  |  |         super().__init__('Checkpoints') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-05 19:19:55 +03:00
										 |  |  |         self.allow_prompt = False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-28 22:52:27 +03:00
										 |  |  |     def refresh(self): | 
					
						
							|  |  |  |         shared.refresh_checkpoints() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-03 22:46:57 +03:00
										 |  |  |     def create_item(self, name, index=None, enable_filter=True): | 
					
						
							| 
									
										
										
										
											2023-07-16 09:49:22 +03:00
										 |  |  |         checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) | 
					
						
							| 
									
										
										
										
											2023-09-09 17:28:06 +09:00
										 |  |  |         if checkpoint is None: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-16 08:38:23 +03:00
										 |  |  |         path, ext = os.path.splitext(checkpoint.filename) | 
					
						
							| 
									
										
										
										
											2024-01-11 15:06:57 -05:00
										 |  |  |         search_terms = [self.search_terms_from_path(checkpoint.filename)] | 
					
						
							|  |  |  |         if checkpoint.sha256: | 
					
						
							|  |  |  |             search_terms.append(checkpoint.sha256) | 
					
						
							| 
									
										
										
										
											2023-07-16 08:38:23 +03:00
										 |  |  |         return { | 
					
						
							|  |  |  |             "name": checkpoint.name_for_extra, | 
					
						
							|  |  |  |             "filename": checkpoint.filename, | 
					
						
							| 
									
										
										
										
											2023-08-13 02:32:54 -04:00
										 |  |  |             "shorthash": checkpoint.shorthash, | 
					
						
							| 
									
										
										
										
											2023-07-16 08:38:23 +03:00
										 |  |  |             "preview": self.find_preview(path), | 
					
						
							|  |  |  |             "description": self.find_description(path), | 
					
						
							| 
									
										
										
										
											2024-01-11 15:06:57 -05:00
										 |  |  |             "search_terms": search_terms, | 
					
						
							| 
									
										
										
										
											2024-02-02 19:41:07 +03:00
										 |  |  |             "onclick": html.escape(f"return selectCheckpoint({ui_extra_networks.quote_js(name)})"), | 
					
						
							| 
									
										
										
										
											2023-07-16 08:38:23 +03:00
										 |  |  |             "local_preview": f"{path}.{shared.opts.samples_format}", | 
					
						
							| 
									
										
										
										
											2023-08-01 07:08:11 +03:00
										 |  |  |             "metadata": checkpoint.metadata, | 
					
						
							| 
									
										
										
										
											2023-07-16 08:38:23 +03:00
										 |  |  |             "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-28 22:52:27 +03:00
										 |  |  |     def list_items(self): | 
					
						
							| 
									
										
										
										
											2023-09-12 09:29:07 +09:00
										 |  |  |         # instantiate a list to protect against concurrent modification | 
					
						
							| 
									
										
										
										
											2023-08-19 08:39:48 +03:00
										 |  |  |         names = list(sd_models.checkpoints_list) | 
					
						
							|  |  |  |         for index, name in enumerate(names): | 
					
						
							| 
									
										
										
										
											2023-09-09 17:28:06 +09:00
										 |  |  |             item = self.create_item(name, index) | 
					
						
							|  |  |  |             if item is not None: | 
					
						
							|  |  |  |                 yield item | 
					
						
							| 
									
										
										
										
											2023-01-28 22:52:27 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def allowed_directories_for_previews(self): | 
					
						
							| 
									
										
										
										
											2023-01-29 02:32:53 -05:00
										 |  |  |         return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] | 
					
						
							| 
									
										
										
										
											2023-01-28 22:52:27 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-04 22:05:40 +03:00
										 |  |  |     def create_user_metadata_editor(self, ui, tabname): | 
					
						
							|  |  |  |         return CheckpointUserMetadataEditor(ui, tabname, self) |