Spaces:
Runtime error
Runtime error
| # coding: utf-8 | |
| """ | |
| Benchmark the inference speed of each module in LivePortrait. | |
| TODO: heavy GPT style, need to refactor | |
| """ | |
| import yaml | |
| import torch | |
| import time | |
| import numpy as np | |
| from src.utils.helper import load_model, concat_feat | |
| from src.config.inference_config import InferenceConfig | |
| def initialize_inputs(batch_size=1): | |
| """ | |
| Generate random input tensors and move them to GPU | |
| """ | |
| feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half() | |
| kp_source = torch.randn(batch_size, 21, 3).cuda().half() | |
| kp_driving = torch.randn(batch_size, 21, 3).cuda().half() | |
| source_image = torch.randn(batch_size, 3, 256, 256).cuda().half() | |
| generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half() | |
| eye_close_ratio = torch.randn(batch_size, 3).cuda().half() | |
| lip_close_ratio = torch.randn(batch_size, 2).cuda().half() | |
| feat_stitching = concat_feat(kp_source, kp_driving).half() | |
| feat_eye = concat_feat(kp_source, eye_close_ratio).half() | |
| feat_lip = concat_feat(kp_source, lip_close_ratio).half() | |
| inputs = { | |
| 'feature_3d': feature_3d, | |
| 'kp_source': kp_source, | |
| 'kp_driving': kp_driving, | |
| 'source_image': source_image, | |
| 'generator_input': generator_input, | |
| 'feat_stitching': feat_stitching, | |
| 'feat_eye': feat_eye, | |
| 'feat_lip': feat_lip | |
| } | |
| return inputs | |
| def load_and_compile_models(cfg, model_config): | |
| """ | |
| Load and compile models for inference | |
| """ | |
| appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor') | |
| motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor') | |
| warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module') | |
| spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator') | |
| stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module') | |
| models_with_params = [ | |
| ('Appearance Feature Extractor', appearance_feature_extractor), | |
| ('Motion Extractor', motion_extractor), | |
| ('Warping Network', warping_module), | |
| ('SPADE Decoder', spade_generator) | |
| ] | |
| compiled_models = {} | |
| for name, model in models_with_params: | |
| model = model.half() | |
| model = torch.compile(model, mode='max-autotune') # Optimize for inference | |
| model.eval() # Switch to evaluation mode | |
| compiled_models[name] = model | |
| retargeting_models = ['stitching', 'eye', 'lip'] | |
| for retarget in retargeting_models: | |
| module = stitching_retargeting_module[retarget].half() | |
| module = torch.compile(module, mode='max-autotune') # Optimize for inference | |
| module.eval() # Switch to evaluation mode | |
| stitching_retargeting_module[retarget] = module | |
| return compiled_models, stitching_retargeting_module | |
| def warm_up_models(compiled_models, stitching_retargeting_module, inputs): | |
| """ | |
| Warm up models to prepare them for benchmarking | |
| """ | |
| print("Warm up start!") | |
| with torch.no_grad(): | |
| for _ in range(10): | |
| compiled_models['Appearance Feature Extractor'](inputs['source_image']) | |
| compiled_models['Motion Extractor'](inputs['source_image']) | |
| compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) | |
| compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required | |
| stitching_retargeting_module['stitching'](inputs['feat_stitching']) | |
| stitching_retargeting_module['eye'](inputs['feat_eye']) | |
| stitching_retargeting_module['lip'](inputs['feat_lip']) | |
| print("Warm up end!") | |
| def measure_inference_times(compiled_models, stitching_retargeting_module, inputs): | |
| """ | |
| Measure inference times for each model | |
| """ | |
| times = {name: [] for name in compiled_models.keys()} | |
| times['Retargeting Models'] = [] | |
| overall_times = [] | |
| with torch.no_grad(): | |
| for _ in range(100): | |
| torch.cuda.synchronize() | |
| overall_start = time.time() | |
| start = time.time() | |
| compiled_models['Appearance Feature Extractor'](inputs['source_image']) | |
| torch.cuda.synchronize() | |
| times['Appearance Feature Extractor'].append(time.time() - start) | |
| start = time.time() | |
| compiled_models['Motion Extractor'](inputs['source_image']) | |
| torch.cuda.synchronize() | |
| times['Motion Extractor'].append(time.time() - start) | |
| start = time.time() | |
| compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source']) | |
| torch.cuda.synchronize() | |
| times['Warping Network'].append(time.time() - start) | |
| start = time.time() | |
| compiled_models['SPADE Decoder'](inputs['generator_input']) # Adjust input as required | |
| torch.cuda.synchronize() | |
| times['SPADE Decoder'].append(time.time() - start) | |
| start = time.time() | |
| stitching_retargeting_module['stitching'](inputs['feat_stitching']) | |
| stitching_retargeting_module['eye'](inputs['feat_eye']) | |
| stitching_retargeting_module['lip'](inputs['feat_lip']) | |
| torch.cuda.synchronize() | |
| times['Retargeting Models'].append(time.time() - start) | |
| overall_times.append(time.time() - overall_start) | |
| return times, overall_times | |
| def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times): | |
| """ | |
| Print benchmark results with average and standard deviation of inference times | |
| """ | |
| average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()} | |
| std_times = {name: np.std(times[name]) * 1000 for name in times.keys()} | |
| for name, model in compiled_models.items(): | |
| num_params = sum(p.numel() for p in model.parameters()) | |
| num_params_in_millions = num_params / 1e6 | |
| print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M") | |
| for index, retarget in enumerate(retargeting_models): | |
| num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters()) | |
| num_params_in_millions = num_params / 1e6 | |
| print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M") | |
| for name, avg_time in average_times.items(): | |
| std_time = std_times[name] | |
| print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)") | |
| def main(): | |
| """ | |
| Main function to benchmark speed and model parameters | |
| """ | |
| # Sample input tensors | |
| inputs = initialize_inputs() | |
| # Load configuration | |
| cfg = InferenceConfig(device_id=0) | |
| model_config_path = cfg.models_config | |
| with open(model_config_path, 'r') as file: | |
| model_config = yaml.safe_load(file) | |
| # Load and compile models | |
| compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config) | |
| # Warm up models | |
| warm_up_models(compiled_models, stitching_retargeting_module, inputs) | |
| # Measure inference times | |
| times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs) | |
| # Print benchmark results | |
| print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times) | |
| if __name__ == "__main__": | |
| main() | |