| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | import os | 
					
						
							| 
									
										
										
										
											2023-04-13 23:12:33 -04:00
										 |  |  | from collections import namedtuple | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | import torch | 
					
						
							|  |  |  | import tqdm | 
					
						
							|  |  |  | import html | 
					
						
							|  |  |  | import datetime | 
					
						
							| 
									
										
										
										
											2022-10-12 23:36:29 +02:00
										 |  |  | import csv | 
					
						
							| 
									
										
										
										
											2023-01-10 18:40:34 -07:00
										 |  |  | import safetensors.torch | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-13 15:04:37 +03:00
										 |  |  | import numpy as np | 
					
						
							| 
									
										
										
										
											2022-10-12 13:15:35 +01:00
										 |  |  | from PIL import Image, PngImagePlugin | 
					
						
							| 
									
										
										
										
											2022-10-20 16:26:16 +02:00
										 |  |  | from torch.utils.tensorboard import SummaryWriter | 
					
						
							| 
									
										
										
										
											2022-10-02 22:41:21 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-31 19:56:37 +03:00
										 |  |  | from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | import modules.textual_inversion.dataset | 
					
						
							| 
									
										
										
										
											2022-10-12 20:49:47 +03:00
										 |  |  | from modules.textual_inversion.learn_schedule import LearnRateScheduler | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-09 23:35:40 +03:00
										 |  |  | from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay | 
					
						
							| 
									
										
										
										
											2023-01-06 08:52:06 +03:00
										 |  |  | from modules.textual_inversion.logging import save_settings_to_file | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-09 23:35:40 +03:00
										 |  |  | TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) | 
					
						
							|  |  |  | textual_inversion_templates = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def list_textual_inversion_templates(): | 
					
						
							|  |  |  |     textual_inversion_templates.clear() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 11:37:18 +03:00
										 |  |  |     for root, _, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): | 
					
						
							| 
									
										
										
										
											2023-01-09 23:35:40 +03:00
										 |  |  |         for fn in fns: | 
					
						
							|  |  |  |             path = os.path.join(root, fn) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return textual_inversion_templates | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | class Embedding: | 
					
						
							|  |  |  |     def __init__(self, vec, name, step=None): | 
					
						
							|  |  |  |         self.vec = vec | 
					
						
							|  |  |  |         self.name = name | 
					
						
							|  |  |  |         self.step = step | 
					
						
							| 
									
										
										
										
											2022-12-31 11:27:02 -05:00
										 |  |  |         self.shape = None | 
					
						
							|  |  |  |         self.vectors = 0 | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |         self.cached_checksum = None | 
					
						
							| 
									
										
										
										
											2022-10-02 20:15:25 +03:00
										 |  |  |         self.sd_checkpoint = None | 
					
						
							|  |  |  |         self.sd_checkpoint_name = None | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |         self.optimizer_state_dict = None | 
					
						
							| 
									
										
										
										
											2023-01-21 08:36:07 +03:00
										 |  |  |         self.filename = None | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def save(self, filename): | 
					
						
							|  |  |  |         embedding_data = { | 
					
						
							|  |  |  |             "string_to_token": {"*": 265}, | 
					
						
							|  |  |  |             "string_to_param": {"*": self.vec}, | 
					
						
							|  |  |  |             "name": self.name, | 
					
						
							|  |  |  |             "step": self.step, | 
					
						
							| 
									
										
										
										
											2022-10-02 20:15:25 +03:00
										 |  |  |             "sd_checkpoint": self.sd_checkpoint, | 
					
						
							|  |  |  |             "sd_checkpoint_name": self.sd_checkpoint_name, | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         torch.save(embedding_data, filename) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |         if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: | 
					
						
							|  |  |  |             optimizer_saved_dict = { | 
					
						
							|  |  |  |                 'hash': self.checksum(), | 
					
						
							|  |  |  |                 'optimizer_state_dict': self.optimizer_state_dict, | 
					
						
							|  |  |  |             } | 
					
						
							| 
									
										
										
										
											2023-05-09 22:17:58 +03:00
										 |  |  |             torch.save(optimizer_saved_dict, f"{filename}.optim") | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |     def checksum(self): | 
					
						
							|  |  |  |         if self.cached_checksum is not None: | 
					
						
							|  |  |  |             return self.cached_checksum | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def const_hash(a): | 
					
						
							|  |  |  |             r = 0 | 
					
						
							|  |  |  |             for v in a: | 
					
						
							|  |  |  |                 r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF | 
					
						
							|  |  |  |             return r | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' | 
					
						
							|  |  |  |         return self.cached_checksum | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 20:15:25 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  | class DirWithTextualInversionEmbeddings: | 
					
						
							|  |  |  |     def __init__(self, path): | 
					
						
							|  |  |  |         self.path = path | 
					
						
							|  |  |  |         self.mtime = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def has_changed(self): | 
					
						
							|  |  |  |         if not os.path.isdir(self.path): | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         mt = os.path.getmtime(self.path) | 
					
						
							|  |  |  |         if self.mtime is None or mt > self.mtime: | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def update(self): | 
					
						
							|  |  |  |         if not os.path.isdir(self.path): | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.mtime = os.path.getmtime(self.path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | class EmbeddingDatabase: | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |     def __init__(self): | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |         self.ids_lookup = {} | 
					
						
							| 
									
										
										
										
											2023-04-13 23:12:33 -04:00
										 |  |  |         self.word_embeddings = {} | 
					
						
							| 
									
										
										
										
											2023-01-02 12:21:22 +11:00
										 |  |  |         self.skipped_embeddings = {} | 
					
						
							| 
									
										
										
										
											2022-12-31 11:27:02 -05:00
										 |  |  |         self.expected_shape = -1 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |         self.embedding_dirs = {} | 
					
						
							| 
									
										
										
										
											2023-01-29 11:53:05 +03:00
										 |  |  |         self.previously_displayed_embeddings = () | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |     def add_embedding_dir(self, path): | 
					
						
							|  |  |  |         self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |     def clear_embedding_dirs(self): | 
					
						
							|  |  |  |         self.embedding_dirs.clear() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def register_embedding(self, embedding, model): | 
					
						
							| 
									
										
										
										
											2023-05-29 01:09:59 +05:00
										 |  |  |         return self.register_embedding_by_name(embedding, model, embedding.name) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-29 01:09:59 +05:00
										 |  |  |     def register_embedding_by_name(self, embedding, model, name): | 
					
						
							|  |  |  |         ids = model.cond_stage_model.tokenize([name])[0] | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |         first_id = ids[0] | 
					
						
							|  |  |  |         if first_id not in self.ids_lookup: | 
					
						
							|  |  |  |             self.ids_lookup[first_id] = [] | 
					
						
							| 
									
										
										
										
											2023-05-29 01:09:59 +05:00
										 |  |  |         if name in self.word_embeddings: | 
					
						
							|  |  |  |             # remove old one from the lookup list | 
					
						
							|  |  |  |             lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             lookup = self.ids_lookup[first_id] | 
					
						
							|  |  |  |         if embedding is not None: | 
					
						
							|  |  |  |             lookup += [(ids, embedding)] | 
					
						
							|  |  |  |         self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True) | 
					
						
							|  |  |  |         if embedding is None: | 
					
						
							|  |  |  |             # unregister embedding with specified name | 
					
						
							|  |  |  |             if name in self.word_embeddings: | 
					
						
							|  |  |  |                 del self.word_embeddings[name] | 
					
						
							|  |  |  |             if len(self.ids_lookup[first_id])==0: | 
					
						
							|  |  |  |                 del self.ids_lookup[first_id] | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  |         self.word_embeddings[name] = embedding | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |         return embedding | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-12-31 11:27:02 -05:00
										 |  |  |     def get_expected_shape(self): | 
					
						
							| 
									
										
										
										
											2022-12-31 22:49:09 +03:00
										 |  |  |         vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) | 
					
						
							|  |  |  |         return vec.shape[1] | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |     def load_from_file(self, path, filename): | 
					
						
							|  |  |  |         name, ext = os.path.splitext(filename) | 
					
						
							|  |  |  |         ext = ext.upper() | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |         if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: | 
					
						
							|  |  |  |             _, second_ext = os.path.splitext(name) | 
					
						
							|  |  |  |             if second_ext.upper() == '.PREVIEW': | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |                 return | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |             embed_image = Image.open(path) | 
					
						
							|  |  |  |             if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: | 
					
						
							|  |  |  |                 data = embedding_from_b64(embed_image.text['sd-ti-embedding']) | 
					
						
							|  |  |  |                 name = data.get('name', name) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |                 data = extract_image_data_embed(embed_image) | 
					
						
							| 
									
										
											  
											
												Fix None type error for TI module
