Jhfhnrqgx-Gxeelqj-Vwxglr / inference.py
ASesYusuf1's picture
Update inference.py
8a49b0c verified
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
import argparse
import time
import logging
import librosa
import sys
import os
import glob
import torch
import torch.nn as nn
import numpy as np
import soundfile as sf
import spaces
import warnings
warnings.filterwarnings("ignore")
# Loglama ayarları
logging.basicConfig(level=logging.DEBUG, filename='utils.log', format='%(asctime)s - %(levelname)s - %(message)s')
# Colab kontrolü
try:
from google.colab import drive
IS_COLAB = True
except ImportError:
IS_COLAB = False
# i18n yer tutucu
class I18nAuto:
def __call__(self, message):
return message
def format(self, message, *args):
return message.format(*args)
i18n = I18nAuto()
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
from utils import demix, get_model_from_config, normalize_audio, denormalize_audio
from utils import prefer_target_instrument, apply_tta, load_start_checkpoint
def shorten_filename(filename, max_length=30):
base, ext = os.path.splitext(filename)
if len(base) <= max_length:
return filename
shortened = base[:15] + "..." + base[-10:] + ext
return shortened
def get_soundfile_subtype(pcm_type, is_float=False):
if pcm_type == 'FLOAT' or is_float:
return 'FLOAT'
subtype_map = {'PCM_16': 'PCM_16', 'PCM_24': 'PCM_24', 'FLOAT': 'FLOAT'}
return subtype_map.get(pcm_type, 'FLOAT')
def update_progress_html(progress_label, progress_percent):
progress_percent = min(max(round(progress_percent), 0), 100)
return f"""
<div id="custom-progress" style="margin-top: 10px;">
<div style="font-size: 1rem; color: #C0C0C0; margin-bottom: 5px;" id="progress-label">{progress_label}</div>
<div style="width: 100%; background-color: #444; border-radius: 5px; overflow: hidden;">
<div id="progress-bar" style="width: {progress_percent}%; height: 20px; background-color: #6e8efb; transition: width 0.3s; max-width: 100%;"></div>
</div>
</div>
"""
def run_folder(model, args, config, device, verbose: bool = False, progress=None):
start_time = time.time()
model.eval()
mixture_paths = sorted(glob.glob(os.path.join(args.input_folder, '*.*')))
sample_rate = getattr(config.audio, 'sample_rate', 44100)
logging.info(f"Total files found: {len(mixture_paths)} with sample rate: {sample_rate}")
print(i18n("total_files_found").format(len(mixture_paths), sample_rate))
instruments = prefer_target_instrument(config)[:]
store_dir = args.store_dir
os.makedirs(store_dir, exist_ok=True)
total_files = len(mixture_paths)
processed_files = 0
base_progress_per_file = 100 / total_files if total_files > 0 else 100
for path in mixture_paths:
try:
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
logging.info(f"Loaded audio: {path}, shape: {mix.shape}")
print(i18n("loaded_audio").format(path, mix.shape))
processed_files += 1
base_progress = round((processed_files - 1) * base_progress_per_file)
if progress is not None and callable(getattr(progress, '__call__', None)):
progress(base_progress / 100, desc=i18n("processing_file").format(processed_files, total_files))
update_progress_html(i18n("processing_file").format(processed_files, total_files), base_progress)
mix_orig = mix.copy()
if 'normalize' in config.inference and config.inference.get('normalize', False):
mix, norm_params = normalize_audio(mix)
waveforms_orig = demix(
config, model, mix, device, model_type=args.model_type, pbar=False,
progress=lambda p, desc: progress((base_progress + p * 50) / 100, desc=desc) if progress else None
)
if args.use_tta:
waveforms_orig = apply_tta(
config, model, mix, waveforms_orig, device, args.model_type,
progress=lambda p, desc: progress((base_progress + 50 + p * 20) / 100, desc=desc) if progress else None
)
if args.demud_phaseremix_inst:
logging.info(f"Demudding track: {path}")
print(i18n("demudding_track").format(path))
instr = 'vocals' if 'vocals' in instruments else instruments[0]
instruments.append('instrumental_phaseremix')
if 'instrumental' not in instruments and 'Instrumental' not in instruments:
mix_modified = mix_orig - 2 * waveforms_orig[instr]
mix_modified_ = mix_modified.copy()
waveforms_modified = demix(
config, model, mix_modified, device, model_type=args.model_type, pbar=False,
progress=lambda p, desc: progress((base_progress + 70 + p * 15) / 100, desc=desc) if progress else None
)
if args.use_tta:
waveforms_modified = apply_tta(
config, model, mix_modified, waveforms_modified, device, args.model_type,
progress=lambda p, desc: progress((base_progress + 85 + p * 10) / 100, desc=desc) if progress else None
)
waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr]
else:
mix_modified = 2 * waveforms_orig[instr] - mix_orig
mix_modified_ = mix_modified.copy()
waveforms_modified = demix(
config, model, mix_modified, device, model_type=args.model_type, pbar=False,
progress=lambda p, desc: progress((base_progress + 70 + p * 15) / 100, desc=desc) if progress else None
)
if args.use_tta:
waveforms_modified = apply_tta(
config, model, mix_modified, waveforms_orig, device, args.model_type,
progress=lambda p, desc: progress((base_progress + 85 + p * 10) / 100, desc=desc) if progress else None
)
waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr]
if args.extract_instrumental:
instr = 'vocals' if 'vocals' in instruments else instruments[0]
waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr]
if 'instrumental' not in instruments:
instruments.append('instrumental')
for i, instr in enumerate(instruments):
estimates = waveforms_orig[instr]
if 'normalize' in config.inference and config.inference.get('normalize', False):
estimates = denormalize_audio(estimates, norm_params)
is_float = getattr(args, 'export_format', '').startswith('wav FLOAT')
codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
subtype = get_soundfile_subtype(args.pcm_type, is_float=is_float)
shortened_filename = shorten_filename(os.path.basename(path))
output_filename = f"{shortened_filename}_{instr}.{codec}"
output_path = os.path.join(store_dir, output_filename)
sf.write(output_path, estimates.T, sr, subtype=subtype)
save_progress = round(base_progress + 95 + (i / len(instruments)) * 5)
if progress is not None and callable(getattr('progress', '__call__', None)):
progress(save_progress / 100, desc=i18n("saving_output").format(instr, processed_files, total_files))
update_progress_html(i18n("saving_output").format(instr, processed_files, total_files), save_progress)
file_progress = round(processed_files * base_progress_per_file)
if progress is not None and callable(getattr(progress, '__call__', None)):
progress(file_progress / 100, desc=i18n("completed_file").format(processed_files, total_files))
update_progress_html(i18n("completed_file").format(processed_files, total_files), file_progress)
except Exception as e:
logging.error(f"Cannot read track: {path}. Error: {str(e)}")
print(i18n("cannot_read_track").format(path))
print(i18n("error_message").format(str(e)))
continue
elapsed_time = time.time() - start_time
logging.info(f"Processing time: {elapsed_time:.2f} seconds")
print(i18n("elapsed_time").format(elapsed_time))
if progress is not None and callable(getattr(progress, '__call__', None)):
progress(1.0, desc=i18n("processing_complete"))
update_progress_html(i18n("processing_complete"), 100)
@spaces.GPU
def proc_folder(args=None, progress=None):
try:
parser = argparse.ArgumentParser(description=i18n("proc_folder_description"))
parser.add_argument("--model_type", type=str, default='melod_band_roformer', help=i18n("model_type_help"))
parser.add_argument("--config_path", type=str, required=True, help=i18n("config_path_help"))
parser.add_argument("--start_check_point", type=str, required=True, help=i18n("start_checkpoint_help"))
parser.add_argument("--input_folder", type=str, required=True, help=i18n("input_folder_help"))
parser.add_argument("--store_dir", type=str, required=True, help=i18n("store_dir_help"))
parser.add_argument("--chunk_size", type=int, default=352800, help=i18n("chunk_size_help"))
parser.add_argument("--overlap", type=int, default=2, help=i18n("overlap_help"))
parser.add_argument("--export_format", type=str, default='wav FLOAT', choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'], help=i18n("export_format_help"))
parser.add_argument("--demud_phaseremix_inst", action='store_true', help=i18n("demud_phaseremix_help"))
parser.add_argument("--extract_instrumental", action='store_true', help=i18n("extract_instrumental_help"))
parser.add_argument("--use_tta", action='store_true', help=i18n("use_tta_help"))
parser.add_argument("--flac_file", action='store_true', help=i18n("flac_file_help"))
parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24', help=i18n("pcm_type_help"))
parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help=i18n("device_ids_help"))
parser.add_argument("--force_cpu", action='store_true', help=i18n("force_cpu_help"))
parser.add_argument("--lora_checkpoint", type=str, default='', help=i18n("lora_checkpoint_help"))
args = parser.parse_args(args if args else [])
except Exception as e:
logging.error(f"Argument parsing failed: {str(e)}")
raise ValueError(f"Invalid command-line arguments: {str(e)}")
device = "cpu"
if args.force_cpu:
logging.info("Forced to use CPU")
elif torch.cuda.is_available():
logging.info("CUDA available")
print(i18n("cuda_available"))
device = f'cuda:{args.device_ids[0]}'
elif torch.backends.mps.is_available():
device = "mps"
logging.info(f"Using device: {device}")
print(i18n("using_device").format(device))
model_load_start_time = time.time()
torch.backends.cudnn.benchmark = True
try:
model, config = get_model_from_config(args.model_type, args.config_path)
except Exception as e:
logging.error(f"Failed to load model: {str(e)}")
raise
if args.start_check_point:
try:
load_start_checkpoint(args, model, type_='inference')
except Exception as e:
logging.error(f"Failed to load checkpoint: {str(e)}")
raise
logging.info(f"Instruments: {config.training.instruments}")
print(i18n("instruments_print").format(config.training.instruments))
if len(args.device_ids) > 1 and not args.force_cpu:
model = nn.DataParallel(model, device_ids=args.device_ids)
logging.info(f"Using DataParallel with devices: {args.device_ids}")
model = model.to(device)
elapsed_time = time.time() - model_load_start_time
logging.info(f"Model load time: {elapsed_time:.2f} seconds")
print(i18n("model_load_time").format(elapsed_time))
run_folder(model, args, config, device, verbose=False, progress=progress)
return "Processing completed"
if __name__ == "__main__":
try:
proc_folder(None)
except Exception as e:
logging.error(f"Main execution failed: {str(e)}")
raise