import gradio as gr import numpy as np import spaces from PIL import Image import torch from torch.amp import autocast from transformers import AutoTokenizer, AutoModel from models.gen_pipeline import NextStepPipeline HF_HUB = "stepfun-ai/NextStep-1-Large" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True) model = AutoModel.from_pretrained( HF_HUB, local_files_only=False, trust_remote_code=True, torch_dtype=torch.bfloat16, ).to(device) pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16) MAX_SEED = np.iinfo(np.int16).max DEFAULT_POSITIVE_PROMPT = None DEFAULT_NEGATIVE_PROMPT = None def _ensure_pil(x): """Ensure returned image is a PIL.Image.Image.""" if isinstance(x, Image.Image): return x import numpy as np if hasattr(x, "detach"): x = x.detach().float().clamp(0, 1).cpu().numpy() if isinstance(x, np.ndarray): if x.dtype != np.uint8: x = (x * 255.0).clip(0, 255).astype(np.uint8) if x.ndim == 3 and x.shape[0] in (1,3,4): # CHW -> HWC x = np.moveaxis(x, 0, -1) return Image.fromarray(x) raise TypeError("Unsupported image type returned by pipeline.") @spaces.GPU(duration=300) def infer( prompt=None, seed=0, width=512, height=512, num_inference_steps=28, positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, progress=gr.Progress(track_tqdm=True), ): """Run inference at exactly (width, height).""" if prompt in [None, ""]: gr.Warning("⚠️ Please enter a prompt!") return None with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16): imgs = pipeline.generate_image( prompt, hw=(int(height), int(width)), num_images_per_caption=1, positive_prompt=positive_prompt, negative_prompt=negative_prompt, cfg=7.5, cfg_img=1.0, cfg_schedule="constant", use_norm=False, num_sampling_steps=int(num_inference_steps), timesteps_shift=1.0, seed=int(seed), progress=True, ) return _ensure_pil(imgs[0]) # Return raw output exactly as generated css = """ #col-container { margin: 0 auto; max-width: 800px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# NextStep-1-Large — Exact Output Size") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=2, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0, variant="primary") cancel_button = gr.Button("Cancel", scale=0, variant="secondary") with gr.Row(): with gr.Accordion("Advanced Settings", open=True): positive_prompt = gr.Text( label="Positive Prompt", show_label=True, max_lines=1, placeholder="Optional: add positives", container=True, ) negative_prompt = gr.Text( label="Negative Prompt", show_label=True, max_lines=2, placeholder="Optional: add negatives", container=True, ) with gr.Row(): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=3407, ) num_inference_steps = gr.Slider( label="Sampling steps", minimum=10, maximum=50, step=1, value=28, ) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=512, step=64, value=512, ) height = gr.Slider( label="Height", minimum=256, maximum=512, step=64, value=512, ) with gr.Row(): result_1 = gr.Image( label="Result", show_label=True, container=True, interactive=False, format="png", ) # Click & Fill Examples (all <=512px) examples = [ [ "A cozy wooden cabin by a frozen lake, northern lights in the sky", 123, 512, 512, 28, "photorealistic, cinematic lighting, starry night, glowing reflections", "low-res, distorted, extra objects" ], [ "Futuristic city skyline at sunset, flying cars, neon reflections", 456, 512, 384, 30, "detailed, vibrant, cinematic, sharp edges", "washed out, cartoon, blurry" ], [ "Close-up of a rare orchid in a greenhouse with soft morning light", 789, 384, 512, 32, "macro lens effect, ultra-detailed petals, dew drops", "grainy, noisy, oversaturated" ], ] gr.Examples( examples=examples, inputs=[ prompt, seed, width, height, num_inference_steps, positive_prompt, negative_prompt, ], label="Click & Fill Examples (Exact Size)", ) def show_result(): return gr.update(visible=True) generation_event = gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ prompt, seed, width, height, num_inference_steps, positive_prompt, negative_prompt, ], outputs=[result_1], ) cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[generation_event]) if __name__ == "__main__": demo.launch()