When user using model_name.png as a preview image, textural_inversion.py still treat it as an embeding, and didn't handle its error, just let python throw out an None type error like following:
```bash
  File "D:\Work\Dev\AI\stable-diffusion-webui\modules\textual_inversion\textual_inversion.py", line 155, in load_from_file
    name = data.get('name', name)
AttributeError: 'NoneType' object has no attribute 'get'
```
With just a simple `if data:` checking as following, there will be no error, breaks nothing, and now this module can works fine with user's preview images.
Old code:  
```python
                data = extract_image_data_embed(embed_image)
                name = data.get('name', name)
```
New code:  
```python
                data = extract_image_data_embed(embed_image)
                if data:
                    name = data.get('name', name)
                else:
                    # if data is None, means this is not an embeding, just a preview image
                    return
```
Also, since there is no more errors on textual inversion module, from now on, extra network can set "model_name.png" as preview image for embedings.
											
										 
											2023-03-25 02:05:00 +08:00
										 |  |  |                 if data: | 
					
						
							|  |  |  |                     name = data.get('name', name) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     # if data is None, means this is not an embeding, just a preview image | 
					
						
							|  |  |  |                     return | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |         elif ext in ['.BIN', '.PT']: | 
					
						
							|  |  |  |             data = torch.load(path, map_location="cpu") | 
					
						
							| 
									
										
										
										
											2023-01-10 18:40:34 -07:00
										 |  |  |         elif ext in ['.SAFETENSORS']: | 
					
						
							|  |  |  |             data = safetensors.torch.load_file(path, device="cpu") | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |         else: | 
					
						
							|  |  |  |             return | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |         # textual inversion embeddings | 
					
						
							|  |  |  |         if 'string_to_param' in data: | 
					
						
							|  |  |  |             param_dict = data['string_to_param'] | 
					
						
							| 
									
										
										
										
											2023-05-10 21:21:32 +03:00
										 |  |  |             param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |             assert len(param_dict) == 1, 'embedding file has multiple terms in it' | 
					
						
							|  |  |  |             emb = next(iter(param_dict.items()))[1] | 
					
						
							|  |  |  |         # diffuser concepts | 
					
						
							|  |  |  |         elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: | 
					
						
							|  |  |  |             assert len(data.keys()) == 1, 'embedding file has multiple terms in it' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             emb = next(iter(data.values())) | 
					
						
							|  |  |  |             if len(emb.shape) == 1: | 
					
						
							|  |  |  |                 emb = emb.unsqueeze(0) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         vec = emb.detach().to(devices.device, dtype=torch.float32) | 
					
						
							|  |  |  |         embedding = Embedding(vec, name) | 
					
						
							|  |  |  |         embedding.step = data.get('step', None) | 
					
						
							|  |  |  |         embedding.sd_checkpoint = data.get('sd_checkpoint', None) | 
					
						
							|  |  |  |         embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) | 
					
						
							|  |  |  |         embedding.vectors = vec.shape[0] | 
					
						
							|  |  |  |         embedding.shape = vec.shape[-1] | 
					
						
							| 
									
										
										
										
											2023-01-21 08:36:07 +03:00
										 |  |  |         embedding.filename = path | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if self.expected_shape == -1 or self.expected_shape == embedding.shape: | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |             self.register_embedding(embedding, shared.sd_model) | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |         else: | 
					
						
							|  |  |  |             self.skipped_embeddings[name] = embedding | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def load_from_dir(self, embdir): | 
					
						
							|  |  |  |         if not os.path.isdir(embdir.path): | 
					
						
							|  |  |  |             return | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 11:37:18 +03:00
										 |  |  |         for root, _, fns in os.walk(embdir.path, followlinks=True): | 
					
						
							| 
									
										
										
										
											2023-01-06 03:38:37 +07:00
										 |  |  |             for fn in fns: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     fullfn = os.path.join(root, fn) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-06 03:38:37 +07:00
										 |  |  |                     if os.stat(fullfn).st_size == 0: | 
					
						
							|  |  |  |                         continue | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |                     self.load_from_file(fullfn, fn) | 
					
						
							| 
									
										
										
										
											2023-01-06 03:38:37 +07:00
										 |  |  |                 except Exception: | 
					
						
							| 
									
										
										
										
											2023-05-31 19:56:37 +03:00
										 |  |  |                     errors.report(f"Error loading embedding {fn}", exc_info=True) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |                     continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |     def load_textual_inversion_embeddings(self, force_reload=False): | 
					
						
							|  |  |  |         if not force_reload: | 
					
						
							|  |  |  |             need_reload = False | 
					
						
							| 
									
										
										
										
											2023-05-10 11:37:18 +03:00
										 |  |  |             for embdir in self.embedding_dirs.values(): | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |                 if embdir.has_changed(): | 
					
						
							|  |  |  |                     need_reload = True | 
					
						
							|  |  |  |                     break | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |             if not need_reload: | 
					
						
							|  |  |  |                 return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.ids_lookup.clear() | 
					
						
							|  |  |  |         self.word_embeddings.clear() | 
					
						
							|  |  |  |         self.skipped_embeddings.clear() | 
					
						
							|  |  |  |         self.expected_shape = self.get_expected_shape() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 11:37:18 +03:00
										 |  |  |         for embdir in self.embedding_dirs.values(): | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |             self.load_from_dir(embdir) | 
					
						
							|  |  |  |             embdir.update() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-08 15:58:00 -04:00
										 |  |  |         # re-sort word_embeddings because load_from_dir may not load in alphabetic order. | 
					
						
							| 
									
										
										
										
											2023-04-13 23:12:33 -04:00
										 |  |  |         # using a temporary copy so we don't reinitialize self.word_embeddings in case other objects have a reference to it. | 
					
						
							|  |  |  |         sorted_word_embeddings = {e.name: e for e in sorted(self.word_embeddings.values(), key=lambda e: e.name.lower())} | 
					
						
							|  |  |  |         self.word_embeddings.clear() | 
					
						
							|  |  |  |         self.word_embeddings.update(sorted_word_embeddings) | 
					
						
							| 
									
										
										
										
											2023-04-08 15:58:00 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-29 11:53:05 +03:00
										 |  |  |         displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys())) | 
					
						
							|  |  |  |         if self.previously_displayed_embeddings != displayed_embeddings: | 
					
						
							|  |  |  |             self.previously_displayed_embeddings = displayed_embeddings | 
					
						
							|  |  |  |             print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") | 
					
						
							| 
									
										
										
										
											2023-06-02 14:58:10 +03:00
										 |  |  |             if self.skipped_embeddings: | 
					
						
							| 
									
										
										
										
											2023-01-29 11:53:05 +03:00
										 |  |  |                 print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def find_embedding_at_position(self, tokens, offset): | 
					
						
							|  |  |  |         token = tokens[offset] | 
					
						
							|  |  |  |         possible_matches = self.ids_lookup.get(token, None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if possible_matches is None: | 
					
						
							| 
									
										
										
										
											2022-10-02 19:40:51 +03:00
										 |  |  |             return None, None | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         for ids, embedding in possible_matches: | 
					
						
							|  |  |  |             if tokens[offset:offset + len(ids)] == ids: | 
					
						
							| 
									
										
										
										
											2022-10-02 19:40:51 +03:00
										 |  |  |                 return embedding, len(ids) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 19:40:51 +03:00
										 |  |  |         return None, None | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 00:10:59 +01:00
										 |  |  | def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |     cond_model = shared.sd_model.cond_stage_model | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-26 09:44:02 +03:00
										 |  |  |     with devices.autocast(): | 
					
						
							|  |  |  |         cond_model([""])  # will send cond model to GPU if lowvram/medvram is active | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-12 09:22:29 +01:00
										 |  |  |     #cond_model expects at least some text, so we provide '*' as backup. | 
					
						
							|  |  |  |     embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |     vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-12 09:22:29 +01:00
										 |  |  |     #Only copy if we provided an init_text, otherwise keep vectors as zeros | 
					
						
							|  |  |  |     if init_text: | 
					
						
							|  |  |  |         for i in range(num_vectors_per_token): | 
					
						
							|  |  |  |             vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-24 23:38:07 -07:00
										 |  |  |     # Remove illegal characters from name. | 
					
						
							|  |  |  |     name = "".join( x for x in name if (x.isalnum() or x in "._- ")) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |     fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") | 
					
						
							| 
									
										
										
										
											2022-10-20 00:10:59 +01:00
										 |  |  |     if not overwrite_old: | 
					
						
							|  |  |  |         assert not os.path.exists(fn), f"file {fn} already exists" | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     embedding = Embedding(vec, name) | 
					
						
							|  |  |  |     embedding.step = 0 | 
					
						
							|  |  |  |     embedding.save(fn) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return fn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-14 22:43:55 +03:00
										 |  |  | def write_loss(log_directory, filename, step, epoch_len, values): | 
					
						
							|  |  |  |     if shared.opts.training_write_csv_every == 0: | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if step % shared.opts.training_write_csv_every != 0: | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  |     write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     with open(os.path.join(log_directory, filename), "a+", newline='') as fout: | 
					
						
							|  |  |  |         csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if write_csv_header: | 
					
						
							|  |  |  |             csv_writer.writeheader() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |         epoch = (step - 1) // epoch_len | 
					
						
							| 
									
										
										
										
											2022-12-14 21:01:32 -05:00
										 |  |  |         epoch_step = (step - 1) % epoch_len | 
					
						
							| 
									
										
										
										
											2022-10-14 22:43:55 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         csv_writer.writerow({ | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |             "step": step, | 
					
						
							| 
									
										
										
										
											2022-10-28 20:48:08 +07:00
										 |  |  |             "epoch": epoch, | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |             "epoch_step": epoch_step, | 
					
						
							| 
									
										
										
										
											2022-10-14 22:43:55 +03:00
										 |  |  |             **values, | 
					
						
							|  |  |  |         }) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 22:37:16 +02:00
										 |  |  | def tensorboard_setup(log_directory): | 
					
						
							|  |  |  |     os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) | 
					
						
							|  |  |  |     return SummaryWriter( | 
					
						
							|  |  |  |             log_dir=os.path.join(log_directory, "tensorboard"), | 
					
						
							|  |  |  |             flush_secs=shared.opts.training_tensorboard_flush_every) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num): | 
					
						
							|  |  |  |     tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step) | 
					
						
							|  |  |  |     tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step) | 
					
						
							|  |  |  |     tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step) | 
					
						
							|  |  |  |     tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 16:26:16 +02:00
										 |  |  | def tensorboard_add_scaler(tensorboard_writer, tag, value, step): | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  |     tensorboard_writer.add_scalar(tag=tag, | 
					
						
							| 
									
										
										
										
											2022-10-20 22:37:16 +02:00
										 |  |  |         scalar_value=value, global_step=step) | 
					
						
							| 
									
										
										
										
											2022-10-20 16:26:16 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): | 
					
						
							| 
									
										
										
										
											2022-10-20 22:37:16 +02:00
										 |  |  |     # Convert a pil image to a torch tensor | 
					
						
							|  |  |  |     img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  |     img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], | 
					
						
							| 
									
										
										
										
											2022-10-20 22:37:16 +02:00
										 |  |  |         len(pil_image.getbands())) | 
					
						
							|  |  |  |     img_tensor = img_tensor.permute((2, 0, 1)) | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 22:37:16 +02:00
										 |  |  |     tensorboard_writer.add_image(tag, img_tensor, global_step=step) | 
					
						
							| 
									
										
										
										
											2022-10-14 22:43:55 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-09 23:35:40 +03:00
										 |  |  | def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): | 
					
						
							| 
									
										
										
										
											2022-10-29 18:09:17 +07:00
										 |  |  |     assert model_name, f"{name} not selected" | 
					
						
							|  |  |  |     assert learn_rate, "Learning rate is empty or 0" | 
					
						
							|  |  |  |     assert isinstance(batch_size, int), "Batch size must be integer" | 
					
						
							|  |  |  |     assert batch_size > 0, "Batch size must be positive" | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" | 
					
						
							|  |  |  |     assert gradient_step > 0, "Gradient accumulation step must be positive" | 
					
						
							| 
									
										
										
										
											2022-10-29 18:09:17 +07:00
										 |  |  |     assert data_root, "Dataset directory is empty" | 
					
						
							|  |  |  |     assert os.path.isdir(data_root), "Dataset directory doesn't exist" | 
					
						
							|  |  |  |     assert os.listdir(data_root), "Dataset directory is empty" | 
					
						
							| 
									
										
										
										
											2023-01-09 23:35:40 +03:00
										 |  |  |     assert template_filename, "Prompt template file not selected" | 
					
						
							|  |  |  |     assert template_file, f"Prompt template file {template_filename} not found" | 
					
						
							|  |  |  |     assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" | 
					
						
							| 
									
										
										
										
											2022-10-29 18:09:17 +07:00
										 |  |  |     assert steps, "Max steps is empty or 0" | 
					
						
							|  |  |  |     assert isinstance(steps, int), "Max steps must be integer" | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |     assert steps > 0, "Max steps must be positive" | 
					
						
							| 
									
										
										
										
											2022-10-29 18:09:17 +07:00
										 |  |  |     assert isinstance(save_model_every, int), "Save {name} must be integer" | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |     assert save_model_every >= 0, "Save {name} must be positive or 0" | 
					
						
							| 
									
										
										
										
											2022-10-29 18:09:17 +07:00
										 |  |  |     assert isinstance(create_image_every, int), "Create image must be integer" | 
					
						
							| 
									
										
										
										
											2023-01-08 09:37:33 +03:00
										 |  |  |     assert create_image_every >= 0, "Create image must be positive or 0" | 
					
						
							| 
									
										
										
										
											2022-10-29 18:09:17 +07:00
										 |  |  |     if save_model_every or create_image_every: | 
					
						
							|  |  |  |         assert log_directory, "Log directory is empty" | 
					
						
							| 
									
										
										
										
											2022-10-14 22:43:55 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-09 23:35:40 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-12 16:29:00 +01:00
										 |  |  | def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): | 
					
						
							| 
									
										
										
										
											2022-10-29 18:09:17 +07:00
										 |  |  |     save_embedding_every = save_embedding_every or 0 | 
					
						
							|  |  |  |     create_image_every = create_image_every or 0 | 
					
						
							| 
									
										
										
										
											2023-01-09 23:35:40 +03:00
										 |  |  |     template_file = textual_inversion_templates.get(template_filename, None) | 
					
						
							|  |  |  |     validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") | 
					
						
							|  |  |  |     template_file = template_file.path | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-03 10:34:51 -05:00
										 |  |  |     shared.state.job = "train-embedding" | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |     shared.state.textinfo = "Initializing textual inversion training..." | 
					
						
							|  |  |  |     shared.state.job_count = steps | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-03 13:10:03 +03:00
										 |  |  |     log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) | 
					
						
							| 
									
										
										
										
											2022-10-31 07:26:08 -04:00
										 |  |  |     unload = shared.opts.unload_models_when_training | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if save_embedding_every > 0: | 
					
						
							|  |  |  |         embedding_dir = os.path.join(log_directory, "embeddings") | 
					
						
							|  |  |  |         os.makedirs(embedding_dir, exist_ok=True) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         embedding_dir = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if create_image_every > 0: | 
					
						
							|  |  |  |         images_dir = os.path.join(log_directory, "images") | 
					
						
							|  |  |  |         os.makedirs(images_dir, exist_ok=True) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         images_dir = None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-10 00:07:52 +01:00
										 |  |  |     if create_image_every > 0 and save_image_with_stored_embedding: | 
					
						
							|  |  |  |         images_embeds_dir = os.path.join(log_directory, "image_embeddings") | 
					
						
							|  |  |  |         os.makedirs(images_embeds_dir, exist_ok=True) | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         images_embeds_dir = None | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     hijack = sd_hijack.model_hijack | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     embedding = hijack.embedding_db.word_embeddings[embedding_name] | 
					
						
							| 
									
										
										
										
											2022-10-30 00:49:29 +07:00
										 |  |  |     checkpoint = sd_models.select_checkpoint() | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 19:43:21 +02:00
										 |  |  |     initial_step = embedding.step or 0 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     if initial_step >= steps: | 
					
						
							| 
									
										
										
										
											2022-12-24 21:35:29 +02:00
										 |  |  |         shared.state.textinfo = "Model has already been trained beyond specified max steps" | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |         return embedding, filename | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 19:43:21 +02:00
										 |  |  |     scheduler = LearnRateScheduler(learn_rate, steps, initial_step) | 
					
						
							| 
									
										
										
										
											2022-11-05 11:48:38 +07:00
										 |  |  |     clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ | 
					
						
							|  |  |  |         torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ | 
					
						
							|  |  |  |         None | 
					
						
							|  |  |  |     if clip_grad: | 
					
						
							| 
									
										
										
										
											2023-01-05 18:44:19 +01:00
										 |  |  |         clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) | 
					
						
							| 
									
										
										
										
											2022-10-29 18:09:17 +07:00
										 |  |  |     # dataset loading may take a while, so input validations and early returns should be done before this | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |     shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." | 
					
						
							| 
									
										
										
										
											2022-11-04 04:50:22 -04:00
										 |  |  |     old_parallel_processing_allowed = shared.parallel_processing_allowed | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-20 16:26:16 +02:00
										 |  |  |     if shared.opts.training_enable_tensorboard: | 
					
						
							| 
									
										
										
										
											2022-10-20 22:37:16 +02:00
										 |  |  |         tensorboard_writer = tensorboard_setup(log_directory) | 
					
						
							| 
									
										
										
										
											2022-10-20 16:26:16 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     pin_memory = shared.opts.pin_memory | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-12 16:29:00 +01:00
										 |  |  |     ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight) | 
					
						
							| 
									
										
										
										
											2022-10-20 16:26:16 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-05 08:14:38 -08:00
										 |  |  |     if shared.opts.save_training_settings_to_txt: | 
					
						
							| 
									
										
										
										
											2023-01-14 09:56:59 +03:00
										 |  |  |         save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     latent_sampling_method = ds.latent_sampling_method | 
					
						
							| 
									
										
										
										
											2022-10-09 05:38:38 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-21 10:15:46 +09:00
										 |  |  |     dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) | 
					
						
							| 
									
										
										
										
											2022-10-12 13:15:35 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-31 07:26:08 -04:00
										 |  |  |     if unload: | 
					
						
							| 
									
										
										
										
											2022-11-04 04:50:22 -04:00
										 |  |  |         shared.parallel_processing_allowed = False | 
					
						
							| 
									
										
										
										
											2022-10-31 07:26:08 -04:00
										 |  |  |         shared.sd_model.first_stage_model.to(devices.cpu) | 
					
						
							| 
									
										
										
										
											2022-10-10 00:07:52 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |     embedding.vec.requires_grad = True | 
					
						
							| 
									
										
										
										
											2022-11-27 00:35:44 +09:00
										 |  |  |     optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |     if shared.opts.save_optimizer_state: | 
					
						
							|  |  |  |         optimizer_state_dict = None | 
					
						
							| 
									
										
										
										
											2023-05-09 22:17:58 +03:00
										 |  |  |         if os.path.exists(f"{filename}.optim"): | 
					
						
							|  |  |  |             optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu') | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |             if embedding.checksum() == optimizer_saved_dict.get('hash', None): | 
					
						
							|  |  |  |                 optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |         if optimizer_state_dict is not None: | 
					
						
							|  |  |  |             optimizer.load_state_dict(optimizer_state_dict) | 
					
						
							|  |  |  |             print("Loaded existing optimizer from checkpoint") | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             print("No saved optimizer exists in checkpoint") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     scaler = torch.cuda.amp.GradScaler() | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     batch_size = ds.batch_size | 
					
						
							|  |  |  |     gradient_step = ds.gradient_step | 
					
						
							|  |  |  |     # n steps = batch_size * gradient_step * n image processed | 
					
						
							|  |  |  |     steps_per_epoch = len(ds) // batch_size // gradient_step | 
					
						
							|  |  |  |     max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step | 
					
						
							|  |  |  |     loss_step = 0 | 
					
						
							|  |  |  |     _loss_step = 0 #internal | 
					
						
							| 
									
										
										
										
											2022-10-10 00:07:52 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  |     last_saved_file = "<none>" | 
					
						
							|  |  |  |     last_saved_image = "<none>" | 
					
						
							| 
									
										
										
										
											2022-10-24 23:22:58 -07:00
										 |  |  |     forced_filename = "<none>" | 
					
						
							| 
									
										
										
										
											2022-10-14 14:55:05 +01:00
										 |  |  |     embedding_yet_to_be_embedded = False | 
					
						
							| 
									
										
										
										
											2022-10-10 00:07:52 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-04 17:58:07 +03:00
										 |  |  |     is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'} | 
					
						
							| 
									
										
										
										
											2022-10-23 14:05:25 +02:00
										 |  |  |     img_c = None | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     pbar = tqdm.tqdm(total=steps - initial_step) | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2023-01-18 23:04:24 +03:00
										 |  |  |         sd_hijack_checkpoint.add() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-10 11:37:18 +03:00
										 |  |  |         for _ in range((steps-initial_step) * gradient_step): | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |             if scheduler.finished: | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |             if shared.state.interrupted: | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |             for j, batch in enumerate(dl): | 
					
						
							|  |  |  |                 # works as a drop_last=True for gradient accumulation | 
					
						
							|  |  |  |                 if j == max_steps_per_epoch: | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  |                 scheduler.apply(optimizer, embedding.step) | 
					
						
							|  |  |  |                 if scheduler.finished: | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  |                 if shared.state.interrupted: | 
					
						
							|  |  |  |                     break | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-04 19:56:35 +03:00
										 |  |  |                 if clip_grad: | 
					
						
							|  |  |  |                     clip_grad_sched.step(embedding.step) | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-28 21:36:35 -05:00
										 |  |  |                 with devices.autocast(): | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                     x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) | 
					
						
							| 
									
										
										
										
											2023-01-12 16:29:00 +01:00
										 |  |  |                     if use_weight: | 
					
						
							|  |  |  |                         w = batch.weight.to(devices.device, non_blocking=pin_memory) | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                     c = shared.sd_model.cond_stage_model(batch.cond_text) | 
					
						
							| 
									
										
										
										
											2022-10-12 13:15:35 +01:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-04 17:58:07 +03:00
										 |  |  |                     if is_training_inpainting_model: | 
					
						
							|  |  |  |                         if img_c is None: | 
					
						
							|  |  |  |                             img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) | 
					
						
							| 
									
										
										
										
											2022-10-20 22:37:16 +02:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-04 17:58:07 +03:00
										 |  |  |                         cond = {"c_concat": [img_c], "c_crossattn": [c]} | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         cond = c | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-12 16:29:00 +01:00
										 |  |  |                     if use_weight: | 
					
						
							|  |  |  |                         loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step | 
					
						
							|  |  |  |                         del w | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         loss = shared.sd_model.forward(x, cond)[0] / gradient_step | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                     del x | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                     _loss_step += loss.item() | 
					
						
							|  |  |  |                 scaler.scale(loss).backward() | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                 # go back until we reach gradient accumulation steps | 
					
						
							|  |  |  |                 if (j + 1) % gradient_step != 0: | 
					
						
							|  |  |  |                     continue | 
					
						
							| 
									
										
										
										
											2023-05-11 18:28:15 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-04 19:56:35 +03:00
										 |  |  |                 if clip_grad: | 
					
						
							|  |  |  |                     clip_grad(embedding.vec, clip_grad_sched.learn_rate) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                 scaler.step(optimizer) | 
					
						
							|  |  |  |                 scaler.update() | 
					
						
							|  |  |  |                 embedding.step += 1 | 
					
						
							|  |  |  |                 pbar.update() | 
					
						
							|  |  |  |                 optimizer.zero_grad(set_to_none=True) | 
					
						
							|  |  |  |                 loss_step = _loss_step | 
					
						
							|  |  |  |                 _loss_step = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 steps_done = embedding.step + 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 epoch_num = embedding.step // steps_per_epoch | 
					
						
							|  |  |  |                 epoch_step = embedding.step % steps_per_epoch | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-13 14:32:15 +03:00
										 |  |  |                 description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" | 
					
						
							| 
									
										
										
										
											2023-01-11 10:28:55 -05:00
										 |  |  |                 pbar.set_description(description) | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                 if embedding_dir is not None and steps_done % save_embedding_every == 0: | 
					
						
							|  |  |  |                     # Before saving, change name to match current checkpoint. | 
					
						
							|  |  |  |                     embedding_name_every = f'{embedding_name}-{steps_done}' | 
					
						
							|  |  |  |                     last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |                     save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                     embedding_yet_to_be_embedded = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { | 
					
						
							|  |  |  |                     "loss": f"{loss_step:.7f}", | 
					
						
							|  |  |  |                     "learn_rate": scheduler.learn_rate | 
					
						
							|  |  |  |                 }) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if images_dir is not None and steps_done % create_image_every == 0: | 
					
						
							|  |  |  |                     forced_filename = f'{embedding_name}-{steps_done}' | 
					
						
							|  |  |  |                     last_saved_image = os.path.join(images_dir, forced_filename) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     shared.sd_model.first_stage_model.to(devices.device) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     p = processing.StableDiffusionProcessingTxt2Img( | 
					
						
							|  |  |  |                         sd_model=shared.sd_model, | 
					
						
							|  |  |  |                         do_not_save_grid=True, | 
					
						
							|  |  |  |                         do_not_save_samples=True, | 
					
						
							|  |  |  |                         do_not_reload_embeddings=True, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if preview_from_txt2img: | 
					
						
							|  |  |  |                         p.prompt = preview_prompt | 
					
						
							|  |  |  |                         p.negative_prompt = preview_negative_prompt | 
					
						
							|  |  |  |                         p.steps = preview_steps | 
					
						
							|  |  |  |                         p.sampler_name = sd_samplers.samplers[preview_sampler_index].name | 
					
						
							|  |  |  |                         p.cfg_scale = preview_cfg_scale | 
					
						
							|  |  |  |                         p.seed = preview_seed | 
					
						
							|  |  |  |                         p.width = preview_width | 
					
						
							|  |  |  |                         p.height = preview_height | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         p.prompt = batch.cond_text[0] | 
					
						
							|  |  |  |                         p.steps = 20 | 
					
						
							| 
									
										
										
										
											2023-01-09 22:52:23 +03:00
										 |  |  |                         p.width = training_width | 
					
						
							|  |  |  |                         p.height = training_height | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     preview_text = p.prompt | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     processed = processing.process_images(p) | 
					
						
							|  |  |  |                     image = processed.images[0] if len(processed.images) > 0 else None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if unload: | 
					
						
							|  |  |  |                         shared.sd_model.first_stage_model.to(devices.cpu) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if image is not None: | 
					
						
							| 
									
										
										
										
											2023-01-15 18:50:56 +03:00
										 |  |  |                         shared.state.assign_current_image(image) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                         last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) | 
					
						
							|  |  |  |                         last_saved_image += f", prompt: {preview_text}" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-13 14:57:38 +03:00
										 |  |  |                         if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: | 
					
						
							|  |  |  |                             tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                     if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         info = PngImagePlugin.PngInfo() | 
					
						
							|  |  |  |                         data = torch.load(last_saved_file) | 
					
						
							|  |  |  |                         info.add_text("sd-ti-embedding", embedding_to_b64(data)) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-09 22:17:58 +03:00
										 |  |  |                         title = f"<{data.get('name', '???')}>" | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  | 
 | 
					
						
							|  |  |  |                         try: | 
					
						
							|  |  |  |                             vectorSize = list(data['string_to_param'].values())[0].shape[0] | 
					
						
							| 
									
										
										
										
											2023-05-10 07:52:45 +03:00
										 |  |  |                         except Exception: | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |                             vectorSize = '?' | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         checkpoint = sd_models.select_checkpoint() | 
					
						
							|  |  |  |                         footer_left = checkpoint.model_name | 
					
						
							| 
									
										
										
										
											2023-05-09 22:17:58 +03:00
										 |  |  |                         footer_mid = f'[{checkpoint.shorthash}]' | 
					
						
							|  |  |  |                         footer_right = f'{vectorSize}v {steps_done}s' | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  | 
 | 
					
						
							|  |  |  |                         captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) | 
					
						
							|  |  |  |                         captioned_image = insert_image_data_embed(captioned_image, data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) | 
					
						
							|  |  |  |                         embedding_yet_to_be_embedded = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) | 
					
						
							|  |  |  |                     last_saved_image += f", prompt: {preview_text}" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 shared.state.job_no = embedding.step | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 shared.state.textinfo = f"""
 | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | <p> | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  | Loss: {loss_step:.7f}<br/> | 
					
						
							| 
									
										
										
										
											2022-11-23 02:49:01 +09:00
										 |  |  | Step: {steps_done}<br/> | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  | Last prompt: {html.escape(batch.cond_text[0])}<br/> | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | Last saved embedding: {html.escape(last_saved_file)}<br/> | 
					
						
							|  |  |  | Last saved image: {html.escape(last_saved_image)}<br/> | 
					
						
							|  |  |  | </p> | 
					
						
							|  |  |  | """
 | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |         filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |         save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     except Exception: | 
					
						
							| 
									
										
										
										
											2023-05-31 19:56:37 +03:00
										 |  |  |         errors.report("Error training embedding", exc_info=True) | 
					
						
							| 
									
										
										
										
											2022-11-20 12:35:26 +09:00
										 |  |  |     finally: | 
					
						
							|  |  |  |         pbar.leave = False | 
					
						
							|  |  |  |         pbar.close() | 
					
						
							|  |  |  |         shared.sd_model.first_stage_model.to(devices.device) | 
					
						
							| 
									
										
										
										
											2022-12-03 10:19:51 +03:00
										 |  |  |         shared.parallel_processing_allowed = old_parallel_processing_allowed | 
					
						
							| 
									
										
										
										
											2023-01-18 23:04:24 +03:00
										 |  |  |         sd_hijack_checkpoint.remove() | 
					
						
							| 
									
										
										
										
											2022-10-02 15:03:39 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |     return embedding, filename | 
					
						
							| 
									
										
										
										
											2022-10-30 00:49:29 +07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-18 23:04:24 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  | def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): | 
					
						
							| 
									
										
										
										
											2022-10-30 00:49:29 +07:00
										 |  |  |     old_embedding_name = embedding.name | 
					
						
							|  |  |  |     old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None | 
					
						
							|  |  |  |     old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None | 
					
						
							|  |  |  |     old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2023-01-14 09:56:59 +03:00
										 |  |  |         embedding.sd_checkpoint = checkpoint.shorthash | 
					
						
							| 
									
										
										
										
											2022-10-30 00:49:29 +07:00
										 |  |  |         embedding.sd_checkpoint_name = checkpoint.model_name | 
					
						
							|  |  |  |         if remove_cached_checksum: | 
					
						
							|  |  |  |             embedding.cached_checksum = None | 
					
						
							|  |  |  |         embedding.name = embedding_name | 
					
						
							| 
									
										
										
										
											2023-01-03 10:26:37 +01:00
										 |  |  |         embedding.optimizer_state_dict = optimizer.state_dict() | 
					
						
							| 
									
										
										
										
											2022-10-30 00:49:29 +07:00
										 |  |  |         embedding.save(filename) | 
					
						
							|  |  |  |     except: | 
					
						
							|  |  |  |         embedding.sd_checkpoint = old_sd_checkpoint | 
					
						
							|  |  |  |         embedding.sd_checkpoint_name = old_sd_checkpoint_name | 
					
						
							|  |  |  |         embedding.name = old_embedding_name | 
					
						
							|  |  |  |         embedding.cached_checksum = old_cached_checksum | 
					
						
							|  |  |  |         raise |