""" Train SDXL-LoRA on lambdalabs/pokemon-blip-captions. Runs in <16 GB RAM (CPU) thanks to LoRA + 8-bit Adam + fp16. """ import os import torch from datasets import load_dataset from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel from diffusers.training_utils import set_seed from transformers import CLIPTokenizer from accelerate import notebook_launcher import gradio as gr MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" DATASET_NAME = "lambdalabs/pokemon-blip-captions" OUTPUT_DIR = "/tmp/sdxl-lora" # LoRA & memory-friendly defaults training_args = dict( pretrained_model_name_or_path=MODEL_ID, dataset_name=DATASET_NAME, resolution=512, center_crop=True, random_flip=True, train_batch_size=1, gradient_accumulation_steps=4, max_train_steps=1000, learning_rate=1e-4, lr_scheduler="constant", lr_warmup_steps=0, mixed_precision="fp16", gradient_checkpointing=True, use_8bit_adam=True, enable_xformers_memory_efficient_attention=True, checkpointing_steps=500, output_dir=OUTPUT_DIR, push_to_hub=True, hub_model_id=os.getenv("REPO_ID"), hub_token=os.getenv("HF_TOKEN"), report_to="none", # no wandb on free tier ) def run_training(): # Launch diffusers SDXL LoRA script from examples.text_to_image.train_text_to_image_lora_sdxl import main main(training_args) def launch(): # Gradio dummy UI so the Space stays alive gr.Interface(fn=lambda: "Training started in background (see logs below)", inputs=None, outputs="text").launch() if __name__ == "__main__": set_seed(42) # start training in a background process notebook_launcher(run_training, num_processes=1) launch()