svjack commited on
Commit
7591849
·
1 Parent(s): 52d3b71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -0
app.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import random
8
+
9
+ from diffusers import AutoPipelineForText2Image
10
+ from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS
11
+
12
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ pipe = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen",
14
+ torch_dtype=torch.float32)
15
+ pipe.to(device)
16
+
17
+ pipe.safety_checker = None
18
+
19
+ '''
20
+ #### 9min a sample (2 cores)
21
+ caption = "Anthropomorphic cat dressed as a fire fighter"
22
+ images = pipe(
23
+ caption,
24
+ width=512,
25
+ height=512,
26
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, #### length of 30
27
+ prior_guidance_scale=4.0,
28
+ num_images_per_prompt=1,
29
+ num_inference_steps = 6, #### default num of 12, 6 favour
30
+ ).images
31
+ '''
32
+
33
+ def process(prompt, num_samples, image_resolution, sample_steps, seed,):
34
+ from PIL import Image
35
+ with torch.no_grad():
36
+ if seed == -1:
37
+ seed = random.randint(0, 65535)
38
+ #control_image = Image.fromarray(detected_map)
39
+
40
+ # run inference
41
+ #generator = torch.Generator(device=device).manual_seed(seed)
42
+ H = image_resolution
43
+ W = image_resolution
44
+ images = []
45
+ for i in range(num_samples):
46
+ image = pipe(
47
+ prompt,
48
+ prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
49
+ prior_guidance_scale=4.0,
50
+ num_inference_steps = sample_steps,
51
+ num_images_per_prompt=1,
52
+ height=H, width=W).images[0]
53
+ images.append(np.asarray(image))
54
+
55
+ results = images
56
+ return results
57
+ #return [255 - detected_map] + results
58
+
59
+ block = gr.Blocks().queue()
60
+ with block:
61
+ with gr.Row():
62
+ gr.Markdown("## Rapid Diffusion model from warp-ai/wuerstchen")
63
+ #gr.Markdown("This _example_ was **drive** from <br/><b><h4>[https://github.com/svjack/ControlLoRA-Chinese](https://github.com/svjack/ControlLoRA-Chinese)</h4></b>\n")
64
+ with gr.Row():
65
+ with gr.Column():
66
+ #input_image = gr.Image(source='upload', type="numpy", value = "hate_dog.png")
67
+ prompt = gr.Textbox(label="Prompt", value = "Anthropomorphic cat dressed as a fire fighter")
68
+ run_button = gr.Button(label="Run")
69
+ with gr.Accordion("Advanced options", open=False):
70
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
71
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
72
+ #low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
73
+ #high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
74
+ sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=6, step=1)
75
+ #scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
76
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
77
+ #eta = gr.Number(label="eta", value=0.0)
78
+ #a_prompt = gr.Textbox(label="Added Prompt", value='')
79
+ #n_prompt = gr.Textbox(label="Negative Prompt",
80
+ # value='低质量,模糊,混乱')
81
+ with gr.Column():
82
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
83
+ #ips = [None, prompt, None, None, num_samples, image_resolution, sample_steps, None, seed, None, None, None]
84
+ ips = [prompt, num_samples, image_resolution, sample_steps, seed]
85
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery], show_progress = True)
86
+
87
+
88
+ block.launch(server_name='0.0.0.0')