JAX-Diffusion
JAX-Diffusion is a project that implements diffusion models using JAX, a high-performance numerical computing library. Diffusion models are a class of generative models that have gained popularity for their ability to generate high-quality data samples.
Features
- Implementation of diffusion models in JAX.
- High-performance and scalable computations.
- Modular and extensible codebase.
Installation
Clone the repository:
git clone https://github.com/your-username/JAX-Diffusion.git cd JAX-Diffusion
Install dependencies:
pip install -r requirements.txt
Usage
To train a diffusion model:
python train.py --config configs/default.yaml
To generate samples:
python generate.py --model checkpoints/model.pth
Contributing
Contributions are welcome! Please follow these steps:
- Fork the repository.
- Create a new branch for your feature or bug fix.
- Submit a pull request with a clear description of your changes.
License
This project is licensed under the MIT License. See the LICENSE file for details.
Acknowledgments
- JAX for providing the foundation for numerical computing.
- The research community for advancements in diffusion models.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
1
Ask for provider support