import gradio as gr import numpy as np import random import spaces from PIL import Image # import spaces #[uncomment to use ZeroGPU] import torch from transformers import AutoTokenizer, AutoModel from models.gen_pipeline import NextStepPipeline from utils.aspect_ratio import center_crop_arr_with_buckets HF_HUB = "stepfun-ai/NextStep-1-Large-Edit" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True) model = AutoModel.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True) pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device) MAX_SEED = np.iinfo(np.int16).max MAX_IMAGE_SIZE = 512 DEFAULT_POSITIVE_PROMPT = None DEFAULT_NEGATIVE_PROMPT = "copy the original image" @spaces.GPU(duration=300) def infer( prompt=None, ref=None, seed=0, text_cfg=7.5, img_cfg=2.0, num_inference_steps=30, positive_prompt=DEFAULT_POSITIVE_PROMPT, negative_prompt=DEFAULT_NEGATIVE_PROMPT, progress=gr.Progress(track_tqdm=True), ): if ref is None: gr.Warning("⚠️ 请上传图片!") return None if prompt in [None, ""]: gr.Warning("⚠️ 请输入提示词!") return None if ref is not None: editing_caption = "" + prompt input_image = ref input_image = center_crop_arr_with_buckets(input_image, buckets=[512]) else: editing_caption = prompt input_image = None img_cfg = 1.0 image = pipeline.generate_image( captions=editing_caption, images=input_image, num_images_per_caption=2, positive_prompt=positive_prompt, negative_prompt=negative_prompt, hw=(input_image.size[1], input_image.size[0]), cfg=text_cfg, cfg_img=img_cfg, cfg_schedule="constant", use_norm=True, num_sampling_steps=num_inference_steps, seed=seed, progress=True, ) return image[0], image[1] examples = [ ["修改图像,让白马向镜头奔跑。", "assets/1.jpg"], ["Change the background to the sea view.", "assets/2.jpg"], ["Add a pirate hat to the dog's head. Change the background to a stormy sea with dark clouds. Include the text 'NextStep-Edit' in bold white letters at the top portion of the image.", "assets/3.jpg"], ["改为吉卜力风格。", "assets/4.jpg"], ] css = """ #col-container { margin: 0 auto; max-width: 800px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(" # NextStep-1-Large-Edit") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, max_lines=1, placeholder="Enter your prompt", container=False, ) run_button = gr.Button("Run", scale=0, variant="primary") with gr.Row(): ref = gr.Image(label="Reference Image", show_label=True, type="pil", height=400) with gr.Accordion("Advanced Settings", open=True): positive_prompt = gr.Text( label="Positive Prompt", show_label=False, max_lines=2, placeholder="Enter your positive prompt", container=False, ) negative_prompt = gr.Text( label="Negative Prompt", show_label=False, max_lines=2, placeholder="Enter your negative prompt", container=False, ) with gr.Row(): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) num_inference_steps = gr.Slider( label="# sampling steps", minimum=10, maximum=50, step=1, value=30, # Replace with defaults that work for your model ) with gr.Row(): text_cfg = gr.Slider( label="Text cfg", minimum=1.0, maximum=15.0, step=0.1, value=7.5, # Replace with defaults that work for your model ) img_cfg = gr.Slider( label="Image cfg", minimum=1.0, maximum=15.0, step=0.1, value=2.0, # Replace with defaults that work for your model ) with gr.Row(): result_1 = gr.Image(label="Result 1", show_label=False, container=True, height=400, visible=False) result_2 = gr.Image(label="Result 2", show_label=False, container=True, height=400, visible=False) gr.Examples(examples=examples, inputs=[prompt, ref]) def show_result(): return gr.update(visible=True), gr.update(visible=True) gr.on( triggers=[run_button.click, prompt.submit], fn=infer, inputs=[ prompt, ref, seed, text_cfg, img_cfg, num_inference_steps, positive_prompt, negative_prompt, ], outputs=[result_1, result_2], ) gr.on( triggers=[run_button.click, prompt.submit], fn=show_result, outputs=[result_1, result_2], ) if __name__ == "__main__": demo.launch()