mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 03:55:05 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			658 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			658 lines
		
	
	
		
			29 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import sys
 | 
						|
import traceback
 | 
						|
import inspect
 | 
						|
from collections import namedtuple
 | 
						|
 | 
						|
import torch
 | 
						|
import tqdm
 | 
						|
import html
 | 
						|
import datetime
 | 
						|
import csv
 | 
						|
import safetensors.torch
 | 
						|
 | 
						|
import numpy as np
 | 
						|
from PIL import Image, PngImagePlugin
 | 
						|
from torch.utils.tensorboard import SummaryWriter
 | 
						|
 | 
						|
from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint
 | 
						|
import modules.textual_inversion.dataset
 | 
						|
from modules.textual_inversion.learn_schedule import LearnRateScheduler
 | 
						|
 | 
						|
from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
 | 
						|
from modules.textual_inversion.logging import save_settings_to_file
 | 
						|
 | 
						|
 | 
						|
TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
 | 
						|
textual_inversion_templates = {}
 | 
						|
 | 
						|
 | 
						|
def list_textual_inversion_templates():
 | 
						|
    textual_inversion_templates.clear()
 | 
						|
 | 
						|
    for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
 | 
						|
        for fn in fns:
 | 
						|
            path = os.path.join(root, fn)
 | 
						|
 | 
						|
            textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
 | 
						|
 | 
						|
    return textual_inversion_templates
 | 
						|
 | 
						|
 | 
						|
class Embedding:
 | 
						|
    def __init__(self, vec, name, step=None):
 | 
						|
        self.vec = vec
 | 
						|
        self.name = name
 | 
						|
        self.step = step
 | 
						|
        self.shape = None
 | 
						|
        self.vectors = 0
 | 
						|
        self.cached_checksum = None
 | 
						|
        self.sd_checkpoint = None
 | 
						|
        self.sd_checkpoint_name = None
 | 
						|
        self.optimizer_state_dict = None
 | 
						|
        self.filename = None
 | 
						|
 | 
						|
    def save(self, filename):
 | 
						|
        embedding_data = {
 | 
						|
            "string_to_token": {"*": 265},
 | 
						|
            "string_to_param": {"*": self.vec},
 | 
						|
            "name": self.name,
 | 
						|
            "step": self.step,
 | 
						|
            "sd_checkpoint": self.sd_checkpoint,
 | 
						|
            "sd_checkpoint_name": self.sd_checkpoint_name,
 | 
						|
        }
 | 
						|
 | 
						|
        torch.save(embedding_data, filename)
 | 
						|
 | 
						|
        if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
 | 
						|
            optimizer_saved_dict = {
 | 
						|
                'hash': self.checksum(),
 | 
						|
                'optimizer_state_dict': self.optimizer_state_dict,
 | 
						|
            }
 | 
						|
            torch.save(optimizer_saved_dict, filename + '.optim')
 | 
						|
 | 
						|
    def checksum(self):
 | 
						|
        if self.cached_checksum is not None:
 | 
						|
            return self.cached_checksum
 | 
						|
 | 
						|
        def const_hash(a):
 | 
						|
            r = 0
 | 
						|
            for v in a:
 | 
						|
                r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
 | 
						|
            return r
 | 
						|
 | 
						|
        self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
 | 
						|
        return self.cached_checksum
 | 
						|
 | 
						|
 | 
						|
class DirWithTextualInversionEmbeddings:
 | 
						|
    def __init__(self, path):
 | 
						|
        self.path = path
 | 
						|
        self.mtime = None
 | 
						|
 | 
						|
    def has_changed(self):
 | 
						|
        if not os.path.isdir(self.path):
 | 
						|
            return False
 | 
						|
 | 
						|
        mt = os.path.getmtime(self.path)
 | 
						|
        if self.mtime is None or mt > self.mtime:
 | 
						|
            return True
 | 
						|
 | 
						|
    def update(self):
 | 
						|
        if not os.path.isdir(self.path):
 | 
						|
            return
 | 
						|
 | 
						|
        self.mtime = os.path.getmtime(self.path)
 | 
						|
 | 
						|
 | 
						|
