AryanRathod3097's picture
Create app.py
237589c verified
"""
Train SDXL-LoRA on lambdalabs/pokemon-blip-captions.
Runs in <16 GB RAM (CPU) thanks to LoRA + 8-bit Adam + fp16.
"""
import os
import torch
from datasets import load_dataset
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
from diffusers.training_utils import set_seed
from transformers import CLIPTokenizer
from accelerate import notebook_launcher
import gradio as gr
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
DATASET_NAME = "lambdalabs/pokemon-blip-captions"
OUTPUT_DIR = "/tmp/sdxl-lora"
# LoRA & memory-friendly defaults
training_args = dict(
pretrained_model_name_or_path=MODEL_ID,
dataset_name=DATASET_NAME,
resolution=512,
center_crop=True,
random_flip=True,
train_batch_size=1,
gradient_accumulation_steps=4,
max_train_steps=1000,
learning_rate=1e-4,
lr_scheduler="constant",
lr_warmup_steps=0,
mixed_precision="fp16",
gradient_checkpointing=True,
use_8bit_adam=True,
enable_xformers_memory_efficient_attention=True,
checkpointing_steps=500,
output_dir=OUTPUT_DIR,
push_to_hub=True,
hub_model_id=os.getenv("REPO_ID"),
hub_token=os.getenv("HF_TOKEN"),
report_to="none", # no wandb on free tier
)
def run_training():
# Launch diffusers SDXL LoRA script
from examples.text_to_image.train_text_to_image_lora_sdxl import main
main(training_args)
def launch():
# Gradio dummy UI so the Space stays alive
gr.Interface(fn=lambda: "Training started in background (see logs below)",
inputs=None, outputs="text").launch()
if __name__ == "__main__":
set_seed(42)
# start training in a background process
notebook_launcher(run_training, num_processes=1)
launch()