|
|
|
|
|
|
|
|
|
|
|
|
|
import importlib |
|
import os |
|
import time |
|
from datetime import timedelta |
|
from typing import Any, Generator, Iterable, Optional |
|
|
|
import torch |
|
from torch.distributed.elastic.multiprocessing.errors import record |
|
|
|
import torchtitan.components.ft as ft |
|
import torchtitan.protocols.train_spec as train_spec_module |
|
|
|
from torchtitan.components.checkpoint import CheckpointManager |
|
from torchtitan.components.metrics import ( |
|
build_metrics_processor, |
|
ensure_pp_loss_visible, |
|
) |
|
from torchtitan.config_manager import JobConfig |
|
from torchtitan.distributed import ParallelDims, utils as dist_utils |
|
from torchtitan.protocols.model_converter import build_model_converters |
|
from torchtitan.tools import utils |
|
from torchtitan.tools.logging import init_logger, logger |
|
from torchtitan.tools.profiling import ( |
|
maybe_enable_memory_snapshot, |
|
maybe_enable_profiling, |
|
) |
|
|
|
|
|
class Trainer(torch.distributed.checkpoint.stateful.Stateful): |
|
job_config: JobConfig |
|
gc_handler: utils.GarbageCollection |
|
|
|
parallel_dims: ParallelDims |
|
train_spec: train_spec_module.TrainSpec |
|
world_mesh: torch.distributed.DeviceMesh |
|
|
|
dataloader: train_spec_module.BaseDataLoader |
|
metrics_processor: train_spec_module.MetricsProcessor |
|
checkpointer: CheckpointManager |
|
train_context: Generator[None, None, None] |
|
|
|
model_parts: list[torch.nn.Module] |
|
loss_fn: train_spec_module.LossFunction |
|
optimizers: train_spec_module.OptimizersContainer |
|
lr_schedulers: train_spec_module.LRSchedulersContainer |
|
|
|
pp_has_first_stage: bool |
|
pp_has_last_stage: bool |
|
|
|
device: torch.device |
|
|
|
|
|
step: int |
|
|
|
|
|
@record |
|
def __init__(self, job_config: JobConfig): |
|
self.job_config = job_config |
|
|
|
logger.info(f"Starting job: {job_config.job.description}") |
|
|
|
if job_config.experimental.custom_import: |
|
importlib.import_module(job_config.experimental.custom_import) |
|
|
|
if job_config.job.print_args: |
|
logger.info(f"Running with args: {job_config.to_dict()}") |
|
|
|
|
|
self.gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) |
|
|
|
device_module, device_type = utils.device_module, utils.device_type |
|
self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") |
|
|
|
device_module.set_device(self.device) |
|
ft_manager = ft.init_ft_manager(job_config) |
|
|
|
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
parallelism_config = job_config.parallelism |
|
if not ft_manager.enabled: |
|
self.parallel_dims = parallel_dims = ParallelDims( |
|
dp_shard=parallelism_config.data_parallel_shard_degree, |
|
dp_replicate=parallelism_config.data_parallel_replicate_degree, |
|
cp=parallelism_config.context_parallel_degree, |
|
tp=parallelism_config.tensor_parallel_degree, |
|
pp=parallelism_config.pipeline_parallel_degree, |
|
world_size=world_size, |
|
enable_loss_parallel=not parallelism_config.disable_loss_parallel, |
|
) |
|
else: |
|
self.parallel_dims = parallel_dims = ft.FTParallelDims( |
|
dp_shard=parallelism_config.data_parallel_shard_degree, |
|
dp_replicate=parallelism_config.data_parallel_replicate_degree, |
|
cp=parallelism_config.context_parallel_degree, |
|
tp=parallelism_config.tensor_parallel_degree, |
|
pp=parallelism_config.pipeline_parallel_degree, |
|
world_size=world_size, |
|
enable_loss_parallel=not parallelism_config.disable_loss_parallel, |
|
ft_manager=ft_manager, |
|
) |
|
dist_utils.init_distributed(job_config) |
|
|
|
|
|
self.world_mesh = world_mesh = parallel_dims.build_mesh(device_type=device_type) |
|
if parallel_dims.dp_enabled: |
|
dp_mesh = world_mesh["dp"] |
|
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() |
|
else: |
|
dp_degree, dp_rank = 1, 0 |
|
|
|
|
|
|
|
dist_utils.set_determinism( |
|
world_mesh, |
|
self.device, |
|
job_config.training.seed, |
|
job_config.training.deterministic, |
|
) |
|
self.train_spec = train_spec_module.get_train_spec(job_config.model.name) |
|
|
|
|
|
tokenizer = ( |
|
self.train_spec.build_tokenizer_fn(job_config) |
|
if self.train_spec.build_tokenizer_fn is not None |
|
else None |
|
) |
|
|
|
|
|
|
|
if ft_manager.enabled: |
|
dp_degree, dp_rank = ft_manager.get_dp_info(dp_degree, dp_rank) |
|
|
|
self.dataloader = self.train_spec.build_dataloader_fn( |
|
dp_world_size=dp_degree, |
|
dp_rank=dp_rank, |
|
tokenizer=tokenizer, |
|
job_config=job_config, |
|
) |
|
|
|
|
|
model_cls = self.train_spec.cls |
|
model_args = self.train_spec.config[job_config.model.flavor] |
|
|
|
model_args.update_from_config(job_config, tokenizer) |
|
|
|
logger.info( |
|
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" |
|
) |
|
with torch.device("meta"): |
|
model = model_cls.from_model_args(model_args) |
|
|
|
|
|
model_converters = build_model_converters(job_config, parallel_dims) |
|
model_converters.convert(model) |
|
|
|
|
|
build_metrics_processor_fn = ( |
|
build_metrics_processor |
|
if self.train_spec.build_metrics_processor_fn is None |
|
else self.train_spec.build_metrics_processor_fn |
|
) |
|
self.metrics_processor = build_metrics_processor_fn(job_config, parallel_dims) |
|
color = self.metrics_processor.color |
|
|
|
|
|
( |
|
model_param_count, |
|
self.metrics_processor.num_flops_per_token, |
|
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) |
|
|
|
logger.info( |
|
f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} " |
|
f"{color.red}size: {model_param_count:,} total parameters{color.reset}" |
|
) |
|
|
|
|
|
if job_config.checkpoint.create_seed_checkpoint: |
|
init_device = "cpu" |
|
buffer_device = None |
|
elif job_config.training.enable_cpu_offload: |
|
init_device = "cpu" |
|
buffer_device = device_type |
|
else: |
|
init_device = device_type |
|
buffer_device = None |
|
|
|
self.loss_fn = self.train_spec.build_loss_fn(job_config) |
|
|
|
|
|
if parallel_dims.pp_enabled: |
|
if not self.train_spec.pipelining_fn: |
|
raise RuntimeError( |
|
f"Pipeline Parallel is enabled but {self.train_spec.name} " |
|
f"does not support pipelining" |
|
) |
|
|
|
|
|
( |
|
self.pp_schedule, |
|
self.model_parts, |
|
self.pp_has_first_stage, |
|
self.pp_has_last_stage, |
|
) = self.train_spec.pipelining_fn( |
|
model, |
|
world_mesh, |
|
parallel_dims, |
|
job_config, |
|
self.device, |
|
model_args, |
|
self.train_spec.parallelize_fn, |
|
self.loss_fn, |
|
) |
|
|
|
|
|
del model |
|
|
|
for m in self.model_parts: |
|
m.to_empty(device=init_device) |
|
with torch.no_grad(): |
|
m.init_weights(buffer_device=buffer_device) |
|
m.train() |
|
|
|
|
|
ensure_pp_loss_visible(parallel_dims, job_config, color) |
|
else: |
|
|
|
model = self.train_spec.parallelize_fn( |
|
model, world_mesh, parallel_dims, job_config |
|
) |
|
|
|
model.to_empty(device=init_device) |
|
with torch.no_grad(): |
|
model.init_weights(buffer_device=buffer_device) |
|
model.train() |
|
|
|
self.model_parts = [model] |
|
|
|
|
|
device_memory_monitor = self.metrics_processor.device_memory_monitor |
|
gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) |
|
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") |
|
device_mem_stats = device_memory_monitor.get_peak_stats() |
|
logger.info( |
|
f"{device_type.upper()} memory usage for model: " |
|
f"{device_mem_stats.max_reserved_gib:.2f}GiB" |
|
f"({device_mem_stats.max_reserved_pct:.2f}%)" |
|
) |
|
|
|
|
|
self.optimizers = self.train_spec.build_optimizers_fn( |
|
self.model_parts, job_config, ft_manager |
|
) |
|
self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( |
|
self.optimizers, job_config |
|
) |
|
|
|
|
|
|
|
self.optimizers.register_step_post_hook( |
|
lambda *args, **kwargs: model_converters.post_optimizer_hook( |
|
self.model_parts |
|
) |
|
) |
|
self.metrics_processor.optimizers = self.optimizers |
|
|
|
|
|
|
|
self.step = 0 |
|
|
|
self.checkpointer = CheckpointManager( |
|
dataloader=self.dataloader, |
|
model_parts=self.model_parts, |
|
optimizers=self.optimizers, |
|
lr_schedulers=self.lr_schedulers, |
|
states={"train_state": self}, |
|
job_config=job_config, |
|
ft_manager=ft_manager, |
|
) |
|
|
|
self.train_context = dist_utils.get_train_context( |
|
parallel_dims.loss_parallel_enabled, |
|
parallelism_config.enable_compiled_autograd, |
|
) |
|
|
|
logger.info( |
|
"Trainer is initialized with " |
|
f"local batch size {job_config.training.batch_size}, " |
|
f"global batch size {job_config.training.batch_size * dp_degree}, " |
|
f"sequence length {job_config.training.seq_len}, " |
|
f"total steps {job_config.training.steps} " |
|
f"(warmup {job_config.lr_scheduler.warmup_steps})." |
|
) |
|
|
|
def next_batch( |
|
self, data_iterator: Iterable |
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
data_load_start = time.perf_counter() |
|
batch = next(data_iterator) |
|
input_dict, labels = batch |
|
self.metrics_processor.ntokens_since_last_log += labels.numel() |
|
self.metrics_processor.data_loading_times.append( |
|
time.perf_counter() - data_load_start |
|
) |
|
|
|
device_type = utils.device_type |
|
for k, _ in input_dict.items(): |
|
input_dict[k] = input_dict[k].to(device_type) |
|
labels = labels.to(device_type) |
|
return input_dict, labels |
|
|
|
def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): |
|
self.optimizers.zero_grad() |
|
|
|
|
|
|
|
model_parts = self.model_parts |
|
world_mesh = self.world_mesh |
|
parallel_dims = self.parallel_dims |
|
|
|
|
|
|
|
inputs = input_dict["input"] |
|
optional_context_parallel_ctx = ( |
|
dist_utils.create_context_parallel_ctx( |
|
cp_mesh=world_mesh["cp"], |
|
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], |
|
cp_seq_dims=[1, 1] + [0 for _ in model_parts], |
|
cp_no_restore_buffers={inputs, labels}, |
|
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, |
|
) |
|
if parallel_dims.cp_enabled |
|
else None |
|
) |
|
|
|
if parallel_dims.pp_enabled: |
|
|
|
with self.train_context(optional_context_parallel_ctx): |
|
targets, losses = ( |
|
(labels, []) if self.pp_has_last_stage else (None, None) |
|
) |
|
if self.pp_has_first_stage: |
|
self.pp_schedule.step(inputs, target=targets, losses=losses) |
|
else: |
|
self.pp_schedule.step(target=targets, losses=losses) |
|
|
|
|
|
|
|
loss = ( |
|
torch.mean(torch.stack(losses)).to(self.device) |
|
if self.pp_has_last_stage |
|
else torch.tensor([-1.0], device=self.device) |
|
) |
|
else: |
|
|
|
with self.train_context(optional_context_parallel_ctx): |
|
assert len(model_parts) == 1 |
|
pred = model_parts[0](inputs) |
|
loss = self.loss_fn(pred, labels) |
|
|
|
del pred |
|
loss.backward() |
|
|
|
dist_utils.clip_grad_norm_( |
|
[p for m in model_parts for p in m.parameters()], |
|
self.job_config.training.max_norm, |
|
foreach=True, |
|
pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, |
|
) |
|
self.checkpointer.maybe_wait_for_staging() |
|
self.optimizers.step() |
|
self.lr_schedulers.step() |
|
|
|
|
|
if not self.metrics_processor.should_log(self.step): |
|
return |
|
|
|
if ( |
|
parallel_dims.dp_replicate_enabled |
|
or parallel_dims.dp_shard_enabled |
|
or parallel_dims.cp_enabled |
|
): |
|
loss = loss.detach() |
|
global_avg_loss, global_max_loss = ( |
|
dist_utils.dist_mean(loss, world_mesh["dp_cp"]), |
|
dist_utils.dist_max(loss, world_mesh["dp_cp"]), |
|
) |
|
else: |
|
global_avg_loss = global_max_loss = loss.detach().item() |
|
|
|
self.metrics_processor.log(self.step, global_avg_loss, global_max_loss) |
|
|
|
@record |
|
def train(self): |
|
job_config = self.job_config |
|
|
|
self.checkpointer.load(step=job_config.checkpoint.load_step) |
|
logger.info(f"Training starts at step {self.step + 1}.") |
|
|
|
with maybe_enable_profiling( |
|
job_config, global_step=self.step |
|
) as torch_profiler, maybe_enable_memory_snapshot( |
|
job_config, global_step=self.step |
|
) as memory_profiler: |
|
data_iterator = iter(self.dataloader) |
|
while self.step < job_config.training.steps: |
|
self.step += 1 |
|
self.gc_handler.run(self.step) |
|
inputs, labels = self.next_batch(data_iterator) |
|
self.train_step(inputs, labels) |
|
self.checkpointer.save( |
|
self.step, force=(self.step == job_config.training.steps) |
|
) |
|
|
|
|
|
if torch_profiler: |
|
torch_profiler.step() |
|
if memory_profiler: |
|
memory_profiler.step() |
|
|
|
|
|
|
|
if self.step == 1: |
|
dist_utils.set_pg_timeouts( |
|
timeout=timedelta( |
|
seconds=job_config.comm.train_timeout_seconds |
|
), |
|
world_mesh=self.world_mesh, |
|
) |
|
|
|
if torch.distributed.get_rank() == 0: |
|
logger.info("Sleeping 2 seconds for other ranks to complete") |
|
time.sleep(2) |
|
|
|
self.metrics_processor.close() |
|
logger.info("Training completed") |
|
|
|
def state_dict(self) -> dict[str, Any]: |
|
return {"step": self.step} |
|
|
|
def load_state_dict(self, state_dict: dict[str, Any]): |
|
self.step = state_dict["step"] |
|
|
|
def close(self) -> None: |
|
if self.checkpointer: |
|
self.checkpointer.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
init_logger() |
|
config = JobConfig() |
|
config.maybe_add_custom_args() |
|
config.parse_args() |
|
trainer: Optional[Trainer] = None |
|
|
|
try: |
|
trainer = Trainer(config) |
|
|
|
if config.checkpoint.create_seed_checkpoint: |
|
assert int( |
|
os.environ["WORLD_SIZE"] |
|
), "Must create seed checkpoint using a single device, to disable sharding." |
|
assert ( |
|
config.checkpoint.enable_checkpoint |
|
), "Must enable checkpointing when creating a seed checkpoint." |
|
trainer.checkpointer.save(curr_step=0, force=True) |
|
logger.info("Created seed checkpoint") |
|
else: |
|
trainer.train() |
|
finally: |
|
if trainer: |
|
trainer.close() |
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.destroy_process_group() |
|
logger.info("Process group destroyed.") |
|
|