ASesYusuf1's picture
Update utils.py
3c66c3b verified
# coding: utf-8
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
import argparse
import numpy as np
import torch
import torch.nn as nn
import yaml
import os
import soundfile as sf
from ml_collections import ConfigDict
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from typing import Dict, List, Tuple, Any, Union
import loralib as lora
import gc # For garbage collection
import logging # Hata takibi için
# Log ayarları
logging.basicConfig(level=logging.INFO, filename='utils.log', format='%(asctime)s - %(message)s')
def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]:
try:
with open(config_path, 'r') as f:
if model_type == 'htdemucs':
config = OmegaConf.load(config_path)
else:
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
return config
except FileNotFoundError:
raise FileNotFoundError(f"Configuration file not found at {config_path}")
except Exception as e:
raise ValueError(f"Error loading configuration: {e}")
def get_model_from_config(model_type: str, config_path: str) -> Tuple:
"""
Load the model specified by the model type and configuration file.
Parameters:
----------
model_type : str
The type of model to load (e.g., 'mdx23c', 'htdemucs', 'scnet', etc.).
config_path : str
The path to the configuration file (YAML or OmegaConf format).
Returns:
-------
model : nn.Module or None
The initialized model based on the `model_type`, or None if the model type is not recognized.
config : Any
The configuration used to initialize the model. This could be in different formats
depending on the model type (e.g., OmegaConf, ConfigDict).
Raises:
------
ValueError:
If the `model_type` is unknown or an error occurs during model initialization.
"""
config = load_config(model_type, config_path)
if model_type == 'mdx23c':
from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net
model = TFC_TDF_net(config)
elif model_type == 'htdemucs':
from models.demucs4ht import get_model
model = get_model(config)
elif model_type == 'segm_models':
from models.segm_models import Segm_Models_Net
model = Segm_Models_Net(config)
elif model_type == 'torchseg':
from models.torchseg_models import Torchseg_Net
model = Torchseg_Net(config)
elif model_type == 'mel_band_roformer':
from models.bs_roformer import MelBandRoformer
model = MelBandRoformer(**dict(config.model))
elif model_type == 'bs_roformer':
from models.bs_roformer import BSRoformer
model = BSRoformer(**dict(config.model))
elif model_type == 'swin_upernet':
from models.upernet_swin_transformers import Swin_UperNet_Model
model = Swin_UperNet_Model(config)
elif model_type == 'bandit':
from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple
model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model)
elif model_type == 'bandit_v2':
from models.bandit_v2.bandit import Bandit
model = Bandit(**config.kwargs)
elif model_type == 'scnet_unofficial':
from models.scnet_unofficial import SCNet
model = SCNet(**config.model)
elif model_type == 'scnet':
from models.scnet import SCNet
model = SCNet(**config.model)
elif model_type == 'apollo':
from models.look2hear.models import BaseModel
model = BaseModel.apollo(**config.model)
elif model_type == 'bs_mamba2':
from models.ts_bs_mamba2 import Separator
model = Separator(**config.model)
elif model_type == 'experimental_mdx23c_stht':
from models.mdx23c_tfc_tdf_v3_with_STHT import TFC_TDF_net
model = TFC_TDF_net(config)
else:
raise ValueError(f"Unknown model type: {model_type}")
return model, config
def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]:
try:
mix, sr = sf.read(path)
if len(mix.shape) == 1: # Mono audio
mix = np.expand_dims(mix, axis=-1)
return mix.T, sr
except Exception as e:
if skip_err:
print(f"No stem {instr}: skip!")
return None, None
raise RuntimeError(f"Error reading the file at {path}: {e}")
def normalize_audio(audio: np.ndarray) -> Tuple[np.ndarray, Dict[str, float]]:
mono = audio.mean(0)
mean, std = mono.mean(), mono.std()
return (audio - mean) / (std + 1e-8), {"mean": mean, "std": std}
def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray:
return audio * norm_params["std"] + norm_params["mean"]
def apply_tta(
config,
model: nn.Module,
mix: torch.Tensor,
waveforms_orig: Dict[str, torch.Tensor],
device: str,
model_type: str,
progress=None # Gradio progress nesnesi
) -> Dict[str, torch.Tensor]:
track_proc_list = [mix[::-1].clone(), -mix.clone()]
total_steps = len(track_proc_list)
processed_steps = 0
for i, augmented_mix in enumerate(track_proc_list):
# TTA adımı için ilerleme güncellemesi
processed_steps += 1
progress_value = round((processed_steps / total_steps) * 50) # TTA için 0-50% aralığı
if progress is not None and callable(getattr(progress, '__call__', None)):
progress(progress_value / 100, desc=f"Applying TTA step {processed_steps}/{total_steps}")
update_progress_html(f"Applying TTA step {processed_steps}/{total_steps}", progress_value)
waveforms = demix(config, model, augmented_mix, device, model_type=model_type, pbar=False, progress=progress)
for el in waveforms:
if i == 0:
waveforms_orig[el] += waveforms[el][::-1].clone()
else:
waveforms_orig[el] -= waveforms[el]
del waveforms, augmented_mix
gc.collect()
if device.startswith('cuda'):
torch.cuda.empty_cache()
for el in waveforms_orig:
waveforms_orig[el] /= (len(track_proc_list) + 1)
# TTA tamamlandı
if progress is not None and callable(getattr(progress, '__call__', None)):
progress(0.5, desc="TTA completed")
update_progress_html("TTA completed", 50)
return waveforms_orig
def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor:
fadein = torch.linspace(0, 1, fade_size)
fadeout = torch.linspace(1, 0, fade_size)
window = torch.ones(window_size)
window[-fade_size:] = fadeout
window[:fade_size] = fadein
return window
def demix(
config: ConfigDict,
model: nn.Module,
mix: torch.Tensor,
device: str,
model_type: str,
pbar: bool = False,
progress=None # Gradio progress nesnesi
) -> Dict[str, np.ndarray]:
logging.info(f"Starting demix for model_type: {model_type}, chunk_size: {config.audio.chunk_size}")
# CPU'da FP16 ile başla
mix = torch.tensor(mix, dtype=torch.float16, device='cpu')
mode = 'demucs' if model_type == 'htdemucs' else 'generic'
# İşlem parametreleri
if mode == 'demucs':
chunk_size = config.training.samplerate * config.training.segment
num_instruments = len(config.training.instruments)
num_overlap = config.inference.num_overlap
step = chunk_size // num_overlap
else:
chunk_size = config.audio.chunk_size
num_instruments = len(prefer_target_instrument(config))
num_overlap = config.inference.num_overlap
fade_size = chunk_size // 10
step = chunk_size // num_overlap
border = chunk_size - step
length_init = mix.shape[-1]
windowing_array = _getWindowingArray(chunk_size, fade_size).to('cpu', dtype=torch.float16)
if length_init > 2 * border and border > 0:
mix = nn.functional.pad(mix, (border, border), mode="reflect")
batch_size = getattr(config.inference, 'batch_size', 1) # Düşük bellek için varsayılan 1
# Modeli cihaza taşı (ZeroGPU için cuda:0)
model = model.to(device)
model.eval()
# Toplam chunk sayısını hesapla
total_chunks = (mix.shape[1] + step - 1) // step
processed_chunks = 0
with torch.no_grad(): # Çıkarım için gradyan yok
with torch.cuda.amp.autocast(enabled=device.startswith('cuda'), dtype=torch.float16):
req_shape = (num_instruments,) + mix.shape
result = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
counter = torch.zeros(req_shape, dtype=torch.float16, device='cpu')
i = 0
batch_data = []
batch_locations = []
start_time = time.time()
while i < mix.shape[1]:
part = mix[:, i:i + chunk_size]
chunk_len = part.shape[-1]
pad_mode = "reflect" if mode == "generic" and chunk_len > chunk_size // 2 else "constant"
part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0)
batch_data.append(part)
batch_locations.append((i, chunk_len))
i += step
if len(batch_data) >= batch_size or i >= mix.shape[1]:
# Veriyi GPU'ya taşı
arr = torch.stack(batch_data, dim=0).to(device, non_blocking=True)
x = model(arr) # Model çıkarımı GPU'da
# Sonuçları hemen CPU'ya taşı
x = x.cpu()
if mode == "generic":
window = windowing_array.clone()
if i - step == 0:
window[:fade_size] = 1
elif i >= mix.shape[1]:
window[-fade_size:] = 1
for j, (start, seg_len) in enumerate(batch_locations):
if mode == "generic":
result[..., start:start + seg_len] += (x[j, ..., :seg_len] * window[..., :seg_len])
counter[..., start:start + seg_len] += window[..., :seg_len]
else:
result[..., start:start + seg_len] += x[j, ..., :seg_len]
counter[..., start:start + seg_len] += 1.0
# İlerleme güncellemesi
processed_chunks += len(batch_data)
progress_value = min(round((processed_chunks / total_chunks) * 100), 100) # %1 hassasiyet
if progress is not None and callable(getattr(progress, '__call__', None)):
progress(progress_value / 100, desc=f"Processing chunk {processed_chunks}/{total_chunks}")
update_progress_html(f"Processing chunk {processed_chunks}/{total_chunks}", progress_value)
del arr, x
batch_data.clear()
batch_locations.clear()
gc.collect()
if device.startswith('cuda'):
torch.cuda.empty_cache()
logging.info("Cleared CUDA cache")
elapsed_time = time.time() - start_time
logging.info(f"Demix completed in {elapsed_time:.2f} seconds")
estimated_sources = result / (counter + 1e-8)
estimated_sources = estimated_sources.numpy().astype(np.float32)
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
if mode == "generic" and length_init > 2 * border and border > 0:
estimated_sources = estimated_sources[..., border:-border]
instruments = config.training.instruments if mode == "demucs" else prefer_target_instrument(config)
ret_data = {k: v for k, v in zip(instruments, estimated_sources)}
logging.info("Demix completed successfully")
# Son ilerleme güncellemesi
if progress is not None and callable(getattr(progress, '__call__', None)):
progress(1.0, desc="Demix completed")
update_progress_html("Demix completed", 100)
return ret_data
def prefer_target_instrument(config: ConfigDict) -> List[str]:
return [config.training.target_instrument] if getattr(config.training, 'target_instrument', None) else config.training.instruments
def load_not_compatible_weights(model: nn.Module, weights: str, verbose: bool = False) -> None:
new_model = model.state_dict()
old_model = torch.load(weights, map_location='cpu')
if 'state' in old_model:
old_model = old_model['state']
if 'state_dict' in old_model:
old_model = old_model['state_dict']
for el in new_model:
if el in old_model and new_model[el].shape == old_model[el].shape:
new_model[el] = old_model[el]
model.load_state_dict(new_model)
def load_lora_weights(model: nn.Module, lora_path: str, device: str = 'cpu') -> None:
lora_state_dict = torch.load(lora_path, map_location=device)
model.load_state_dict(lora_state_dict, strict=False)
def load_start_checkpoint(args: argparse.Namespace, model: nn.Module, type_='train') -> None:
print(f'Start from checkpoint: {args.start_check_point}')
device = 'cpu'
state_dict = torch.load(args.start_check_point, map_location=device, weights_only=True)
if args.model_type in ['htdemucs', 'apollo'] and isinstance(state_dict, dict):
state_dict = state_dict.get('state', state_dict.get('state_dict', state_dict))
model.load_state_dict(state_dict)
if args.lora_checkpoint:
print(f"Loading LoRA weights from: {args.lora_checkpoint}")
load_lora_weights(model, args.lora_checkpoint, device)
def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module:
if 'lora' not in config:
raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.")
replaced_layers = 0
for name, module in model.named_modules():
hierarchy = name.split('.')
layer_name = hierarchy[-1]
if isinstance(module, nn.Linear):
try:
parent_module = model
for submodule_name in hierarchy[:-1]:
parent_module = getattr(parent_module, submodule_name)
setattr(
parent_module,
layer_name,
lora.MergedLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
**config['lora']
)
)
replaced_layers += 1
except Exception as e:
print(f"Error replacing layer {name}: {e}")
print(f"Number of layers replaced with LoRA: {replaced_layers}")
return model
def draw_spectrogram(waveform, sample_rate, length, output_file):
import librosa.display
x = waveform[:int(length * sample_rate), :]
X = librosa.stft(x.mean(axis=-1))
Xdb = librosa.amplitude_to_db(np.abs(X), ref=np.max)
fig, ax = plt.subplots()
img = librosa.display.specshow(
Xdb, cmap='plasma', sr=sample_rate, x_axis='time', y_axis='linear', ax=ax
)
ax.set(title='File: ' + os.path.basename(output_file))
fig.colorbar(img, ax=ax, format="%+2.f dB")
if output_file:
plt.savefig(output_file)
plt.close()