class EmbeddingDatabase:
 | 
						|
    def __init__(self):
 | 
						|
        self.ids_lookup = {}
 | 
						|
        self.word_embeddings = {}
 | 
						|
        self.skipped_embeddings = {}
 | 
						|
        self.expected_shape = -1
 | 
						|
        self.embedding_dirs = {}
 | 
						|
        self.previously_displayed_embeddings = ()
 | 
						|
 | 
						|
    def add_embedding_dir(self, path):
 | 
						|
        self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
 | 
						|
 | 
						|
    def clear_embedding_dirs(self):
 | 
						|
        self.embedding_dirs.clear()
 | 
						|
 | 
						|
    def register_embedding(self, embedding, model):
 | 
						|
        self.word_embeddings[embedding.name] = embedding
 | 
						|
 | 
						|
        ids = model.cond_stage_model.tokenize([embedding.name])[0]
 | 
						|
 | 
						|
        first_id = ids[0]
 | 
						|
        if first_id not in self.ids_lookup:
 | 
						|
            self.ids_lookup[first_id] = []
 | 
						|
 | 
						|
        self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
 | 
						|
 | 
						|
        return embedding
 | 
						|
 | 
						|
    def get_expected_shape(self):
 | 
						|
        vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
 | 
						|
        return vec.shape[1]
 | 
						|
 | 
						|
    def load_from_file(self, path, filename):
 | 
						|
        name, ext = os.path.splitext(filename)
 | 
						|
        ext = ext.upper()
 | 
						|
 | 
						|
        if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
 | 
						|
            _, second_ext = os.path.splitext(name)
 | 
						|
            if second_ext.upper() == '.PREVIEW':
 | 
						|
                return
 | 
						|
 | 
						|
            embed_image = Image.open(path)
 | 
						|
            if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
 | 
						|
                data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
 | 
						|
                name = data.get('name', name)
 | 
						|
            else:
 | 
						|
                data = extract_image_data_embed(embed_image)
 | 
						|
                name = data.get('name', name)
 | 
						|
        elif ext in ['.BIN', '.PT']:
 | 
						|
            data = torch.load(path, map_location="cpu")
 | 
						|
        elif ext in ['.SAFETENSORS']:
 | 
						|
            data = safetensors.torch.load_file(path, device="cpu")
 | 
						|
        else:
 | 
						|
            return
 | 
						|
 | 
						|
        # textual inversion embeddings
 | 
						|
        if 'string_to_param' in data:
 | 
						|
            param_dict = data['string_to_param']
 | 
						|
            if hasattr(param_dict, '_parameters'):
 | 
						|
                param_dict = getattr(param_dict, '_parameters')  # fix for torch 1.12.1 loading saved file from torch 1.11
 | 
						|
            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
 | 
						|
            emb = next(iter(param_dict.items()))[1]
 | 
						|
        # diffuser concepts
 | 
						|
        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
 | 
						|
            assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
 | 
						|
 | 
						|
            emb = next(iter(data.values()))
 | 
						|
            if len(emb.shape) == 1:
 | 
						|
                emb = emb.unsqueeze(0)
 | 
						|
        else:
 | 
						|
            raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
 | 
						|
 | 
						|
        vec = emb.detach().to(devices.device, dtype=torch.float32)
 | 
						|
        embedding = Embedding(vec, name)
 | 
						|
        embedding.step = data.get('step', None)
 | 
						|
        embedding.sd_checkpoint = data.get('sd_checkpoint', None)
 | 
						|
        embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
 | 
						|
        embedding.vectors = vec.shape[0]
 | 
						|
        embedding.shape = vec.shape[-1]
 | 
						|
        embedding.filename = path
 | 
						|
 | 
						|
        if self.expected_shape == -1 or self.expected_shape == embedding.shape:
 | 
						|
            self.register_embedding(embedding, shared.sd_model)
 | 
						|
        else:
 | 
						|
            self.skipped_embeddings[name] = embedding
 | 
						|
 | 
						|
    def load_from_dir(self, embdir):
 | 
						|
        if not os.path.isdir(embdir.path):
 | 
						|
            return
 | 
						|
 | 
						|
        for root, dirs, fns in os.walk(embdir.path, followlinks=True):
 | 
						|
            for fn in fns:
 | 
						|
                try:
 | 
						|
                    fullfn = os.path.join(root, fn)
 | 
						|
 | 
						|
                    if os.stat(fullfn).st_size == 0:
 | 
						|
                        continue
 | 
						|
 | 
						|
                    self.load_from_file(fullfn, fn)
 | 
						|
                except Exception:
 | 
						|
                    print(f"Error loading embedding {fn}:", file=sys.stderr)
 | 
						|
                    print(traceback.format_exc(), file=sys.stderr)
 | 
						|
                    continue
 | 
						|
 | 
						|
    def load_textual_inversion_embeddings(self, force_reload=False):
 | 
						|
        if not force_reload:
 | 
						|
            need_reload = False
 | 
						|
            for path, embdir in self.embedding_dirs.items():
 | 
						|
                if embdir.has_changed():
 | 
						|
                    need_reload = True
 | 
						|
                    break
 | 
						|
 | 
						|
            if not need_reload:
 | 
						|
                return
 | 
						|
 | 
						|
        self.ids_lookup.clear()
 | 
						|
        self.word_embeddings.clear()
 | 
						|
        self.skipped_embeddings.clear()
 | 
						|
        self.expected_shape = self.get_expected_shape()
 | 
						|
 | 
						|
        for path, embdir in self.embedding_dirs.items():
 | 
						|
            self.load_from_dir(embdir)
 | 
						|
            embdir.update()
 | 
						|
 | 
						|
        displayed_embeddings = (tuple(self.word_embeddings.keys()), tuple(self.skipped_embeddings.keys()))
 | 
						|
        if self.previously_displayed_embeddings != displayed_embeddings:
 | 
						|
            self.previously_displayed_embeddings = displayed_embeddings
 | 
						|
            print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
 | 
						|
            if len(self.skipped_embeddings) > 0:
 | 
						|
                print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
 | 
						|
 | 
						|
    def find_embedding_at_position(self, tokens, offset):
 | 
						|
        token = tokens[offset]
 | 
						|
        possible_matches = self.ids_lookup.get(token, None)
 | 
						|
 | 
						|
        if possible_matches is None:
 | 
						|
            return None, None
 | 
						|
 | 
						|
        for ids, embedding in possible_matches:
 | 
						|
            if tokens[offset:offset + len(ids)] == ids:
 | 
						|
                return embedding, len(ids)
 | 
						|
 | 
						|
        return None, None
 | 
						|
 | 
						|
 | 
						|
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
 | 
						|
    cond_model = shared.sd_model.cond_stage_model
 | 
						|
 | 
						|
    with devices.autocast():
 | 
						|
        cond_model([""])  # will send cond model to GPU if lowvram/medvram is active
 | 
						|
 | 
						|
    #cond_model expects at least some text, so we provide '*' as backup.
 | 
						|
    embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
 | 
						|
    vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
 | 
						|
 | 
						|
    #Only copy if we provided an init_text, otherwise keep vectors as zeros
 | 
						|
    if init_text:
 | 
						|
        for i in range(num_vectors_per_token):
 | 
						|
            vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
 | 
						|
 | 
						|
    # Remove illegal characters from name.
 | 
						|
    name = "".join( x for x in name if (x.isalnum() or x in "._- "))
 | 
						|
    fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
 | 
						|
    if not overwrite_old:
 | 
						|
        assert not os.path.exists(fn), f"file {fn} already exists"
 | 
						|
 | 
						|
    embedding = Embedding(vec, name)
 | 
						|
    embedding.step = 0
 | 
						|
    embedding.save(fn)
 | 
						|
 | 
						|
    return fn
 | 
						|
 | 
						|
 | 
						|
