Update app.py
Browse files
app.py
CHANGED
@@ -69,7 +69,7 @@ def get_modelscope_pipeline(
|
|
69 |
lora.merge_and_unload()
|
70 |
pipe.unet = lora
|
71 |
|
72 |
-
pipe = pipe.to(device)
|
73 |
|
74 |
return pipe
|
75 |
|
@@ -136,22 +136,48 @@ def get_animatediff_pipeline(
|
|
136 |
lora.merge_and_unload()
|
137 |
pipe.unet = lora
|
138 |
|
139 |
-
pipe = pipe.to(device)
|
140 |
return pipe
|
141 |
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
# }
|
149 |
-
cache_pipeline = {
|
150 |
-
"base_model": None,
|
151 |
-
"variant": None,
|
152 |
-
"pipeline": None,
|
153 |
}
|
|
|
|
|
|
|
|
|
|
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
@spaces.GPU(duration=120)
|
157 |
def infer(
|
@@ -180,45 +206,46 @@ def infer(
|
|
180 |
# )
|
181 |
# else:
|
182 |
# raise ValueError(f"Unknown base_model {base_model}")
|
183 |
-
if (
|
184 |
-
|
185 |
-
|
186 |
-
):
|
187 |
-
|
188 |
-
else:
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
|
|
|
216 |
if randomize_seed:
|
217 |
seed = random.randint(0, MAX_SEED)
|
218 |
|
219 |
generator = torch.Generator("cpu").manual_seed(seed)
|
220 |
|
221 |
-
output =
|
222 |
prompt=prompt,
|
223 |
num_frames=16,
|
224 |
guidance_scale=1.0,
|
@@ -238,7 +265,8 @@ def infer(
|
|
238 |
fps=7,
|
239 |
)
|
240 |
print(f"Saved to {save_path}")
|
241 |
-
|
|
|
242 |
|
243 |
|
244 |
examples = [
|
@@ -402,7 +430,7 @@ with gr.Blocks(css=css) as demo:
|
|
402 |
inputs=[base_model, variant_dropdown, prompt, num_inference_steps],
|
403 |
cache_examples=True,
|
404 |
fn=infer,
|
405 |
-
outputs=[result],
|
406 |
)
|
407 |
|
408 |
run_button.click(
|
@@ -415,7 +443,7 @@ with gr.Blocks(css=css) as demo:
|
|
415 |
seed,
|
416 |
randomize_seed,
|
417 |
],
|
418 |
-
outputs=[result],
|
419 |
)
|
420 |
|
421 |
demo.queue().launch()
|
|
|
69 |
lora.merge_and_unload()
|
70 |
pipe.unet = lora
|
71 |
|
72 |
+
# pipe = pipe.to(device)
|
73 |
|
74 |
return pipe
|
75 |
|
|
|
136 |
lora.merge_and_unload()
|
137 |
pipe.unet = lora
|
138 |
|
139 |
+
# pipe = pipe.to(device)
|
140 |
return pipe
|
141 |
|
142 |
|
143 |
+
pipe_dict = {
|
144 |
+
"ModelScope T2V": {"WebVid": None, "LAION-aes": None, "Anime": None, "Realistic": None, "3D Cartoon": None},
|
145 |
+
"AnimateDiff (SD1.5)": {"WebVid": None, "LAION-aes": None},
|
146 |
+
"AnimateDiff (RealisticVision)": {"WebVid": None, "LAION-aes": None},
|
147 |
+
"AnimateDiff (epiCRealism)": {"WebVid": None, "LAION-aes": None},
|
|
|
|
|
|
|
|
|
|
|
148 |
}
|
149 |
+
# cache_pipeline = {
|
150 |
+
# "base_model": None,
|
151 |
+
# "variant": None,
|
152 |
+
# "pipeline": None,
|
153 |
+
# }
|
154 |
|
155 |
+
def init_pipelines():
|
156 |
+
for base_model in variants.keys():
|
157 |
+
for variant in variants[base_model]:
|
158 |
+
if pipe_dict[base_model][variant] is None:
|
159 |
+
if base_model == "ModelScope T2V":
|
160 |
+
pipe_dict[base_model][variant] = get_modelscope_pipeline(mcm_variant=variant)
|
161 |
+
elif base_model == "AnimateDiff (SD1.5)":
|
162 |
+
pipe_dict[base_model][variant] = get_animatediff_pipeline(
|
163 |
+
real_variant=None,
|
164 |
+
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
|
165 |
+
mcm_variant=variant,
|
166 |
+
)
|
167 |
+
elif base_model == "AnimateDiff (RealisticVision)":
|
168 |
+
pipe_dict[base_model][variant] = get_animatediff_pipeline(
|
169 |
+
real_variant="realvision",
|
170 |
+
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
|
171 |
+
mcm_variant=variant,
|
172 |
+
)
|
173 |
+
elif base_model == "AnimateDiff (epiCRealism)":
|
174 |
+
pipe_dict[base_model][variant] = get_animatediff_pipeline(
|
175 |
+
real_variant="epicrealism",
|
176 |
+
motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
|
177 |
+
mcm_variant=variant,
|
178 |
+
)
|
179 |
+
else:
|
180 |
+
raise ValueError(f"Unknown base_model {base_model}")
|
181 |
|
182 |
@spaces.GPU(duration=120)
|
183 |
def infer(
|
|
|
206 |
# )
|
207 |
# else:
|
208 |
# raise ValueError(f"Unknown base_model {base_model}")
|
209 |
+
# if (
|
210 |
+
# cache_pipeline["base_model"] == base_model
|
211 |
+
# and cache_pipeline["variant"] == variant
|
212 |
+
# ):
|
213 |
+
# pass
|
214 |
+
# else:
|
215 |
+
# if base_model == "ModelScope T2V":
|
216 |
+
# pipeline = get_modelscope_pipeline(mcm_variant=variant)
|
217 |
+
# elif base_model == "AnimateDiff (SD1.5)":
|
218 |
+
# pipeline = get_animatediff_pipeline(
|
219 |
+
# real_variant=None,
|
220 |
+
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
|
221 |
+
# mcm_variant=variant,
|
222 |
+
# )
|
223 |
+
# elif base_model == "AnimateDiff (RealisticVision)":
|
224 |
+
# pipeline = get_animatediff_pipeline(
|
225 |
+
# real_variant="realvision",
|
226 |
+
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
|
227 |
+
# mcm_variant=variant,
|
228 |
+
# )
|
229 |
+
# elif base_model == "AnimateDiff (epiCRealism)":
|
230 |
+
# pipeline = get_animatediff_pipeline(
|
231 |
+
# real_variant="epicrealism",
|
232 |
+
# motion_module_path="guoyww/animatediff-motion-adapter-v1-5-2",
|
233 |
+
# mcm_variant=variant,
|
234 |
+
# )
|
235 |
+
# else:
|
236 |
+
# raise ValueError(f"Unknown base_model {base_model}")
|
237 |
|
238 |
+
# cache_pipeline["base_model"] = base_model
|
239 |
+
# cache_pipeline["variant"] = variant
|
240 |
+
# cache_pipeline["pipeline"] = pipeline
|
241 |
|
242 |
+
pipe_dict[base_model][variant] = pipe_dict[base_model][variant].to(device)
|
243 |
if randomize_seed:
|
244 |
seed = random.randint(0, MAX_SEED)
|
245 |
|
246 |
generator = torch.Generator("cpu").manual_seed(seed)
|
247 |
|
248 |
+
output = pipe_dict[base_model][variant](
|
249 |
prompt=prompt,
|
250 |
num_frames=16,
|
251 |
guidance_scale=1.0,
|
|
|
265 |
fps=7,
|
266 |
)
|
267 |
print(f"Saved to {save_path}")
|
268 |
+
pipe_dict[base_model][variant] = pipe_dict[base_model][variant].to("cpu")
|
269 |
+
return save_path, seed
|
270 |
|
271 |
|
272 |
examples = [
|
|
|
430 |
inputs=[base_model, variant_dropdown, prompt, num_inference_steps],
|
431 |
cache_examples=True,
|
432 |
fn=infer,
|
433 |
+
outputs=[result, seed],
|
434 |
)
|
435 |
|
436 |
run_button.click(
|
|
|
443 |
seed,
|
444 |
randomize_seed,
|
445 |
],
|
446 |
+
outputs=[result, seed],
|
447 |
)
|
448 |
|
449 |
demo.queue().launch()
|