Spaces:
Runtime error
Runtime error
Commit
·
53a8438
1
Parent(s):
9b9cd68
requirement fix
Browse files- app.py +38 -8
- assets/avengers.gif +2 -2
- inpainter/.DS_Store +0 -0
- inpainter/base_inpainter.py +6 -4
- inpainter/model/modules/tfocal_transformer_hq.py +2 -0
- requirements.txt +2 -5
- track_anything.py +10 -6
- tracker/.DS_Store +0 -0
- tracker/base_tracker.py +1 -0
app.py
CHANGED
|
@@ -13,7 +13,13 @@ import requests
|
|
| 13 |
import json
|
| 14 |
import torchvision
|
| 15 |
import torch
|
|
|
|
|
|
|
| 16 |
from tools.painter import mask_painter
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# download checkpoints
|
| 19 |
def download_checkpoint(url, folder, filename):
|
|
@@ -200,6 +206,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
|
| 200 |
|
| 201 |
# tracking vos
|
| 202 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
|
|
| 203 |
model.xmem.clear_memory()
|
| 204 |
if interactive_state["track_end_number"]:
|
| 205 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
@@ -219,6 +226,8 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
| 219 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
| 220 |
fps = video_state["fps"]
|
| 221 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
|
|
|
|
|
|
| 222 |
|
| 223 |
if interactive_state["track_end_number"]:
|
| 224 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
@@ -258,6 +267,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
| 258 |
|
| 259 |
# inpaint
|
| 260 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
|
|
|
| 261 |
frames = np.asarray(video_state["origin_images"])
|
| 262 |
fps = video_state["fps"]
|
| 263 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
@@ -304,20 +314,33 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
|
| 304 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
| 305 |
return output_path
|
| 306 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
# check and download checkpoints if needed
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
xmem_checkpoint = "XMem-s012.pth"
|
| 311 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
| 312 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
| 313 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
| 314 |
|
|
|
|
| 315 |
folder ="./checkpoints"
|
| 316 |
-
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder,
|
| 317 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
| 318 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
| 319 |
-
# args, defined in track_anything.py
|
| 320 |
-
args = parse_augment()
|
| 321 |
# args.port = 12315
|
| 322 |
# args.device = "cuda:2"
|
| 323 |
# args.mask_save = True
|
|
@@ -325,6 +348,12 @@ args = parse_augment()
|
|
| 325 |
# initialize sam, xmem, e2fgvi models
|
| 326 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
with gr.Blocks() as iface:
|
| 329 |
"""
|
| 330 |
state for
|
|
@@ -356,7 +385,8 @@ with gr.Blocks() as iface:
|
|
| 356 |
"fps": 30
|
| 357 |
}
|
| 358 |
)
|
| 359 |
-
|
|
|
|
| 360 |
with gr.Row():
|
| 361 |
|
| 362 |
# for user video input
|
|
@@ -365,7 +395,7 @@ with gr.Blocks() as iface:
|
|
| 365 |
video_input = gr.Video(autosize=True)
|
| 366 |
with gr.Column():
|
| 367 |
video_info = gr.Textbox()
|
| 368 |
-
|
| 369 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
| 370 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
| 371 |
|
|
@@ -534,7 +564,7 @@ with gr.Blocks() as iface:
|
|
| 534 |
# cache_examples=True,
|
| 535 |
)
|
| 536 |
iface.queue(concurrency_count=1)
|
| 537 |
-
iface.launch(debug=True
|
| 538 |
|
| 539 |
|
| 540 |
|
|
|
|
| 13 |
import json
|
| 14 |
import torchvision
|
| 15 |
import torch
|
| 16 |
+
from tools.interact_tools import SamControler
|
| 17 |
+
from tracker.base_tracker import BaseTracker
|
| 18 |
from tools.painter import mask_painter
|
| 19 |
+
try:
|
| 20 |
+
from mmcv.cnn import ConvModule
|
| 21 |
+
except:
|
| 22 |
+
os.system("mim install mmcv")
|
| 23 |
|
| 24 |
# download checkpoints
|
| 25 |
def download_checkpoint(url, folder, filename):
|
|
|
|
| 206 |
|
| 207 |
# tracking vos
|
| 208 |
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
| 209 |
+
|
| 210 |
model.xmem.clear_memory()
|
| 211 |
if interactive_state["track_end_number"]:
|
| 212 |
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
|
|
|
| 226 |
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
| 227 |
fps = video_state["fps"]
|
| 228 |
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
| 229 |
+
# clear GPU memory
|
| 230 |
+
model.xmem.clear_memory()
|
| 231 |
|
| 232 |
if interactive_state["track_end_number"]:
|
| 233 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
|
|
|
| 267 |
|
| 268 |
# inpaint
|
| 269 |
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
| 270 |
+
|
| 271 |
frames = np.asarray(video_state["origin_images"])
|
| 272 |
fps = video_state["fps"]
|
| 273 |
inpaint_masks = np.asarray(video_state["masks"])
|
|
|
|
| 314 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
| 315 |
return output_path
|
| 316 |
|
| 317 |
+
|
| 318 |
+
# args, defined in track_anything.py
|
| 319 |
+
args = parse_augment()
|
| 320 |
+
|
| 321 |
# check and download checkpoints if needed
|
| 322 |
+
SAM_checkpoint_dict = {
|
| 323 |
+
'vit_h': "sam_vit_h_4b8939.pth",
|
| 324 |
+
'vit_l': "sam_vit_l_0b3195.pth",
|
| 325 |
+
"vit_b": "sam_vit_b_01ec64.pth"
|
| 326 |
+
}
|
| 327 |
+
SAM_checkpoint_url_dict = {
|
| 328 |
+
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
| 329 |
+
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
| 330 |
+
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
| 331 |
+
}
|
| 332 |
+
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
|
| 333 |
+
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
|
| 334 |
xmem_checkpoint = "XMem-s012.pth"
|
| 335 |
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
| 336 |
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
| 337 |
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
| 338 |
|
| 339 |
+
|
| 340 |
folder ="./checkpoints"
|
| 341 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
| 342 |
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
| 343 |
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
|
|
|
|
|
|
| 344 |
# args.port = 12315
|
| 345 |
# args.device = "cuda:2"
|
| 346 |
# args.mask_save = True
|
|
|
|
| 348 |
# initialize sam, xmem, e2fgvi models
|
| 349 |
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
| 350 |
|
| 351 |
+
|
| 352 |
+
title = """<p><h1 align="center">Track-Anything</h1></p>
|
| 353 |
+
"""
|
| 354 |
+
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
| 355 |
+
|
| 356 |
+
|
| 357 |
with gr.Blocks() as iface:
|
| 358 |
"""
|
| 359 |
state for
|
|
|
|
| 385 |
"fps": 30
|
| 386 |
}
|
| 387 |
)
|
| 388 |
+
gr.Markdown(title)
|
| 389 |
+
gr.Markdown(description)
|
| 390 |
with gr.Row():
|
| 391 |
|
| 392 |
# for user video input
|
|
|
|
| 395 |
video_input = gr.Video(autosize=True)
|
| 396 |
with gr.Column():
|
| 397 |
video_info = gr.Textbox()
|
| 398 |
+
resize_info = gr.Textbox(value="Due to server restrictions, please upload a video that is no longer than 2 minutes. If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
|
| 399 |
Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
|
| 400 |
resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
|
| 401 |
|
|
|
|
| 564 |
# cache_examples=True,
|
| 565 |
)
|
| 566 |
iface.queue(concurrency_count=1)
|
| 567 |
+
iface.launch(debug=True)
|
| 568 |
|
| 569 |
|
| 570 |
|
assets/avengers.gif
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
inpainter/.DS_Store
CHANGED
|
Binary files a/inpainter/.DS_Store and b/inpainter/.DS_Store differ
|
|
|
inpainter/base_inpainter.py
CHANGED
|
@@ -7,6 +7,8 @@ import yaml
|
|
| 7 |
import cv2
|
| 8 |
import importlib
|
| 9 |
import numpy as np
|
|
|
|
|
|
|
| 10 |
from inpainter.util.tensor_util import resize_frames, resize_masks
|
| 11 |
|
| 12 |
|
|
@@ -66,15 +68,15 @@ class BaseInpainter:
|
|
| 66 |
if ratio == 1:
|
| 67 |
size = None
|
| 68 |
else:
|
| 69 |
-
size =
|
| 70 |
if size[0] % 2 > 0:
|
| 71 |
size[0] += 1
|
| 72 |
if size[1] % 2 > 0:
|
| 73 |
size[1] += 1
|
| 74 |
|
| 75 |
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
| 76 |
-
binary_masks = resize_masks(masks, size)
|
| 77 |
-
frames = resize_frames(frames, size) # T, H, W, 3
|
| 78 |
# frames and binary_masks are numpy arrays
|
| 79 |
|
| 80 |
h, w = frames.shape[1:3]
|
|
@@ -87,7 +89,7 @@ class BaseInpainter:
|
|
| 87 |
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
| 88 |
comp_frames = [None] * video_length
|
| 89 |
|
| 90 |
-
for f in range(0, video_length, self.neighbor_stride):
|
| 91 |
neighbor_ids = [
|
| 92 |
i for i in range(max(0, f - self.neighbor_stride),
|
| 93 |
min(video_length, f + self.neighbor_stride + 1))
|
|
|
|
| 7 |
import cv2
|
| 8 |
import importlib
|
| 9 |
import numpy as np
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
from inpainter.util.tensor_util import resize_frames, resize_masks
|
| 13 |
|
| 14 |
|
|
|
|
| 68 |
if ratio == 1:
|
| 69 |
size = None
|
| 70 |
else:
|
| 71 |
+
size = [int(W*ratio), int(H*ratio)]
|
| 72 |
if size[0] % 2 > 0:
|
| 73 |
size[0] += 1
|
| 74 |
if size[1] % 2 > 0:
|
| 75 |
size[1] += 1
|
| 76 |
|
| 77 |
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
| 78 |
+
binary_masks = resize_masks(masks, tuple(size))
|
| 79 |
+
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
|
| 80 |
# frames and binary_masks are numpy arrays
|
| 81 |
|
| 82 |
h, w = frames.shape[1:3]
|
|
|
|
| 89 |
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
| 90 |
comp_frames = [None] * video_length
|
| 91 |
|
| 92 |
+
for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
|
| 93 |
neighbor_ids = [
|
| 94 |
i for i in range(max(0, f - self.neighbor_stride),
|
| 95 |
min(video_length, f + self.neighbor_stride + 1))
|
inpainter/model/modules/tfocal_transformer_hq.py
CHANGED
|
@@ -128,8 +128,10 @@ def window_partition(x, window_size):
|
|
| 128 |
windows: (B*num_windows, T*window_size*window_size, C)
|
| 129 |
"""
|
| 130 |
B, T, H, W, C = x.shape
|
|
|
|
| 131 |
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
| 132 |
window_size[1], C)
|
|
|
|
| 133 |
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
| 134 |
-1, T * window_size[0] * window_size[1], C)
|
| 135 |
return windows
|
|
|
|
| 128 |
windows: (B*num_windows, T*window_size*window_size, C)
|
| 129 |
"""
|
| 130 |
B, T, H, W, C = x.shape
|
| 131 |
+
|
| 132 |
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
| 133 |
window_size[1], C)
|
| 134 |
+
|
| 135 |
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
| 136 |
-1, T * window_size[0] * window_size[1], C)
|
| 137 |
return windows
|
requirements.txt
CHANGED
|
@@ -10,10 +10,7 @@ gradio==3.25.0
|
|
| 10 |
opencv-python
|
| 11 |
pycocotools
|
| 12 |
matplotlib
|
| 13 |
-
onnxruntime
|
| 14 |
-
onnx
|
| 15 |
-
metaseg==0.6.1
|
| 16 |
pyyaml
|
| 17 |
av
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 10 |
opencv-python
|
| 11 |
pycocotools
|
| 12 |
matplotlib
|
|
|
|
|
|
|
|
|
|
| 13 |
pyyaml
|
| 14 |
av
|
| 15 |
+
openmim
|
| 16 |
+
tqdm
|
track_anything.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
import PIL
|
|
|
|
|
|
|
| 2 |
from tools.interact_tools import SamControler
|
| 3 |
from tracker.base_tracker import BaseTracker
|
| 4 |
from inpainter.base_inpainter import BaseInpainter
|
|
@@ -10,9 +12,12 @@ import argparse
|
|
| 10 |
class TrackingAnything():
|
| 11 |
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
|
| 12 |
self.args = args
|
| 13 |
-
self.
|
| 14 |
-
self.
|
| 15 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 16 |
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
| 17 |
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
| 18 |
# if first_flag:
|
|
@@ -39,7 +44,7 @@ class TrackingAnything():
|
|
| 39 |
masks = []
|
| 40 |
logits = []
|
| 41 |
painted_images = []
|
| 42 |
-
for i in range(len(images)):
|
| 43 |
if i ==0:
|
| 44 |
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
| 45 |
masks.append(mask)
|
|
@@ -51,7 +56,6 @@ class TrackingAnything():
|
|
| 51 |
masks.append(mask)
|
| 52 |
logits.append(logit)
|
| 53 |
painted_images.append(painted_image)
|
| 54 |
-
print("tracking image {}".format(i))
|
| 55 |
return masks, logits, painted_images
|
| 56 |
|
| 57 |
|
|
|
|
| 1 |
+
import PIL
|
| 2 |
+
from tqdm import tqdm
|
| 3 |
+
|
| 4 |
from tools.interact_tools import SamControler
|
| 5 |
from tracker.base_tracker import BaseTracker
|
| 6 |
from inpainter.base_inpainter import BaseInpainter
|
|
|
|
| 12 |
class TrackingAnything():
|
| 13 |
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
|
| 14 |
self.args = args
|
| 15 |
+
self.sam_checkpoint = sam_checkpoint
|
| 16 |
+
self.xmem_checkpoint = xmem_checkpoint
|
| 17 |
+
self.e2fgvi_checkpoint = e2fgvi_checkpoint
|
| 18 |
+
self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device)
|
| 19 |
+
self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
|
| 20 |
+
self.baseinpainter = BaseInpainter(self.e2fgvi_checkpoint, args.device)
|
| 21 |
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
| 22 |
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
| 23 |
# if first_flag:
|
|
|
|
| 44 |
masks = []
|
| 45 |
logits = []
|
| 46 |
painted_images = []
|
| 47 |
+
for i in tqdm(range(len(images)), desc="Tracking image"):
|
| 48 |
if i ==0:
|
| 49 |
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
| 50 |
masks.append(mask)
|
|
|
|
| 56 |
masks.append(mask)
|
| 57 |
logits.append(logit)
|
| 58 |
painted_images.append(painted_image)
|
|
|
|
| 59 |
return masks, logits, painted_images
|
| 60 |
|
| 61 |
|
tracker/.DS_Store
CHANGED
|
Binary files a/tracker/.DS_Store and b/tracker/.DS_Store differ
|
|
|
tracker/base_tracker.py
CHANGED
|
@@ -126,6 +126,7 @@ class BaseTracker:
|
|
| 126 |
def clear_memory(self):
|
| 127 |
self.tracker.clear_memory()
|
| 128 |
self.mapper.clear_labels()
|
|
|
|
| 129 |
|
| 130 |
|
| 131 |
## how to use:
|
|
|
|
| 126 |
def clear_memory(self):
|
| 127 |
self.tracker.clear_memory()
|
| 128 |
self.mapper.clear_labels()
|
| 129 |
+
torch.cuda.empty_cache()
|
| 130 |
|
| 131 |
|
| 132 |
## how to use:
|