p1atdev commited on
Commit
2f0a1c2
Β·
1 Parent(s): 3954b30

fix: dtype, feat: better ui

Browse files
Files changed (1) hide show
  1. app.py +75 -61
app.py CHANGED
@@ -56,6 +56,7 @@ def infer(
56
  ),
57
  generator=generator,
58
  device=device,
 
59
  )
60
  latents = apply_tkg_noise(
61
  latents,
@@ -103,7 +104,7 @@ def on_generate(
103
  tkg_channels = color_name_to_channels(color_name)
104
  # TODO: custom channels
105
 
106
- return infer(
107
  prompt,
108
  negative_prompt,
109
  seed,
@@ -113,8 +114,12 @@ def on_generate(
113
  guidance_scale,
114
  num_inference_steps,
115
  tkg_channels=tkg_channels,
 
 
116
  )
117
 
 
 
118
  examples = [
119
  # "1girl, arima kana, oshi no ko, hoshimachi suisei, hoshimachi suisei \(1st costume\), cosplay, looking at viewer, smile, outdoors, night, v, masterpiece, high score, great score, absurdres",
120
  "1girl, solo, upper body, looking at viewer, straight-on, masterpiece, best quality",
@@ -134,71 +139,79 @@ TKG-DMπŸ₯šπŸš: Training-free Chroma Key Content Generation Diffusion Model
134
  """)
135
 
136
  with gr.Row():
137
- prompt = gr.Text(
138
- label="Prompt",
139
- show_label=False,
140
- max_lines=1,
141
- placeholder="Enter your prompt",
142
- container=False,
143
- )
144
-
145
- run_button = gr.Button("Run", scale=0, variant="primary")
146
-
147
- with gr.Row():
148
- result_w_tkg = gr.Image(label="Result", show_label=False)
149
- result_wo_tkg = gr.Image(label="Result", show_label=False)
150
-
151
- with gr.Accordion("Advanced Settings", open=False):
152
- negative_prompt = gr.Text(
153
- label="Negative prompt",
154
- max_lines=1,
155
- placeholder="Enter a negative prompt",
156
- value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
157
- )
158
-
159
- seed = gr.Slider(
160
- label="Seed",
161
- minimum=0,
162
- maximum=MAX_SEED,
163
- step=1,
164
- value=0,
165
- )
166
-
167
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
168
-
169
- with gr.Row():
170
- width = gr.Slider(
171
- label="Width",
172
- minimum=256,
173
- maximum=MAX_IMAGE_SIZE,
174
- step=32,
175
- value=832,
176
  )
177
-
178
- height = gr.Slider(
179
- label="Height",
180
- minimum=256,
181
- maximum=MAX_IMAGE_SIZE,
182
- step=32,
183
- value=1152,
184
  )
185
 
186
- with gr.Row():
187
- guidance_scale = gr.Slider(
188
- label="Guidance scale",
189
- minimum=0.0,
190
- maximum=10.0,
191
- step=0.1,
192
- value=5.0,
193
  )
194
 
195
- num_inference_steps = gr.Slider(
196
- label="Number of inference steps",
197
- minimum=1,
198
- maximum=50,
199
- step=1,
200
- value=25,
201
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  gr.Examples(examples=examples, inputs=[prompt])
204
 
@@ -214,6 +227,7 @@ TKG-DMπŸ₯šπŸš: Training-free Chroma Key Content Generation Diffusion Model
214
  height,
215
  guidance_scale,
216
  num_inference_steps,
 
217
  ],
218
  outputs=[result_w_tkg, result_wo_tkg, seed],
219
  )
 
56
  ),
57
  generator=generator,
58
  device=device,
59
+ dtype=torch.bfloat16,
60
  )
61
  latents = apply_tkg_noise(
62
  latents,
 
104
  tkg_channels = color_name_to_channels(color_name)
105
  # TODO: custom channels
106
 
107
+ w_tkg, wo_tkg, seed = infer(
108
  prompt,
109
  negative_prompt,
110
  seed,
 
114
  guidance_scale,
115
  num_inference_steps,
116
  tkg_channels=tkg_channels,
117
+ *args,
118
+ **kwargs,
119
  )
120
 
121
+ return w_tkg, wo_tkg, seed
122
+
123
  examples = [
124
  # "1girl, arima kana, oshi no ko, hoshimachi suisei, hoshimachi suisei \(1st costume\), cosplay, looking at viewer, smile, outdoors, night, v, masterpiece, high score, great score, absurdres",
125
  "1girl, solo, upper body, looking at viewer, straight-on, masterpiece, best quality",
 
139
  """)
140
 
141
  with gr.Row():
142
+ with gr.Column():
143
+ prompt = gr.Text(
144
+ label="Prompt",
145
+ show_label=False,
146
+ max_lines=1,
147
+ placeholder="Enter your prompt",
148
+ container=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  )
150
+ negative_prompt = gr.Textbox(
151
+ label="Negative prompt",
152
+ max_lines=4,
153
+ placeholder="Enter a negative prompt",
154
+ value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
 
 
155
  )
156
 
157
+ color_set = gr.Dropdown(
158
+ label="Chroma key color",
159
+ choices=list(COLOR_SET_MAP.keys()),
160
+ value="green",
 
 
 
161
  )
162
 
163
+
164
+ with gr.Accordion("Advanced Settings", open=False):
165
+ seed = gr.Slider(
166
+ label="Seed",
167
+ minimum=0,
168
+ maximum=MAX_SEED,
169
+ step=1,
170
+ value=0,
171
+ )
172
+
173
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
174
+
175
+ with gr.Row():
176
+ width = gr.Slider(
177
+ label="Width",
178
+ minimum=256,
179
+ maximum=MAX_IMAGE_SIZE,
180
+ step=32,
181
+ value=832,
182
+ )
183
+
184
+ height = gr.Slider(
185
+ label="Height",
186
+ minimum=256,
187
+ maximum=MAX_IMAGE_SIZE,
188
+ step=32,
189
+ value=1152,
190
+ )
191
+
192
+ with gr.Row():
193
+ guidance_scale = gr.Slider(
194
+ label="Guidance scale",
195
+ minimum=0.0,
196
+ maximum=10.0,
197
+ step=0.1,
198
+ value=5.0,
199
+ )
200
+
201
+ num_inference_steps = gr.Slider(
202
+ label="Number of inference steps",
203
+ minimum=1,
204
+ maximum=50,
205
+ step=1,
206
+ value=25,
207
+ )
208
+
209
+ with gr.Column():
210
+ run_button = gr.Button("Run", scale=0, variant="primary")
211
+ result_w_tkg = gr.Image(label="Result", show_label=False)
212
+ result_wo_tkg = gr.Image(label="Result", show_label=False)
213
+
214
+
215
 
216
  gr.Examples(examples=examples, inputs=[prompt])
217
 
 
227
  height,
228
  guidance_scale,
229
  num_inference_steps,
230
+ color_set,
231
  ],
232
  outputs=[result_w_tkg, result_wo_tkg, seed],
233
  )