SimpleFSDP
This folder includes an experimental frontend implementation for SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile. SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.
Enable SimpleFSDP Training
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile --training.mixed_precision_param float32
Note: The mixed precision training support is on-going. We set training.mixed_precision_param
to float32
for now and will remove it once the integration is completed.
Composability Support
Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features:
Feature | Support |
---|---|
Meta Initialization | β |
Activation Checkpointing | β |
Mixed Precision Training | π§ |
Tensor Parallelism | π§ |
Context Parallelism | β |
Pipeline Parallelism | β |
Distributed Checkpointing | π§ |
Float8 Training | β |
Citation
If you find SimpleFSDP useful, please kindly consider citing the following paper:
@article{zhang2024simplefsdp,
title={SimpleFSDP: Simpler Fully Sharded Data Parallel with torch. compile},
author={Zhang, Ruisi and Liu, Tianyu and Feng, Will and Gu, Andrew and Purandare, Sanket and Liang, Wanchao and Massa, Francisco},
journal={arXiv preprint arXiv:2411.00284},
year={2024}
}