ahmdliaqat commited on
Commit
7611753
·
1 Parent(s): 2cb8cb6
Files changed (1) hide show
  1. app.py +37 -29
app.py CHANGED
@@ -35,42 +35,50 @@ from transformers import CLIPFeatureExtractor
35
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
36
 
37
  # Function
38
- @spaces.CPU(duration=30,queue=False)
39
  def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Progress()):
40
  global step_loaded
41
  global base_loaded
42
  global motion_loaded
43
  print(prompt, base, step)
44
 
45
- if step_loaded != step:
46
- repo = "ByteDance/AnimateDiff-Lightning"
47
- ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
48
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
49
- step_loaded = step
50
-
51
- if base_loaded != base:
52
- pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
53
- base_loaded = base
54
-
55
- if motion_loaded != motion:
56
- pipe.unload_lora_weights()
57
- if motion != "":
58
- pipe.load_lora_weights(motion, adapter_name="motion")
59
- pipe.set_adapters(["motion"], [0.7])
60
- motion_loaded = motion
61
-
62
- progress((0, step))
63
- def progress_callback(i, t, z):
64
- progress((i+1, step))
65
-
66
- output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1)
67
-
68
- name = str(uuid.uuid4()).replace("-", "")
69
- path = f"/tmp/{name}.mp4"
70
- export_to_video(output.frames[0], path, fps=10)
71
- return path
72
-
 
73
 
 
 
 
 
 
 
 
 
74
  # Gradio Interface
75
  with gr.Blocks(css="style.css") as demo:
76
  gr.HTML(
 
35
  feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
36
 
37
  # Function
 
38
  def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Progress()):
39
  global step_loaded
40
  global base_loaded
41
  global motion_loaded
42
  print(prompt, base, step)
43
 
44
+ try:
45
+ if step_loaded != step:
46
+ repo = "ByteDance/AnimateDiff-Lightning"
47
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
48
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
49
+ step_loaded = step
50
+
51
+ if base_loaded != base:
52
+ pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
53
+ base_loaded = base
54
+
55
+ if motion_loaded != motion:
56
+ pipe.unload_lora_weights()
57
+ if motion != "":
58
+ pipe.load_lora_weights(motion, adapter_name="motion")
59
+ pipe.set_adapters(["motion"], [0.7])
60
+ motion_loaded = motion
61
+
62
+ progress((0, step))
63
+ def progress_callback(i, t, z):
64
+ progress((i+1, step))
65
+
66
+ output = pipe(
67
+ prompt=prompt,
68
+ guidance_scale=1.2,
69
+ num_inference_steps=step,
70
+ callback=progress_callback,
71
+ callback_steps=1
72
+ )
73
 
74
+ name = str(uuid.uuid4()).replace("-", "")
75
+ path = f"/tmp/{name}.mp4"
76
+ export_to_video(output.frames[0], path, fps=10)
77
+ return path
78
+
79
+ except Exception as e:
80
+ print(f"Error during generation: {str(e)}")
81
+ return None
82
  # Gradio Interface
83
  with gr.Blocks(css="style.css") as demo:
84
  gr.HTML(