File size: 5,160 Bytes
75b6530 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import os
import pickle
import time
import torch
from torchtitan.config_manager import JobConfig
from torchtitan.tools.logging import logger
# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3
# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
@contextlib.contextmanager
def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
# get user defined profiler settings
enable_profiling = config.profiling.enable_profiling
if enable_profiling:
dump_dir = config.job.dump_folder
save_trace_dir = config.profiling.save_traces_folder
trace_dir = os.path.join(dump_dir, save_trace_dir)
profile_freq = config.profiling.profile_freq
rank = torch.distributed.get_rank()
def trace_handler(prof):
curr_trace_dir_name = "iteration_" + str(prof.step_num)
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)
logger.info(f"Dumping profiler traces at step {prof.step_num}")
begin = time.monotonic()
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
logger.info(
f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds"
)
logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)
warmup, active = WARMUP, 1
wait = profile_freq - (active + warmup)
assert (
wait >= 0
), "profile_freq must be greater than or equal to warmup + active"
gpu_device_profiled = None
if torch.cuda.is_available():
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
elif torch.xpu.is_available():
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
gpu_device_profiled,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
on_trace_ready=trace_handler,
record_shapes=True,
) as torch_profiler:
torch_profiler.step_num = global_step
yield torch_profiler
else:
torch_profiler = contextlib.nullcontext()
yield None
@contextlib.contextmanager
def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0):
enable_snapshot = config.profiling.enable_memory_snapshot
if enable_snapshot:
snapshot_folder = config.profiling.save_memory_snapshot_folder
snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder)
if not os.path.exists(snapshot_dir):
os.makedirs(snapshot_dir, exist_ok=True)
rank = torch.distributed.get_rank()
class MemoryProfiler:
def __init__(self, step_num: int, freq: int):
torch.cuda.memory._record_memory_history(
max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES
)
# when resume training, we start from the last step
self.step_num = step_num
self.freq = freq
def step(self, exit_ctx: bool = False):
self.step_num += 1
if not exit_ctx and self.step_num % self.freq != 0:
return
if not exit_ctx:
curr_step = self.step_num
dir_name = f"iteration_{curr_step}"
else:
# dump as iteration_0_exit if OOM at iter 1
curr_step = self.step_num - 1
dir_name = f"iteration_{curr_step}_exit"
curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
if not os.path.exists(curr_snapshot_dir):
os.makedirs(curr_snapshot_dir, exist_ok=True)
logger.info(f"Dumping memory snapshot at step {curr_step}")
begin = time.monotonic()
with open(
f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb"
) as output:
pickle.dump(torch.cuda.memory._snapshot(), output)
logger.info(
f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds"
)
logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
profiler = MemoryProfiler(global_step, config.profiling.profile_freq)
try:
yield profiler
except torch.OutOfMemoryError as e:
profiler.step(exit_ctx=True)
else:
yield None
|