jhj0517
commited on
Commit
·
46fa2af
1
Parent(s):
7785d3b
Apply model type enum
Browse files- app.py +2 -0
- modules/live_portrait/live_portrait_inferencer.py +54 -11
app.py
CHANGED
|
@@ -22,6 +22,8 @@ class App:
|
|
| 22 |
@staticmethod
|
| 23 |
def create_parameters():
|
| 24 |
return [
|
|
|
|
|
|
|
| 25 |
gr.Slider(label=_("Rotate Pitch"), minimum=-20, maximum=20, step=0.5, value=0),
|
| 26 |
gr.Slider(label=_("Rotate Yaw"), minimum=-20, maximum=20, step=0.5, value=0),
|
| 27 |
gr.Slider(label=_("Rotate Roll"), minimum=-20, maximum=20, step=0.5, value=0),
|
|
|
|
| 22 |
@staticmethod
|
| 23 |
def create_parameters():
|
| 24 |
return [
|
| 25 |
+
gr.Dropdown(label=_("Model Type"),
|
| 26 |
+
choices=[item.value for item in ModelType], value=ModelType.HUMAN.value),
|
| 27 |
gr.Slider(label=_("Rotate Pitch"), minimum=-20, maximum=20, step=0.5, value=0),
|
| 28 |
gr.Slider(label=_("Rotate Yaw"), minimum=-20, maximum=20, step=0.5, value=0),
|
| 29 |
gr.Slider(label=_("Rotate Roll"), minimum=-20, maximum=20, step=0.5, value=0),
|
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import logging
|
|
|
|
| 2 |
import cv2
|
| 3 |
import time
|
| 4 |
import copy
|
|
@@ -6,7 +7,10 @@ import dill
|
|
| 6 |
from ultralytics import YOLO
|
| 7 |
import safetensors.torch
|
| 8 |
import gradio as gr
|
|
|
|
| 9 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
|
|
|
|
|
|
| 10 |
|
| 11 |
from modules.utils.paths import *
|
| 12 |
from modules.utils.image_helper import *
|
|
@@ -14,6 +18,7 @@ from modules.live_portrait.model_downloader import *
|
|
| 14 |
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
|
| 15 |
from modules.utils.camera import get_rotation_matrix
|
| 16 |
from modules.utils.helper import load_yaml
|
|
|
|
| 17 |
from modules.config.inference_config import InferenceConfig
|
| 18 |
from modules.live_portrait.spade_generator import SPADEDecoder
|
| 19 |
from modules.live_portrait.warping_network import WarpingNetwork
|
|
@@ -27,6 +32,7 @@ class LivePortraitInferencer:
|
|
| 27 |
model_dir: str = MODELS_DIR,
|
| 28 |
output_dir: str = OUTPUTS_DIR):
|
| 29 |
self.model_dir = model_dir
|
|
|
|
| 30 |
self.output_dir = output_dir
|
| 31 |
self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
|
| 32 |
|
|
@@ -38,6 +44,7 @@ class LivePortraitInferencer:
|
|
| 38 |
self.pipeline = None
|
| 39 |
self.detect_model = None
|
| 40 |
self.device = self.get_device()
|
|
|
|
| 41 |
|
| 42 |
self.mask_img = None
|
| 43 |
self.temp_img_idx = 0
|
|
@@ -52,8 +59,22 @@ class LivePortraitInferencer:
|
|
| 52 |
self.d_info = None
|
| 53 |
|
| 54 |
def load_models(self,
|
|
|
|
| 55 |
progress=gr.Progress()):
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
total_models_num = 5
|
| 59 |
progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
|
|
@@ -61,7 +82,7 @@ class LivePortraitInferencer:
|
|
| 61 |
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
| 62 |
self.appearance_feature_extractor = self.load_safe_tensor(
|
| 63 |
self.appearance_feature_extractor,
|
| 64 |
-
os.path.join(
|
| 65 |
)
|
| 66 |
|
| 67 |
progress(1/total_models_num, desc="Loading Motion Extractor model...")
|
|
@@ -69,7 +90,7 @@ class LivePortraitInferencer:
|
|
| 69 |
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
| 70 |
self.motion_extractor = self.load_safe_tensor(
|
| 71 |
self.motion_extractor,
|
| 72 |
-
os.path.join(
|
| 73 |
)
|
| 74 |
|
| 75 |
progress(2/total_models_num, desc="Loading Warping Module model...")
|
|
@@ -77,7 +98,7 @@ class LivePortraitInferencer:
|
|
| 77 |
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
| 78 |
self.warping_module = self.load_safe_tensor(
|
| 79 |
self.warping_module,
|
| 80 |
-
os.path.join(
|
| 81 |
)
|
| 82 |
|
| 83 |
progress(3/total_models_num, desc="Loading Spade generator model...")
|
|
@@ -85,7 +106,7 @@ class LivePortraitInferencer:
|
|
| 85 |
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
| 86 |
self.spade_generator = self.load_safe_tensor(
|
| 87 |
self.spade_generator,
|
| 88 |
-
os.path.join(
|
| 89 |
)
|
| 90 |
|
| 91 |
progress(4/total_models_num, desc="Loading Stitcher model...")
|
|
@@ -93,7 +114,7 @@ class LivePortraitInferencer:
|
|
| 93 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
|
| 94 |
self.stitching_retargeting_module = self.load_safe_tensor(
|
| 95 |
self.stitching_retargeting_module,
|
| 96 |
-
os.path.join(
|
| 97 |
True
|
| 98 |
)
|
| 99 |
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
|
@@ -111,6 +132,7 @@ class LivePortraitInferencer:
|
|
| 111 |
self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"]).to(self.device)
|
| 112 |
|
| 113 |
def edit_expression(self,
|
|
|
|
| 114 |
rotate_pitch=0,
|
| 115 |
rotate_yaw=0,
|
| 116 |
rotate_roll=0,
|
|
@@ -131,8 +153,15 @@ class LivePortraitInferencer:
|
|
| 131 |
sample_image=None,
|
| 132 |
motion_link=None,
|
| 133 |
add_exp=None):
|
| 134 |
-
if
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
try:
|
| 138 |
rotate_yaw = -rotate_yaw
|
|
@@ -330,14 +359,27 @@ class LivePortraitInferencer:
|
|
| 330 |
return out_imgs
|
| 331 |
|
| 332 |
def download_if_no_models(self,
|
| 333 |
-
|
|
|
|
| 334 |
progress(0, desc="Downloading models...")
|
| 335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
if model_url.endswith(".pt"):
|
| 337 |
model_name += ".pt"
|
|
|
|
|
|
|
| 338 |
else:
|
| 339 |
model_name += ".safetensors"
|
| 340 |
-
model_path = os.path.join(
|
| 341 |
if not os.path.exists(model_path):
|
| 342 |
download_model(model_path, model_url)
|
| 343 |
|
|
@@ -779,3 +821,4 @@ class Command:
|
|
| 779 |
self.es = es
|
| 780 |
self.change = change
|
| 781 |
self.keep = keep
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
import os
|
| 3 |
import cv2
|
| 4 |
import time
|
| 5 |
import copy
|
|
|
|
| 7 |
from ultralytics import YOLO
|
| 8 |
import safetensors.torch
|
| 9 |
import gradio as gr
|
| 10 |
+
from gradio_i18n import Translate, gettext as _
|
| 11 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
| 12 |
+
from enum import Enum
|
| 13 |
+
from typing import Union
|
| 14 |
|
| 15 |
from modules.utils.paths import *
|
| 16 |
from modules.utils.image_helper import *
|
|
|
|
| 18 |
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
|
| 19 |
from modules.utils.camera import get_rotation_matrix
|
| 20 |
from modules.utils.helper import load_yaml
|
| 21 |
+
from modules.utils.constants import *
|
| 22 |
from modules.config.inference_config import InferenceConfig
|
| 23 |
from modules.live_portrait.spade_generator import SPADEDecoder
|
| 24 |
from modules.live_portrait.warping_network import WarpingNetwork
|
|
|
|
| 32 |
model_dir: str = MODELS_DIR,
|
| 33 |
output_dir: str = OUTPUTS_DIR):
|
| 34 |
self.model_dir = model_dir
|
| 35 |
+
os.makedirs(os.path.join(self.model_dir, "animal"), exist_ok=True)
|
| 36 |
self.output_dir = output_dir
|
| 37 |
self.model_config = load_yaml(MODEL_CONFIG)["model_params"]
|
| 38 |
|
|
|
|
| 44 |
self.pipeline = None
|
| 45 |
self.detect_model = None
|
| 46 |
self.device = self.get_device()
|
| 47 |
+
self.model_type = ModelType.HUMAN.value
|
| 48 |
|
| 49 |
self.mask_img = None
|
| 50 |
self.temp_img_idx = 0
|
|
|
|
| 59 |
self.d_info = None
|
| 60 |
|
| 61 |
def load_models(self,
|
| 62 |
+
model_type: str = ModelType.HUMAN.value,
|
| 63 |
progress=gr.Progress()):
|
| 64 |
+
if isinstance(model_type, ModelType):
|
| 65 |
+
model_type = model_type.value
|
| 66 |
+
if model_type not in [mode.value for mode in ModelType]:
|
| 67 |
+
model_type = ModelType.HUMAN.value
|
| 68 |
+
|
| 69 |
+
self.model_type = model_type
|
| 70 |
+
if model_type == ModelType.ANIMAL.value:
|
| 71 |
+
model_dir = os.path.join(self.model_dir, "animal")
|
| 72 |
+
else:
|
| 73 |
+
model_dir = self.model_dir
|
| 74 |
+
|
| 75 |
+
self.download_if_no_models(
|
| 76 |
+
model_type=model_type
|
| 77 |
+
)
|
| 78 |
|
| 79 |
total_models_num = 5
|
| 80 |
progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
|
|
|
|
| 82 |
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
| 83 |
self.appearance_feature_extractor = self.load_safe_tensor(
|
| 84 |
self.appearance_feature_extractor,
|
| 85 |
+
os.path.join(model_dir, "appearance_feature_extractor.safetensors")
|
| 86 |
)
|
| 87 |
|
| 88 |
progress(1/total_models_num, desc="Loading Motion Extractor model...")
|
|
|
|
| 90 |
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
| 91 |
self.motion_extractor = self.load_safe_tensor(
|
| 92 |
self.motion_extractor,
|
| 93 |
+
os.path.join(model_dir, "motion_extractor.safetensors")
|
| 94 |
)
|
| 95 |
|
| 96 |
progress(2/total_models_num, desc="Loading Warping Module model...")
|
|
|
|
| 98 |
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
| 99 |
self.warping_module = self.load_safe_tensor(
|
| 100 |
self.warping_module,
|
| 101 |
+
os.path.join(model_dir, "warping_module.safetensors")
|
| 102 |
)
|
| 103 |
|
| 104 |
progress(3/total_models_num, desc="Loading Spade generator model...")
|
|
|
|
| 106 |
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
| 107 |
self.spade_generator = self.load_safe_tensor(
|
| 108 |
self.spade_generator,
|
| 109 |
+
os.path.join(model_dir, "spade_generator.safetensors")
|
| 110 |
)
|
| 111 |
|
| 112 |
progress(4/total_models_num, desc="Loading Stitcher model...")
|
|
|
|
| 114 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
|
| 115 |
self.stitching_retargeting_module = self.load_safe_tensor(
|
| 116 |
self.stitching_retargeting_module,
|
| 117 |
+
os.path.join(model_dir, "stitching_retargeting_module.safetensors"),
|
| 118 |
True
|
| 119 |
)
|
| 120 |
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
|
|
|
| 132 |
self.detect_model = YOLO(MODEL_PATHS["face_yolov8n"]).to(self.device)
|
| 133 |
|
| 134 |
def edit_expression(self,
|
| 135 |
+
model_type: str = ModelType.HUMAN.value,
|
| 136 |
rotate_pitch=0,
|
| 137 |
rotate_yaw=0,
|
| 138 |
rotate_roll=0,
|
|
|
|
| 153 |
sample_image=None,
|
| 154 |
motion_link=None,
|
| 155 |
add_exp=None):
|
| 156 |
+
if isinstance(model_type, ModelType):
|
| 157 |
+
model_type = model_type.value
|
| 158 |
+
if model_type not in [mode.value for mode in ModelType]:
|
| 159 |
+
model_type = ModelType.HUMAN
|
| 160 |
+
|
| 161 |
+
if self.pipeline is None or model_type != self.model_type:
|
| 162 |
+
self.load_models(
|
| 163 |
+
model_type=model_type
|
| 164 |
+
)
|
| 165 |
|
| 166 |
try:
|
| 167 |
rotate_yaw = -rotate_yaw
|
|
|
|
| 359 |
return out_imgs
|
| 360 |
|
| 361 |
def download_if_no_models(self,
|
| 362 |
+
model_type: str = ModelType.HUMAN.value,
|
| 363 |
+
progress=gr.Progress(), ):
|
| 364 |
progress(0, desc="Downloading models...")
|
| 365 |
+
|
| 366 |
+
if isinstance(model_type, ModelType):
|
| 367 |
+
model_type = model_type.value
|
| 368 |
+
if model_type == ModelType.ANIMAL.value:
|
| 369 |
+
models_urls_dic = MODELS_ANIMAL_URL
|
| 370 |
+
model_dir = os.path.join(self.model_dir, "animal")
|
| 371 |
+
else:
|
| 372 |
+
models_urls_dic = MODELS_URL
|
| 373 |
+
model_dir = self.model_dir
|
| 374 |
+
|
| 375 |
+
for model_name, model_url in models_urls_dic.items():
|
| 376 |
if model_url.endswith(".pt"):
|
| 377 |
model_name += ".pt"
|
| 378 |
+
# Exception for face_yolov8n.pt
|
| 379 |
+
model_dir = self.model_dir
|
| 380 |
else:
|
| 381 |
model_name += ".safetensors"
|
| 382 |
+
model_path = os.path.join(model_dir, model_name)
|
| 383 |
if not os.path.exists(model_path):
|
| 384 |
download_model(model_path, model_url)
|
| 385 |
|
|
|
|
| 821 |
self.es = es
|
| 822 |
self.change = change
|
| 823 |
self.keep = keep
|
| 824 |
+
|