Spaces:
Running
Running
# This module handles the data loading and preprocessing for various phoneme transcription datasets. | |
import torch | |
import torchaudio | |
import zipfile | |
from pathlib import Path | |
# Get absolute path | |
CURRENT_DIR = Path(__file__).parent.absolute() | |
# Constants | |
DATA_DIR = CURRENT_DIR / "data" | |
TIMIT_PATH = DATA_DIR / "TIMIT.zip" | |
# Abstract data manager class | |
class DataManager: | |
"""Abstract class for handling dataset operations""" | |
def get_file_list(self, subset: str) -> list[str]: | |
"""Get list of files for given subset""" | |
raise NotImplementedError | |
def load_audio(self, filename: str) -> torch.Tensor: | |
"""Load and preprocess audio file""" | |
raise NotImplementedError | |
def get_phonemes(self, filename: str) -> str: | |
"""Get phoneme sequence from file""" | |
raise NotImplementedError | |
# Implement datasets | |
class TimitDataManager(DataManager): | |
"""Handles all TIMIT dataset operations""" | |
# TIMIT to IPA mapping with direct simplifications | |
_TIMIT_TO_IPA = { | |
# Vowels (simplified) | |
"aa": "ɑ", | |
"ae": "æ", | |
"ah": "ʌ", | |
"ao": "ɔ", | |
"aw": "aʊ", | |
"ay": "aɪ", | |
"eh": "ɛ", | |
"er": "ɹ", # Simplified from 'ɝ' | |
"ey": "eɪ", | |
"ih": "ɪ", | |
"ix": "i", # Simplified from 'ɨ' | |
"iy": "i", | |
"ow": "oʊ", | |
"oy": "ɔɪ", | |
"uh": "ʊ", | |
"uw": "u", | |
"ux": "u", # Simplified from 'ʉ' | |
"ax": "ə", | |
"ax-h": "ə", # Simplified from 'ə̥' | |
"axr": "ɹ", # Simplified from 'ɚ' | |
# Consonants | |
"b": "", | |
"bcl": "b", | |
"d": "", | |
"dcl": "d", | |
"g": "", | |
"gcl": "g", | |
"p": "", | |
"pcl": "p", | |
"t": "", | |
"tcl": "t", | |
"k": "", | |
"kcl": "k", | |
"dx": "ɾ", | |
"q": "ʔ", | |
# Fricatives | |
"jh": "dʒ", | |
"ch": "tʃ", | |
"s": "s", | |
"sh": "ʃ", | |
"z": "z", | |
"zh": "ʒ", | |
"f": "f", | |
"th": "θ", | |
"v": "v", | |
"dh": "ð", | |
"hh": "h", | |
"hv": "h", # Simplified from 'ɦ' | |
# Nasals (simplified) | |
"m": "m", | |
"n": "n", | |
"ng": "ŋ", | |
"em": "m", # Simplified from 'm̩' | |
"en": "n", # Simplified from 'n̩' | |
"eng": "ŋ", # Simplified from 'ŋ̍' | |
"nx": "ɾ", # Simplified from 'ɾ̃' | |
# Semivowels and Glides | |
"l": "l", | |
"r": "ɹ", | |
"w": "w", | |
"wh": "ʍ", | |
"y": "j", | |
"el": "l", # Simplified from 'l̩' | |
# Special | |
"epi": "", # Remove epenthetic silence | |
"h#": "", # Remove start/end silence | |
"pau": "", # Remove pause | |
} | |
def __init__(self, timit_path: Path): | |
self.timit_path = timit_path | |
self._zip_ = None | |
print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}") | |
if not self.timit_path.exists(): | |
raise FileNotFoundError( | |
f"TIMIT dataset not found at {self.timit_path.absolute()}. Try running ./scripts/download_data_lfs.sh again." | |
) | |
else: | |
print("TIMIT dataset file exists!") | |
def _zip(self): | |
if not self._zip_: | |
self._zip_ = zipfile.ZipFile(self.timit_path, "r") | |
return self._zip_ | |
def get_file_list(self, subset: str) -> list[str]: | |
"""Get list of WAV files for given subset""" | |
files = [ | |
f | |
for f in self._zip.namelist() | |
if f.endswith(".WAV") and subset.lower() in f.lower() | |
] | |
print(f"Found {len(files)} WAV files in {subset} subset") | |
if files: | |
print("First 3 files:", files[:3]) | |
return files | |
def load_audio(self, filename: str) -> torch.Tensor: | |
"""Load and preprocess audio file""" | |
with self._zip.open(filename) as wav_file: | |
waveform, sample_rate = torchaudio.load(wav_file) # type: ignore | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
if sample_rate != 16000: | |
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) | |
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7) | |
if waveform.dim() == 1: | |
waveform = waveform.unsqueeze(0) | |
return waveform | |
def get_phonemes(self, filename: str) -> str: | |
"""Get cleaned phoneme sequence from PHN file and convert to IPA""" | |
phn_file = filename.replace(".WAV", ".PHN") | |
with self._zip.open(phn_file) as f: | |
phonemes = [] | |
for line in f.read().decode("utf-8").splitlines(): | |
if line.strip(): | |
_, _, phone = line.split() | |
phone = self._remove_stress_mark(phone) | |
# Convert to IPA instead of using simplify_timit | |
ipa = self._TIMIT_TO_IPA.get(phone.lower(), "") | |
if ipa: | |
phonemes.append(ipa) | |
return "".join(phonemes) # Join without spaces for IPA | |
def _remove_stress_mark(self, text: str) -> str: | |
"""Removes the combining double inverted breve (͡) from text""" | |
if not isinstance(text, str): | |
raise TypeError("Input must be string") | |
return text.replace("͡", "") | |
# Initialize data managers | |
timit_manager = TimitDataManager(TIMIT_PATH) | |