jhj0517
commited on
Commit
·
4daf2ff
1
Parent(s):
a83a3d9
Add inferencer class
Browse files
modules/live_portrait/live_portrait_inferencer.py
ADDED
|
@@ -0,0 +1,763 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import cv2
|
| 6 |
+
import time
|
| 7 |
+
import copy
|
| 8 |
+
import dill
|
| 9 |
+
from ultralytics import YOLO
|
| 10 |
+
import safetensors.torch
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from modules.utils.paths import *
|
| 14 |
+
from modules.utils.image_helper import *
|
| 15 |
+
from modules.live_portrait.model_downloader import *
|
| 16 |
+
from modules.live_portrait_wrapper import LivePortraitWrapper
|
| 17 |
+
from modules.utils.camera import get_rotation_matrix
|
| 18 |
+
from modules.utils.helper import load_yaml
|
| 19 |
+
from modules.config.inference_config import InferenceConfig
|
| 20 |
+
from modules.live_portrait.spade_generator import SPADEDecoder
|
| 21 |
+
from modules.live_portrait.warping_network import WarpingNetwork
|
| 22 |
+
from modules.live_portrait.motion_extractor import MotionExtractor
|
| 23 |
+
from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
|
| 24 |
+
from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
|
| 25 |
+
from collections import OrderedDict
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LivePortraitInferencer:
|
| 29 |
+
def __init__(self,
|
| 30 |
+
model_dir: str = MODELS_DIR,
|
| 31 |
+
output_dir: str = OUTPUTS_DIR):
|
| 32 |
+
self.model_dir = model_dir
|
| 33 |
+
self.output_dir = output_dir
|
| 34 |
+
self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
|
| 35 |
+
|
| 36 |
+
self.appearance_feature_extractor = None
|
| 37 |
+
self.motion_extractor = None
|
| 38 |
+
self.warping_module = None
|
| 39 |
+
self.spade_generator = None
|
| 40 |
+
self.stitching_retargeting_module = None
|
| 41 |
+
self.pipeline = None
|
| 42 |
+
self.detect_model = None
|
| 43 |
+
self.device = self.get_device()
|
| 44 |
+
|
| 45 |
+
self.mask_img = None
|
| 46 |
+
self.temp_img_idx = 0
|
| 47 |
+
self.src_image = None
|
| 48 |
+
self.src_image_list = None
|
| 49 |
+
self.sample_image = None
|
| 50 |
+
self.driving_images = None
|
| 51 |
+
self.driving_values = None
|
| 52 |
+
self.crop_factor = None
|
| 53 |
+
self.psi = None
|
| 54 |
+
self.psi_list = None
|
| 55 |
+
self.d_info = None
|
| 56 |
+
|
| 57 |
+
def load_models(self):
|
| 58 |
+
self.download_if_no_models()
|
| 59 |
+
|
| 60 |
+
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
| 61 |
+
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
| 62 |
+
self.appearance_feature_extractor = self.load_safe_tensor(
|
| 63 |
+
self.appearance_feature_extractor,
|
| 64 |
+
os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
motion_ext_config = self.model_config["motion_extractor_params"]
|
| 68 |
+
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
| 69 |
+
self.motion_extractor = self.load_safe_tensor(
|
| 70 |
+
self.motion_extractor,
|
| 71 |
+
os.path.join(self.model_dir, "motion_extractor.safetensors")
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
warping_module_config = self.model_config["warping_module_params"]
|
| 75 |
+
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
| 76 |
+
self.warping_module = self.load_safe_tensor(
|
| 77 |
+
self.warping_module,
|
| 78 |
+
os.path.join(self.model_dir, "warping_module.safetensors")
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
spaded_decoder_config = self.model_config["spade_generator_params"]
|
| 82 |
+
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
| 83 |
+
self.spade_generator = self.load_safe_tensor(
|
| 84 |
+
self.spade_generator,
|
| 85 |
+
os.path.join(self.model_dir, "spade_generator.safetensors")
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def filter_stitcher(checkpoint, prefix):
|
| 89 |
+
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
| 90 |
+
key.startswith(prefix)}
|
| 91 |
+
return filtered_checkpoint
|
| 92 |
+
|
| 93 |
+
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
| 94 |
+
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
| 95 |
+
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|
| 96 |
+
ckpt = safetensors.torch.load_file(stitcher_model_path)
|
| 97 |
+
self.stitching_retargeting_module.load_state_dict(filter_stitcher(ckpt, 'retarget_shoulder'))
|
| 98 |
+
self.stitching_retargeting_module.to(self.device).eval()
|
| 99 |
+
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
| 100 |
+
|
| 101 |
+
if self.pipeline is None:
|
| 102 |
+
self.pipeline = LivePortraitWrapper(
|
| 103 |
+
InferenceConfig(),
|
| 104 |
+
self.appearance_feature_extractor,
|
| 105 |
+
self.motion_extractor,
|
| 106 |
+
self.warping_module,
|
| 107 |
+
self.spade_generator,
|
| 108 |
+
self.stitching_retargeting_module
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"])
|
| 112 |
+
|
| 113 |
+
def edit_expression(self,
|
| 114 |
+
rotate_pitch=0,
|
| 115 |
+
rotate_yaw=0,
|
| 116 |
+
rotate_roll=0,
|
| 117 |
+
blink=0,
|
| 118 |
+
eyebrow=0,
|
| 119 |
+
wink=0,
|
| 120 |
+
pupil_x=0,
|
| 121 |
+
pupil_y=0,
|
| 122 |
+
aaa=0,
|
| 123 |
+
eee=0,
|
| 124 |
+
woo=0,
|
| 125 |
+
smile=0,
|
| 126 |
+
src_ratio=1,
|
| 127 |
+
sample_ratio=1,
|
| 128 |
+
sample_parts="All",
|
| 129 |
+
crop_factor=1.5,
|
| 130 |
+
src_image=None,
|
| 131 |
+
sample_image=None,
|
| 132 |
+
motion_link=None,
|
| 133 |
+
add_exp=None):
|
| 134 |
+
if self.pipeline is None:
|
| 135 |
+
self.load_models()
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
rotate_yaw = -rotate_yaw
|
| 139 |
+
|
| 140 |
+
new_editor_link = None
|
| 141 |
+
if motion_link is not None:
|
| 142 |
+
self.psi = motion_link[0]
|
| 143 |
+
new_editor_link = motion_link.copy()
|
| 144 |
+
elif src_image is not None:
|
| 145 |
+
if id(src_image) != id(self.src_image) or self.crop_factor != crop_factor:
|
| 146 |
+
self.crop_factor = crop_factor
|
| 147 |
+
self.psi = self.prepare_source(src_image, crop_factor)
|
| 148 |
+
self.src_image = src_image
|
| 149 |
+
new_editor_link = []
|
| 150 |
+
new_editor_link.append(self.psi)
|
| 151 |
+
else:
|
| 152 |
+
return None, None
|
| 153 |
+
|
| 154 |
+
psi = self.psi
|
| 155 |
+
s_info = psi.x_s_info
|
| 156 |
+
#delta_new = copy.deepcopy()
|
| 157 |
+
s_exp = s_info['exp'] * src_ratio
|
| 158 |
+
s_exp[0, 5] = s_info['exp'][0, 5]
|
| 159 |
+
s_exp += s_info['kp']
|
| 160 |
+
|
| 161 |
+
es = ExpressionSet()
|
| 162 |
+
|
| 163 |
+
if sample_image is not None:
|
| 164 |
+
if id(self.sample_image) != id(sample_image):
|
| 165 |
+
self.sample_image = sample_image
|
| 166 |
+
d_image_np = (sample_image * 255).byte().numpy()
|
| 167 |
+
d_face = self.crop_face(d_image_np[0], 1.7)
|
| 168 |
+
i_d = self.prepare_src_image(d_face)
|
| 169 |
+
self.d_info = self.pipeline.get_kp_info(i_d)
|
| 170 |
+
self.d_info['exp'][0, 5, 0] = 0
|
| 171 |
+
self.d_info['exp'][0, 5, 1] = 0
|
| 172 |
+
|
| 173 |
+
# "OnlyExpression", "OnlyRotation", "OnlyMouth", "OnlyEyes", "All"
|
| 174 |
+
if sample_parts == "OnlyExpression" or sample_parts == "All":
|
| 175 |
+
es.e += self.d_info['exp'] * sample_ratio
|
| 176 |
+
if sample_parts == "OnlyRotation" or sample_parts == "All":
|
| 177 |
+
rotate_pitch += self.d_info['pitch'] * sample_ratio
|
| 178 |
+
rotate_yaw += self.d_info['yaw'] * sample_ratio
|
| 179 |
+
rotate_roll += self.d_info['roll'] * sample_ratio
|
| 180 |
+
elif sample_parts == "OnlyMouth":
|
| 181 |
+
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (14, 17, 19, 20))
|
| 182 |
+
elif sample_parts == "OnlyEyes":
|
| 183 |
+
self.retargeting(es.e, self.d_info['exp'], sample_ratio, (1, 2, 11, 13, 15, 16))
|
| 184 |
+
|
| 185 |
+
es.r = self.calc_fe(es.e, blink, eyebrow, wink, pupil_x, pupil_y, aaa, eee, woo, smile,
|
| 186 |
+
rotate_pitch, rotate_yaw, rotate_roll)
|
| 187 |
+
|
| 188 |
+
if add_exp is not None:
|
| 189 |
+
es.add(add_exp)
|
| 190 |
+
|
| 191 |
+
new_rotate = get_rotation_matrix(s_info['pitch'] + es.r[0], s_info['yaw'] + es.r[1],
|
| 192 |
+
s_info['roll'] + es.r[2])
|
| 193 |
+
x_d_new = (s_info['scale'] * (1 + es.s)) * ((s_exp + es.e) @ new_rotate) + s_info['t']
|
| 194 |
+
|
| 195 |
+
x_d_new = self.pipeline.stitching(psi.x_s_user, x_d_new)
|
| 196 |
+
|
| 197 |
+
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, x_d_new)
|
| 198 |
+
crop_out = self.pipeline.parse_output(crop_out['out'])[0]
|
| 199 |
+
|
| 200 |
+
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb), cv2.INTER_LINEAR)
|
| 201 |
+
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(np.uint8)
|
| 202 |
+
|
| 203 |
+
out_img = pil2tensor(out)
|
| 204 |
+
out_img_path = get_auto_incremental_file_path(TEMP_DIR, "png")
|
| 205 |
+
|
| 206 |
+
img = Image.fromarray(crop_out)
|
| 207 |
+
img.save(out_img_path, compress_level=1)
|
| 208 |
+
new_editor_link.append(es)
|
| 209 |
+
|
| 210 |
+
return out_img # {"ui": {"images": results}, "result": (out_img, new_editor_link, es)}
|
| 211 |
+
except Exception as e:
|
| 212 |
+
raise
|
| 213 |
+
|
| 214 |
+
def create_video(self,
|
| 215 |
+
retargeting_eyes,
|
| 216 |
+
retargeting_mouth,
|
| 217 |
+
turn_on,
|
| 218 |
+
tracking_src_vid,
|
| 219 |
+
animate_without_vid,
|
| 220 |
+
command,
|
| 221 |
+
crop_factor,
|
| 222 |
+
src_images=None,
|
| 223 |
+
driving_images=None,
|
| 224 |
+
motion_link=None,
|
| 225 |
+
progress=gr.Progress()):
|
| 226 |
+
if not turn_on:
|
| 227 |
+
return None, None
|
| 228 |
+
src_length = 1
|
| 229 |
+
|
| 230 |
+
if src_images is None:
|
| 231 |
+
if motion_link is not None:
|
| 232 |
+
self.psi_list = [motion_link[0]]
|
| 233 |
+
else:
|
| 234 |
+
return None, None
|
| 235 |
+
|
| 236 |
+
if src_images is not None:
|
| 237 |
+
src_length = len(src_images)
|
| 238 |
+
if id(src_images) != id(self.src_images) or self.crop_factor != crop_factor:
|
| 239 |
+
self.crop_factor = crop_factor
|
| 240 |
+
self.src_images = src_images
|
| 241 |
+
if 1 < src_length:
|
| 242 |
+
self.psi_list = self.prepare_source(src_images, crop_factor, True, tracking_src_vid)
|
| 243 |
+
else:
|
| 244 |
+
self.psi_list = [self.prepare_source(src_images, crop_factor)]
|
| 245 |
+
|
| 246 |
+
cmd_list, cmd_length = self.parsing_command(command, motion_link)
|
| 247 |
+
if cmd_list is None:
|
| 248 |
+
return None,None
|
| 249 |
+
cmd_idx = 0
|
| 250 |
+
|
| 251 |
+
driving_length = 0
|
| 252 |
+
if driving_images is not None:
|
| 253 |
+
if id(driving_images) != id(self.driving_images):
|
| 254 |
+
self.driving_images = driving_images
|
| 255 |
+
self.driving_values = self.prepare_driving_video(driving_images)
|
| 256 |
+
driving_length = len(self.driving_values)
|
| 257 |
+
|
| 258 |
+
total_length = max(driving_length, src_length)
|
| 259 |
+
|
| 260 |
+
if animate_without_vid:
|
| 261 |
+
total_length = max(total_length, cmd_length)
|
| 262 |
+
|
| 263 |
+
c_i_es = ExpressionSet()
|
| 264 |
+
c_o_es = ExpressionSet()
|
| 265 |
+
d_0_es = None
|
| 266 |
+
out_list = []
|
| 267 |
+
|
| 268 |
+
psi = None
|
| 269 |
+
for i in range(total_length):
|
| 270 |
+
|
| 271 |
+
if i < src_length:
|
| 272 |
+
psi = self.psi_list[i]
|
| 273 |
+
s_info = psi.x_s_info
|
| 274 |
+
s_es = ExpressionSet(erst=(s_info['kp'] + s_info['exp'], torch.Tensor([0, 0, 0]), s_info['scale'], s_info['t']))
|
| 275 |
+
|
| 276 |
+
new_es = ExpressionSet(es=s_es)
|
| 277 |
+
|
| 278 |
+
if i < cmd_length:
|
| 279 |
+
cmd = cmd_list[cmd_idx]
|
| 280 |
+
if 0 < cmd.change:
|
| 281 |
+
cmd.change -= 1
|
| 282 |
+
c_i_es.add(cmd.es)
|
| 283 |
+
c_i_es.sub(c_o_es)
|
| 284 |
+
elif 0 < cmd.keep:
|
| 285 |
+
cmd.keep -= 1
|
| 286 |
+
|
| 287 |
+
new_es.add(c_i_es)
|
| 288 |
+
|
| 289 |
+
if cmd.change == 0 and cmd.keep == 0:
|
| 290 |
+
cmd_idx += 1
|
| 291 |
+
if cmd_idx < len(cmd_list):
|
| 292 |
+
c_o_es = ExpressionSet(es=c_i_es)
|
| 293 |
+
cmd = cmd_list[cmd_idx]
|
| 294 |
+
c_o_es.div(cmd.change)
|
| 295 |
+
elif 0 < cmd_length:
|
| 296 |
+
new_es.add(c_i_es)
|
| 297 |
+
|
| 298 |
+
if i < driving_length:
|
| 299 |
+
d_i_info = self.driving_values[i]
|
| 300 |
+
d_i_r = torch.Tensor([d_i_info['pitch'], d_i_info['yaw'], d_i_info['roll']])#.float().to(device="cuda:0")
|
| 301 |
+
|
| 302 |
+
if d_0_es is None:
|
| 303 |
+
d_0_es = ExpressionSet(erst = (d_i_info['exp'], d_i_r, d_i_info['scale'], d_i_info['t']))
|
| 304 |
+
|
| 305 |
+
self.retargeting(s_es.e, d_0_es.e, retargeting_eyes, (11, 13, 15, 16))
|
| 306 |
+
self.retargeting(s_es.e, d_0_es.e, retargeting_mouth, (14, 17, 19, 20))
|
| 307 |
+
|
| 308 |
+
new_es.e += d_i_info['exp'] - d_0_es.e
|
| 309 |
+
new_es.r += d_i_r - d_0_es.r
|
| 310 |
+
new_es.t += d_i_info['t'] - d_0_es.t
|
| 311 |
+
|
| 312 |
+
r_new = get_rotation_matrix(
|
| 313 |
+
s_info['pitch'] + new_es.r[0], s_info['yaw'] + new_es.r[1], s_info['roll'] + new_es.r[2])
|
| 314 |
+
d_new = new_es.s * (new_es.e @ r_new) + new_es.t
|
| 315 |
+
d_new = self.pipeline.stitching(psi.x_s_user, d_new)
|
| 316 |
+
crop_out = self.pipeline.warp_decode(psi.f_s_user, psi.x_s_user, d_new)
|
| 317 |
+
crop_out = self.pipeline.parse_output(crop_out['out'])[0]
|
| 318 |
+
|
| 319 |
+
crop_with_fullsize = cv2.warpAffine(crop_out, psi.crop_trans_m, get_rgb_size(psi.src_rgb),
|
| 320 |
+
cv2.INTER_LINEAR)
|
| 321 |
+
out = np.clip(psi.mask_ori * crop_with_fullsize + (1 - psi.mask_ori) * psi.src_rgb, 0, 255).astype(
|
| 322 |
+
np.uint8)
|
| 323 |
+
out_list.append(out)
|
| 324 |
+
|
| 325 |
+
progress(i/total_length, "predicting..")
|
| 326 |
+
|
| 327 |
+
if len(out_list) == 0:
|
| 328 |
+
return None
|
| 329 |
+
|
| 330 |
+
out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
|
| 331 |
+
return out_imgs
|
| 332 |
+
|
| 333 |
+
def download_if_no_models(self):
|
| 334 |
+
for model_name, model_url in MODELS_URL.items():
|
| 335 |
+
if model_url.endswith(".pt"):
|
| 336 |
+
model_name += ".pt"
|
| 337 |
+
else:
|
| 338 |
+
model_name += ".safetensors"
|
| 339 |
+
model_path = os.path.join(self.model_dir, model_name)
|
| 340 |
+
if not os.path.exists(model_path):
|
| 341 |
+
download_model(model_path, model_url)
|
| 342 |
+
|
| 343 |
+
@staticmethod
|
| 344 |
+
def load_safe_tensor(model, file_path):
|
| 345 |
+
model.load_state_dict(safetensors.torch.load_file(file_path))
|
| 346 |
+
model.eval()
|
| 347 |
+
return model
|
| 348 |
+
|
| 349 |
+
@staticmethod
|
| 350 |
+
def get_device():
|
| 351 |
+
if torch.cuda.is_available():
|
| 352 |
+
return "cuda"
|
| 353 |
+
elif torch.backends.mps.is_available():
|
| 354 |
+
return "mps"
|
| 355 |
+
else:
|
| 356 |
+
return "cpu"
|
| 357 |
+
|
| 358 |
+
def get_temp_img_name(self):
|
| 359 |
+
self.temp_img_idx += 1
|
| 360 |
+
return "expression_edit_preview" + str(self.temp_img_idx) + ".png"
|
| 361 |
+
|
| 362 |
+
@staticmethod
|
| 363 |
+
def parsing_command(command, motoin_link):
|
| 364 |
+
command.replace(' ', '')
|
| 365 |
+
lines = command.split('\n')
|
| 366 |
+
|
| 367 |
+
cmd_list = []
|
| 368 |
+
|
| 369 |
+
total_length = 0
|
| 370 |
+
|
| 371 |
+
i = 0
|
| 372 |
+
for line in lines:
|
| 373 |
+
i += 1
|
| 374 |
+
if not line:
|
| 375 |
+
continue
|
| 376 |
+
try:
|
| 377 |
+
cmds = line.split('=')
|
| 378 |
+
idx = int(cmds[0])
|
| 379 |
+
if idx == 0: es = ExpressionSet()
|
| 380 |
+
else: es = ExpressionSet(es = motoin_link[idx])
|
| 381 |
+
cmds = cmds[1].split(':')
|
| 382 |
+
change = int(cmds[0])
|
| 383 |
+
keep = int(cmds[1])
|
| 384 |
+
except Exception as e:
|
| 385 |
+
print(f"(AdvancedLivePortrait) Command Err Line {i}: {line}, :{e}")
|
| 386 |
+
return None, None
|
| 387 |
+
|
| 388 |
+
total_length += change + keep
|
| 389 |
+
es.div(change)
|
| 390 |
+
cmd_list.append(Command(es, change, keep))
|
| 391 |
+
|
| 392 |
+
return cmd_list, total_length
|
| 393 |
+
|
| 394 |
+
def get_face_bboxes(self, image_rgb):
|
| 395 |
+
pred = self.detect_model(image_rgb, conf=0.7, device="")
|
| 396 |
+
return pred[0].boxes.xyxy.cpu().numpy()
|
| 397 |
+
|
| 398 |
+
def detect_face(self, image_rgb, crop_factor, sort = True):
|
| 399 |
+
bboxes = self.get_face_bboxes(image_rgb)
|
| 400 |
+
w, h = get_rgb_size(image_rgb)
|
| 401 |
+
|
| 402 |
+
print(f"w, h:{w, h}")
|
| 403 |
+
|
| 404 |
+
cx = w / 2
|
| 405 |
+
min_diff = w
|
| 406 |
+
best_box = None
|
| 407 |
+
for x1, y1, x2, y2 in bboxes:
|
| 408 |
+
bbox_w = x2 - x1
|
| 409 |
+
if bbox_w < 30: continue
|
| 410 |
+
diff = abs(cx - (x1 + bbox_w / 2))
|
| 411 |
+
if diff < min_diff:
|
| 412 |
+
best_box = [x1, y1, x2, y2]
|
| 413 |
+
print(f"diff, min_diff, best_box:{diff, min_diff, best_box}")
|
| 414 |
+
min_diff = diff
|
| 415 |
+
|
| 416 |
+
if best_box == None:
|
| 417 |
+
print("Failed to detect face!!")
|
| 418 |
+
return [0, 0, w, h]
|
| 419 |
+
|
| 420 |
+
x1, y1, x2, y2 = best_box
|
| 421 |
+
|
| 422 |
+
#for x1, y1, x2, y2 in bboxes:
|
| 423 |
+
bbox_w = x2 - x1
|
| 424 |
+
bbox_h = y2 - y1
|
| 425 |
+
|
| 426 |
+
crop_w = bbox_w * crop_factor
|
| 427 |
+
crop_h = bbox_h * crop_factor
|
| 428 |
+
|
| 429 |
+
crop_w = max(crop_h, crop_w)
|
| 430 |
+
crop_h = crop_w
|
| 431 |
+
|
| 432 |
+
kernel_x = int(x1 + bbox_w / 2)
|
| 433 |
+
kernel_y = int(y1 + bbox_h / 2)
|
| 434 |
+
|
| 435 |
+
new_x1 = int(kernel_x - crop_w / 2)
|
| 436 |
+
new_x2 = int(kernel_x + crop_w / 2)
|
| 437 |
+
new_y1 = int(kernel_y - crop_h / 2)
|
| 438 |
+
new_y2 = int(kernel_y + crop_h / 2)
|
| 439 |
+
|
| 440 |
+
if not sort:
|
| 441 |
+
return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)]
|
| 442 |
+
|
| 443 |
+
if new_x1 < 0:
|
| 444 |
+
new_x2 -= new_x1
|
| 445 |
+
new_x1 = 0
|
| 446 |
+
elif w < new_x2:
|
| 447 |
+
new_x1 -= (new_x2 - w)
|
| 448 |
+
new_x2 = w
|
| 449 |
+
if new_x1 < 0:
|
| 450 |
+
new_x2 -= new_x1
|
| 451 |
+
new_x1 = 0
|
| 452 |
+
|
| 453 |
+
if new_y1 < 0:
|
| 454 |
+
new_y2 -= new_y1
|
| 455 |
+
new_y1 = 0
|
| 456 |
+
elif h < new_y2:
|
| 457 |
+
new_y1 -= (new_y2 - h)
|
| 458 |
+
new_y2 = h
|
| 459 |
+
if new_y1 < 0:
|
| 460 |
+
new_y2 -= new_y1
|
| 461 |
+
new_y1 = 0
|
| 462 |
+
|
| 463 |
+
if w < new_x2 and h < new_y2:
|
| 464 |
+
over_x = new_x2 - w
|
| 465 |
+
over_y = new_y2 - h
|
| 466 |
+
over_min = min(over_x, over_y)
|
| 467 |
+
new_x2 -= over_min
|
| 468 |
+
new_y2 -= over_min
|
| 469 |
+
|
| 470 |
+
return [int(new_x1), int(new_y1), int(new_x2), int(new_y2)]
|
| 471 |
+
|
| 472 |
+
@staticmethod
|
| 473 |
+
def retargeting(delta_out, driving_exp, factor, idxes):
|
| 474 |
+
for idx in idxes:
|
| 475 |
+
# delta_out[0, idx] -= src_exp[0, idx] * factor
|
| 476 |
+
delta_out[0, idx] += driving_exp[0, idx] * factor
|
| 477 |
+
|
| 478 |
+
@staticmethod
|
| 479 |
+
def calc_face_region(square, dsize):
|
| 480 |
+
region = copy.deepcopy(square)
|
| 481 |
+
is_changed = False
|
| 482 |
+
if dsize[0] < region[2]:
|
| 483 |
+
region[2] = dsize[0]
|
| 484 |
+
is_changed = True
|
| 485 |
+
if dsize[1] < region[3]:
|
| 486 |
+
region[3] = dsize[1]
|
| 487 |
+
is_changed = True
|
| 488 |
+
|
| 489 |
+
return region, is_changed
|
| 490 |
+
|
| 491 |
+
@staticmethod
|
| 492 |
+
def expand_img(rgb_img, square):
|
| 493 |
+
crop_trans_m = create_transform_matrix(max(-square[0], 0), max(-square[1], 0), 1, 1)
|
| 494 |
+
new_img = cv2.warpAffine(rgb_img, crop_trans_m, (square[2] - square[0], square[3] - square[1]),
|
| 495 |
+
cv2.INTER_LINEAR)
|
| 496 |
+
return new_img
|
| 497 |
+
|
| 498 |
+
def prepare_src_image(self, img):
|
| 499 |
+
h, w = img.shape[:2]
|
| 500 |
+
input_shape = [256,256]
|
| 501 |
+
if h != input_shape[0] or w != input_shape[1]:
|
| 502 |
+
if 256 < h: interpolation = cv2.INTER_AREA
|
| 503 |
+
else: interpolation = cv2.INTER_LINEAR
|
| 504 |
+
x = cv2.resize(img, (input_shape[0], input_shape[1]), interpolation = interpolation)
|
| 505 |
+
else:
|
| 506 |
+
x = img.copy()
|
| 507 |
+
|
| 508 |
+
if x.ndim == 3:
|
| 509 |
+
x = x[np.newaxis].astype(np.float32) / 255. # HxWx3 -> 1xHxWx3, normalized to 0~1
|
| 510 |
+
elif x.ndim == 4:
|
| 511 |
+
x = x.astype(np.float32) / 255. # BxHxWx3, normalized to 0~1
|
| 512 |
+
else:
|
| 513 |
+
raise ValueError(f'img ndim should be 3 or 4: {x.ndim}')
|
| 514 |
+
x = np.clip(x, 0, 1) # clip to 0~1
|
| 515 |
+
x = torch.from_numpy(x).permute(0, 3, 1, 2) # 1xHxWx3 -> 1x3xHxW
|
| 516 |
+
x = x.to(self.device)
|
| 517 |
+
return x
|
| 518 |
+
|
| 519 |
+
def get_mask_img(self):
|
| 520 |
+
if self.mask_img is None:
|
| 521 |
+
self.mask_img = cv2.imread(MASK_TEMPLATES, cv2.IMREAD_COLOR)
|
| 522 |
+
return self.mask_img
|
| 523 |
+
|
| 524 |
+
def crop_face(self, img_rgb, crop_factor):
|
| 525 |
+
crop_region = self.detect_face(img_rgb, crop_factor)
|
| 526 |
+
face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb))
|
| 527 |
+
face_img = rgb_crop(img_rgb, face_region)
|
| 528 |
+
if is_changed: face_img = self.expand_img(face_img, crop_region)
|
| 529 |
+
return face_img
|
| 530 |
+
|
| 531 |
+
def prepare_source(self, source_image, crop_factor, is_video=False, tracking=False):
|
| 532 |
+
print("Prepare source...")
|
| 533 |
+
#source_image_np = (source_image * 255).byte().numpy()
|
| 534 |
+
# img_rgb = source_image_np[0]
|
| 535 |
+
|
| 536 |
+
psi_list = []
|
| 537 |
+
for img_rgb in source_image:
|
| 538 |
+
if tracking or len(psi_list) == 0:
|
| 539 |
+
crop_region = self.detect_face(img_rgb, crop_factor)
|
| 540 |
+
face_region, is_changed = self.calc_face_region(crop_region, get_rgb_size(img_rgb))
|
| 541 |
+
|
| 542 |
+
s_x = (face_region[2] - face_region[0]) / 512.
|
| 543 |
+
s_y = (face_region[3] - face_region[1]) / 512.
|
| 544 |
+
crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s_x, s_y)
|
| 545 |
+
mask_ori = cv2.warpAffine(self.get_mask_img(), crop_trans_m, get_rgb_size(img_rgb), cv2.INTER_LINEAR)
|
| 546 |
+
mask_ori = mask_ori.astype(np.float32) / 255.
|
| 547 |
+
|
| 548 |
+
if is_changed:
|
| 549 |
+
s = (crop_region[2] - crop_region[0]) / 512.
|
| 550 |
+
crop_trans_m = create_transform_matrix(crop_region[0], crop_region[1], s, s)
|
| 551 |
+
|
| 552 |
+
face_img = rgb_crop(img_rgb, face_region)
|
| 553 |
+
if is_changed: face_img = self.expand_img(face_img, crop_region)
|
| 554 |
+
i_s = self.prepare_src_image(face_img)
|
| 555 |
+
x_s_info = self.pipeline.get_kp_info(i_s)
|
| 556 |
+
f_s_user = self.pipeline.extract_feature_3d(i_s)
|
| 557 |
+
x_s_user = self.pipeline.transform_keypoint(x_s_info)
|
| 558 |
+
psi = PreparedSrcImg(img_rgb, crop_trans_m, x_s_info, f_s_user, x_s_user, mask_ori)
|
| 559 |
+
if is_video == False:
|
| 560 |
+
return psi
|
| 561 |
+
psi_list.append(psi)
|
| 562 |
+
|
| 563 |
+
return psi_list
|
| 564 |
+
|
| 565 |
+
def prepare_driving_video(self, face_images):
|
| 566 |
+
print("Prepare driving video...")
|
| 567 |
+
f_img_np = (face_images * 255).byte().numpy()
|
| 568 |
+
|
| 569 |
+
out_list = []
|
| 570 |
+
for f_img in f_img_np:
|
| 571 |
+
i_d = self.prepare_src_image(f_img)
|
| 572 |
+
d_info = self.pipeline.get_kp_info(i_d)
|
| 573 |
+
out_list.append(d_info)
|
| 574 |
+
|
| 575 |
+
return out_list
|
| 576 |
+
|
| 577 |
+
@staticmethod
|
| 578 |
+
def calc_fe(x_d_new, eyes, eyebrow, wink, pupil_x, pupil_y, mouth, eee, woo, smile,
|
| 579 |
+
rotate_pitch, rotate_yaw, rotate_roll):
|
| 580 |
+
|
| 581 |
+
x_d_new[0, 20, 1] += smile * -0.01
|
| 582 |
+
x_d_new[0, 14, 1] += smile * -0.02
|
| 583 |
+
x_d_new[0, 17, 1] += smile * 0.0065
|
| 584 |
+
x_d_new[0, 17, 2] += smile * 0.003
|
| 585 |
+
x_d_new[0, 13, 1] += smile * -0.00275
|
| 586 |
+
x_d_new[0, 16, 1] += smile * -0.00275
|
| 587 |
+
x_d_new[0, 3, 1] += smile * -0.0035
|
| 588 |
+
x_d_new[0, 7, 1] += smile * -0.0035
|
| 589 |
+
|
| 590 |
+
x_d_new[0, 19, 1] += mouth * 0.001
|
| 591 |
+
x_d_new[0, 19, 2] += mouth * 0.0001
|
| 592 |
+
x_d_new[0, 17, 1] += mouth * -0.0001
|
| 593 |
+
rotate_pitch -= mouth * 0.05
|
| 594 |
+
|
| 595 |
+
x_d_new[0, 20, 2] += eee * -0.001
|
| 596 |
+
x_d_new[0, 20, 1] += eee * -0.001
|
| 597 |
+
#x_d_new[0, 19, 1] += eee * 0.0006
|
| 598 |
+
x_d_new[0, 14, 1] += eee * -0.001
|
| 599 |
+
|
| 600 |
+
x_d_new[0, 14, 1] += woo * 0.001
|
| 601 |
+
x_d_new[0, 3, 1] += woo * -0.0005
|
| 602 |
+
x_d_new[0, 7, 1] += woo * -0.0005
|
| 603 |
+
x_d_new[0, 17, 2] += woo * -0.0005
|
| 604 |
+
|
| 605 |
+
x_d_new[0, 11, 1] += wink * 0.001
|
| 606 |
+
x_d_new[0, 13, 1] += wink * -0.0003
|
| 607 |
+
x_d_new[0, 17, 0] += wink * 0.0003
|
| 608 |
+
x_d_new[0, 17, 1] += wink * 0.0003
|
| 609 |
+
x_d_new[0, 3, 1] += wink * -0.0003
|
| 610 |
+
rotate_roll -= wink * 0.1
|
| 611 |
+
rotate_yaw -= wink * 0.1
|
| 612 |
+
|
| 613 |
+
if 0 < pupil_x:
|
| 614 |
+
x_d_new[0, 11, 0] += pupil_x * 0.0007
|
| 615 |
+
x_d_new[0, 15, 0] += pupil_x * 0.001
|
| 616 |
+
else:
|
| 617 |
+
x_d_new[0, 11, 0] += pupil_x * 0.001
|
| 618 |
+
x_d_new[0, 15, 0] += pupil_x * 0.0007
|
| 619 |
+
|
| 620 |
+
x_d_new[0, 11, 1] += pupil_y * -0.001
|
| 621 |
+
x_d_new[0, 15, 1] += pupil_y * -0.001
|
| 622 |
+
eyes -= pupil_y / 2.
|
| 623 |
+
|
| 624 |
+
x_d_new[0, 11, 1] += eyes * -0.001
|
| 625 |
+
x_d_new[0, 13, 1] += eyes * 0.0003
|
| 626 |
+
x_d_new[0, 15, 1] += eyes * -0.001
|
| 627 |
+
x_d_new[0, 16, 1] += eyes * 0.0003
|
| 628 |
+
x_d_new[0, 1, 1] += eyes * -0.00025
|
| 629 |
+
x_d_new[0, 2, 1] += eyes * 0.00025
|
| 630 |
+
|
| 631 |
+
if 0 < eyebrow:
|
| 632 |
+
x_d_new[0, 1, 1] += eyebrow * 0.001
|
| 633 |
+
x_d_new[0, 2, 1] += eyebrow * -0.001
|
| 634 |
+
else:
|
| 635 |
+
x_d_new[0, 1, 0] += eyebrow * -0.001
|
| 636 |
+
x_d_new[0, 2, 0] += eyebrow * 0.001
|
| 637 |
+
x_d_new[0, 1, 1] += eyebrow * 0.0003
|
| 638 |
+
x_d_new[0, 2, 1] += eyebrow * -0.0003
|
| 639 |
+
|
| 640 |
+
return torch.Tensor([rotate_pitch, rotate_yaw, rotate_roll])
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
class ExpressionSet:
|
| 644 |
+
def __init__(self, erst=None, es=None):
|
| 645 |
+
if es is not None:
|
| 646 |
+
self.e = copy.deepcopy(es.e) # [:, :, :]
|
| 647 |
+
self.r = copy.deepcopy(es.r) # [:]
|
| 648 |
+
self.s = copy.deepcopy(es.s)
|
| 649 |
+
self.t = copy.deepcopy(es.t)
|
| 650 |
+
elif erst is not None:
|
| 651 |
+
self.e = erst[0]
|
| 652 |
+
self.r = erst[1]
|
| 653 |
+
self.s = erst[2]
|
| 654 |
+
self.t = erst[3]
|
| 655 |
+
else:
|
| 656 |
+
self.e = torch.from_numpy(np.zeros((1, 21, 3))).float().to(self.get_device())
|
| 657 |
+
self.r = torch.Tensor([0, 0, 0])
|
| 658 |
+
self.s = 0
|
| 659 |
+
self.t = 0
|
| 660 |
+
|
| 661 |
+
def div(self, value):
|
| 662 |
+
self.e /= value
|
| 663 |
+
self.r /= value
|
| 664 |
+
self.s /= value
|
| 665 |
+
self.t /= value
|
| 666 |
+
|
| 667 |
+
def add(self, other):
|
| 668 |
+
self.e += other.e
|
| 669 |
+
self.r += other.r
|
| 670 |
+
self.s += other.s
|
| 671 |
+
self.t += other.t
|
| 672 |
+
|
| 673 |
+
def sub(self, other):
|
| 674 |
+
self.e -= other.e
|
| 675 |
+
self.r -= other.r
|
| 676 |
+
self.s -= other.s
|
| 677 |
+
self.t -= other.t
|
| 678 |
+
|
| 679 |
+
def mul(self, value):
|
| 680 |
+
self.e *= value
|
| 681 |
+
self.r *= value
|
| 682 |
+
self.s *= value
|
| 683 |
+
self.t *= value
|
| 684 |
+
|
| 685 |
+
@staticmethod
|
| 686 |
+
def get_device():
|
| 687 |
+
if torch.cuda.is_available():
|
| 688 |
+
return "cuda"
|
| 689 |
+
elif torch.backends.mps.is_available():
|
| 690 |
+
return "mps"
|
| 691 |
+
else:
|
| 692 |
+
return "cpu"
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def logging_time(original_fn):
|
| 696 |
+
def wrapper_fn(*args, **kwargs):
|
| 697 |
+
start_time = time.time()
|
| 698 |
+
result = original_fn(*args, **kwargs)
|
| 699 |
+
end_time = time.time()
|
| 700 |
+
print("WorkingTime[{}]: {} sec".format(original_fn.__name__, end_time - start_time))
|
| 701 |
+
return result
|
| 702 |
+
|
| 703 |
+
return wrapper_fn
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
def save_exp_data(file_name: str, save_exp: ExpressionSet = None):
|
| 707 |
+
if save_exp is None or not file_name:
|
| 708 |
+
return file_name
|
| 709 |
+
|
| 710 |
+
with open(os.path.join(EXP_OUTPUT_DIR, file_name + ".exp"), "wb") as f:
|
| 711 |
+
dill.dump(save_exp, f)
|
| 712 |
+
|
| 713 |
+
return file_name
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def load_exp_data(self, file_name, ratio):
|
| 717 |
+
file_list = [os.path.splitext(file)[0] for file in os.listdir(EXP_OUTPUT_DIR) if file.endswith('.exp')]
|
| 718 |
+
with open(os.path.join(EXP_OUTPUT_DIR, file_name + ".exp"), 'rb') as f:
|
| 719 |
+
es = dill.load(f)
|
| 720 |
+
es.mul(ratio)
|
| 721 |
+
return es
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
def handle_exp_data(code1, value1, code2, value2, code3, value3, code4, value4, code5, value5, add_exp=None):
|
| 725 |
+
if add_exp is None:
|
| 726 |
+
es = ExpressionSet()
|
| 727 |
+
else:
|
| 728 |
+
es = ExpressionSet(es=add_exp)
|
| 729 |
+
|
| 730 |
+
codes = [code1, code2, code3, code4, code5]
|
| 731 |
+
values = [value1, value2, value3, value4, value5]
|
| 732 |
+
for i in range(5):
|
| 733 |
+
idx = int(codes[i] / 10)
|
| 734 |
+
r = codes[i] % 10
|
| 735 |
+
es.e[0, idx, r] += values[i] * 0.001
|
| 736 |
+
|
| 737 |
+
return es
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
def print_exp_data(cut_noise, exp=None):
|
| 741 |
+
if exp is None:
|
| 742 |
+
return exp
|
| 743 |
+
|
| 744 |
+
cuted_list = []
|
| 745 |
+
e = exp.exp * 1000
|
| 746 |
+
for idx in range(21):
|
| 747 |
+
for r in range(3):
|
| 748 |
+
a = abs(e[0, idx, r])
|
| 749 |
+
if (cut_noise < a): cuted_list.append((a, e[0, idx, r], idx * 10 + r))
|
| 750 |
+
|
| 751 |
+
sorted_list = sorted(cuted_list, reverse=True, key=lambda item: item[0])
|
| 752 |
+
print(f"sorted_list: {[[item[2], round(float(item[1]), 1)] for item in sorted_list]}")
|
| 753 |
+
return exp
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
class Command:
|
| 757 |
+
def __init__(self,
|
| 758 |
+
es: ExpressionSet,
|
| 759 |
+
change,
|
| 760 |
+
keep):
|
| 761 |
+
self.es = es
|
| 762 |
+
self.change = change
|
| 763 |
+
self.keep = keep
|