ddosxd commited on
Commit
6ece283
·
verified ·
1 Parent(s): 271487e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -12
app.py CHANGED
@@ -20,11 +20,11 @@ if not torch.cuda.is_available():
20
  DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
- CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") != "0"
24
- MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
- PREVIEW_IMAGES = True
28
 
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -47,10 +47,12 @@ if torch.cuda.is_available():
47
  previewer = Previewer()
48
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
  previewer.load_state_dict(previewer_state_dict)
50
- def callback_prior(i, t, latents):
 
51
  output = previewer(latents)
52
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
53
- return output
 
54
  callback_steps = 1
55
  else:
56
  previewer = None
@@ -62,6 +64,7 @@ else:
62
 
63
 
64
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
 
65
  if randomize_seed:
66
  seed = random.randint(0, MAX_SEED)
67
  return seed
@@ -82,7 +85,8 @@ def generate(
82
  num_images_per_prompt: int = 2,
83
  profile: gr.OAuthProfile | None = None,
84
  ) -> PIL.Image.Image:
85
- previewer.eval().requires_grad_(False).to(device).to(dtype)
 
86
  prior_pipeline.to(device)
87
  decoder_pipeline.to(device)
88
 
@@ -98,10 +102,9 @@ def generate(
98
  guidance_scale=prior_guidance_scale,
99
  num_images_per_prompt=num_images_per_prompt,
100
  generator=generator,
101
- callback=callback_prior,
102
- callback_steps=callback_steps
103
  )
104
-
105
  if PREVIEW_IMAGES:
106
  for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
107
  r = next(prior_output)
@@ -119,7 +122,7 @@ def generate(
119
  generator=generator,
120
  output_type="pil",
121
  ).images
122
-
123
  #Save images
124
  for image in decoder_output:
125
  user_history.save_image(
@@ -137,14 +140,14 @@ def generate(
137
  "num_images_per_prompt": num_images_per_prompt,
138
  },
139
  )
140
-
141
  yield decoder_output[0]
142
 
143
 
144
  examples = [
145
  "An astronaut riding a green horse",
146
  "A mecha robot in a favela by Tarsila do Amaral",
147
- "The sprirt of a Tamagotchi wandering in the city of Los Angeles",
148
  "A delicious feijoada ramen dish"
149
  ]
150
 
@@ -186,12 +189,14 @@ with gr.Blocks() as demo:
186
  label="Width",
187
  minimum=1024,
188
  maximum=MAX_IMAGE_SIZE,
 
189
  value=1024,
190
  )
191
  height = gr.Slider(
192
  label="Height",
193
  minimum=1024,
194
  maximum=MAX_IMAGE_SIZE,
 
195
  value=1024,
196
  )
197
  num_images_per_prompt = gr.Slider(
 
20
  DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
21
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
+ CACHE_EXAMPLES = False #torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") != "0"
24
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
25
  USE_TORCH_COMPILE = False
26
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
27
+ PREVIEW_IMAGES = False
28
 
29
  dtype = torch.bfloat16
30
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
47
  previewer = Previewer()
48
  previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
49
  previewer.load_state_dict(previewer_state_dict)
50
+ def callback_prior(pipeline, step_index, t, callback_kwargs):
51
+ latents = callback_kwargs["latents"]
52
  output = previewer(latents)
53
  output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
54
+ callback_kwargs["preview_output"] = output
55
+ return callback_kwargs
56
  callback_steps = 1
57
  else:
58
  previewer = None
 
64
 
65
 
66
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
67
+ print("randomizing seed")
68
  if randomize_seed:
69
  seed = random.randint(0, MAX_SEED)
70
  return seed
 
85
  num_images_per_prompt: int = 2,
86
  profile: gr.OAuthProfile | None = None,
87
  ) -> PIL.Image.Image:
88
+
89
+ #previewer.eval().requires_grad_(False).to(device).to(dtype)
90
  prior_pipeline.to(device)
91
  decoder_pipeline.to(device)
92
 
 
102
  guidance_scale=prior_guidance_scale,
103
  num_images_per_prompt=num_images_per_prompt,
104
  generator=generator,
105
+ #callback_on_step_end=callback_prior,
106
+ #callback_on_step_end_tensor_inputs=['latents']
107
  )
 
108
  if PREVIEW_IMAGES:
109
  for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
110
  r = next(prior_output)
 
122
  generator=generator,
123
  output_type="pil",
124
  ).images
125
+ print(decoder_output)
126
  #Save images
127
  for image in decoder_output:
128
  user_history.save_image(
 
140
  "num_images_per_prompt": num_images_per_prompt,
141
  },
142
  )
143
+
144
  yield decoder_output[0]
145
 
146
 
147
  examples = [
148
  "An astronaut riding a green horse",
149
  "A mecha robot in a favela by Tarsila do Amaral",
150
+ "The spirit of a Tamagotchi wandering in the city of Los Angeles",
151
  "A delicious feijoada ramen dish"
152
  ]
153
 
 
189
  label="Width",
190
  minimum=1024,
191
  maximum=MAX_IMAGE_SIZE,
192
+ step=512,
193
  value=1024,
194
  )
195
  height = gr.Slider(
196
  label="Height",
197
  minimum=1024,
198
  maximum=MAX_IMAGE_SIZE,
199
+ step=512,
200
  value=1024,
201
  )
202
  num_images_per_prompt = gr.Slider(