humblemikey commited on
Commit
9c865fe
Β·
verified Β·
1 Parent(s): 1ea2fbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -126
app.py CHANGED
@@ -12,7 +12,7 @@ import spaces
12
  import torch
13
  from diffusers import AutoencoderKL, DiffusionPipeline
14
 
15
- DESCRIPTION = "# SDXL"
16
  if not torch.cuda.is_available():
17
  DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
18
 
@@ -20,11 +20,10 @@ MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
21
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
22
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
23
- ENABLE_REFINER = os.getenv("ENABLE_REFINER", "1") == "1"
24
 
25
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26
  if torch.cuda.is_available():
27
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
28
  pipe = DiffusionPipeline.from_pretrained(
29
  "humblemikey/PixelWave10",
30
  #vae=vae,
@@ -32,14 +31,6 @@ if torch.cuda.is_available():
32
  use_safetensors=True,
33
  #variant="fp16",
34
  )
35
- if ENABLE_REFINER:
36
- refiner = DiffusionPipeline.from_pretrained(
37
- "stabilityai/stable-diffusion-xl-refiner-1.0",
38
- vae=vae,
39
- torch_dtype=torch.float16,
40
- use_safetensors=True,
41
- variant="fp16",
42
- )
43
 
44
  if ENABLE_CPU_OFFLOAD:
45
  pipe.enable_model_cpu_offload()
@@ -47,13 +38,9 @@ if torch.cuda.is_available():
47
  refiner.enable_model_cpu_offload()
48
  else:
49
  pipe.to(device)
50
- if ENABLE_REFINER:
51
- refiner.to(device)
52
 
53
  if USE_TORCH_COMPILE:
54
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
55
- if ENABLE_REFINER:
56
- refiner.unet = torch.compile(refiner.unet, mode="reduce-overhead", fullgraph=True)
57
 
58
 
59
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
@@ -66,67 +53,31 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
66
  def generate(
67
  prompt: str,
68
  negative_prompt: str = "",
69
- prompt_2: str = "",
70
- negative_prompt_2: str = "",
71
  use_negative_prompt: bool = False,
72
- use_prompt_2: bool = False,
73
- use_negative_prompt_2: bool = False,
74
  seed: int = 0,
75
  width: int = 1024,
76
  height: int = 1024,
77
- guidance_scale_base: float = 5.0,
78
- guidance_scale_refiner: float = 5.0,
79
- num_inference_steps_base: int = 25,
80
- num_inference_steps_refiner: int = 25,
81
- apply_refiner: bool = False,
82
  ) -> PIL.Image.Image:
83
  generator = torch.Generator().manual_seed(seed)
84
 
85
  if not use_negative_prompt:
86
  negative_prompt = None # type: ignore
