mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-11-04 12:03:36 +00:00 
			
		
		
		
	train: make it possible to make text files with prompts train: rework scheduler so that there's less repeating code in textual inversion and hypernets train: move epochs setting to options
		
			
				
	
	
		
			70 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			70 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import tqdm
 | 
						|
 | 
						|
 | 
						|
class LearnScheduleIterator:
 | 
						|
    def __init__(self, learn_rate, max_steps, cur_step=0):
 | 
						|
        """
 | 
						|
        specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
 | 
						|
        """
 | 
						|
 | 
						|
        pairs = learn_rate.split(',')
 | 
						|
        self.rates = []
 | 
						|
        self.it = 0
 | 
						|
        self.maxit = 0
 | 
						|
        for i, pair in enumerate(pairs):
 | 
						|
            tmp = pair.split(':')
 | 
						|
            if len(tmp) == 2:
 | 
						|
                step = int(tmp[1])
 | 
						|
                if step > cur_step:
 | 
						|
                    self.rates.append((float(tmp[0]), min(step, max_steps)))
 | 
						|
                    self.maxit += 1
 | 
						|
                    if step > max_steps:
 | 
						|
                        return
 | 
						|
                elif step == -1:
 | 
						|
                    self.rates.append((float(tmp[0]), max_steps))
 | 
						|
                    self.maxit += 1
 | 
						|
                    return
 | 
						|
            else:
 | 
						|
                self.rates.append((float(tmp[0]), max_steps))
 | 
						|
                self.maxit += 1
 | 
						|
                return
 | 
						|
 | 
						|
    def __iter__(self):
 | 
						|
        return self
 | 
						|
 | 
						|
    def __next__(self):
 | 
						|
        if self.it < self.maxit:
 | 
						|
            self.it += 1
 | 
						|
            return self.rates[self.it - 1]
 | 
						|
        else:
 | 
						|
            raise StopIteration
 | 
						|
 | 
						|
 | 
						|
class LearnRateScheduler:
 | 
						|
    def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
 | 
						|
        self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
 | 
						|
        (self.learn_rate,  self.end_step) = next(self.schedules)
 | 
						|
        self.verbose = verbose
 | 
						|
 | 
						|
        if self.verbose:
 | 
						|
            print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
 | 
						|
 | 
						|
        self.finished = False
 | 
						|
 | 
						|
    def apply(self, optimizer, step_number):
 | 
						|
        if step_number <= self.end_step:
 | 
						|
            return
 | 
						|
 | 
						|
        try:
 | 
						|
            (self.learn_rate, self.end_step) = next(self.schedules)
 | 
						|
        except Exception:
 | 
						|
            self.finished = True
 | 
						|
            return
 | 
						|
 | 
						|
        if self.verbose:
 | 
						|
            tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
 | 
						|
 | 
						|
        for pg in optimizer.param_groups:
 | 
						|
            pg['lr'] = self.learn_rate
 | 
						|
 |