humblemikey commited on
Commit
54bbeca
Β·
verified Β·
1 Parent(s): b58b727

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -0
app.py CHANGED
@@ -11,6 +11,13 @@ import PIL.Image
11
  import spaces
12
  import torch
13
  from diffusers import AutoencoderKL, StableDiffusionXLPipeline
 
 
 
 
 
 
 
14
 
15
  DESCRIPTION = "# humblemikey/PixelWave10"
16
  if not torch.cuda.is_available():
@@ -46,6 +53,16 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
46
  seed = random.randint(0, MAX_SEED)
47
  return seed
48
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  @spaces.GPU
51
  def generate(
@@ -57,9 +74,14 @@ def generate(
57
  height: int = 1024,
58
  guidance_scale_base: float = 4.0,
59
  num_inference_steps_base: int = 40,
 
 
60
  ) -> PIL.Image.Image:
61
  generator = torch.Generator().manual_seed(seed)
62
 
 
 
 
63
  if not use_negative_prompt:
64
  negative_prompt = None # type: ignore
65
 
 
11
  import spaces
12
  import torch
13
  from diffusers import AutoencoderKL, StableDiffusionXLPipeline
14
+ from diffusers import (
15
+ DDIMScheduler,
16
+ DPMSolverMultistepScheduler,
17
+ DPMSolverSinglestepScheduler,
18
+ EulerAncestralDiscreteScheduler,
19
+ EulerDiscreteScheduler,
20
+ )
21
 
22
  DESCRIPTION = "# humblemikey/PixelWave10"
23
  if not torch.cuda.is_available():
 
53
  seed = random.randint(0, MAX_SEED)
54
  return seed
55
 
56
+ def get_scheduler(scheduler_config: Dict, name: str) -> Optional[Callable]:
57
+ scheduler_factory_map = {
58
+ "DPM++ 2M Karras": lambda: DPMSolverMultistepScheduler.from_config(scheduler_config, use_karras_sigmas=True),
59
+ "DPM++ SDE Karras": lambda: DPMSolverSinglestepScheduler.from_config(scheduler_config, use_karras_sigmas=True),
60
+ "DPM++ 2M SDE Karras": lambda: DPMSolverMultistepScheduler.from_config(scheduler_config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++"),
61
+ "Euler": lambda: EulerDiscreteScheduler.from_config(scheduler_config),
62
+ "Euler a": lambda: EulerAncestralDiscreteScheduler.from_config(scheduler_config),
63
+ "DDIM": lambda: DDIMScheduler.from_config(scheduler_config),
64
+ }
65
+ return scheduler_factory_map.get(name, lambda: None)()
66
 
67
  @spaces.GPU
68
  def generate(
 
74
  height: int = 1024,
75
  guidance_scale_base: float = 4.0,
76
  num_inference_steps_base: int = 40,
77
+ sampler: str = "DPM++ 2M SDE Karras",
78
+ progress=gr.Progress(track_tqdm=True)
79
  ) -> PIL.Image.Image:
80
  generator = torch.Generator().manual_seed(seed)
81
 
82
+ #backup_scheduler = pipe.scheduler
83
+ pipe.scheduler = get_scheduler(pipe.scheduler.config, sampler)
84
+
85
  if not use_negative_prompt:
86
  negative_prompt = None # type: ignore
87