Spaces:
Runtime error
Runtime error
anonymous
commited on
Commit
·
2896183
1
Parent(s):
0a4007d
update
Browse files- app.py +24 -16
- src/ddim_v_hacked.py +5 -3
app.py
CHANGED
|
@@ -303,6 +303,8 @@ def process1(*args):
|
|
| 303 |
imgs = sorted(os.listdir(cfg.input_dir))
|
| 304 |
imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
|
| 305 |
|
|
|
|
|
|
|
| 306 |
with torch.no_grad():
|
| 307 |
frame = cv2.imread(imgs[0])
|
| 308 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
@@ -607,6 +609,7 @@ def process2(*args):
|
|
| 607 |
|
| 608 |
return key_video_path
|
| 609 |
|
|
|
|
| 610 |
DESCRIPTION = '''
|
| 611 |
## Rerender A Video
|
| 612 |
### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
|
|
@@ -644,12 +647,13 @@ with block:
|
|
| 644 |
run_button3 = gr.Button(value='Run Propagation')
|
| 645 |
with gr.Accordion('Advanced options for the 1st frame translation',
|
| 646 |
open=False):
|
| 647 |
-
image_resolution = gr.Slider(
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
|
|
|
| 653 |
control_strength = gr.Slider(label='ControNet strength',
|
| 654 |
minimum=0.0,
|
| 655 |
maximum=2.0,
|
|
@@ -734,12 +738,13 @@ with block:
|
|
| 734 |
value=1,
|
| 735 |
step=1,
|
| 736 |
info='Uniformly sample the key frames every K frames')
|
| 737 |
-
keyframe_count = gr.Slider(
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
|
|
|
| 743 |
|
| 744 |
use_constraints = gr.CheckboxGroup(
|
| 745 |
[
|
|
@@ -769,8 +774,10 @@ with block:
|
|
| 769 |
maximum=100,
|
| 770 |
value=1,
|
| 771 |
step=1,
|
| 772 |
-
info=
|
| 773 |
-
|
|
|
|
|
|
|
| 774 |
with gr.Row():
|
| 775 |
warp_start = gr.Slider(label='Shape-aware fusion start',
|
| 776 |
minimum=0,
|
|
@@ -912,8 +919,9 @@ with block:
|
|
| 912 |
run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
|
| 913 |
|
| 914 |
def process3():
|
| 915 |
-
raise gr.Error(
|
| 916 |
-
|
|
|
|
| 917 |
|
| 918 |
run_button3.click(fn=process3, outputs=[result_keyframe])
|
| 919 |
|
|
|
|
| 303 |
imgs = sorted(os.listdir(cfg.input_dir))
|
| 304 |
imgs = [os.path.join(cfg.input_dir, img) for img in imgs]
|
| 305 |
|
| 306 |
+
model.cond_stage_model.device = device
|
| 307 |
+
|
| 308 |
with torch.no_grad():
|
| 309 |
frame = cv2.imread(imgs[0])
|
| 310 |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
|
|
| 609 |
|
| 610 |
return key_video_path
|
| 611 |
|
| 612 |
+
|
| 613 |
DESCRIPTION = '''
|
| 614 |
## Rerender A Video
|
| 615 |
### This space provides the function of key frame translation. Full code for full video translation will be released upon the publication of the paper.
|
|
|
|
| 647 |
run_button3 = gr.Button(value='Run Propagation')
|
| 648 |
with gr.Accordion('Advanced options for the 1st frame translation',
|
| 649 |
open=False):
|
| 650 |
+
image_resolution = gr.Slider(
|
| 651 |
+
label='Frame rsolution',
|
| 652 |
+
minimum=256,
|
| 653 |
+
maximum=512,
|
| 654 |
+
value=512,
|
| 655 |
+
step=64,
|
| 656 |
+
info='To avoid overload, maximum 512')
|
| 657 |
control_strength = gr.Slider(label='ControNet strength',
|
| 658 |
minimum=0.0,
|
| 659 |
maximum=2.0,
|
|
|
|
| 738 |
value=1,
|
| 739 |
step=1,
|
| 740 |
info='Uniformly sample the key frames every K frames')
|
| 741 |
+
keyframe_count = gr.Slider(
|
| 742 |
+
label='Number of key frames',
|
| 743 |
+
minimum=1,
|
| 744 |
+
maximum=1,
|
| 745 |
+
value=1,
|
| 746 |
+
step=1,
|
| 747 |
+
info='To avoid overload, maximum 8 key frames')
|
| 748 |
|
| 749 |
use_constraints = gr.CheckboxGroup(
|
| 750 |
[
|
|
|
|
| 774 |
maximum=100,
|
| 775 |
value=1,
|
| 776 |
step=1,
|
| 777 |
+
info=
|
| 778 |
+
('Update the key and value for '
|
| 779 |
+
'cross-frame attention every N key frames (recommend N*K>=10)'
|
| 780 |
+
))
|
| 781 |
with gr.Row():
|
| 782 |
warp_start = gr.Slider(label='Shape-aware fusion start',
|
| 783 |
minimum=0,
|
|
|
|
| 919 |
run_button2.click(fn=process2, inputs=ips, outputs=[result_keyframe])
|
| 920 |
|
| 921 |
def process3():
|
| 922 |
+
raise gr.Error(
|
| 923 |
+
"Coming Soon. Full code for full video translation will be "
|
| 924 |
+
"released upon the publication of the paper.")
|
| 925 |
|
| 926 |
run_button3.click(fn=process3, outputs=[result_keyframe])
|
| 927 |
|
src/ddim_v_hacked.py
CHANGED
|
@@ -14,6 +14,8 @@ from ControlNet.ldm.modules.diffusionmodules.util import (
|
|
| 14 |
|
| 15 |
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def register_attention_control(model, controller=None):
|
| 19 |
|
|
@@ -36,7 +38,7 @@ def register_attention_control(model, controller=None):
|
|
| 36 |
|
| 37 |
# force cast to fp32 to avoid overflowing
|
| 38 |
if _ATTN_PRECISION == 'fp32':
|
| 39 |
-
with torch.autocast(enabled=False, device_type=
|
| 40 |
q, k = q.float(), k.float()
|
| 41 |
sim = torch.einsum('b i d, b j d -> b i j', q,
|
| 42 |
k) * self.scale
|
|
@@ -98,8 +100,8 @@ class DDIMVSampler(object):
|
|
| 98 |
|
| 99 |
def register_buffer(self, name, attr):
|
| 100 |
if type(attr) == torch.Tensor:
|
| 101 |
-
if attr.device != torch.device(
|
| 102 |
-
attr = attr.to(torch.device(
|
| 103 |
setattr(self, name, attr)
|
| 104 |
|
| 105 |
def make_schedule(self,
|
|
|
|
| 14 |
|
| 15 |
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
|
| 16 |
|
| 17 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 18 |
+
|
| 19 |
|
| 20 |
def register_attention_control(model, controller=None):
|
| 21 |
|
|
|
|
| 38 |
|
| 39 |
# force cast to fp32 to avoid overflowing
|
| 40 |
if _ATTN_PRECISION == 'fp32':
|
| 41 |
+
with torch.autocast(enabled=False, device_type=device):
|
| 42 |
q, k = q.float(), k.float()
|
| 43 |
sim = torch.einsum('b i d, b j d -> b i j', q,
|
| 44 |
k) * self.scale
|
|
|
|
| 100 |
|
| 101 |
def register_buffer(self, name, attr):
|
| 102 |
if type(attr) == torch.Tensor:
|
| 103 |
+
if attr.device != torch.device(device):
|
| 104 |
+
attr = attr.to(torch.device(device))
|
| 105 |
setattr(self, name, attr)
|
| 106 |
|
| 107 |
def make_schedule(self,
|