Create vtoonify/style_transfer.py
Browse files- vtoonify/style_transfer.py +232 -0
    	
        vtoonify/style_transfer.py
    ADDED
    
    | @@ -0,0 +1,232 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            #os.environ['CUDA_VISIBLE_DEVICES'] = "0"
         | 
| 3 | 
            +
            import argparse
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import cv2
         | 
| 6 | 
            +
            import dlib
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from torchvision import transforms
         | 
| 9 | 
            +
            import torch.nn.functional as F
         | 
| 10 | 
            +
            from tqdm import tqdm
         | 
| 11 | 
            +
            from model.vtoonify import VToonify
         | 
| 12 | 
            +
            from model.bisenet.model import BiSeNet
         | 
| 13 | 
            +
            from model.encoder.align_all_parallel import align_face
         | 
| 14 | 
            +
            from util import save_image, load_image, visualize, load_psp_standalone, get_video_crop_parameter, tensor2cv2
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class TestOptions():
         | 
| 18 | 
            +
                def __init__(self):
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    self.parser = argparse.ArgumentParser(description="Style Transfer")
         | 
| 21 | 
            +
                    self.parser.add_argument("--content", type=str, default='./data/077436.jpg', help="path of the content image/video")
         | 
| 22 | 
            +
                    self.parser.add_argument("--style_id", type=int, default=26, help="the id of the style image")
         | 
| 23 | 
            +
                    self.parser.add_argument("--style_degree", type=float, default=0.5, help="style degree for VToonify-D")
         | 
| 24 | 
            +
                    self.parser.add_argument("--color_transfer", action="store_true", help="transfer the color of the style")
         | 
| 25 | 
            +
                    self.parser.add_argument("--ckpt", type=str, default='./checkpoint/vtoonify_d_cartoon/vtoonify_s_d.pt', help="path of the saved model")
         | 
| 26 | 
            +
                    self.parser.add_argument("--output_path", type=str, default='./output/', help="path of the output images")
         | 
| 27 | 
            +
                    self.parser.add_argument("--scale_image", action="store_true", help="resize and crop the image to best fit the model")
         | 
| 28 | 
            +
                    self.parser.add_argument("--style_encoder_path", type=str, default='./checkpoint/encoder.pt', help="path of the style encoder")
         | 
| 29 | 
            +
                    self.parser.add_argument("--exstyle_path", type=str, default=None, help="path of the extrinsic style code")
         | 
| 30 | 
            +
                    self.parser.add_argument("--faceparsing_path", type=str, default='./checkpoint/faceparsing.pth', help="path of the face parsing model")
         | 
| 31 | 
            +
                    self.parser.add_argument("--video", action="store_true", help="if true, video stylization; if false, image stylization")
         | 
| 32 | 
            +
                    self.parser.add_argument("--cpu", action="store_true", help="if true, only use cpu")
         | 
| 33 | 
            +
                    self.parser.add_argument("--backbone", type=str, default='dualstylegan', help="dualstylegan | toonify")
         | 
| 34 | 
            +
                    self.parser.add_argument("--padding", type=int, nargs=4, default=[200,200,200,200], help="left, right, top, bottom paddings to the face center")
         | 
| 35 | 
            +
                    self.parser.add_argument("--batch_size", type=int, default=4, help="batch size of frames when processing video")
         | 
| 36 | 
            +
                    self.parser.add_argument("--parsing_map_path", type=str, default=None, help="path of the refined parsing map of the target video")
         | 
| 37 | 
            +
                    
         | 
| 38 | 
            +
                def parse(self):
         | 
| 39 | 
            +
                    self.opt = self.parser.parse_args()
         | 
| 40 | 
            +
                    if self.opt.exstyle_path is None:
         | 
| 41 | 
            +
                        self.opt.exstyle_path = os.path.join(os.path.dirname(self.opt.ckpt), 'exstyle_code.npy')
         | 
| 42 | 
            +
                    args = vars(self.opt)
         | 
| 43 | 
            +
                    print('Load options')
         | 
| 44 | 
            +
                    for name, value in sorted(args.items()):
         | 
| 45 | 
            +
                        print('%s: %s' % (str(name), str(value)))
         | 
