|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import os |
|
import pickle |
|
import time |
|
|
|
import torch |
|
|
|
from torchtitan.config_manager import JobConfig |
|
from torchtitan.tools.logging import logger |
|
|
|
|
|
WARMUP = 3 |
|
|
|
|
|
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 |
|
|
|
|
|
@contextlib.contextmanager |
|
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): |
|
|
|
enable_profiling = config.profiling.enable_profiling |
|
|
|
if enable_profiling: |
|
dump_dir = config.job.dump_folder |
|
save_trace_dir = config.profiling.save_traces_folder |
|
trace_dir = os.path.join(dump_dir, save_trace_dir) |
|
profile_freq = config.profiling.profile_freq |
|
|
|
rank = torch.distributed.get_rank() |
|
|
|
def trace_handler(prof): |
|
curr_trace_dir_name = "iteration_" + str(prof.step_num) |
|
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) |
|
if not os.path.exists(curr_trace_dir): |
|
os.makedirs(curr_trace_dir, exist_ok=True) |
|
|
|
logger.info(f"Dumping profiler traces at step {prof.step_num}") |
|
begin = time.monotonic() |
|
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") |
|
logger.info( |
|
f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" |
|
) |
|
|
|
logger.info(f"Profiling active. Traces will be saved at {trace_dir}") |
|
|
|
if not os.path.exists(trace_dir): |
|
os.makedirs(trace_dir, exist_ok=True) |
|
|
|
warmup, active = WARMUP, 1 |
|
wait = profile_freq - (active + warmup) |
|
assert ( |
|
wait >= 0 |
|
), "profile_freq must be greater than or equal to warmup + active" |
|
gpu_device_profiled = None |
|
if torch.cuda.is_available(): |
|
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA |
|
elif torch.xpu.is_available(): |
|
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU |
|
with torch.profiler.profile( |
|
activities=[ |
|
torch.profiler.ProfilerActivity.CPU, |
|
gpu_device_profiled, |
|
], |
|
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), |
|
on_trace_ready=trace_handler, |
|
record_shapes=True, |
|
) as torch_profiler: |
|
torch_profiler.step_num = global_step |
|
yield torch_profiler |
|
else: |
|
torch_profiler = contextlib.nullcontext() |
|
yield None |
|
|
|
|
|
@contextlib.contextmanager |
|
def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): |
|
enable_snapshot = config.profiling.enable_memory_snapshot |
|
if enable_snapshot: |
|
snapshot_folder = config.profiling.save_memory_snapshot_folder |
|
snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) |
|
if not os.path.exists(snapshot_dir): |
|
os.makedirs(snapshot_dir, exist_ok=True) |
|
rank = torch.distributed.get_rank() |
|
|
|
class MemoryProfiler: |
|
def __init__(self, step_num: int, freq: int): |
|
torch.cuda.memory._record_memory_history( |
|
max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES |
|
) |
|
|
|
self.step_num = step_num |
|
self.freq = freq |
|
|
|
def step(self, exit_ctx: bool = False): |
|
self.step_num += 1 |
|
if not exit_ctx and self.step_num % self.freq != 0: |
|
return |
|
if not exit_ctx: |
|
curr_step = self.step_num |
|
dir_name = f"iteration_{curr_step}" |
|
else: |
|
|
|
curr_step = self.step_num - 1 |
|
dir_name = f"iteration_{curr_step}_exit" |
|
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) |
|
if not os.path.exists(curr_snapshot_dir): |
|
os.makedirs(curr_snapshot_dir, exist_ok=True) |
|
logger.info(f"Dumping memory snapshot at step {curr_step}") |
|
begin = time.monotonic() |
|
with open( |
|
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" |
|
) as output: |
|
pickle.dump(torch.cuda.memory._snapshot(), output) |
|
logger.info( |
|
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" |
|
) |
|
|
|
logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") |
|
profiler = MemoryProfiler(global_step, config.profiling.profile_freq) |
|
try: |
|
yield profiler |
|
except torch.OutOfMemoryError as e: |
|
profiler.step(exit_ctx=True) |
|
else: |
|
yield None |
|
|