Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import logging | |
import os | |
from dataclasses import asdict | |
import torch | |
import torch.nn as nn | |
from core.trainer import train_hubert_quantizer | |
from core.model.hubert import ( | |
HuBERTForBarkSemantic, | |
HubertForBarkSemanticConfig, | |
) | |
from core.utils import download_dataset_from_hf | |
from core.bark.constants import HUBERT_OUTPUT_VOCAB_SIZE | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
WORKSPACE = "./" | |
# HF repo id to the dataset | |
DATASET_REPO_ID = "sleeper371/bark-wave-semantic" | |
# if choose to publish checkpoint to HF, this will be the repo-id to publish checkpoint | |
CHECKPOINT_REPO_ID = "sleeper371/hubert-for-bark-semantic" | |
# name of the noise data file on the HF dataset repo | |
HF_NOISE_FILE_NAME = "environmental_sound.zip" | |
# local path that has the noise data use to enhance the training data | |
_LOCAL_NOISE_DATA_PATH = "noise_dataset" | |
# local path to the training audio folder | |
_LOCAL_TRAINING_DATA_PATH = "wav_semantic_dataset" | |
# local folder path to save trained checkpoint | |
_LOCAL_CHECKPOINTS_PATH = "checkpoints" | |
def prefix_workspace(workspace_path: str, path: str) -> str: | |
return os.path.join(workspace_path, path) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="HuBERT Training Script") | |
parser.add_argument( | |
"--hubert-checkpoint-name", | |
type=str, | |
default="facebook/hubert-base-ls960", | |
help="checkpoint name that will be used as the feature extractor layer for CustomHuBERT", | |
) | |
parser.add_argument( | |
"--feature-layer", | |
type=int, | |
default=11, | |
help="layer at which to use features for the LSTM", | |
) | |
parser.add_argument( | |
"--mix-precision", | |
action="store_true", | |
help="train model with mix precision bfloat16 and gradient scaler", | |
) | |
parser.add_argument( | |
"--lr", type=float, default=8e-5, help="Learning rate (default: 8e-5)" | |
) | |
parser.add_argument( | |
"--num-epochs", | |
type=int, | |
default=3, | |
help="Number of training epochs (default: 3)", | |
) | |
parser.add_argument( | |
"--train-ratio", | |
type=float, | |
default=0.8, | |
help="Train/validation split ratio (default: 0.8)", | |
) | |
parser.add_argument( | |
"--batch-size", | |
type=int, | |
default=2, | |
help="Batch size for training (default: 16)", | |
) | |
parser.add_argument( | |
"--dataset-file-name", | |
type=str, | |
default="short_sentences.zip", | |
help="name of the dataset file in the HF repo to download", | |
) | |
parser.add_argument( | |
"--save-checkpoint-every", | |
type=int, | |
default=1, | |
help="Save checkpoint every N epochs (default: 1)", | |
) | |
parser.add_argument( | |
"--model-bfloat16", | |
action="store_true", | |
default=False, | |
help="set true to convert and train model in bfloat16", | |
) | |
parser.add_argument( | |
"--augment-data-with-noise", | |
action="store_true", | |
default=False, | |
help="load and add noise randomly to training data as a regularization technique", | |
) | |
parser.add_argument( | |
"--augment-prob", | |
type=float, | |
default=0.5, | |
help="noise will be added to audio sample with this probability", | |
) | |
parser.add_argument( | |
"--publish-hf", | |
action="store_true", | |
default=False, | |
help="if set, publish checkpoints to huggingface hub", | |
) | |
parser.add_argument( | |
"--workspace", | |
type=str, | |
default=WORKSPACE, | |
help="workspace folder to store data", | |
) | |
parser.add_argument( | |
"--num_samples", | |
type=int, | |
default=10000, | |
help="number of examples to load from the dataset", | |
) | |
return parser.parse_args() | |
def ensure_directory(path: str): | |
"""Create directory if it doesn't exist.""" | |
os.makedirs(path, exist_ok=True) | |
def calculate_model_memory(model: nn.Module): | |
""" | |
Calculate and print the memory usage of a PyTorch model's parameters based on their detected data type. | |
Args: | |
model (nn.Module): The PyTorch model to analyze. | |
""" | |
# Dictionary mapping PyTorch dtypes to bytes per parameter | |
bytes_per_param_dict = { | |
torch.float32: 4, # 32 bits = 4 bytes | |
torch.float16: 2, # 16 bits = 2 bytes | |
torch.int8: 1, # 8 bits = 1 byte | |
torch.int32: 4, # 32 bits = 4 bytes | |
torch.int64: 8, # 64 bits = 8 bytes | |
} | |
# Detect the data type from the first parameter | |
param_iter = iter(model.parameters()) | |
try: | |
first_param = next(param_iter) | |
dtype = first_param.dtype | |
except StopIteration: | |
print("Model has no parameters!") | |
return | |
# Get bytes per parameter based on detected dtype | |
# Default to 4 bytes if dtype not found | |
bytes_per_param = bytes_per_param_dict.get(dtype, 4) | |
dtype_name = str(dtype).replace("torch.", "") # Clean up dtype name for printing | |
# Count total number of parameters | |
total_params = sum(p.numel() for p in model.parameters()) | |
# Count total number of parameters | |
total_params = sum(p.numel() for p in model.parameters()) | |
# Calculate total memory in bytes | |
total_memory_bytes = total_params * bytes_per_param | |
# Convert to KB, MB, and GB for readability | |
total_memory_kb = total_memory_bytes / 1024 | |
total_memory_mb = total_memory_kb / 1024 | |
total_memory_gb = total_memory_mb / 1024 | |
# Print results | |
logger.info(f"Model Memory Usage (Detected dtype: {dtype_name}):") | |
logger.info(f"Total Parameters: {total_params:,}") | |
logger.info(f"Total Memory: {total_memory_gb:,.2f} GB") | |
def main(): | |
args = parse_args() | |
# local path that has the noise data use to enhance the training data | |
LOCAL_NOISE_DATA_PATH = prefix_workspace(args.workspace, _LOCAL_NOISE_DATA_PATH) | |
# local path to the training audio folder | |
LOCAL_TRAINING_DATA_PATH = prefix_workspace( | |
args.workspace, _LOCAL_TRAINING_DATA_PATH | |
) | |
# local folder path to save trained checkpoint | |
LOCAL_CHECKPOINTS_PATH = prefix_workspace(args.workspace, _LOCAL_CHECKPOINTS_PATH) | |
# Create necessary directories | |
ensure_directory(LOCAL_CHECKPOINTS_PATH) | |
logger.info("Starting HuBERT training") | |
device = ( | |
torch.device("cuda") | |
if torch.cuda.is_available() | |
else ( | |
torch.device("mps") | |
if torch.backends.mps.is_available() | |
else torch.device("cpu") | |
) | |
) | |
config = HubertForBarkSemanticConfig( | |
vocab_size=HUBERT_OUTPUT_VOCAB_SIZE, | |
checkpoint_name=args.hubert_checkpoint_name, | |
feature_layer=args.feature_layer, | |
num_decoder_layer=6, | |
) | |
model = HuBERTForBarkSemantic( | |
config=config, load_hubert_pretrained_weights=True, device=device | |
) | |
if args.model_bfloat16: | |
model = model.to(torch.bfloat16) | |
logger.info("Training model in bfloat16 precision") | |
calculate_model_memory(model) | |
# Download datasets if needed | |
if not os.path.exists(LOCAL_TRAINING_DATA_PATH): | |
download_dataset_from_hf( | |
DATASET_REPO_ID, | |
args.dataset_file_name, | |
LOCAL_TRAINING_DATA_PATH, | |
) | |
if args.augment_data_with_noise and not os.path.exists(LOCAL_NOISE_DATA_PATH): | |
download_dataset_from_hf( | |
DATASET_REPO_ID, | |
HF_NOISE_FILE_NAME, | |
LOCAL_NOISE_DATA_PATH, | |
) | |
# Train the model | |
trained_model = train_hubert_quantizer( | |
model=model, | |
model_config=asdict(config), | |
lr=args.lr, | |
num_epoch=args.num_epochs, | |
train_ratio=args.train_ratio, | |
batch_size=args.batch_size, | |
data_path=LOCAL_TRAINING_DATA_PATH, | |
checkpoint_path=LOCAL_CHECKPOINTS_PATH, | |
save_checkpoint_every=args.save_checkpoint_every, | |
augment_data_with_noise=args.augment_data_with_noise, | |
augment_prob=args.augment_prob, | |
noise_data_path=LOCAL_NOISE_DATA_PATH, | |
publish_hf=args.publish_hf, | |
publish_to_repo=CHECKPOINT_REPO_ID, | |
device=device, | |
num_samples=args.num_samples, | |
enable_grad_scaler=args.mix_precision, | |
) | |
logger.info("Training completed") | |
return trained_model | |
if __name__ == "__main__": | |
main() | |