def write_loss(log_directory, filename, step, epoch_len, values):
 | 
						|
    if shared.opts.training_write_csv_every == 0:
 | 
						|
        return
 | 
						|
 | 
						|
    if step % shared.opts.training_write_csv_every != 0:
 | 
						|
        return
 | 
						|
    write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
 | 
						|
 | 
						|
    with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
 | 
						|
        csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
 | 
						|
 | 
						|
        if write_csv_header:
 | 
						|
            csv_writer.writeheader()
 | 
						|
 | 
						|
        epoch = (step - 1) // epoch_len
 | 
						|
        epoch_step = (step - 1) % epoch_len
 | 
						|
 | 
						|
        csv_writer.writerow({
 | 
						|
            "step": step,
 | 
						|
            "epoch": epoch,
 | 
						|
            "epoch_step": epoch_step,
 | 
						|
            **values,
 | 
						|
        })
 | 
						|
 | 
						|
def tensorboard_setup(log_directory):
 | 
						|
    os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
 | 
						|
    return SummaryWriter(
 | 
						|
            log_dir=os.path.join(log_directory, "tensorboard"),
 | 
						|
            flush_secs=shared.opts.training_tensorboard_flush_every)
 | 
						|
 | 
						|
def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
 | 
						|
    tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
 | 
						|
    tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
 | 
						|
    tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
 | 
						|
    tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)
 | 
						|
 | 
						|