87
- if not use_prompt_2:
88
- prompt_2 = None # type: ignore
89
- if not use_negative_prompt_2:
90
- negative_prompt_2 = None # type: ignore
91
-
92
- if not apply_refiner:
93
- return pipe(
94
- prompt=prompt,
95
- negative_prompt=negative_prompt,
96
- prompt_2=prompt_2,
97
- negative_prompt_2=negative_prompt_2,
98
- width=width,
99
- height=height,
100
- guidance_scale=guidance_scale_base,
101
- num_inference_steps=num_inference_steps_base,
102
- generator=generator,
103
- output_type="pil",
104
- ).images[0]
105
- else:
106
- latents = pipe(
107
- prompt=prompt,
108
- negative_prompt=negative_prompt,
109
- prompt_2=prompt_2,
110
- negative_prompt_2=negative_prompt_2,
111
- width=width,
112
- height=height,
113
- guidance_scale=guidance_scale_base,
114
- num_inference_steps=num_inference_steps_base,
115
- generator=generator,
116
- output_type="latent",
117
- ).images
118
- image = refiner(
119
- prompt=prompt,
120
- negative_prompt=negative_prompt,
121
- prompt_2=prompt_2,
122
- negative_prompt_2=negative_prompt_2,
123
- guidance_scale=guidance_scale_refiner,
124
- num_inference_steps=num_inference_steps_refiner,
125
- image=latents,
126
- generator=generator,
127
- ).images[0]
128
- return image
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  examples = [
132
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
@@ -154,27 +105,12 @@ with gr.Blocks(css="style.css") as demo:
154
  with gr.Accordion("Advanced options", open=False):
155
  with gr.Row():
156
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
157
- use_prompt_2 = gr.Checkbox(label="Use prompt 2", value=False)
158
- use_negative_prompt_2 = gr.Checkbox(label="Use negative prompt 2", value=False)
159
  negative_prompt = gr.Text(
160
  label="Negative prompt",
161
  max_lines=1,
162
  placeholder="Enter a negative prompt",
163
  visible=False,
164
  )
165
- prompt_2 = gr.Text(
166
- label="Prompt 2",
167
- max_lines=1,
168
- placeholder="Enter your prompt",
169
- visible=False,
170
- )
171
- negative_prompt_2 = gr.Text(
172
- label="Negative prompt 2",
173
- max_lines=1,
174
- placeholder="Enter a negative prompt",
175
- visible=False,
176
- )
177
-
178
  seed = gr.Slider(
179
  label="Seed",
180
  minimum=0,
@@ -198,7 +134,6 @@ with gr.Blocks(css="style.css") as demo:
198
  step=32,
199
  value=1024,
200
  )
201
- apply_refiner = gr.Checkbox(label="Apply refiner", value=False, visible=ENABLE_REFINER)
202
  with gr.Row():
203
  guidance_scale_base = gr.Slider(
204
  label="Guidance scale for base",
@@ -214,21 +149,6 @@ with gr.Blocks(css="style.css") as demo:
214
  step=1,
215
  value=25,
216
  )
217
- with gr.Row(visible=False) as refiner_params:
218
- guidance_scale_refiner = gr.Slider(
219
- label="Guidance scale for refiner",
220
- minimum=1,
221
- maximum=20,
222
- step=0.1,
223
- value=5.0,
224
- )
225
- num_inference_steps_refiner = gr.Slider(
226
- label="Number of inference steps for refiner",
227
- minimum=10,
228
- maximum=100,
229
- step=1,
230
- value=25,
231
- )
232
 
233
  gr.Examples(
234
  examples=examples,
@@ -244,34 +164,11 @@ with gr.Blocks(css="style.css") as demo:
244
  queue=False,
245
  api_name=False,
246
  )
247
- use_prompt_2.change(
248
- fn=lambda x: gr.update(visible=x),
249
- inputs=use_prompt_2,
250
- outputs=prompt_2,
251
- queue=False,
252
- api_name=False,
253
- )
254
- use_negative_prompt_2.change(
255
- fn=lambda x: gr.update(visible=x),
256
- inputs=use_negative_prompt_2,
257
- outputs=negative_prompt_2,
258
- queue=False,
259
- api_name=False,
260
- )
261
- apply_refiner.change(
262
- fn=lambda x: gr.update(visible=x),
263
- inputs=apply_refiner,
264
- outputs=refiner_params,
265
- queue=False,
266
- api_name=False,
267
- )
268
 
269
  gr.on(
270
  triggers=[
271
  prompt.submit,
272
  negative_prompt.submit,
273
- prompt_2.submit,
274
- negative_prompt_2.submit,
275
  run_button.click,
276
  ],
277
  fn=randomize_seed_fn,
@@ -284,19 +181,13 @@ with gr.Blocks(css="style.css") as demo:
284
  inputs=[
285
  prompt,
286
  negative_prompt,
287
- prompt_2,
288
- negative_prompt_2,
289
  use_negative_prompt,
290
- use_prompt_2,
291
- use_negative_prompt_2,
292
  seed,
293
  width,
294
  height,
295
  guidance_scale_base,
296
  guidance_scale_refiner,
297
  num_inference_steps_base,
298
- num_inference_steps_refiner,
299
- apply_refiner,
300
  ],
301
  outputs=result,
302
  api_name="run",
 
12
  import torch
13
  from diffusers import AutoencoderKL, DiffusionPipeline
14
 
15
+ DESCRIPTION = "humblemikey/PixelWave10"
16
  if not torch.cuda.is_available():
17
  DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
18
 
 
20
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
21
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
22
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
 
23
 
24
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
  if torch.cuda.is_available():
26
+ #vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
27
  pipe = DiffusionPipeline.from_pretrained(
28
  "humblemikey/PixelWave10",
29
  #vae=vae,
 
31
  use_safetensors=True,
32
  #variant="fp16",
33
  )
 
 
 
 
 
 
 
 
34
 
35
  if ENABLE_CPU_OFFLOAD:
36
  pipe.enable_model_cpu_offload()
 
38
  refiner.enable_model_cpu_offload()
39
  else:
40
  pipe.to(device)
 
 
41
 
42
  if USE_TORCH_COMPILE:
43
  pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
 
 
44
 
45
 
46
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
 
53
  def generate(
54
  prompt: str,
55
  negative_prompt: str = "",
 
 
56
  use_negative_prompt: bool = False,
 
 
57
  seed: int = 0,
58
  width: int = 1024,
59
  height: int = 1024,
60
+ guidance_scale_base: float = 4.0,
61
+ num_inference_steps_base: int = 40,
 
 
 
62
  ) -> PIL.Image.Image:
63
  generator = torch.Generator().manual_seed(seed)
64
 
65
  if not use_negative_prompt:
66
  negative_prompt = None # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ return pipe(
69
+ prompt=prompt,
70
+ negative_prompt=negative_prompt,
71
+ prompt_2=prompt_2,
72
+ negative_prompt_2=negative_prompt_2,
73
+ width=width,
74
+ height=height,
75
+ guidance_scale=guidance_scale_base,
76
+ num_inference_steps=num_inference_steps_base,
77
+ generator=generator,
78
+ output_type="pil",
79
+ ).images[0]
80
+ return image
81
 
82
  examples = [
83
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
 
105
  with gr.Accordion("Advanced options", open=False):
106
  with gr.Row():
107
  use_negative_prompt = gr.Checkbox(label="Use negative prompt", value=False)
 
 
108
  negative_prompt = gr.Text(
109
  label="Negative prompt",
110
  max_lines=1,
111
  placeholder="Enter a negative prompt",
112
  visible=False,
113
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  seed = gr.Slider(
115
  label="Seed",
116
  minimum=0,
 
134
  step=32,
135
  value=1024,
136
  )
 
137
  with gr.Row():
138
  guidance_scale_base = gr.Slider(
139
  label="Guidance scale for base",
 
149
  step=1,
150
  value=25,
151
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  gr.Examples(
154
  examples=examples,
 
164
  queue=False,
165
  api_name=False,
166
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  gr.on(
169
  triggers=[
170
  prompt.submit,
171
  negative_prompt.submit,
 
 
172
  run_button.click,
173
  ],
174
  fn=randomize_seed_fn,
 
181
  inputs=[
182
  prompt,
183
  negative_prompt,
 
 
184
  use_negative_prompt,
 
 
185
  seed,
186
  width,
187
  height,
188
  guidance_scale_base,
189
  guidance_scale_refiner,
190
  num_inference_steps_base,
 
 
191
  ],
192
  outputs=result,
193
  api_name="run",