mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-31 10:03:40 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			82 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			82 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import tqdm
 | |
| 
 | |
| 
 | |
| class LearnScheduleIterator:
 | |
|     def __init__(self, learn_rate, max_steps, cur_step=0):
 | |
|         """
 | |
|         specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
 | |
|         """
 | |
| 
 | |
|         pairs = learn_rate.split(',')
 | |
|         self.rates = []
 | |
|         self.it = 0
 | |
|         self.maxit = 0
 | |
|         try:
 | |
|             for pair in pairs:
 | |
|                 if not pair.strip():
 | |
|                     continue
 | |
|                 tmp = pair.split(':')
 | |
|                 if len(tmp) == 2:
 | |
|                     step = int(tmp[1])
 | |
|                     if step > cur_step:
 | |
|                         self.rates.append((float(tmp[0]), min(step, max_steps)))
 | |
|                         self.maxit += 1
 | |
|                         if step > max_steps:
 | |
|                             return
 | |
|                     elif step == -1:
 | |
|                         self.rates.append((float(tmp[0]), max_steps))
 | |
|                         self.maxit += 1
 | |
|                         return
 | |
|                 else:
 | |
|                     self.rates.append((float(tmp[0]), max_steps))
 | |
|                     self.maxit += 1
 | |
|                     return
 | |
|             assert self.rates
 | |
|         except (ValueError, AssertionError) as e:
 | |
|             raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e
 | |
| 
 | |
| 
 | |
|     def __iter__(self):
 | |
|         return self
 | |
| 
 | |
|     def __next__(self):
 | |
|         if self.it < self.maxit:
 | |
|             self.it += 1
 | |
|             return self.rates[self.it - 1]
 | |
|         else:
 | |
|             raise StopIteration
 | |
| 
 | |
| 
 | |
| class LearnRateScheduler:
 | |
|     def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
 | |
|         self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
 | |
|         (self.learn_rate,  self.end_step) = next(self.schedules)
 | |
|         self.verbose = verbose
 | |
| 
 | |
|         if self.verbose:
 | |
|             print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
 | |
| 
 | |
|         self.finished = False
 | |
| 
 | |
|     def step(self, step_number):
 | |
|         if step_number < self.end_step:
 | |
|             return False
 | |
| 
 | |
|         try:
 | |
|             (self.learn_rate, self.end_step) = next(self.schedules)
 | |
|         except StopIteration:
 | |
|             self.finished = True
 | |
|             return False
 | |
|         return True
 | |
| 
 | |
|     def apply(self, optimizer, step_number):
 | |
|         if not self.step(step_number):
 | |
|             return
 | |
| 
 | |
|         if self.verbose:
 | |
|             tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
 | |
| 
 | |
|         for pg in optimizer.param_groups:
 | |
|             pg['lr'] = self.learn_rate
 | |
| 
 | 