def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
 | 
						|
    tensorboard_writer.add_scalar(tag=tag, 
 | 
						|
        scalar_value=value, global_step=step)
 | 
						|
 | 
						|
def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
 | 
						|
    # Convert a pil image to a torch tensor
 | 
						|
    img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
 | 
						|
    img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], 
 | 
						|
        len(pil_image.getbands()))
 | 
						|
    img_tensor = img_tensor.permute((2, 0, 1))
 | 
						|
                
 | 
						|
    tensorboard_writer.add_image(tag, img_tensor, global_step=step)
 | 
						|
 | 
						|
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
 | 
						|
    assert model_name, f"{name} not selected"
 | 
						|
    assert learn_rate, "Learning rate is empty or 0"
 | 
						|
    assert isinstance(batch_size, int), "Batch size must be integer"
 | 
						|
    assert batch_size > 0, "Batch size must be positive"
 | 
						|
    assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
 | 
						|
    assert gradient_step > 0, "Gradient accumulation step must be positive"
 | 
						|
    assert data_root, "Dataset directory is empty"
 | 
						|
    assert os.path.isdir(data_root), "Dataset directory doesn't exist"
 | 
						|
    assert os.listdir(data_root), "Dataset directory is empty"
 | 
						|
    assert template_filename, "Prompt template file not selected"
 | 
						|
    assert template_file, f"Prompt template file {template_filename} not found"
 | 
						|
    assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
 | 
						|
    assert steps, "Max steps is empty or 0"
 | 
						|
    assert isinstance(steps, int), "Max steps must be integer"
 | 
						|
    assert steps > 0, "Max steps must be positive"
 | 
						|
    assert isinstance(save_model_every, int), "Save {name} must be integer"
 | 
						|
    assert save_model_every >= 0, "Save {name} must be positive or 0"
 | 
						|
    assert isinstance(create_image_every, int), "Create image must be integer"
 | 
						|
    assert create_image_every >= 0, "Create image must be positive or 0"
 | 
						|
    if save_model_every or create_image_every:
 | 
						|
        assert log_directory, "Log directory is empty"
 | 
						|
 | 
						|
 | 
						|
