yhzhai commited on
Commit
51e733e
·
verified ·
1 Parent(s): 986f833

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -47
app.py CHANGED
@@ -69,7 +69,7 @@ def get_modelscope_pipeline(
69
  lora.merge_and_unload()
70
  pipe.unet = lora
71
 
72
- pipe = pipe.to(device)
73
 
74
  return pipe
75
 
@@ -136,22 +136,48 @@ def get_animatediff_pipeline(
136
  lora.merge_and_unload()
137
  pipe.unet = lora
138
 
139
- pipe = pipe.to(device)
140
  return pipe
141
 
142
 
143
- # pipe_dict = {
144
- # "ModelScope T2V": {"WebVid": None, "LAION-aes": None, "Anime": None, "Realistic": None, "3D Cartoon": None},
145
- # "AnimateDiff (SD1.5)": {"WebVid": None, "LAION-aes": None},
146
- # "AnimateDiff (RealisticVision)": {"WebVid": None, "LAION-aes": None},
147
- # "AnimateDiff (epiCRealism)": {"WebVid": None, "LAION-aes": None},
148
- # }
149
- cache_pipeline = {
150
- "base_model": None,
151
- "variant": None,
152
- "pipeline": None,
153
  }
 
 
 
 
 
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  @spaces.GPU(duration=120)
157
  def infer(
@@ -180,45 +206,46 @@ def infer(
180
  # )
181
  # else:
182
  # raise ValueError(f"Unknown base_model {base_model}")
183
- if (
184
- cache_pipeline["base_model"] == base_model
185
- and cache_pipeline["variant"] == variant
186
- ):
187
- pass
188
- else:
189
- if base_model == "ModelScope T2V":
190
- pipeline = get_modelscope_pipeline(mcm_variant=variant)
191
- elif base_model == "AnimateDiff (SD1.5)":
192
- pipeline = get_animatediff_pipeline(
193
- real_variant=None,
194
- motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
195
- mcm_variant=variant,
196
- )
197
- elif base_model == "AnimateDiff (RealisticVision)":
198
- pipeline = get_animatediff_pipeline(
199
- real_variant="realvision",
200
- motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
201
- mcm_variant=variant,
202
- )
203
- elif base_model == "AnimateDiff (epiCRealism)":
204
- pipeline = get_animatediff_pipeline(
205
- real_variant="epicrealism",
206
- motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
207
- mcm_variant=variant,
208
- )
209
- else:
210
- raise ValueError(f"Unknown base_model {base_model}")
211
 
212
- cache_pipeline["base_model"] = base_model
213
- cache_pipeline["variant"] = variant
214
- cache_pipeline["pipeline"] = pipeline
215
 
 
216
  if randomize_seed:
217
  seed = random.randint(0, MAX_SEED)
218
 
219
  generator = torch.Generator("cpu").manual_seed(seed)
220
 
221
- output = cache_pipeline["pipeline"](
222
  prompt=prompt,
223
  num_frames=16,
224
  guidance_scale=1.0,
@@ -238,7 +265,8 @@ def infer(
238
  fps=7,
239
  )
240
  print(f"Saved to {save_path}")
241
- return save_path
 
242
 
243
 
244
  examples = [
@@ -402,7 +430,7 @@ with gr.Blocks(css=css) as demo:
402
  inputs=[base_model, variant_dropdown, prompt, num_inference_steps],
403
  cache_examples=True,
404
  fn=infer,
405
- outputs=[result],
406
  )
407
 
408
  run_button.click(
@@ -415,7 +443,7 @@ with gr.Blocks(css=css) as demo:
415
  seed,
416
  randomize_seed,
417
  ],
418
- outputs=[result],
419
  )
420
 
421
  demo.queue().launch()
 
69
  lora.merge_and_unload()
70
  pipe.unet = lora
71
 
72
+ # pipe = pipe.to(device)
73
 
74
  return pipe
75
 
 
136
  lora.merge_and_unload()
137
  pipe.unet = lora
138
 
139
+ # pipe = pipe.to(device)
140
  return pipe
141
 
142
 
143
+ pipe_dict = {
144
+ "ModelScope T2V": {"WebVid": None, "LAION-aes": None, "Anime": None, "Realistic": None, "3D Cartoon": None},
145
+ "AnimateDiff (SD1.5)": {"WebVid": None, "LAION-aes": None},
146
+ "AnimateDiff (RealisticVision)": {"WebVid": None, "LAION-aes": None},
147
+ "AnimateDiff (epiCRealism)": {"WebVid": None, "LAION-aes": None},
 
 
 
 
 
148
  }
149
+ # cache_pipeline = {
150
+ # "base_model": None,
151
+ # "variant": None,
152
+ # "pipeline": None,
153
+ # }
154
 
155
+ def init_pipelines():
156
+ for base_model in variants.keys():
157
+ for variant in variants[base_model]:
158
+ if pipe_dict[base_model][variant] is None:
159
+ if base_model == "ModelScope T2V":
160
+ pipe_dict[base_model][variant] = get_modelscope_pipeline(mcm_variant=variant)
161
+ elif base_model == "AnimateDiff (SD1.5)":
162
+ pipe_dict[base_model][variant] = get_animatediff_pipeline(
163
+ real_variant=None,
164
+ motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
165
+ mcm_variant=variant,
166
+ )
167
+ elif base_model == "AnimateDiff (RealisticVision)":
168
+ pipe_dict[base_model][variant] = get_animatediff_pipeline(
169
+ real_variant="realvision",
170
+ motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
171
+ mcm_variant=variant,
172
+ )
173
+ elif base_model == "AnimateDiff (epiCRealism)":
174
+ pipe_dict[base_model][variant] = get_animatediff_pipeline(
175
+ real_variant="epicrealism",
176
+ motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
177
+ mcm_variant=variant,
178
+ )
179
+ else:
180
+ raise ValueError(f"Unknown base_model {base_model}")
181
 
182
  @spaces.GPU(duration=120)
183
  def infer(
 
206
  # )
207
  # else:
208
  # raise ValueError(f"Unknown base_model {base_model}")
209
+ # if (
210
+ # cache_pipeline["base_model"] == base_model
211
+ # and cache_pipeline["variant"] == variant
212
+ # ):
213
+ # pass
214
+ # else:
215
+ # if base_model == "ModelScope T2V":
216
+ # pipeline = get_modelscope_pipeline(mcm_variant=variant)
217
+ # elif base_model == "AnimateDiff (SD1.5)":
218
+ # pipeline = get_animatediff_pipeline(
219
+ # real_variant=None,
220
+ # motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
221
+ # mcm_variant=variant,
222
+ # )
223
+ # elif base_model == "AnimateDiff (RealisticVision)":
224
+ # pipeline = get_animatediff_pipeline(
225
+ # real_variant="realvision",
226
+ # motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
227
+ # mcm_variant=variant,
228
+ # )
229
+ # elif base_model == "AnimateDiff (epiCRealism)":
230
+ # pipeline = get_animatediff_pipeline(
231
+ # real_variant="epicrealism",
232
+ # motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
233
+ # mcm_variant=variant,
234
+ # )
235
+ # else:
236
+ # raise ValueError(f"Unknown base_model {base_model}")
237
 
238
+ # cache_pipeline["base_model"] = base_model
239
+ # cache_pipeline["variant"] = variant
240
+ # cache_pipeline["pipeline"] = pipeline
241
 
242
+ pipe_dict[base_model][variant] = pipe_dict[base_model][variant].to(device)
243
  if randomize_seed:
244
  seed = random.randint(0, MAX_SEED)
245
 
246
  generator = torch.Generator("cpu").manual_seed(seed)
247
 
248
+ output = pipe_dict[base_model][variant](
249
  prompt=prompt,
250
  num_frames=16,
251
  guidance_scale=1.0,
 
265
  fps=7,
266
  )
267
  print(f"Saved to {save_path}")
268
+ pipe_dict[base_model][variant] = pipe_dict[base_model][variant].to("cpu")
269
+ return save_path, seed
270
 
271
 
272
  examples = [
 
430
  inputs=[base_model, variant_dropdown, prompt, num_inference_steps],
431
  cache_examples=True,
432
  fn=infer,
433
+ outputs=[result, seed],
434
  )
435
 
436
  run_button.click(
 
443
  seed,
444
  randomize_seed,
445
  ],
446
+ outputs=[result, seed],
447
  )
448
 
449
  demo.queue().launch()