Spaces:
Runtime error
Runtime error
| import sys | |
| import argparse | |
| import os | |
| import cv2 | |
| import glob | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from .raft import RAFT | |
| from .utils import flow_viz | |
| from .utils.utils import InputPadder | |
| DEVICE = 'cuda' | |
| def load_image(imfile): | |
| img = np.array(Image.open(imfile)).astype(np.uint8) | |
| img = torch.from_numpy(img).permute(2, 0, 1).float() | |
| return img | |
| def load_image_list(image_files): | |
| images = [] | |
| for imfile in sorted(image_files): | |
| images.append(load_image(imfile)) | |
| images = torch.stack(images, dim=0) | |
| images = images.to(DEVICE) | |
| padder = InputPadder(images.shape) | |
| return padder.pad(images)[0] | |
| def viz(img, flo): | |
| img = img[0].permute(1,2,0).cpu().numpy() | |
| flo = flo[0].permute(1,2,0).cpu().numpy() | |
| # map flow to rgb image | |
| flo = flow_viz.flow_to_image(flo) | |
| # img_flo = np.concatenate([img, flo], axis=0) | |
| img_flo = flo | |
| cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) | |
| # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) | |
| # cv2.waitKey() | |
| def demo(args): | |
| model = torch.nn.DataParallel(RAFT(args)) | |
| model.load_state_dict(torch.load(args.model)) | |
| model = model.module | |
| model.to(DEVICE) | |
| model.eval() | |
| with torch.no_grad(): | |
| images = glob.glob(os.path.join(args.path, '*.png')) + \ | |
| glob.glob(os.path.join(args.path, '*.jpg')) | |
| images = load_image_list(images) | |
| for i in range(images.shape[0]-1): | |
| image1 = images[i,None] | |
| image2 = images[i+1,None] | |
| flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) | |
| viz(image1, flow_up) | |
| def RAFT_infer(args): | |
| model = torch.nn.DataParallel(RAFT(args)) | |
| model.load_state_dict(torch.load(args.model)) | |
| model = model.module | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |