bark_with_batch_inference / train_hubert.py
sleeper371's picture
add code
37a9836
import argparse
import logging
import os
from dataclasses import asdict
import torch
import torch.nn as nn
from core.trainer import train_hubert_quantizer
from core.model.hubert import (
HuBERTForBarkSemantic,
HubertForBarkSemanticConfig,
)
from core.utils import download_dataset_from_hf
from core.bark.constants import HUBERT_OUTPUT_VOCAB_SIZE
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
WORKSPACE = "./"
# HF repo id to the dataset
DATASET_REPO_ID = "sleeper371/bark-wave-semantic"
# if choose to publish checkpoint to HF, this will be the repo-id to publish checkpoint
CHECKPOINT_REPO_ID = "sleeper371/hubert-for-bark-semantic"
# name of the noise data file on the HF dataset repo
HF_NOISE_FILE_NAME = "environmental_sound.zip"
# local path that has the noise data use to enhance the training data
_LOCAL_NOISE_DATA_PATH = "noise_dataset"
# local path to the training audio folder
_LOCAL_TRAINING_DATA_PATH = "wav_semantic_dataset"
# local folder path to save trained checkpoint
_LOCAL_CHECKPOINTS_PATH = "checkpoints"
def prefix_workspace(workspace_path: str, path: str) -> str:
return os.path.join(workspace_path, path)
def parse_args():
parser = argparse.ArgumentParser(description="HuBERT Training Script")
parser.add_argument(
"--hubert-checkpoint-name",
type=str,
default="facebook/hubert-base-ls960",
help="checkpoint name that will be used as the feature extractor layer for CustomHuBERT",
)
parser.add_argument(
"--feature-layer",
type=int,
default=11,
help="layer at which to use features for the LSTM",
)
parser.add_argument(
"--mix-precision",
action="store_true",
help="train model with mix precision bfloat16 and gradient scaler",
)
parser.add_argument(
"--lr", type=float, default=8e-5, help="Learning rate (default: 8e-5)"
)
parser.add_argument(
"--num-epochs",
type=int,
default=3,
help="Number of training epochs (default: 3)",
)
parser.add_argument(
"--train-ratio",
type=float,
default=0.8,
help="Train/validation split ratio (default: 0.8)",
)
parser.add_argument(
"--batch-size",
type=int,
default=2,
help="Batch size for training (default: 16)",
)
parser.add_argument(
"--dataset-file-name",
type=str,
default="short_sentences.zip",
help="name of the dataset file in the HF repo to download",
)
parser.add_argument(
"--save-checkpoint-every",
type=int,
default=1,
help="Save checkpoint every N epochs (default: 1)",
)
parser.add_argument(
"--model-bfloat16",
action="store_true",
default=False,
help="set true to convert and train model in bfloat16",
)
parser.add_argument(
"--augment-data-with-noise",
action="store_true",
default=False,
help="load and add noise randomly to training data as a regularization technique",
)
parser.add_argument(
"--augment-prob",
type=float,
default=0.5,
help="noise will be added to audio sample with this probability",
)
parser.add_argument(
"--publish-hf",
action="store_true",
default=False,
help="if set, publish checkpoints to huggingface hub",
)
parser.add_argument(
"--workspace",
type=str,
default=WORKSPACE,
help="workspace folder to store data",
)
parser.add_argument(
"--num_samples",
type=int,
default=10000,
help="number of examples to load from the dataset",
)
return parser.parse_args()
def ensure_directory(path: str):
"""Create directory if it doesn't exist."""
os.makedirs(path, exist_ok=True)
def calculate_model_memory(model: nn.Module):
"""
Calculate and print the memory usage of a PyTorch model's parameters based on their detected data type.
Args:
model (nn.Module): The PyTorch model to analyze.
"""
# Dictionary mapping PyTorch dtypes to bytes per parameter
bytes_per_param_dict = {
torch.float32: 4, # 32 bits = 4 bytes
torch.float16: 2, # 16 bits = 2 bytes
torch.int8: 1, # 8 bits = 1 byte
torch.int32: 4, # 32 bits = 4 bytes
torch.int64: 8, # 64 bits = 8 bytes
}
# Detect the data type from the first parameter
param_iter = iter(model.parameters())
try:
first_param = next(param_iter)
dtype = first_param.dtype
except StopIteration:
print("Model has no parameters!")
return
# Get bytes per parameter based on detected dtype
# Default to 4 bytes if dtype not found
bytes_per_param = bytes_per_param_dict.get(dtype, 4)
dtype_name = str(dtype).replace("torch.", "") # Clean up dtype name for printing
# Count total number of parameters
total_params = sum(p.numel() for p in model.parameters())
# Count total number of parameters
total_params = sum(p.numel() for p in model.parameters())
# Calculate total memory in bytes
total_memory_bytes = total_params * bytes_per_param
# Convert to KB, MB, and GB for readability
total_memory_kb = total_memory_bytes / 1024
total_memory_mb = total_memory_kb / 1024
total_memory_gb = total_memory_mb / 1024
# Print results
logger.info(f"Model Memory Usage (Detected dtype: {dtype_name}):")
logger.info(f"Total Parameters: {total_params:,}")
logger.info(f"Total Memory: {total_memory_gb:,.2f} GB")
def main():
args = parse_args()
# local path that has the noise data use to enhance the training data
LOCAL_NOISE_DATA_PATH = prefix_workspace(args.workspace, _LOCAL_NOISE_DATA_PATH)
# local path to the training audio folder
LOCAL_TRAINING_DATA_PATH = prefix_workspace(
args.workspace, _LOCAL_TRAINING_DATA_PATH
)
# local folder path to save trained checkpoint
LOCAL_CHECKPOINTS_PATH = prefix_workspace(args.workspace, _LOCAL_CHECKPOINTS_PATH)
# Create necessary directories
ensure_directory(LOCAL_CHECKPOINTS_PATH)
logger.info("Starting HuBERT training")
device = (
torch.device("cuda")
if torch.cuda.is_available()
else (
torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
)
config = HubertForBarkSemanticConfig(
vocab_size=HUBERT_OUTPUT_VOCAB_SIZE,
checkpoint_name=args.hubert_checkpoint_name,
feature_layer=args.feature_layer,
num_decoder_layer=6,
)
model = HuBERTForBarkSemantic(
config=config, load_hubert_pretrained_weights=True, device=device
)
if args.model_bfloat16:
model = model.to(torch.bfloat16)
logger.info("Training model in bfloat16 precision")
calculate_model_memory(model)
# Download datasets if needed
if not os.path.exists(LOCAL_TRAINING_DATA_PATH):
download_dataset_from_hf(
DATASET_REPO_ID,
args.dataset_file_name,
LOCAL_TRAINING_DATA_PATH,
)
if args.augment_data_with_noise and not os.path.exists(LOCAL_NOISE_DATA_PATH):
download_dataset_from_hf(
DATASET_REPO_ID,
HF_NOISE_FILE_NAME,
LOCAL_NOISE_DATA_PATH,
)
# Train the model
trained_model = train_hubert_quantizer(
model=model,
model_config=asdict(config),
lr=args.lr,
num_epoch=args.num_epochs,
train_ratio=args.train_ratio,
batch_size=args.batch_size,
data_path=LOCAL_TRAINING_DATA_PATH,
checkpoint_path=LOCAL_CHECKPOINTS_PATH,
save_checkpoint_every=args.save_checkpoint_every,
augment_data_with_noise=args.augment_data_with_noise,
augment_prob=args.augment_prob,
noise_data_path=LOCAL_NOISE_DATA_PATH,
publish_hf=args.publish_hf,
publish_to_repo=CHECKPOINT_REPO_ID,
device=device,
num_samples=args.num_samples,
enable_grad_scaler=args.mix_precision,
)
logger.info("Training completed")
return trained_model
if __name__ == "__main__":
main()