SanderGi's picture
clean up and make contribution ready
38024bc
# This module handles the data loading and preprocessing for various phoneme transcription datasets.
import torch
import torchaudio
import zipfile
from pathlib import Path
# Get absolute path
CURRENT_DIR = Path(__file__).parent.absolute()
# Constants
DATA_DIR = CURRENT_DIR / "data"
TIMIT_PATH = DATA_DIR / "TIMIT.zip"
# Abstract data manager class
class DataManager:
"""Abstract class for handling dataset operations"""
def get_file_list(self, subset: str) -> list[str]:
"""Get list of files for given subset"""
raise NotImplementedError
def load_audio(self, filename: str) -> torch.Tensor:
"""Load and preprocess audio file"""
raise NotImplementedError
def get_phonemes(self, filename: str) -> str:
"""Get phoneme sequence from file"""
raise NotImplementedError
# Implement datasets
class TimitDataManager(DataManager):
"""Handles all TIMIT dataset operations"""
# TIMIT to IPA mapping with direct simplifications
_TIMIT_TO_IPA = {
# Vowels (simplified)
"aa": "ɑ",
"ae": "æ",
"ah": "ʌ",
"ao": "ɔ",
"aw": "aʊ",
"ay": "aɪ",
"eh": "ɛ",
"er": "ɹ", # Simplified from 'ɝ'
"ey": "eɪ",
"ih": "ɪ",
"ix": "i", # Simplified from 'ɨ'
"iy": "i",
"ow": "oʊ",
"oy": "ɔɪ",
"uh": "ʊ",
"uw": "u",
"ux": "u", # Simplified from 'ʉ'
"ax": "ə",
"ax-h": "ə", # Simplified from 'ə̥'
"axr": "ɹ", # Simplified from 'ɚ'
# Consonants
"b": "",
"bcl": "b",
"d": "",
"dcl": "d",
"g": "",
"gcl": "g",
"p": "",
"pcl": "p",
"t": "",
"tcl": "t",
"k": "",
"kcl": "k",
"dx": "ɾ",
"q": "ʔ",
# Fricatives
"jh": "dʒ",
"ch": "tʃ",
"s": "s",
"sh": "ʃ",
"z": "z",
"zh": "ʒ",
"f": "f",
"th": "θ",
"v": "v",
"dh": "ð",
"hh": "h",
"hv": "h", # Simplified from 'ɦ'
# Nasals (simplified)
"m": "m",
"n": "n",
"ng": "ŋ",
"em": "m", # Simplified from 'm̩'
"en": "n", # Simplified from 'n̩'
"eng": "ŋ", # Simplified from 'ŋ̍'
"nx": "ɾ", # Simplified from 'ɾ̃'
# Semivowels and Glides
"l": "l",
"r": "ɹ",
"w": "w",
"wh": "ʍ",
"y": "j",
"el": "l", # Simplified from 'l̩'
# Special
"epi": "", # Remove epenthetic silence
"h#": "", # Remove start/end silence
"pau": "", # Remove pause
}
def __init__(self, timit_path: Path):
self.timit_path = timit_path
self._zip_ = None
print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}")
if not self.timit_path.exists():
raise FileNotFoundError(
f"TIMIT dataset not found at {self.timit_path.absolute()}. Try running ./scripts/download_data_lfs.sh again."
)
else:
print("TIMIT dataset file exists!")
@property
def _zip(self):
if not self._zip_:
self._zip_ = zipfile.ZipFile(self.timit_path, "r")
return self._zip_
def get_file_list(self, subset: str) -> list[str]:
"""Get list of WAV files for given subset"""
files = [
f
for f in self._zip.namelist()
if f.endswith(".WAV") and subset.lower() in f.lower()
]
print(f"Found {len(files)} WAV files in {subset} subset")
if files:
print("First 3 files:", files[:3])
return files
def load_audio(self, filename: str) -> torch.Tensor:
"""Load and preprocess audio file"""
with self._zip.open(filename) as wav_file:
waveform, sample_rate = torchaudio.load(wav_file) # type: ignore
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
return waveform
def get_phonemes(self, filename: str) -> str:
"""Get cleaned phoneme sequence from PHN file and convert to IPA"""
phn_file = filename.replace(".WAV", ".PHN")
with self._zip.open(phn_file) as f:
phonemes = []
for line in f.read().decode("utf-8").splitlines():
if line.strip():
_, _, phone = line.split()
phone = self._remove_stress_mark(phone)
# Convert to IPA instead of using simplify_timit
ipa = self._TIMIT_TO_IPA.get(phone.lower(), "")
if ipa:
phonemes.append(ipa)
return "".join(phonemes) # Join without spaces for IPA
def _remove_stress_mark(self, text: str) -> str:
"""Removes the combining double inverted breve (͡) from text"""
if not isinstance(text, str):
raise TypeError("Input must be string")
return text.replace("͡", "")
# Initialize data managers
timit_manager = TimitDataManager(TIMIT_PATH)