Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import logging | |
import os | |
from typing import Optional | |
from core.bark.generate_audio_semantic_dataset import ( | |
generate_wav_semantic_dataset, | |
BarkGenerationConfig, | |
) | |
from core.utils import upload_file_to_hf, zip_folder | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
def parse_dataset_args(args_list=None): | |
"""Parse arguments specific to dataset creation.""" | |
parser = argparse.ArgumentParser(description="Audio Semantic Dataset Creation") | |
parser.add_argument( | |
"--text-file", | |
type=str, | |
default="data/test_data.txt", | |
help="Path to text file for dataset generation", | |
) | |
parser.add_argument( | |
"--batch-size", | |
type=int, | |
default=2, | |
help="Batch size for processing (default: 1)", | |
) | |
parser.add_argument( | |
"--output-dir", | |
type=str, | |
default="./dataset", | |
help="Output directory for generated files (default: ./dataset)", | |
) | |
parser.add_argument( | |
"--max-tokens", | |
type=int, | |
default=256, | |
help="Maximum tokens per example (default: 256)", | |
) | |
parser.add_argument( | |
"--use-small-model", | |
action="store_true", | |
help="Use small model for generation", | |
) | |
parser.add_argument( | |
"--save-raw-audio", | |
action="store_true", | |
help="Store generated audio as .wav instead of .npz", | |
) | |
parser.add_argument( | |
"--publish-hf", | |
action="store_true", | |
help="Publish dataset to HuggingFace Hub", | |
) | |
parser.add_argument( | |
"--repo-id", | |
type=str, | |
help="HuggingFace repo ID to publish to", | |
) | |
parser.add_argument( | |
"--path-in-repo", | |
type=str, | |
help="Path in HF repo", | |
default=None, | |
) | |
parser.add_argument( | |
"--silent", action="store_true", help="Suppress progress output" | |
) | |
return parser.parse_args(args_list) | |
def create_audio_semantic_dataset( | |
text_file: str, | |
output_dir: str = "./dataset", | |
batch_size: int = 1, | |
max_tokens: int = 256, | |
use_small_model: bool = False, | |
save_raw_audio: bool = False, | |
publish_hf: bool = False, | |
repo_id: Optional[str] = None, | |
path_in_repo: Optional[str] = None, | |
silent: bool = False, | |
) -> None: | |
"""Create audio semantic dataset from text file. | |
Can be called directly with parameters or via command line using parse_dataset_args(). | |
Args: | |
text_file: Path to input text file | |
output_dir: Directory to save generated dataset | |
batch_size: Batch size for processing | |
max_tokens: Maximum tokens per example | |
use_small_model: Whether to use small model | |
save_raw_audio: Save as raw audio (.wav) instead of .npz | |
publish_hf: Whether to publish to HuggingFace Hub | |
repo_id: HF repo ID to publish to | |
path_in_repo: Path in HF repo | |
silent: Suppress progress output | |
""" | |
os.makedirs(output_dir, exist_ok=True) | |
if not os.path.isfile(text_file): | |
raise FileNotFoundError(f"Text file not found: {text_file}") | |
logger.info(f"Starting dataset generation from {text_file}") | |
generation_config = BarkGenerationConfig( | |
temperature=None, | |
generate_coarse_temperature=None, | |
generate_fine_temperature=None, | |
use_small_model=use_small_model, | |
) | |
generate_wav_semantic_dataset( | |
text_file_path=text_file, | |
generation_config=generation_config, | |
batch_size=batch_size, | |
save_path=output_dir, | |
save_data_as_raw_audio=save_raw_audio, | |
silent=silent, | |
) | |
logger.info("Dataset generation completed") | |
if publish_hf and repo_id: | |
logger.info("Publishing dataset to huggingface hub") | |
zip_path = "./dataset.zip" | |
success = zip_folder(output_dir, zip_path) | |
if not success: | |
raise RuntimeError(f"Unable to zip folder {output_dir}") | |
upload_file_to_hf(zip_path, repo_id, "dataset", path_in_repo=path_in_repo) | |
if __name__ == "__main__": | |
args = parse_dataset_args() | |
create_audio_semantic_dataset( | |
text_file=args.text_file, | |
output_dir=args.output_dir, | |
batch_size=args.batch_size, | |
max_tokens=args.max_tokens, | |
use_small_model=args.use_small_model, | |
save_raw_audio=args.save_raw_audio, | |
publish_hf=args.publish_hf, | |
repo_id=args.repo_id, | |
path_in_repo=args.path_in_repo, | |
silent=args.silent, | |
) | |