| 46 | 
            +
                    return self.opt
         | 
| 47 | 
            +
                
         | 
| 48 | 
            +
            if __name__ == "__main__":
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                parser = TestOptions()
         | 
| 51 | 
            +
                args = parser.parse()
         | 
| 52 | 
            +
                print('*'*98)
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
                
         | 
| 55 | 
            +
                device = "cpu" if args.cpu else "cuda"
         | 
| 56 | 
            +
                
         | 
| 57 | 
            +
                transform = transforms.Compose([
         | 
| 58 | 
            +
                    transforms.ToTensor(),
         | 
| 59 | 
            +
                    transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
         | 
| 60 | 
            +
                    ])
         | 
| 61 | 
            +
                
         | 
| 62 | 
            +
                vtoonify = VToonify(backbone = args.backbone)
         | 
| 63 | 
            +
                vtoonify.load_state_dict(torch.load(args.ckpt, map_location=lambda storage, loc: storage)['g_ema'])
         | 
| 64 | 
            +
                vtoonify.to(device)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                parsingpredictor = BiSeNet(n_classes=19)
         | 
| 67 | 
            +
                parsingpredictor.load_state_dict(torch.load(args.faceparsing_path, map_location=lambda storage, loc: storage))
         | 
| 68 | 
            +
                parsingpredictor.to(device).eval()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                modelname = './checkpoint/shape_predictor_68_face_landmarks.dat'
         | 
| 71 | 
            +
                if not os.path.exists(modelname):
         | 
| 72 | 
            +
                    import wget, bz2
         | 
| 73 | 
            +
                    wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname+'.bz2')
         | 
| 74 | 
            +
                    zipfile = bz2.BZ2File(modelname+'.bz2')
         | 
| 75 | 
            +
                    data = zipfile.read()
         | 
| 76 | 
            +
                    open(modelname, 'wb').write(data) 
         | 
| 77 | 
            +
                landmarkpredictor = dlib.shape_predictor(modelname)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                pspencoder = load_psp_standalone(args.style_encoder_path, device)    
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                if args.backbone == 'dualstylegan':
         | 
| 82 | 
            +
                    exstyles = np.load(args.exstyle_path, allow_pickle='TRUE').item()
         | 
| 83 | 
            +
                    stylename = list(exstyles.keys())[args.style_id]
         | 
| 84 | 
            +
                    exstyle = torch.tensor(exstyles[stylename]).to(device)
         | 
| 85 | 
            +
                    with torch.no_grad():  
         | 
| 86 | 
            +
                        exstyle = vtoonify.zplus2wplus(exstyle)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if args.video and args.parsing_map_path is not None:
         | 
| 89 | 
            +
                    x_p_hat = torch.tensor(np.load(args.parsing_map_path))          
         | 
| 90 | 
            +
                        
         | 
| 91 | 
            +
                print('Load models successfully!')
         | 
| 92 | 
            +
                
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                filename = args.content
         | 
| 95 | 
            +
                basename = os.path.basename(filename).split('.')[0]
         | 
| 96 | 
            +
                scale = 1
         | 
| 97 | 
            +
                kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
         | 
| 98 | 
            +
                print('Processing ' + os.path.basename(filename) + ' with vtoonify_' + args.backbone[0])
         | 
| 99 | 
            +
                if args.video:
         | 
| 100 | 
            +
                    cropname = os.path.join(args.output_path, basename + '_input.mp4')
         | 
| 101 | 
            +
                    savename = os.path.join(args.output_path, basename + '_vtoonify_' +  args.backbone[0] + '.mp4')
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    video_cap = cv2.VideoCapture(filename)
         | 
| 104 | 
            +
                    num = int(video_cap.get(7))
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    first_valid_frame = True
         | 
| 107 | 
            +
                    batch_frames = []
         | 
| 108 | 
            +
                    for i in tqdm(range(num)):
         | 
| 109 | 
            +
                        success, frame = video_cap.read()
         | 
| 110 | 
            +
                        if success == False:
         | 
| 111 | 
            +
                            assert('load video frames error')
         | 
| 112 | 
            +
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
         | 
| 113 | 
            +
                        # We proprocess the video by detecting the face in the first frame, 
         | 
