mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-31 10:03:40 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			86 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			86 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import threading
 | |
| import time
 | |
| from collections import defaultdict
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| class MemUsageMonitor(threading.Thread):
 | |
|     run_flag = None
 | |
|     device = None
 | |
|     disabled = False
 | |
|     opts = None
 | |
|     data = None
 | |
| 
 | |
|     def __init__(self, name, device, opts):
 | |
|         threading.Thread.__init__(self)
 | |
|         self.name = name
 | |
|         self.device = device
 | |
|         self.opts = opts
 | |
| 
 | |
|         self.daemon = True
 | |
|         self.run_flag = threading.Event()
 | |
|         self.data = defaultdict(int)
 | |
| 
 | |
|         try:
 | |
|             torch.cuda.mem_get_info()
 | |
|             torch.cuda.memory_stats(self.device)
 | |
|         except Exception as e:  # AMD or whatever
 | |
|             print(f"Warning: caught exception '{e}', memory monitor disabled")
 | |
|             self.disabled = True
 | |
| 
 | |
|     def run(self):
 | |
|         if self.disabled:
 | |
|             return
 | |
| 
 | |
|         while True:
 | |
|             self.run_flag.wait()
 | |
| 
 | |
|             torch.cuda.reset_peak_memory_stats()
 | |
|             self.data.clear()
 | |
| 
 | |
|             if self.opts.memmon_poll_rate <= 0:
 | |
|                 self.run_flag.clear()
 | |
|                 continue
 | |
| 
 | |
|             self.data["min_free"] = torch.cuda.mem_get_info()[0]
 | |
| 
 | |
|             while self.run_flag.is_set():
 | |
|                 free, total = torch.cuda.mem_get_info()  # calling with self.device errors, torch bug?
 | |
|                 self.data["min_free"] = min(self.data["min_free"], free)
 | |
| 
 | |
|                 time.sleep(1 / self.opts.memmon_poll_rate)
 | |
| 
 | |
|     def dump_debug(self):
 | |
|         print(self, 'recorded data:')
 | |
|         for k, v in self.read().items():
 | |
|             print(k, -(v // -(1024 ** 2)))
 | |
| 
 | |
|         print(self, 'raw torch memory stats:')
 | |
|         tm = torch.cuda.memory_stats(self.device)
 | |
|         for k, v in tm.items():
 | |
|             if 'bytes' not in k:
 | |
|                 continue
 | |
|             print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
 | |
| 
 | |
|         print(torch.cuda.memory_summary())
 | |
| 
 | |
|     def monitor(self):
 | |
|         self.run_flag.set()
 | |
| 
 | |
|     def read(self):
 | |
|         if not self.disabled:
 | |
|             free, total = torch.cuda.mem_get_info()
 | |
|             self.data["total"] = total
 | |
| 
 | |
|             torch_stats = torch.cuda.memory_stats(self.device)
 | |
|             self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
 | |
|             self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
 | |
|             self.data["system_peak"] = total - self.data["min_free"]
 | |
| 
 | |
|         return self.data
 | |
| 
 | |
|     def stop(self):
 | |
|         self.run_flag.clear()
 | |
|         return self.read()
 | 
