File size: 2,908 Bytes
e65f030
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
import torch
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler

# Initialize the model and other components
apply_canny = CannyDetector()
model = create_model('./models/cldm_v21_512_latctrl_coltrans.yaml').cpu()
model.load_state_dict(load_state_dict('xywwww/scene_diffusion/checkpoints/epoch=25-step=112553.ckpt', location='cuda'), strict=False)
model = model.cuda()
ddim_sampler = DDIMSampler(model)

def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold):
    with torch.no_grad():
        img = resize_image(HWC3(input_image), image_resolution)
        H, W, C = img.shape
        # detected_map = apply_canny(img, low_threshold, high_threshold)
        # detected_map = HWC3(detected_map)
        # Add the rest of the processing logic here

def create_demo(process):
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                input_image = gr.Image()
                prompt = gr.Textbox(label="Prompt", submit_btn=True)
                a_prompt = gr.Textbox(label="Additional Prompt")
                n_prompt = gr.Textbox(label="Negative Prompt")
                with gr.Accordion("Advanced options", open=False):
                    num_samples = gr.Slider(label="Number of images", minimum=1, maximum=10, value=1, step=1)
                    image_resolution = gr.Slider(label="Image resolution", minimum=256, maximum=1024, value=512, step=256)
                    ddim_steps = gr.Slider(label="DDIM Steps", minimum=1, maximum=100, value=50, step=1)
                    guess_mode = gr.Checkbox(label="Guess Mode")
                    strength = gr.Slider(label="Strength", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
                    scale = gr.Slider(label="Scale", minimum=0.1, maximum=30.0, value=10.0, step=0.1)
                    seed = gr.Slider(label="Seed", minimum=0, maximum=10000, value=42, step=1)
                    eta = gr.Slider(label="ETA", minimum=0.0, maximum=1.0, value=0.0, step=0.1)
                    low_threshold = gr.Slider(label="Canny Low Threshold", minimum=1, maximum=255, value=100, step=1)
                    high_threshold = gr.Slider(label="Canny High Threshold", minimum=1, maximum=255, value=200, step=1)
                submit = gr.Button("Generate")
            with gr.Column():
                output_image = gr.Image()
        submit.click(fn=process, inputs=[input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, guess_mode, strength, scale, seed, eta, low_threshold, high_threshold], outputs=output_image)
    return demo

demo = create_demo(process)
demo.launch()