| 114 | 
            +
                        # and resizing the frame so that the eye distance is 64 pixels.
         | 
| 115 | 
            +
                        # Centered on the eyes, we crop the first frame to almost 400x400 (based on args.padding).
         | 
| 116 | 
            +
                        # All other frames use the same resizing and cropping parameters as the first frame.
         | 
| 117 | 
            +
                        if first_valid_frame:
         | 
| 118 | 
            +
                            if args.scale_image:
         | 
| 119 | 
            +
                                paras = get_video_crop_parameter(frame, landmarkpredictor, args.padding)
         | 
| 120 | 
            +
                                if paras is None:
         | 
| 121 | 
            +
                                    continue
         | 
| 122 | 
            +
                                h,w,top,bottom,left,right,scale = paras
         | 
| 123 | 
            +
                                H, W = int(bottom-top), int(right-left)
         | 
| 124 | 
            +
                                # for HR video, we apply gaussian blur to the frames to avoid flickers caused by bilinear downsampling
         | 
| 125 | 
            +
                                # this can also prevent over-sharp stylization results. 
         | 
| 126 | 
            +
                                if scale <= 0.75:
         | 
| 127 | 
            +
                                    frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
         | 
| 128 | 
            +
                                if scale <= 0.375:
         | 
| 129 | 
            +
                                    frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
         | 
| 130 | 
            +
                                frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
         | 
| 131 | 
            +
                            else:
         | 
| 132 | 
            +
                                H, W = frame.shape[0], frame.shape[1]
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
         | 
| 135 | 
            +
                            videoWriter = cv2.VideoWriter(cropname, fourcc, video_cap.get(5), (W, H))
         | 
| 136 | 
            +
                            videoWriter2 = cv2.VideoWriter(savename, fourcc, video_cap.get(5), (4*W, 4*H))
         | 
| 137 | 
            +
                            
         | 
| 138 | 
            +
                            # For each video, we detect and align the face in the first frame for pSp to obtain the style code. 
         | 
| 139 | 
            +
                            # This style code is used for all other frames.
         | 
| 140 | 
            +
                            with torch.no_grad():
         | 
| 141 | 
            +
                                I = align_face(frame, landmarkpredictor)
         | 
| 142 | 
            +
                                I = transform(I).unsqueeze(dim=0).to(device)
         | 
| 143 | 
            +
                                s_w = pspencoder(I)
         | 
| 144 | 
            +
                                s_w = vtoonify.zplus2wplus(s_w)
         | 
| 145 | 
            +
                                if vtoonify.backbone == 'dualstylegan':
         | 
| 146 | 
            +
                                    if args.color_transfer:
         | 
| 147 | 
            +
                                        s_w = exstyle
         | 
| 148 | 
            +
                                    else:
         | 
| 149 | 
            +
                                        s_w[:,:7] = exstyle[:,:7]
         | 
| 150 | 
            +
                            first_valid_frame = False
         | 
| 151 | 
            +
                        elif args.scale_image:
         | 
| 152 | 
            +
                            if scale <= 0.75:
         | 
| 153 | 
            +
                                frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
         | 
| 154 | 
            +
                            if scale <= 0.375:
         | 
| 155 | 
            +
                                frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
         | 
| 156 | 
            +
                            frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                        videoWriter.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                        batch_frames += [transform(frame).unsqueeze(dim=0).to(device)]
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                        if len(batch_frames) == args.batch_size or (i+1) == num:
         | 
| 163 | 
            +
                            x = torch.cat(batch_frames, dim=0)
         | 
| 164 | 
            +
                            batch_frames = []
         | 
| 165 | 
            +
                            with torch.no_grad():
         | 
| 166 | 
            +
                                # parsing network works best on 512x512 images, so we predict parsing maps on upsmapled frames
         | 
| 167 | 
            +
                                # followed by downsampling the parsing maps
         | 
| 168 | 
            +
                                if args.video and args.parsing_map_path is not None:
         | 
| 169 | 
            +
                                    x_p = x_p_hat[i+1-x.size(0):i+1].to(device)
         | 
| 170 | 
            +
                                else:
         | 
| 171 | 
            +
                                    x_p = F.interpolate(parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0], 
         | 
