|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
import mlx.core as mx |
|
from stable_diffusion import StableDiffusion |
|
|
|
def generate_images(prompt, n_images=4, steps=50, cfg=7.5, negative_prompt="", n_rows=1): |
|
sd = StableDiffusion() |
|
|
|
|
|
latents = sd.generate_latents( |
|
prompt, |
|
n_images=n_images, |
|
cfg_weight=cfg, |
|
num_steps=steps, |
|
negative_text=negative_prompt, |
|
) |
|
for x_t in latents: |
|
mx.simplify(x_t) |
|
mx.simplify(x_t) |
|
mx.eval(x_t) |
|
|
|
|
|
decoded = [] |
|
for i in range(0, n_images): |
|
decoded_img = sd.decode(x_t[i:i+1]) |
|
mx.eval(decoded_img) |
|
decoded.append(decoded_img) |
|
|
|
|
|
x = mx.concatenate(decoded, axis=0) |
|
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)]) |
|
B, H, W, C = x.shape |
|
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) |
|
x = x.reshape(n_rows * H, B // n_rows * W, C) |
|
x = (x * 255).astype(mx.uint8) |
|
|
|
|
|
return Image.fromarray(x.__array__()) |
|
|
|
iface = gr.Interface( |
|
fn=generate_images, |
|
inputs=[ |
|
gr.Textbox(label="Prompt"), |
|
gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of Images"), |
|
gr.Slider(minimum=20, maximum=100, step=1, value=50, label="Steps"), |
|
gr.Slider(minimum=0.0, maximum=10.0, step=0.1, value=7.5, label="CFG Weight"), |
|
gr.Textbox(default="", label="Negative Prompt"), |
|
gr.Slider(minimum=1, maximum=10, step=1, value=1, label="Number of Rows") |
|
], |
|
outputs="image", |
|
title="Stable Diffusion Image Generator", |
|
description="Generate images from a textual prompt using Stable Diffusion" |
|
) |
|
|
|
iface.launch() |