Spaces:
Build error
Build error
""" | |
Train SDXL-LoRA on lambdalabs/pokemon-blip-captions. | |
Runs in <16 GB RAM (CPU) thanks to LoRA + 8-bit Adam + fp16. | |
""" | |
import os | |
import torch | |
from datasets import load_dataset | |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel | |
from diffusers.training_utils import set_seed | |
from transformers import CLIPTokenizer | |
from accelerate import notebook_launcher | |
import gradio as gr | |
MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" | |
DATASET_NAME = "lambdalabs/pokemon-blip-captions" | |
OUTPUT_DIR = "/tmp/sdxl-lora" | |
# LoRA & memory-friendly defaults | |
training_args = dict( | |
pretrained_model_name_or_path=MODEL_ID, | |
dataset_name=DATASET_NAME, | |
resolution=512, | |
center_crop=True, | |
random_flip=True, | |
train_batch_size=1, | |
gradient_accumulation_steps=4, | |
max_train_steps=1000, | |
learning_rate=1e-4, | |
lr_scheduler="constant", | |
lr_warmup_steps=0, | |
mixed_precision="fp16", | |
gradient_checkpointing=True, | |
use_8bit_adam=True, | |
enable_xformers_memory_efficient_attention=True, | |
checkpointing_steps=500, | |
output_dir=OUTPUT_DIR, | |
push_to_hub=True, | |
hub_model_id=os.getenv("REPO_ID"), | |
hub_token=os.getenv("HF_TOKEN"), | |
report_to="none", # no wandb on free tier | |
) | |
def run_training(): | |
# Launch diffusers SDXL LoRA script | |
from examples.text_to_image.train_text_to_image_lora_sdxl import main | |
main(training_args) | |
def launch(): | |
# Gradio dummy UI so the Space stays alive | |
gr.Interface(fn=lambda: "Training started in background (see logs below)", | |
inputs=None, outputs="text").launch() | |
if __name__ == "__main__": | |
set_seed(42) | |
# start training in a background process | |
notebook_launcher(run_training, num_processes=1) | |
launch() |