YaohuiW commited on
Commit
cbe9c24
·
verified ·
1 Parent(s): 323ea49

Update gradio_tabs/animation.py

Browse files
Files changed (1) hide show
  1. gradio_tabs/animation.py +5 -12
gradio_tabs/animation.py CHANGED
@@ -10,17 +10,12 @@ import spaces
10
  extensions_dir = "./torch_extension/"
11
  os.environ["TORCH_EXTENSIONS_DIR"] = extensions_dir
12
 
13
- @spaces.GPU(duration=240)
14
- def load_model():
15
- from networks.generator import Generator
16
 
17
- device = torch.device("cuda")
18
- ckpt_path = './models/lia-x.pt'
19
- model = Generator(size=512, motion_dim=40, scale=2).to(device)
20
- model.load_state_dict(torch.load(ckpt_path, weights_only=False))
21
- return model
22
-
23
- gen = load_model()
24
  gen.eval()
25
 
26
  output_dir = "./res_gradio"
@@ -129,7 +124,6 @@ def vid_postprocessing(video, fps, output_path=output_dir + "/output_vid.mp4"):
129
  return output_path
130
 
131
 
132
- @spaces.GPU(duration=240)
133
  def edit_media(image, *selected_s):
134
 
135
  image_tensor = img_preprocessing(image, 512)
@@ -143,7 +137,6 @@ def edit_media(image, *selected_s):
143
  return edited_image
144
 
145
 
146
- @spaces.GPU(duration=240)
147
  def animate_media(image, video, *selected_s):
148
 
149
  image_tensor = img_preprocessing(image, 512)
 
10
  extensions_dir = "./torch_extension/"
11
  os.environ["TORCH_EXTENSIONS_DIR"] = extensions_dir
12
 
13
+ from networks.generator import Generator
 
 
14
 
15
+ device = torch.device("cuda")
16
+ ckpt_path = './models/lia-x.pt'
17
+ gen = Generator(size=512, motion_dim=40, scale=2).to(device)
18
+ gen.load_state_dict(torch.load(ckpt_path, weights_only=False))
 
 
 
19
  gen.eval()
20
 
21
  output_dir = "./res_gradio"
 
124
  return output_path
125
 
126
 
 
127
  def edit_media(image, *selected_s):
128
 
129
  image_tensor = img_preprocessing(image, 512)
 
137
  return edited_image
138
 
139
 
 
140
  def animate_media(image, video, *selected_s):
141
 
142
  image_tensor = img_preprocessing(image, 512)