YaohuiW commited on
Commit
685fe6b
·
verified ·
1 Parent(s): 4e5ad3a

Update gradio_tabs/animation.py

Browse files
Files changed (1) hide show
  1. gradio_tabs/animation.py +29 -35
gradio_tabs/animation.py CHANGED
@@ -7,17 +7,15 @@ import numpy as np
7
  import imageio
8
  import spaces
9
 
10
- extensions_dir = "./torch_extension/"
11
- os.environ["TORCH_EXTENSIONS_DIR"] = extensions_dir
12
 
13
- from networks.generator import Generator
14
 
15
- device = torch.device("cuda")
16
- #ckpt_path = './models/lia-x.pt'
17
- gen = Generator(size=512, motion_dim=40, scale=2).to(device)
18
- #gen.load_state_dict(torch.load(ckpt_path, weights_only=False))
19
- gen.load_state_dict(torch.hub.load_state_dict_from_url(f"https://huggingface.co/YaohuiW/LIA-X/resolve/main/lia-x.pt"))
20
- gen.eval()
21
 
22
  output_dir = "./res_gradio"
23
  os.makedirs(output_dir, exist_ok=True)
@@ -124,45 +122,41 @@ def vid_postprocessing(video, fps, output_path=output_dir + "/output_vid.mp4"):
124
 
125
  return output_path
126
 
127
- @torch.no_grad()
128
- def edit_media(image, *selected_s):
129
 
130
- image_tensor = img_preprocessing(image, 512)
131
- image_tensor = image_tensor.to(device)
132
-
133
- edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
134
-
135
- # de-norm
136
- edited_image = img_postprocessing(edited_image_tensor)
137
-
138
- return edited_image
139
 
140
- @torch.no_grad()
141
- def animate_media(image, video, *selected_s):
142
 
143
- image_tensor = img_preprocessing(image, 512)
144
- vid_target_tensor, fps = vid_preprocessing(video, 512)
145
- image_tensor = image_tensor.to(device)
146
- video_target_tensor = vid_target_tensor.to(device)
147
 
148
- animated_video = gen.animate(image_tensor, video_target_tensor, labels_v, selected_s)
 
149
 
150
- # postprocessing
151
- animated_video = vid_postprocessing(animated_video, fps)
152
 
153
- return animated_video
 
154
 
 
 
 
 
155
 
156
- def clear_media():
157
- return None, None, *([0] * len(labels_k))
158
 
 
 
159
 
160
- # image_output = gr.Image(label="Output Image", elem_id="output_img", type='numpy', interactive=False, width=512)
161
- # video_output = gr.Video(label="Output Video", elem_id="output_vid", width=512)
162
 
163
 
 
 
164
 
165
- def animation():
166
 
167
  with gr.Tab("Animation & Image Editing"):
168
 
 
7
  import imageio
8
  import spaces
9
 
10
+ # extensions_dir = "./torch_extension/"
11
+ # os.environ["TORCH_EXTENSIONS_DIR"] = extensions_dir
12
 
13
+ # from networks.generator import Generator
14
 
15
+ # device = torch.device("cuda")
16
+ # gen = Generator(size=512, motion_dim=40, scale=2).to(device)
17
+ # gen.load_state_dict(torch.hub.load_state_dict_from_url(f"https://huggingface.co/YaohuiW/LIA-X/resolve/main/lia-x.pt"))
18
+ # gen.eval()
 
 
19
 
20
  output_dir = "./res_gradio"
21
  os.makedirs(output_dir, exist_ok=True)
 
122
 
123
  return output_path
124
 
 
 
125
 
126
+ def animation(gen):
127
+
128
+ @torch.no_grad()
129
+ def edit_media(image, *selected_s):
 
 
 
 
 
130
 
131
+ image_tensor = img_preprocessing(image, 512)
132
+ image_tensor = image_tensor.to(device)
133
 
134
+ edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
 
 
 
135
 
136
+ # de-norm
137
+ edited_image = img_postprocessing(edited_image_tensor)
138
 
139
+ return edited_image
 
140
 
141
+ @torch.no_grad()
142
+ def animate_media(image, video, *selected_s):
143
 
144
+ image_tensor = img_preprocessing(image, 512)
145
+ vid_target_tensor, fps = vid_preprocessing(video, 512)
146
+ image_tensor = image_tensor.to(device)
147
+ video_target_tensor = vid_target_tensor.to(device)
148
 
149
+ animated_video = gen.animate(image_tensor, video_target_tensor, labels_v, selected_s)
 
150
 
151
+ # postprocessing
152
+ animated_video = vid_postprocessing(animated_video, fps)
153
 
154
+ return animated_video
 
155
 
156
 
157
+ def clear_media():
158
+ return None, None, *([0] * len(labels_k))
159
 
 
160
 
161
  with gr.Tab("Animation & Image Editing"):
162