|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
import time |
|
from datetime import timedelta |
|
|
|
import torch |
|
from datasets import interleave_datasets, load_dataset |
|
from torch.distributed.elastic.multiprocessing.errors import record |
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
|
|
|
import fla |
|
from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss |
|
from fla.ops.common.utils import prepare_position_ids |
|
from flame.components.checkpoint import TrainState |
|
from flame.config_manager import JobConfig |
|
from flame.data import build_dataloader, shuffle |
|
from flame.models.parallelize_fla import parallelize_fla |
|
from flame.models.pipeline_fla import pipeline_fla |
|
from flame.tools.utils import get_nparams_and_flops |
|
from flame.utils.checkpoint import cleanup_local_checkpoints |
|
from flame.utils.convert_dcp_to_hf import save_pretrained |
|
from flame.utils.hf_utils import upload_checkpoint_to_hf |
|
from datetime import datetime |
|
from torchtitan.components.checkpoint import CheckpointManager |
|
from torchtitan.components.ft import FTParallelDims, init_ft_manager |
|
from torchtitan.components.loss import build_cross_entropy_loss |
|
from torchtitan.components.lr_scheduler import build_lr_schedulers |
|
from torchtitan.components.metrics import build_device_memory_monitor, build_metrics_processor, ensure_pp_loss_visible |
|
from torchtitan.components.optimizer import build_optimizers |
|
from torchtitan.distributed import ParallelDims |
|
from torchtitan.distributed import utils as dist_utils |
|
from torchtitan.protocols.model_converter import build_model_converters |
|
from torchtitan.protocols.train_spec import TrainSpec, get_train_spec, register_train_spec |
|
from torchtitan.tools import utils |
|
from torchtitan.tools.logging import init_logger, logger |
|
from torchtitan.tools.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling |
|
|
|
|
|
def build_tokenizer(job_config: JobConfig) -> AutoTokenizer: |
|
return AutoTokenizer.from_pretrained(job_config.model.tokenizer_path) |
|
|
|
|
|
register_train_spec( |
|
TrainSpec( |
|
name="fla", |
|
cls=AutoModelForCausalLM, |
|
config=AutoConfig, |
|
parallelize_fn=parallelize_fla, |
|
pipelining_fn=pipeline_fla, |
|
build_optimizers_fn=build_optimizers, |
|
build_lr_schedulers_fn=build_lr_schedulers, |
|
build_dataloader_fn=build_dataloader, |
|
build_tokenizer_fn=build_tokenizer, |
|
build_loss_fn=build_cross_entropy_loss, |
|
) |
|
) |
|
|
|
|
|
|
|
@record |
|
def main(job_config: JobConfig): |
|
logger.info(f"Starting job: {job_config.job.description}") |
|
|
|
if job_config.experimental.custom_model_path: |
|
utils.import_module_from_path(job_config.experimental.custom_model_path) |
|
|
|
|
|
color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color |
|
|
|
if job_config.job.print_args: |
|
logger.info( |
|
f"{color.green}{json.dumps(job_config.to_dict(), indent=2, sort_keys=True)}{color.reset}" |
|
) |
|
|
|
|
|
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) |
|
|
|
device_module, device_type = utils.device_module, utils.device_type |
|
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") |
|
|
|
device_module.set_device(device) |
|
ft_manager = init_ft_manager(job_config) |
|
|
|
run_specific_repo_id = None |
|
if getattr(job_config.checkpoint, "hf_upload_enabled", False): |
|
hf_repo_base = getattr(job_config.checkpoint, "hf_repo_base_name", None) |
|
if hf_repo_base: |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
|
run_specific_repo_id = f"{hf_repo_base}-{timestamp}" |
|
logger.info(f"Target Hugging Face repository for this run: {run_specific_repo_id}") |
|
else: |
|
logger.warning("HF Hub upload enabled, but 'checkpoint.hf_repo_base_name' is not set.") |
|
|
|
job_config.checkpoint.hf_upload_enabled = False |
|
|
|
|
|
world_size = int(os.environ["WORLD_SIZE"]) |
|
if not ft_manager.enabled: |
|
parallel_dims = ParallelDims( |
|
dp_shard=job_config.training.data_parallel_shard_degree, |
|
dp_replicate=job_config.training.data_parallel_replicate_degree, |
|
cp=job_config.experimental.context_parallel_degree, |
|
tp=job_config.training.tensor_parallel_degree, |
|
pp=job_config.experimental.pipeline_parallel_degree, |
|
world_size=world_size, |
|
enable_loss_parallel=not job_config.training.disable_loss_parallel, |
|
) |
|
else: |
|
parallel_dims = FTParallelDims( |
|
dp_shard=job_config.training.data_parallel_shard_degree, |
|
dp_replicate=job_config.training.data_parallel_replicate_degree, |
|
cp=job_config.experimental.context_parallel_degree, |
|
tp=job_config.training.tensor_parallel_degree, |
|
pp=job_config.experimental.pipeline_parallel_degree, |
|
world_size=world_size, |
|
enable_loss_parallel=not job_config.training.disable_loss_parallel, |
|
ft_manager=ft_manager, |
|
) |
|
dist_utils.init_distributed(job_config) |
|
|
|
device_memory_monitor = build_device_memory_monitor() |
|
gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) |
|
logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") |
|
|
|
|
|
world_mesh = parallel_dims.build_mesh(device_type=device_type) |
|
if parallel_dims.dp_enabled: |
|
dp_mesh = world_mesh["dp"] |
|
dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() |
|
else: |
|
dp_degree, dp_rank = 1, 0 |
|
|
|
if parallel_dims.pp_enabled: |
|
raise NotImplementedError( |
|
"Pipeline parallelism is not supported in this version" |
|
) |
|
""" |
|
! TODO[flame]: We need to fix the pipeline parallelism for flame |
|
[x] Match the key of models' components with the actual naming |
|
[ ] Fix the post-init and tie-embedding for pipeline parallelism, HF's transformer automatically |
|
forces to tie if head is None, we need to handle this case |
|
[ ] |
|
""" |
|
pp_mesh = world_mesh["pp"] |
|
|
|
|
|
dist_utils.set_determinism( |
|
world_mesh, device, job_config.training.seed, job_config.training.deterministic |
|
) |
|
train_spec = get_train_spec(job_config.model.name) |
|
|
|
logger.info("Loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
job_config.model.tokenizer_path, |
|
trust_remote_code=True, |
|
model_max_length=int(1e10), |
|
) |
|
logger.info(f"{tokenizer}") |
|
logger.info( |
|
f"Loading dataset {job_config.training.dataset}" |
|
f":{job_config.training.dataset_name}" |
|
if job_config.training.dataset_name is not None |
|
else "" |
|
) |
|
|
|
min_num_shards = dp_degree * job_config.training.num_workers |
|
if len(job_config.training.dataset.split(",")) == 1: |
|
dataset = load_dataset( |
|
path=job_config.training.dataset, |
|
name=getattr(job_config.training, "dataset_name", None), |
|
data_dir=getattr(job_config.training, "data_dir", None), |
|
data_files=getattr(job_config.training, "data_files", None), |
|
split=job_config.training.dataset_split or "train", |
|
trust_remote_code=True, |
|
streaming=job_config.training.streaming, |
|
num_proc=( |
|
job_config.training.num_workers |
|
if not job_config.training.streaming |
|
else None |
|
), |
|
) |
|
logger.info(f"{dataset}") |
|
|
|
logger.info(f"Shuffling the dataset with seed {job_config.training.seed}") |
|
if not job_config.training.streaming: |
|
|
|
dataset = dataset.shuffle( |
|
seed=job_config.training.seed |
|
).to_iterable_dataset(num_shards=min_num_shards) |
|
else: |
|
if dataset.num_shards < min_num_shards: |
|
logger.warning( |
|
f"{color.red}" |
|
f"Dataset {job_config.training.dataset} has insufficient shards ({dataset.num_shards}). " |
|
f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " |
|
f"{job_config.training.num_workers} dataloader workers. " |
|
f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards." |
|
f"{color.reset}" |
|
) |
|
dataset = ( |
|
load_dataset( |
|
path=job_config.training.dataset, |
|
name=getattr(job_config.training, "dataset_name", None), |
|
data_dir=getattr(job_config.training, "data_dir", None), |
|
data_files=getattr(job_config.training, "data_files", None), |
|
split=job_config.training.dataset_split or "train", |
|
trust_remote_code=True, |
|
streaming=False, |
|
num_proc=job_config.training.num_workers, |
|
) |
|
.shuffle(seed=job_config.training.seed) |
|
.to_iterable_dataset(num_shards=min_num_shards) |
|
) |
|
else: |
|
dataset = shuffle(dataset, seed=job_config.training.seed) |
|
else: |
|
datasets = job_config.training.dataset.split(",") |
|
if job_config.training.dataset_name is not None: |
|
dataset_names = [ |
|
name or None for name in job_config.training.dataset_name.split(",") |
|
] |
|
assert len(dataset_names) == len(datasets), ( |
|
"The number of dataset names must match the number of datasets" |
|
) |
|
else: |
|
dataset_names = [None] * len(datasets) |
|
if job_config.training.dataset_split is not None: |
|
dataset_splits = [ |
|
split or "train" |
|
for split in job_config.training.dataset_split.split(",") |
|
] |
|
assert len(dataset_splits) == len(datasets), ( |
|
"The number of dataset splits must match the number of datasets" |
|
) |
|
else: |
|
dataset_splits = ["train"] * len(datasets) |
|
if job_config.training.data_dir is not None: |
|
data_dirs = [ |
|
data_dir or None for data_dir in job_config.training.data_dir.split(",") |
|
] |
|
assert len(data_dirs) == len(datasets), ( |
|
"The number of data dirs must match the number of datasets" |
|
) |
|
else: |
|
data_dirs = [None] * len(datasets) |
|
if job_config.training.data_files is not None: |
|
data_files = job_config.training.data_files.split(",") |
|
assert len(data_files) == len(datasets), ( |
|
"The number of data files must match the number of datasets" |
|
) |
|
else: |
|
data_files = [None] * len(datasets) |
|
if job_config.training.data_probs is not None: |
|
data_probs = [float(p) for p in job_config.training.data_probs.split(",")] |
|
assert len(data_probs) == len(datasets), ( |
|
"The number of data probabilities must match the number of datasets" |
|
) |
|
else: |
|
raise ValueError( |
|
"Data sampling probabilities are required if using multiple datasets" |
|
) |
|
|
|
subsets = [] |
|
for i, prob in enumerate(data_probs): |
|
subset = load_dataset( |
|
path=datasets[i], |
|
name=dataset_names[i], |
|
data_dir=data_dirs[i], |
|
data_files=data_files[i], |
|
split=dataset_splits[i], |
|
trust_remote_code=True, |
|
streaming=job_config.training.streaming, |
|
num_proc=( |
|
job_config.training.num_workers |
|
if not job_config.training.streaming |
|
else None |
|
), |
|
) |
|
logger.info( |
|
f"Subset {color.cyan}{datasets[i]}" |
|
+ (f":{dataset_names[i]} " if dataset_names[i] else " ") |
|
+ f"(p = {prob:.3f}){color.reset}:\n" |
|
+ f"{subset}" |
|
) |
|
|
|
logger.info(f"Shuffling the dataset with seed {job_config.training.seed}") |
|
if not job_config.training.streaming: |
|
|
|
subset = subset.shuffle( |
|
seed=job_config.training.seed |
|
).to_iterable_dataset(num_shards=min_num_shards) |
|
else: |
|
if subset.num_shards < min_num_shards: |
|
logger.warning( |
|
f"{color.red}" |
|
f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). " |
|
f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " |
|
f"{job_config.training.num_workers} dataloader workers. " |
|
f"Resharding dataset to {min_num_shards} shards and disabling streaming mode." |
|
f"{color.reset}" |
|
) |
|
|
|
|
|
subset = ( |
|
load_dataset( |
|
path=datasets[i], |
|
name=dataset_names[i], |
|
data_dir=data_dirs[i], |
|
data_files=data_files[i], |
|
split=dataset_splits[i], |
|
trust_remote_code=True, |
|
streaming=False, |
|
num_proc=job_config.training.num_workers, |
|
) |
|
.shuffle(seed=job_config.training.seed) |
|
.to_iterable_dataset(min_num_shards) |
|
) |
|
else: |
|
|
|
subset = shuffle( |
|
subset, |
|
seed=job_config.training.seed, |
|
buffer_size=max(128, 1024 // len(datasets)), |
|
) |
|
|
|
if "text" in subset.column_names: |
|
subset = subset.select_columns("text") |
|
elif "content" in subset.column_names: |
|
subset = subset.select_columns("content") |
|
else: |
|
raise ValueError( |
|
f"Subset {datasets[i]} has no 'text' or 'content' column" |
|
) |
|
subsets.append(subset) |
|
|
|
logger.info( |
|
f"Interleaving {len(subsets)} datasets with probabilities {data_probs}" |
|
) |
|
dataset = interleave_datasets( |
|
datasets=subsets, |
|
probabilities=data_probs, |
|
stopping_strategy="all_exhausted", |
|
seed=job_config.training.seed, |
|
) |
|
logger.info(f"{dataset}") |
|
|
|
logger.info("Building dataloader...") |
|
dataloader = build_dataloader( |
|
dataset=dataset, |
|
tokenizer=tokenizer, |
|
rank=dp_rank, |
|
world_size=dp_degree, |
|
batch_size=job_config.training.batch_size, |
|
seq_len=job_config.training.seq_len, |
|
context_len=job_config.training.context_len, |
|
varlen=job_config.training.varlen, |
|
num_workers=job_config.training.num_workers, |
|
pin_memory=job_config.training.pin_memory, |
|
persistent_workers=job_config.training.persistent_workers, |
|
snapshot_every_n_steps=job_config.checkpoint.interval, |
|
) |
|
|
|
logger.info(f"Loading model config from {job_config.model.config}") |
|
model_config = AutoConfig.from_pretrained(job_config.model.config) |
|
|
|
|
|
|
|
|
|
|
|
if parallel_dims.tp_enabled: |
|
if model_config.fuse_norm: |
|
logger.warning( |
|
f"{color.red}" |
|
f"Fused norm is not compatible with tensor parallelism. " |
|
f"Disabling it for now." |
|
f"{color.reset}" |
|
) |
|
model_config.fuse_norm = False |
|
if parallel_dims.loss_parallel_enabled: |
|
if model_config.fuse_cross_entropy: |
|
logger.warning( |
|
f"{color.red}" |
|
f"Loss parallel enabled. Disabling fused cross entropy for now." |
|
f"{color.reset}" |
|
) |
|
model_config.fuse_cross_entropy = False |
|
model_config.vocab_size = max(tokenizer.vocab_size, model_config.vocab_size) |
|
|
|
logger.info( |
|
f"Building model from the config\n{color.green}{model_config}{color.reset}" |
|
) |
|
with torch.device("meta"): |
|
model = AutoModelForCausalLM.from_config(model_config) |
|
if ( |
|
getattr(model_config, "fuse_cross_entropy", False) |
|
and FusedLinearCrossEntropyLoss is not None |
|
): |
|
model.criterion = FusedLinearCrossEntropyLoss( |
|
num_chunks=8 // parallel_dims.tp |
|
) |
|
|
|
model.apply(lambda m: setattr(m, "_is_hf_initialized", False)) |
|
logger.info(f"{color.blue}\n{model}{color.reset}\n") |
|
|
|
|
|
model_converters = build_model_converters(job_config, parallel_dims) |
|
model_converters.convert(model) |
|
|
|
|
|
model_param_count, num_flops_per_token = get_nparams_and_flops( |
|
model, model_config, job_config.training.context_len |
|
) |
|
|
|
|
|
if job_config.checkpoint.create_seed_checkpoint: |
|
init_device = "cpu" |
|
elif job_config.training.enable_cpu_offload: |
|
init_device = "cpu" |
|
else: |
|
init_device = device_type |
|
|
|
|
|
if parallel_dims.pp_enabled: |
|
|
|
( |
|
pp_schedule, |
|
model_parts, |
|
has_first_stage, |
|
has_last_stage, |
|
) = train_spec.pipelining_fn( |
|
model, |
|
pp_mesh, |
|
parallel_dims, |
|
job_config, |
|
device, |
|
model_config, |
|
train_spec.loss_fn, |
|
) |
|
|
|
del model |
|
|
|
|
|
|
|
|
|
for m in model_parts: |
|
|
|
train_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config) |
|
m.to_empty(device=init_device) |
|
with torch.no_grad(): |
|
m.post_init() |
|
m.train() |
|
|
|
|
|
ensure_pp_loss_visible(parallel_dims, job_config, color) |
|
else: |
|
|
|
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config) |
|
model.to_empty(device=init_device) |
|
with torch.no_grad(): |
|
model.post_init() |
|
model.train() |
|
|
|
model_parts = [model] |
|
|
|
device_mem_stats = device_memory_monitor.get_peak_stats() |
|
logger.info( |
|
f"{device_type.upper()} memory usage for model: " |
|
f"{device_mem_stats.max_reserved_gib:.2f}GiB" |
|
f"({device_mem_stats.max_reserved_pct:.2f}%)" |
|
) |
|
|
|
|
|
optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager) |
|
lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config) |
|
|
|
|
|
|
|
optimizers.register_step_post_hook( |
|
lambda *args, **kwargs: model_converters.post_optimizer_hook(model_parts) |
|
) |
|
|
|
train_state = TrainState() |
|
|
|
|
|
checkpoint = CheckpointManager( |
|
dataloader=dataloader, |
|
model_parts=model_parts, |
|
optimizers=optimizers, |
|
lr_schedulers=lr_schedulers, |
|
states={"train_state": train_state}, |
|
job_config=job_config, |
|
ft_manager=ft_manager, |
|
) |
|
|
|
if job_config.checkpoint.create_seed_checkpoint: |
|
assert world_size == 1, ( |
|
"Must create seed checkpoint using a single device, to disable sharding" |
|
) |
|
assert job_config.checkpoint.enable_checkpoint, ( |
|
"Must enable checkpointing when creating a seed checkpoint" |
|
) |
|
checkpoint.save(curr_step=0, force=True) |
|
logger.info("Created seed checkpoint") |
|
return |
|
|
|
checkpoint.load(step=job_config.checkpoint.load_step) |
|
metric_logger = build_metrics_processor(job_config, parallel_dims) |
|
|
|
metric_logger.num_flops_per_token = num_flops_per_token |
|
metric_logger.optimizers = optimizers |
|
metric_logger.lr_schedulers = ( |
|
lr_schedulers |
|
) |
|
|
|
|
|
|
|
|
|
if train_state.step > 0 and len(metric_logger.data_loading_times) > 0: |
|
for idx, step in enumerate(train_state.log_steps): |
|
metric_logger.log( |
|
step, |
|
global_avg_loss=train_state.global_avg_losses[idx], |
|
global_max_loss=train_state.global_max_losses[idx], |
|
) |
|
|
|
data_iterator = iter(dataloader) |
|
|
|
train_context = dist_utils.get_train_context( |
|
parallel_dims.loss_parallel_enabled, |
|
job_config.experimental.enable_compiled_autograd, |
|
) |
|
|
|
|
|
device_memory_monitor.reset_peak_stats() |
|
|
|
global_batch_size = ( |
|
job_config.training.batch_size |
|
* dp_degree |
|
* job_config.training.gradient_accumulation_steps |
|
) |
|
num_tokens_per_step = global_batch_size * job_config.training.seq_len |
|
|
|
logger.info(f"{color.red}***** Running training *****{color.reset}") |
|
logger.info(f"{color.green} Training starts at step {train_state.step + 1}") |
|
logger.info( |
|
f"{color.green} Number of tokens per sequence = {job_config.training.seq_len:,}" |
|
) |
|
logger.info( |
|
f"{color.green} Gradient Accumulation steps = {job_config.training.gradient_accumulation_steps}" |
|
) |
|
logger.info( |
|
f"{color.green} Instantaneous batch size (per device) = {job_config.training.batch_size:,}" |
|
) |
|
logger.info( |
|
f"{color.green} Global batch size (w. parallel, distributed & accumulation) = {global_batch_size:,}" |
|
f" ({num_tokens_per_step:,} tokens)" |
|
) |
|
logger.info( |
|
f"{color.green} Total optimization steps = {job_config.training.steps:,} " |
|
f"({job_config.training.steps * num_tokens_per_step:,} tokens)" |
|
) |
|
logger.info( |
|
f"{color.green} Warmup steps = {job_config.lr_scheduler.warmup_steps:,}" |
|
f" ({job_config.lr_scheduler.warmup_steps * num_tokens_per_step:,} tokens)" |
|
) |
|
logger.info( |
|
f"{color.green} Number of parameters = {model_param_count:,} {color.reset}" |
|
) |
|
|
|
with ( |
|
maybe_enable_profiling( |
|
job_config, global_step=train_state.step |
|
) as torch_profiler, |
|
maybe_enable_memory_snapshot( |
|
job_config, global_step=train_state.step |
|
) as memory_profiler, |
|
): |
|
while train_state.step < job_config.training.steps: |
|
train_state.step += 1 |
|
gc_handler.run(train_state.step) |
|
|
|
optimizers.zero_grad() |
|
|
|
losses = [] |
|
|
|
for _ in range(job_config.training.gradient_accumulation_steps): |
|
|
|
data_load_start = time.perf_counter() |
|
batch = next(data_iterator) |
|
input_ids, labels = batch["input_ids"], batch["labels"] |
|
|
|
|
|
metric_logger.ntokens_since_last_log += labels.numel() |
|
metric_logger.data_loading_times.append( |
|
time.perf_counter() - data_load_start |
|
) |
|
|
|
input_ids = input_ids.to(device_type) |
|
|
|
""" |
|
TODO[flame]: We need to carefully handle the position_ids for TP/CP |
|
Depending on the Models'PE, the position_ids might be different. |
|
|
|
e.g. for TP |
|
For RoPE, all ranks have the same position_ids. [FOR HF model] |
|
For sinusoidal, each rank has the coresponding chunked position_ids. [FOR HF model] |
|
|
|
e.g. for CP, [optional_context_parallel_ctx shoudl automatically distbute the position_ids] |
|
Each rank has the coresponding chunked position_ids. [FOR All model] |
|
|
|
""" |
|
labels = labels.to(device_type) |
|
cu_seqlens = ( |
|
batch["cu_seqlens"].to(device_type) |
|
if "cu_seqlens" in batch |
|
else None |
|
) |
|
if cu_seqlens is not None: |
|
position_ids = prepare_position_ids(cu_seqlens).to(torch.int32) |
|
else: |
|
position_ids = ( |
|
torch.arange(0, input_ids.shape[1], device=device_type) |
|
.repeat(input_ids.shape[0], 1) |
|
.to(torch.int32) |
|
) |
|
|
|
|
|
optional_context_parallel_ctx = ( |
|
dist_utils.create_context_parallel_ctx( |
|
cp_mesh=world_mesh["cp"], |
|
cp_buffers=[input_ids, labels, position_ids], |
|
cp_seq_dims=[1, 1, 1], |
|
cp_no_restore_buffers={input_ids, labels, position_ids}, |
|
cp_rotate_method=job_config.experimental.context_parallel_rotate_method, |
|
) |
|
if parallel_dims.cp_enabled |
|
else None |
|
) |
|
|
|
|
|
if parallel_dims.pp_enabled: |
|
raise NotImplementedError( |
|
"Pipeline parallelism is not supported in this version" |
|
) |
|
|
|
with train_context(optional_context_parallel_ctx): |
|
targets, losses = ( |
|
(labels, []) if has_last_stage else (None, None) |
|
) |
|
|
|
if has_first_stage: |
|
pp_schedule.step(input_ids, target=targets, losses=losses) |
|
else: |
|
pp_schedule.step(target=targets, losses=losses) |
|
|
|
|
|
|
|
loss = ( |
|
torch.mean(torch.stack(losses)).to(device) |
|
if has_last_stage |
|
else torch.tensor([-1.0], device=device) |
|
) |
|
else: |
|
|
|
with train_context(optional_context_parallel_ctx): |
|
output = model( |
|
input_ids=input_ids, |
|
labels=labels, |
|
position_ids=position_ids, |
|
cu_seqlens=cu_seqlens, |
|
) |
|
loss = ( |
|
output.loss |
|
/ job_config.training.gradient_accumulation_steps |
|
) |
|
loss.backward() |
|
|
|
losses.append(loss) |
|
loss = sum(losses) |
|
|
|
|
|
grad_norm = dist_utils.clip_grad_norm_( |
|
[p for m in model_parts for p in m.parameters()], |
|
job_config.training.max_norm, |
|
foreach=True, |
|
pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, |
|
) |
|
|
|
|
|
checkpoint.maybe_wait_for_staging() |
|
if job_config.training.skip_nan_inf and ( |
|
grad_norm.isnan() or grad_norm.isinf() |
|
): |
|
logger.warning( |
|
f"Skipping optimizer step - detected invalid gradient norm: {grad_norm:.4f}" |
|
) |
|
optimizers.zero_grad() |
|
train_state.skipped_step += 1 |
|
else: |
|
optimizers.step() |
|
lr_schedulers.step() |
|
|
|
|
|
if metric_logger.should_log(train_state.step): |
|
if ( |
|
parallel_dims.dp_replicate_enabled |
|
or parallel_dims.dp_shard_enabled |
|
or parallel_dims.cp_enabled |
|
): |
|
loss = loss.detach() |
|
|
|
global_avg_loss, global_max_loss = ( |
|
dist_utils.dist_mean( |
|
loss, |
|
world_mesh["dp_cp"], |
|
), |
|
dist_utils.dist_max( |
|
loss, |
|
world_mesh["dp_cp"], |
|
), |
|
) |
|
else: |
|
|
|
global_avg_loss = global_max_loss = loss.item() |
|
|
|
|
|
time_now = time.perf_counter() |
|
time_delta = ( |
|
time_now - metric_logger.time_last_log |
|
) |
|
train_state.token += ( |
|
metric_logger.ntokens_since_last_log |
|
* parallel_dims.world_size |
|
/ parallel_dims.non_data_parallel_size |
|
) |
|
train_state.elapsed += timedelta(seconds=time_delta) |
|
train_state.log_steps.append(train_state.step) |
|
train_state.global_avg_losses.append(global_avg_loss) |
|
train_state.global_max_losses.append(global_max_loss) |
|
|
|
|
|
last_lr = lr_schedulers.schedulers[0].get_last_lr()[0] |
|
eta = ( |
|
train_state.elapsed |
|
* (job_config.training.steps - train_state.step) |
|
/ train_state.step |
|
) |
|
metric_logger.log( |
|
train_state.step, |
|
global_avg_loss, |
|
global_max_loss, |
|
extra_metrics={ |
|
"optimizer/lr": last_lr, |
|
"optimizer/grad_norm": grad_norm.item(), |
|
"optimizer/skipped_step": train_state.skipped_step, |
|
}, |
|
) |
|
|
|
logger.info( |
|
f"{color.blue}lr: {last_lr:.4e} gnorm: {grad_norm:5.2f} " |
|
f"{color.magenta}[{str(train_state.elapsed).split('.')[0]:>8}<{str(eta).split('.')[0]:>8}]{color.reset}" |
|
) |
|
|
|
checkpoint.save( |
|
train_state.step, force=(train_state.step == job_config.training.steps) |
|
) |
|
|
|
if torch.distributed.get_rank() == 0: |
|
if job_config.checkpoint.enable_checkpoint: |
|
hf_target_path = None |
|
dcp_save_path = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder, f"step-{train_state.step}") |
|
|
|
|
|
if getattr(job_config.checkpoint, "convert_to_hf_on_save", False): |
|
try: |
|
|
|
|
|
hf_target_path = f"{dcp_save_path}" |
|
|
|
logger.info(f"Converting step {train_state.step} DCP checkpoint to HF format at: {hf_target_path}") |
|
save_pretrained( |
|
path=hf_target_path, |
|
step=train_state.step, |
|
config=job_config.model.config, |
|
tokenizer=job_config.model.tokenizer_path |
|
) |
|
logger.info(f"Successfully converted step {train_state.step} to HF format.") |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to convert checkpoint step {train_state.step} to HF format: {e}", exc_info=True) |
|
|
|
base_checkpoint_dir = os.path.join(job_config.job.dump_folder, job_config.checkpoint.folder) |
|
if getattr(job_config.checkpoint, "hf_upload_enabled", True): |
|
upload_format = getattr(job_config.checkpoint, "hf_upload_format", "hf") |
|
keep_k_hub = getattr(job_config.checkpoint, "hf_keep_latest_k", 5) |
|
|
|
local_path_to_upload = None |
|
if upload_format == "hf": |
|
if hf_target_path and os.path.isdir(hf_target_path): |
|
local_path_to_upload = hf_target_path |
|
elif upload_format == "dcp": |
|
if dcp_save_path and os.path.isdir(dcp_save_path): |
|
local_path_to_upload = dcp_save_path |
|
|
|
if local_path_to_upload: |
|
try: |
|
upload_checkpoint_to_hf( |
|
local_path=local_path_to_upload, |
|
step=train_state.step, |
|
hf_repo_id_for_run=run_specific_repo_id, |
|
upload_format=upload_format, |
|
hf_keep_latest_k=job_config.checkpoint.keep_latest_k, |
|
) |
|
except Exception as e: |
|
logger.error(f"Failed during HF Hub upload for step {train_state.step}: {e}", exc_info=True) |
|
|
|
|
|
if torch_profiler: |
|
torch_profiler.step() |
|
if memory_profiler: |
|
memory_profiler.step() |
|
|
|
|
|
|
|
if train_state.step == 1: |
|
dist_utils.set_pg_timeouts( |
|
timeout=timedelta(seconds=job_config.comm.train_timeout_seconds), |
|
world_mesh=world_mesh, |
|
) |
|
|
|
if torch.distributed.get_rank() == 0: |
|
logger.info("Sleeping 2 seconds for other ranks to complete") |
|
time.sleep(2) |
|
|
|
metric_logger.close() |
|
logger.info("Training completed") |
|
|
|
|
|
if __name__ == "__main__": |
|
init_logger() |
|
config = JobConfig() |
|
config.parse_args() |
|
main(config) |
|
torch.distributed.destroy_process_group() |
|
|