Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	user
				
			
		
		5fb352c
		
		| import threading | |
| import time | |
| from collections import defaultdict | |
| import torch | |
| class MemUsageMonitor(threading.Thread): | |
| run_flag = None | |
| device = None | |
| disabled = False | |
| opts = None | |
| data = None | |
| def __init__(self, name, device, opts): | |
| threading.Thread.__init__(self) | |
| self.name = name | |
| self.device = device | |
| self.opts = opts | |
| self.daemon = True | |
| self.run_flag = threading.Event() | |
| self.data = defaultdict(int) | |
| try: | |
| torch.cuda.mem_get_info() | |
| torch.cuda.memory_stats(self.device) | |
| except Exception as e: # AMD or whatever | |
| print(f"Warning: caught exception '{e}', memory monitor disabled") | |
| self.disabled = True | |
| def run(self): | |
| if self.disabled: | |
| return | |
| while True: | |
| self.run_flag.wait() | |
| torch.cuda.reset_peak_memory_stats() | |
| self.data.clear() | |
| if self.opts.memmon_poll_rate <= 0: | |
| self.run_flag.clear() | |
| continue | |
| self.data["min_free"] = torch.cuda.mem_get_info()[0] | |
| while self.run_flag.is_set(): | |
| free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug? | |
| self.data["min_free"] = min(self.data["min_free"], free) | |
| time.sleep(1 / self.opts.memmon_poll_rate) | |
| def dump_debug(self): | |
| print(self, 'recorded data:') | |
| for k, v in self.read().items(): | |
| print(k, -(v // -(1024 ** 2))) | |
| print(self, 'raw torch memory stats:') | |
| tm = torch.cuda.memory_stats(self.device) | |
| for k, v in tm.items(): | |
| if 'bytes' not in k: | |
| continue | |
| print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) | |
| print(torch.cuda.memory_summary()) | |
| def monitor(self): | |
| self.run_flag.set() | |
| def read(self): | |
| if not self.disabled: | |
| free, total = torch.cuda.mem_get_info() | |
| self.data["free"] = free | |
| self.data["total"] = total | |
| torch_stats = torch.cuda.memory_stats(self.device) | |
| self.data["active"] = torch_stats["active.all.current"] | |
| self.data["active_peak"] = torch_stats["active_bytes.all.peak"] | |
| self.data["reserved"] = torch_stats["reserved_bytes.all.current"] | |
| self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] | |
| self.data["system_peak"] = total - self.data["min_free"] | |
| return self.data | |
| def stop(self): | |
| self.run_flag.clear() | |
| return self.read() | |