def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
 | 
						|
    save_embedding_every = save_embedding_every or 0
 | 
						|
    create_image_every = create_image_every or 0
 | 
						|
    template_file = textual_inversion_templates.get(template_filename, None)
 | 
						|
    validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
 | 
						|
    template_file = template_file.path
 | 
						|
 | 
						|
    shared.state.job = "train-embedding"
 | 
						|
    shared.state.textinfo = "Initializing textual inversion training..."
 | 
						|
    shared.state.job_count = steps
 | 
						|
 | 
						|
    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
 | 
						|
 | 
						|
    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
 | 
						|
    unload = shared.opts.unload_models_when_training
 | 
						|
 | 
						|
    if save_embedding_every > 0:
 | 
						|
        embedding_dir = os.path.join(log_directory, "embeddings")
 | 
						|
        os.makedirs(embedding_dir, exist_ok=True)
 | 
						|
    else:
 | 
						|
        embedding_dir = None
 | 
						|
 | 
						|
    if create_image_every > 0:
 | 
						|
        images_dir = os.path.join(log_directory, "images")
 | 
						|
        os.makedirs(images_dir, exist_ok=True)
 | 
						|
    else:
 | 
						|
        images_dir = None
 | 
						|
 | 
						|
    if create_image_every > 0 and save_image_with_stored_embedding:
 | 
						|
        images_embeds_dir = os.path.join(log_directory, "image_embeddings")
 | 
						|
        os.makedirs(images_embeds_dir, exist_ok=True)
 | 
						|
    else:
 | 
						|
        images_embeds_dir = None
 | 
						|
 | 
						|
    hijack = sd_hijack.model_hijack
 | 
						|
 | 
						|
    embedding = hijack.embedding_db.word_embeddings[embedding_name]
 | 
						|
    checkpoint = sd_models.select_checkpoint()
 | 
						|
 | 
						|
    initial_step = embedding.step or 0
 | 
						|
    if initial_step >= steps:
 | 
						|
        shared.state.textinfo = "Model has already been trained beyond specified max steps"
 | 
						|
        return embedding, filename
 | 
						|
    
 | 
						|
    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
 | 
						|
    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
 | 
						|
        torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
 | 
						|
        None
 | 
						|
    if clip_grad:
 | 
						|
        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
 | 
						|
    # dataset loading may take a while, so input validations and early returns should be done before this
 | 
						|
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
 | 
						|
    old_parallel_processing_allowed = shared.parallel_processing_allowed
 | 
						|
    
 | 
						|
    if shared.opts.training_enable_tensorboard:
 | 
						|
        tensorboard_writer = tensorboard_setup(log_directory)
 | 
						|
 | 
						|
    pin_memory = shared.opts.pin_memory
 | 
						|
 | 
						|
    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize, use_weight=use_weight)
 | 
						|
 | 
						|
    if shared.opts.save_training_settings_to_txt:
 | 
						|
        save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
 | 
						|
 | 
						|
    latent_sampling_method = ds.latent_sampling_method
 | 
						|
 | 
						|
    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
 | 
						|
 | 
						|
    if unload:
 | 
						|
        shared.parallel_processing_allowed = False
 | 
						|
        shared.sd_model.first_stage_model.to(devices.cpu)
 | 
						|
 | 
						|
    embedding.vec.requires_grad = True
 | 
						|
    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
 | 
						|
    if shared.opts.save_optimizer_state:
 | 
						|
        optimizer_state_dict = None
 | 
						|
        if os.path.exists(filename + '.optim'):
 | 
						|
            optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
 | 
						|
            if embedding.checksum() == optimizer_saved_dict.get('hash', None):
 | 
						|
                optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
 | 
						|
    
 | 
						|
        if optimizer_state_dict is not None:
 | 
						|
            optimizer.load_state_dict(optimizer_state_dict)
 | 
						|
            print("Loaded existing optimizer from checkpoint")
 | 
						|
        else:
 | 
						|
            print("No saved optimizer exists in checkpoint")
 | 
						|
 | 
						|
    scaler = torch.cuda.amp.GradScaler()
 | 
						|
 | 
						|
    batch_size = ds.batch_size
 | 
						|
    gradient_step = ds.gradient_step
 | 
						|
    # n steps = batch_size * gradient_step * n image processed
 | 
						|
    steps_per_epoch = len(ds) // batch_size // gradient_step
 | 
						|
    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
 | 
						|
    loss_step = 0
 | 
						|
    _loss_step = 0 #internal
 | 
						|
 | 
						|
    last_saved_file = "<none>"
 | 
						|
    last_saved_image = "<none>"
 | 
						|
    forced_filename = "<none>"
 | 
						|
    embedding_yet_to_be_embedded = False
 | 
						|
 | 
						|
    is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
 | 
						|
    img_c = None
 | 
						|
 | 
						|
    pbar = tqdm.tqdm(total=steps - initial_step)
 | 
						|
    try:
 | 
						|
        sd_hijack_checkpoint.add()
 | 
						|
 | 
						|
        for i in range((steps-initial_step) * gradient_step):
 | 
						|
            if scheduler.finished:
 | 
						|
                break
 | 
						|
            if shared.state.interrupted:
 | 
						|
                break
 | 
						|
            for j, batch in enumerate(dl):
 | 
						|
                # works as a drop_last=True for gradient accumulation
 | 
						|
                if j == max_steps_per_epoch:
 | 
						|
                    break
 | 
						|
                scheduler.apply(optimizer, embedding.step)
 | 
						|
                if scheduler.finished:
 | 
						|
                    break
 | 
						|
                if shared.state.interrupted:
 | 
						|
                    break
 | 
						|
 | 
						|
                if clip_grad:
 | 
						|
                    clip_grad_sched.step(embedding.step)
 | 
						|
            
 | 
						|
                with devices.autocast():
 | 
						|
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
 | 
						|
                    if use_weight:
 | 
						|
                        w = batch.weight.to(devices.device, non_blocking=pin_memory)
 | 
						|
                    c = shared.sd_model.cond_stage_model(batch.cond_text)
 | 
						|
 | 
						|
                    if is_training_inpainting_model:
 | 
						|
                        if img_c is None:
 | 
						|
                            img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
 | 
						|
 | 
						|
                        cond = {"c_concat": [img_c], "c_crossattn": [c]}
 | 
						|
                    else:
 | 
						|
                        cond = c
 | 
						|
 | 
						|
                    if use_weight:
 | 
						|
                        loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
 | 
						|
                        del w
 | 
						|
                    else:
 | 
						|
                        loss = shared.sd_model.forward(x, cond)[0] / gradient_step
 | 
						|
                    del x
 | 
						|
 | 
						|
                    _loss_step += loss.item()
 | 
						|
                scaler.scale(loss).backward()
 | 
						|
 | 
						|
                # go back until we reach gradient accumulation steps
 | 
						|
                if (j + 1) % gradient_step != 0:
 | 
						|
                    continue
 | 
						|
                
 | 
						|
                if clip_grad:
 | 
						|
                    clip_grad(embedding.vec, clip_grad_sched.learn_rate)
 | 
						|
 | 
						|
                scaler.step(optimizer)
 | 
						|
                scaler.update()
 | 
						|
                embedding.step += 1
 | 
						|
                pbar.update()
 | 
						|
                optimizer.zero_grad(set_to_none=True)
 | 
						|
                loss_step = _loss_step
 | 
						|
                _loss_step = 0
 | 
						|
 | 
						|
                steps_done = embedding.step + 1
 | 
						|
 | 
						|
                epoch_num = embedding.step // steps_per_epoch
 | 
						|
                epoch_step = embedding.step % steps_per_epoch
 | 
						|
 | 
						|
                description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
 | 
						|
                pbar.set_description(description)
 | 
						|
                if embedding_dir is not None and steps_done % save_embedding_every == 0:
 | 
						|
                    # Before saving, change name to match current checkpoint.
 | 
						|
                    embedding_name_every = f'{embedding_name}-{steps_done}'
 | 
						|
                    last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
 | 
						|
                    save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
 | 
						|
                    embedding_yet_to_be_embedded = True
 | 
						|
 | 
						|
                write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
 | 
						|
                    "loss": f"{loss_step:.7f}",
 | 
						|
                    "learn_rate": scheduler.learn_rate
 | 
						|
                })
 | 
						|
 | 
						|
                if images_dir is not None and steps_done % create_image_every == 0:
 | 
						|
                    forced_filename = f'{embedding_name}-{steps_done}'
 | 
						|
                    last_saved_image = os.path.join(images_dir, forced_filename)
 | 
						|
 | 
						|
                    shared.sd_model.first_stage_model.to(devices.device)
 | 
						|
 | 
						|
                    p = processing.StableDiffusionProcessingTxt2Img(
 | 
						|
                        sd_model=shared.sd_model,
 | 
						|
                        do_not_save_grid=True,
 | 
						|
                        do_not_save_samples=True,
 | 
						|
                        do_not_reload_embeddings=True,
 | 
						|
                    )
 | 
						|
 | 
						|
                    if preview_from_txt2img:
 | 
						|
                        p.prompt = preview_prompt
 | 
						|
                        p.negative_prompt = preview_negative_prompt
 | 
						|
                        p.steps = preview_steps
 | 
						|
                        p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
 | 
						|
                        p.cfg_scale = preview_cfg_scale
 | 
						|
                        p.seed = preview_seed
 | 
						|
                        p.width = preview_width
 | 
						|
                        p.height = preview_height
 | 
						|
                    else:
 | 
						|
                        p.prompt = batch.cond_text[0]
 | 
						|
                        p.steps = 20
 | 
						|
                        p.width = training_width
 | 
						|
                        p.height = training_height
 | 
						|
 | 
						|
                    preview_text = p.prompt
 | 
						|
 | 
						|
                    processed = processing.process_images(p)
 | 
						|
                    image = processed.images[0] if len(processed.images) > 0 else None
 | 
						|
 | 
						|
                    if unload:
 | 
						|
                        shared.sd_model.first_stage_model.to(devices.cpu)
 | 
						|
 | 
						|
                    if image is not None:
 | 
						|
                        shared.state.assign_current_image(image)
 | 
						|
 | 
						|
                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
 | 
						|
                        last_saved_image += f", prompt: {preview_text}"
 | 
						|
 | 
						|
                        if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
 | 
						|
                            tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
 | 
						|
 | 
						|
                    if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
 | 
						|
 | 
						|
                        last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
 | 
						|
 | 
						|
                        info = PngImagePlugin.PngInfo()
 | 
						|
                        data = torch.load(last_saved_file)
 | 
						|
                        info.add_text("sd-ti-embedding", embedding_to_b64(data))
 | 
						|
 | 
						|
                        title = "<{}>".format(data.get('name', '???'))
 | 
						|
 | 
						|
                        try:
 | 
						|
                            vectorSize = list(data['string_to_param'].values())[0].shape[0]
 | 
						|
                        except Exception as e:
 | 
						|
                            vectorSize = '?'
 | 
						|
 | 
						|
                        checkpoint = sd_models.select_checkpoint()
 | 
						|
                        footer_left = checkpoint.model_name
 | 
						|
                        footer_mid = '[{}]'.format(checkpoint.shorthash)
 | 
						|
                        footer_right = '{}v {}s'.format(vectorSize, steps_done)
 | 
						|
 | 
						|
                        captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
 | 
						|
                        captioned_image = insert_image_data_embed(captioned_image, data)
 | 
						|
 | 
						|
                        captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
 | 
						|
                        embedding_yet_to_be_embedded = False
 | 
						|
 | 
						|
                    last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
 | 
						|
                    last_saved_image += f", prompt: {preview_text}"
 | 
						|
 | 
						|
                shared.state.job_no = embedding.step
 | 
						|
 | 
						|
                shared.state.textinfo = f"""
 | 
						|
<p>
 | 
						|
Loss: {loss_step:.7f}<br/>
 | 
						|
Step: {steps_done}<br/>
 | 
						|
Last prompt: {html.escape(batch.cond_text[0])}<br/>
 | 
						|
Last saved embedding: {html.escape(last_saved_file)}<br/>
 | 
						|
Last saved image: {html.escape(last_saved_image)}<br/>
 | 
						|
</p>
 | 
						|
"""
 | 
						|
        filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
 | 
						|
        save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
 | 
						|
    except Exception:
 | 
						|
        print(traceback.format_exc(), file=sys.stderr)
 | 
						|
        pass
 | 
						|
    finally:
 | 
						|
        pbar.leave = False
 | 
						|
        pbar.close()
 | 
						|
        shared.sd_model.first_stage_model.to(devices.device)
 | 
						|
        shared.parallel_processing_allowed = old_parallel_processing_allowed
 | 
						|
        sd_hijack_checkpoint.remove()
 | 
						|
 | 
						|
    return embedding, filename
 | 
						|
 | 
						|
 | 
						|
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
 | 
						|
    old_embedding_name = embedding.name
 | 
						|
    old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
 | 
						|
    old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
 | 
						|
    old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
 | 
						|
    try:
 | 
						|
        embedding.sd_checkpoint = checkpoint.shorthash
 | 
						|
        embedding.sd_checkpoint_name = checkpoint.model_name
 | 
						|
        if remove_cached_checksum:
 | 
						|
            embedding.cached_checksum = None
 | 
						|
        embedding.name = embedding_name
 | 
						|
        embedding.optimizer_state_dict = optimizer.state_dict()
 | 
						|
        embedding.save(filename)
 | 
						|
    except:
 | 
						|
        embedding.sd_checkpoint = old_sd_checkpoint
 | 
						|
        embedding.sd_checkpoint_name = old_sd_checkpoint_name
 | 
						|
        embedding.name = old_embedding_name
 | 
						|
        embedding.cached_checksum = old_cached_checksum
 | 
						|
        raise
 |