zaydzuhri's picture
Add files using upload-large-folder tool
e49db55 verified
|
raw
history blame
1.72 kB

SimpleFSDP

This folder includes an experimental frontend implementation for SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile. SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.

Enable SimpleFSDP Training

CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile --training.mixed_precision_param float32

Note: The mixed precision training support is on-going. We set training.mixed_precision_param to float32 for now and will remove it once the integration is completed.

Composability Support

Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features:

Feature Support
Meta Initialization βœ…
Activation Checkpointing βœ…
Mixed Precision Training 🚧
Tensor Parallelism 🚧
Context Parallelism βœ…
Pipeline Parallelism βœ…
Distributed Checkpointing 🚧
Float8 Training ❌

Citation

If you find SimpleFSDP useful, please kindly consider citing the following paper:

@article{zhang2024simplefsdp,
  title={SimpleFSDP: Simpler Fully Sharded Data Parallel with torch. compile},
  author={Zhang, Ruisi and Liu, Tianyu and Feng, Will and Gu, Andrew and Purandare, Sanket and Liang, Wanchao and Massa, Francisco},
  journal={arXiv preprint arXiv:2411.00284},
  year={2024}
}