# This module handles the data loading and preprocessing for various phoneme transcription datasets. import torch import torchaudio import zipfile from pathlib import Path # Get absolute path CURRENT_DIR = Path(__file__).parent.absolute() # Constants DATA_DIR = CURRENT_DIR / "data" TIMIT_PATH = DATA_DIR / "TIMIT.zip" # Abstract data manager class class DataManager: """Abstract class for handling dataset operations""" def get_file_list(self, subset: str) -> list[str]: """Get list of files for given subset""" raise NotImplementedError def load_audio(self, filename: str) -> torch.Tensor: """Load and preprocess audio file""" raise NotImplementedError def get_phonemes(self, filename: str) -> str: """Get phoneme sequence from file""" raise NotImplementedError # Implement datasets class TimitDataManager(DataManager): """Handles all TIMIT dataset operations""" # TIMIT to IPA mapping with direct simplifications _TIMIT_TO_IPA = { # Vowels (simplified) "aa": "ɑ", "ae": "æ", "ah": "ʌ", "ao": "ɔ", "aw": "aʊ", "ay": "aɪ", "eh": "ɛ", "er": "ɹ", # Simplified from 'ɝ' "ey": "eɪ", "ih": "ɪ", "ix": "i", # Simplified from 'ɨ' "iy": "i", "ow": "oʊ", "oy": "ɔɪ", "uh": "ʊ", "uw": "u", "ux": "u", # Simplified from 'ʉ' "ax": "ə", "ax-h": "ə", # Simplified from 'ə̥' "axr": "ɹ", # Simplified from 'ɚ' # Consonants "b": "", "bcl": "b", "d": "", "dcl": "d", "g": "", "gcl": "g", "p": "", "pcl": "p", "t": "", "tcl": "t", "k": "", "kcl": "k", "dx": "ɾ", "q": "ʔ", # Fricatives "jh": "dʒ", "ch": "tʃ", "s": "s", "sh": "ʃ", "z": "z", "zh": "ʒ", "f": "f", "th": "θ", "v": "v", "dh": "ð", "hh": "h", "hv": "h", # Simplified from 'ɦ' # Nasals (simplified) "m": "m", "n": "n", "ng": "ŋ", "em": "m", # Simplified from 'm̩' "en": "n", # Simplified from 'n̩' "eng": "ŋ", # Simplified from 'ŋ̍' "nx": "ɾ", # Simplified from 'ɾ̃' # Semivowels and Glides "l": "l", "r": "ɹ", "w": "w", "wh": "ʍ", "y": "j", "el": "l", # Simplified from 'l̩' # Special "epi": "", # Remove epenthetic silence "h#": "", # Remove start/end silence "pau": "", # Remove pause } def __init__(self, timit_path: Path): self.timit_path = timit_path self._zip_ = None print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}") if not self.timit_path.exists(): raise FileNotFoundError( f"TIMIT dataset not found at {self.timit_path.absolute()}. Try running ./scripts/download_data_lfs.sh again." ) else: print("TIMIT dataset file exists!") @property def _zip(self): if not self._zip_: self._zip_ = zipfile.ZipFile(self.timit_path, "r") return self._zip_ def get_file_list(self, subset: str) -> list[str]: """Get list of WAV files for given subset""" files = [ f for f in self._zip.namelist() if f.endswith(".WAV") and subset.lower() in f.lower() ] print(f"Found {len(files)} WAV files in {subset} subset") if files: print("First 3 files:", files[:3]) return files def load_audio(self, filename: str) -> torch.Tensor: """Load and preprocess audio file""" with self._zip.open(filename) as wav_file: waveform, sample_rate = torchaudio.load(wav_file) # type: ignore if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sample_rate != 16000: waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform) waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7) if waveform.dim() == 1: waveform = waveform.unsqueeze(0) return waveform def get_phonemes(self, filename: str) -> str: """Get cleaned phoneme sequence from PHN file and convert to IPA""" phn_file = filename.replace(".WAV", ".PHN") with self._zip.open(phn_file) as f: phonemes = [] for line in f.read().decode("utf-8").splitlines(): if line.strip(): _, _, phone = line.split() phone = self._remove_stress_mark(phone) # Convert to IPA instead of using simplify_timit ipa = self._TIMIT_TO_IPA.get(phone.lower(), "") if ipa: phonemes.append(ipa) return "".join(phonemes) # Join without spaces for IPA def _remove_stress_mark(self, text: str) -> str: """Removes the combining double inverted breve (͡) from text""" if not isinstance(text, str): raise TypeError("Input must be string") return text.replace("͡", "") # Initialize data managers timit_manager = TimitDataManager(TIMIT_PATH)