| 172 | 
            +
                                                    scale_factor=0.5, recompute_scale_factor=False).detach()
         | 
| 173 | 
            +
                                # we give parsing maps lower weight (1/16)
         | 
| 174 | 
            +
                                inputs = torch.cat((x, x_p/16.), dim=1)
         | 
| 175 | 
            +
                                # d_s has no effect when backbone is toonify
         | 
| 176 | 
            +
                                y_tilde = vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = args.style_degree)       
         | 
| 177 | 
            +
                                y_tilde = torch.clamp(y_tilde, -1, 1)
         | 
| 178 | 
            +
                            for k in range(y_tilde.size(0)):
         | 
| 179 | 
            +
                                videoWriter2.write(tensor2cv2(y_tilde[k].cpu()))
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    videoWriter.release()
         | 
| 182 | 
            +
                    videoWriter2.release()
         | 
| 183 | 
            +
                    video_cap.release()
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                
         | 
| 186 | 
            +
                else:
         | 
| 187 | 
            +
                    cropname = os.path.join(args.output_path, basename + '_input.jpg')
         | 
| 188 | 
            +
                    savename = os.path.join(args.output_path, basename + '_vtoonify_' +  args.backbone[0] + '.jpg')
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    frame = cv2.imread(filename)
         | 
| 191 | 
            +
                    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # We detect the face in the image, and resize the image so that the eye distance is 64 pixels.
         | 
| 194 | 
            +
                    # Centered on the eyes, we crop the image to almost 400x400 (based on args.padding).
         | 
| 195 | 
            +
                    if args.scale_image:
         | 
| 196 | 
            +
                        paras = get_video_crop_parameter(frame, landmarkpredictor, args.padding)
         | 
| 197 | 
            +
                        if paras is not None:
         | 
| 198 | 
            +
                            h,w,top,bottom,left,right,scale = paras
         | 
| 199 | 
            +
                            H, W = int(bottom-top), int(right-left)
         | 
| 200 | 
            +
                            # for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
         | 
| 201 | 
            +
                            if scale <= 0.75:
         | 
| 202 | 
            +
                                frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
         | 
| 203 | 
            +
                            if scale <= 0.375:
         | 
| 204 | 
            +
                                frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
         | 
| 205 | 
            +
                            frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    with torch.no_grad():
         | 
| 208 | 
            +
                        I = align_face(frame, landmarkpredictor)
         | 
| 209 | 
            +
                        I = transform(I).unsqueeze(dim=0).to(device)
         | 
| 210 | 
            +
                        s_w = pspencoder(I)
         | 
| 211 | 
            +
                        s_w = vtoonify.zplus2wplus(s_w)
         | 
| 212 | 
            +
                        if vtoonify.backbone == 'dualstylegan':
         | 
| 213 | 
            +
                            if args.color_transfer:
         | 
| 214 | 
            +
                                s_w = exstyle
         | 
| 215 | 
            +
                            else:
         | 
| 216 | 
            +
                                s_w[:,:7] = exstyle[:,:7]
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        x = transform(frame).unsqueeze(dim=0).to(device)
         | 
| 219 | 
            +
                        # parsing network works best on 512x512 images, so we predict parsing maps on upsmapled frames
         | 
| 220 | 
            +
                        # followed by downsampling the parsing maps
         | 
| 221 | 
            +
                        x_p = F.interpolate(parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0], 
         | 
| 222 | 
            +
                                            scale_factor=0.5, recompute_scale_factor=False).detach()
         | 
| 223 | 
            +
                        # we give parsing maps lower weight (1/16)
         | 
| 224 | 
            +
                        inputs = torch.cat((x, x_p/16.), dim=1)
         | 
| 225 | 
            +
                        # d_s has no effect when backbone is toonify
         | 
| 226 | 
            +
                        y_tilde = vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = args.style_degree)        
         | 
| 227 | 
            +
                        y_tilde = torch.clamp(y_tilde, -1, 1)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    cv2.imwrite(cropname, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
         | 
| 230 | 
            +
                    save_image(y_tilde[0].cpu(), savename)
         | 
| 231 | 
            +
                    
         | 
| 232 | 
            +
                print('Transfer style successfully!')
         |