# coding: utf-8 __author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' import argparse import time import logging import librosa import sys import os import glob import torch import torch.nn as nn import numpy as np import soundfile as sf import spaces import warnings warnings.filterwarnings("ignore") # Loglama ayarları logging.basicConfig(level=logging.DEBUG, filename='utils.log', format='%(asctime)s - %(levelname)s - %(message)s') # Colab kontrolü try: from google.colab import drive IS_COLAB = True except ImportError: IS_COLAB = False # i18n yer tutucu class I18nAuto: def __call__(self, message): return message def format(self, message, *args): return message.format(*args) i18n = I18nAuto() current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) from utils import demix, get_model_from_config, normalize_audio, denormalize_audio from utils import prefer_target_instrument, apply_tta, load_start_checkpoint def shorten_filename(filename, max_length=30): base, ext = os.path.splitext(filename) if len(base) <= max_length: return filename shortened = base[:15] + "..." + base[-10:] + ext return shortened def get_soundfile_subtype(pcm_type, is_float=False): if pcm_type == 'FLOAT' or is_float: return 'FLOAT' subtype_map = {'PCM_16': 'PCM_16', 'PCM_24': 'PCM_24', 'FLOAT': 'FLOAT'} return subtype_map.get(pcm_type, 'FLOAT') def update_progress_html(progress_label, progress_percent): progress_percent = min(max(round(progress_percent), 0), 100) return f"""
{progress_label}
""" def run_folder(model, args, config, device, verbose: bool = False, progress=None): start_time = time.time() model.eval() mixture_paths = sorted(glob.glob(os.path.join(args.input_folder, '*.*'))) sample_rate = getattr(config.audio, 'sample_rate', 44100) logging.info(f"Total files found: {len(mixture_paths)} with sample rate: {sample_rate}") print(i18n("total_files_found").format(len(mixture_paths), sample_rate)) instruments = prefer_target_instrument(config)[:] store_dir = args.store_dir os.makedirs(store_dir, exist_ok=True) total_files = len(mixture_paths) processed_files = 0 base_progress_per_file = 100 / total_files if total_files > 0 else 100 for path in mixture_paths: try: mix, sr = librosa.load(path, sr=sample_rate, mono=False) logging.info(f"Loaded audio: {path}, shape: {mix.shape}") print(i18n("loaded_audio").format(path, mix.shape)) processed_files += 1 base_progress = round((processed_files - 1) * base_progress_per_file) if progress is not None and callable(getattr(progress, '__call__', None)): progress(base_progress / 100, desc=i18n("processing_file").format(processed_files, total_files)) update_progress_html(i18n("processing_file").format(processed_files, total_files), base_progress) mix_orig = mix.copy() if 'normalize' in config.inference and config.inference.get('normalize', False): mix, norm_params = normalize_audio(mix) waveforms_orig = demix( config, model, mix, device, model_type=args.model_type, pbar=False, progress=lambda p, desc: progress((base_progress + p * 50) / 100, desc=desc) if progress else None ) if args.use_tta: waveforms_orig = apply_tta( config, model, mix, waveforms_orig, device, args.model_type, progress=lambda p, desc: progress((base_progress + 50 + p * 20) / 100, desc=desc) if progress else None ) if args.demud_phaseremix_inst: logging.info(f"Demudding track: {path}") print(i18n("demudding_track").format(path)) instr = 'vocals' if 'vocals' in instruments else instruments[0] instruments.append('instrumental_phaseremix') if 'instrumental' not in instruments and 'Instrumental' not in instruments: mix_modified = mix_orig - 2 * waveforms_orig[instr] mix_modified_ = mix_modified.copy() waveforms_modified = demix( config, model, mix_modified, device, model_type=args.model_type, pbar=False, progress=lambda p, desc: progress((base_progress + 70 + p * 15) / 100, desc=desc) if progress else None ) if args.use_tta: waveforms_modified = apply_tta( config, model, mix_modified, waveforms_modified, device, args.model_type, progress=lambda p, desc: progress((base_progress + 85 + p * 10) / 100, desc=desc) if progress else None ) waveforms_orig['instrumental_phaseremix'] = mix_orig + waveforms_modified[instr] else: mix_modified = 2 * waveforms_orig[instr] - mix_orig mix_modified_ = mix_modified.copy() waveforms_modified = demix( config, model, mix_modified, device, model_type=args.model_type, pbar=False, progress=lambda p, desc: progress((base_progress + 70 + p * 15) / 100, desc=desc) if progress else None ) if args.use_tta: waveforms_modified = apply_tta( config, model, mix_modified, waveforms_orig, device, args.model_type, progress=lambda p, desc: progress((base_progress + 85 + p * 10) / 100, desc=desc) if progress else None ) waveforms_orig['instrumental_phaseremix'] = mix_orig + mix_modified_ - waveforms_modified[instr] if args.extract_instrumental: instr = 'vocals' if 'vocals' in instruments else instruments[0] waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr] if 'instrumental' not in instruments: instruments.append('instrumental') for i, instr in enumerate(instruments): estimates = waveforms_orig[instr] if 'normalize' in config.inference and config.inference.get('normalize', False): estimates = denormalize_audio(estimates, norm_params) is_float = getattr(args, 'export_format', '').startswith('wav FLOAT') codec = 'flac' if getattr(args, 'flac_file', False) else 'wav' subtype = get_soundfile_subtype(args.pcm_type, is_float=is_float) shortened_filename = shorten_filename(os.path.basename(path)) output_filename = f"{shortened_filename}_{instr}.{codec}" output_path = os.path.join(store_dir, output_filename) sf.write(output_path, estimates.T, sr, subtype=subtype) save_progress = round(base_progress + 95 + (i / len(instruments)) * 5) if progress is not None and callable(getattr('progress', '__call__', None)): progress(save_progress / 100, desc=i18n("saving_output").format(instr, processed_files, total_files)) update_progress_html(i18n("saving_output").format(instr, processed_files, total_files), save_progress) file_progress = round(processed_files * base_progress_per_file) if progress is not None and callable(getattr(progress, '__call__', None)): progress(file_progress / 100, desc=i18n("completed_file").format(processed_files, total_files)) update_progress_html(i18n("completed_file").format(processed_files, total_files), file_progress) except Exception as e: logging.error(f"Cannot read track: {path}. Error: {str(e)}") print(i18n("cannot_read_track").format(path)) print(i18n("error_message").format(str(e))) continue elapsed_time = time.time() - start_time logging.info(f"Processing time: {elapsed_time:.2f} seconds") print(i18n("elapsed_time").format(elapsed_time)) if progress is not None and callable(getattr(progress, '__call__', None)): progress(1.0, desc=i18n("processing_complete")) update_progress_html(i18n("processing_complete"), 100) @spaces.GPU def proc_folder(args=None, progress=None): try: parser = argparse.ArgumentParser(description=i18n("proc_folder_description")) parser.add_argument("--model_type", type=str, default='melod_band_roformer', help=i18n("model_type_help")) parser.add_argument("--config_path", type=str, required=True, help=i18n("config_path_help")) parser.add_argument("--start_check_point", type=str, required=True, help=i18n("start_checkpoint_help")) parser.add_argument("--input_folder", type=str, required=True, help=i18n("input_folder_help")) parser.add_argument("--store_dir", type=str, required=True, help=i18n("store_dir_help")) parser.add_argument("--chunk_size", type=int, default=352800, help=i18n("chunk_size_help")) parser.add_argument("--overlap", type=int, default=2, help=i18n("overlap_help")) parser.add_argument("--export_format", type=str, default='wav FLOAT', choices=['wav FLOAT', 'flac PCM_16', 'flac PCM_24'], help=i18n("export_format_help")) parser.add_argument("--demud_phaseremix_inst", action='store_true', help=i18n("demud_phaseremix_help")) parser.add_argument("--extract_instrumental", action='store_true', help=i18n("extract_instrumental_help")) parser.add_argument("--use_tta", action='store_true', help=i18n("use_tta_help")) parser.add_argument("--flac_file", action='store_true', help=i18n("flac_file_help")) parser.add_argument("--pcm_type", type=str, choices=['PCM_16', 'PCM_24'], default='PCM_24', help=i18n("pcm_type_help")) parser.add_argument("--device_ids", nargs='+', type=int, default=[0], help=i18n("device_ids_help")) parser.add_argument("--force_cpu", action='store_true', help=i18n("force_cpu_help")) parser.add_argument("--lora_checkpoint", type=str, default='', help=i18n("lora_checkpoint_help")) args = parser.parse_args(args if args else []) except Exception as e: logging.error(f"Argument parsing failed: {str(e)}") raise ValueError(f"Invalid command-line arguments: {str(e)}") device = "cpu" if args.force_cpu: logging.info("Forced to use CPU") elif torch.cuda.is_available(): logging.info("CUDA available") print(i18n("cuda_available")) device = f'cuda:{args.device_ids[0]}' elif torch.backends.mps.is_available(): device = "mps" logging.info(f"Using device: {device}") print(i18n("using_device").format(device)) model_load_start_time = time.time() torch.backends.cudnn.benchmark = True try: model, config = get_model_from_config(args.model_type, args.config_path) except Exception as e: logging.error(f"Failed to load model: {str(e)}") raise if args.start_check_point: try: load_start_checkpoint(args, model, type_='inference') except Exception as e: logging.error(f"Failed to load checkpoint: {str(e)}") raise logging.info(f"Instruments: {config.training.instruments}") print(i18n("instruments_print").format(config.training.instruments)) if len(args.device_ids) > 1 and not args.force_cpu: model = nn.DataParallel(model, device_ids=args.device_ids) logging.info(f"Using DataParallel with devices: {args.device_ids}") model = model.to(device) elapsed_time = time.time() - model_load_start_time logging.info(f"Model load time: {elapsed_time:.2f} seconds") print(i18n("model_load_time").format(elapsed_time)) run_folder(model, args, config, device, verbose=False, progress=progress) return "Processing completed" if __name__ == "__main__": try: proc_folder(None) except Exception as e: logging.error(f"Main execution failed: {str(e)}") raise