Spaces:
Running
Running
File size: 5,481 Bytes
38024bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# This module handles the data loading and preprocessing for various phoneme transcription datasets.
import torch
import torchaudio
import zipfile
from pathlib import Path
# Get absolute path
CURRENT_DIR = Path(__file__).parent.absolute()
# Constants
DATA_DIR = CURRENT_DIR / "data"
TIMIT_PATH = DATA_DIR / "TIMIT.zip"
# Abstract data manager class
class DataManager:
"""Abstract class for handling dataset operations"""
def get_file_list(self, subset: str) -> list[str]:
"""Get list of files for given subset"""
raise NotImplementedError
def load_audio(self, filename: str) -> torch.Tensor:
"""Load and preprocess audio file"""
raise NotImplementedError
def get_phonemes(self, filename: str) -> str:
"""Get phoneme sequence from file"""
raise NotImplementedError
# Implement datasets
class TimitDataManager(DataManager):
"""Handles all TIMIT dataset operations"""
# TIMIT to IPA mapping with direct simplifications
_TIMIT_TO_IPA = {
# Vowels (simplified)
"aa": "ɑ",
"ae": "æ",
"ah": "ʌ",
"ao": "ɔ",
"aw": "aʊ",
"ay": "aɪ",
"eh": "ɛ",
"er": "ɹ", # Simplified from 'ɝ'
"ey": "eɪ",
"ih": "ɪ",
"ix": "i", # Simplified from 'ɨ'
"iy": "i",
"ow": "oʊ",
"oy": "ɔɪ",
"uh": "ʊ",
"uw": "u",
"ux": "u", # Simplified from 'ʉ'
"ax": "ə",
"ax-h": "ə", # Simplified from 'ə̥'
"axr": "ɹ", # Simplified from 'ɚ'
# Consonants
"b": "",
"bcl": "b",
"d": "",
"dcl": "d",
"g": "",
"gcl": "g",
"p": "",
"pcl": "p",
"t": "",
"tcl": "t",
"k": "",
"kcl": "k",
"dx": "ɾ",
"q": "ʔ",
# Fricatives
"jh": "dʒ",
"ch": "tʃ",
"s": "s",
"sh": "ʃ",
"z": "z",
"zh": "ʒ",
"f": "f",
"th": "θ",
"v": "v",
"dh": "ð",
"hh": "h",
"hv": "h", # Simplified from 'ɦ'
# Nasals (simplified)
"m": "m",
"n": "n",
"ng": "ŋ",
"em": "m", # Simplified from 'm̩'
"en": "n", # Simplified from 'n̩'
"eng": "ŋ", # Simplified from 'ŋ̍'
"nx": "ɾ", # Simplified from 'ɾ̃'
# Semivowels and Glides
"l": "l",
"r": "ɹ",
"w": "w",
"wh": "ʍ",
"y": "j",
"el": "l", # Simplified from 'l̩'
# Special
"epi": "", # Remove epenthetic silence
"h#": "", # Remove start/end silence
"pau": "", # Remove pause
}
def __init__(self, timit_path: Path):
self.timit_path = timit_path
self._zip_ = None
print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}")
if not self.timit_path.exists():
raise FileNotFoundError(
f"TIMIT dataset not found at {self.timit_path.absolute()}. Try running ./scripts/download_data_lfs.sh again."
)
else:
print("TIMIT dataset file exists!")
@property
def _zip(self):
if not self._zip_:
self._zip_ = zipfile.ZipFile(self.timit_path, "r")
return self._zip_
def get_file_list(self, subset: str) -> list[str]:
"""Get list of WAV files for given subset"""
files = [
f
for f in self._zip.namelist()
if f.endswith(".WAV") and subset.lower() in f.lower()
]
print(f"Found {len(files)} WAV files in {subset} subset")
if files:
print("First 3 files:", files[:3])
return files
def load_audio(self, filename: str) -> torch.Tensor:
"""Load and preprocess audio file"""
with self._zip.open(filename) as wav_file:
waveform, sample_rate = torchaudio.load(wav_file) # type: ignore
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7)
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
return waveform
def get_phonemes(self, filename: str) -> str:
"""Get cleaned phoneme sequence from PHN file and convert to IPA"""
phn_file = filename.replace(".WAV", ".PHN")
with self._zip.open(phn_file) as f:
phonemes = []
for line in f.read().decode("utf-8").splitlines():
if line.strip():
_, _, phone = line.split()
phone = self._remove_stress_mark(phone)
# Convert to IPA instead of using simplify_timit
ipa = self._TIMIT_TO_IPA.get(phone.lower(), "")
if ipa:
phonemes.append(ipa)
return "".join(phonemes) # Join without spaces for IPA
def _remove_stress_mark(self, text: str) -> str:
"""Removes the combining double inverted breve (͡) from text"""
if not isinstance(text, str):
raise TypeError("Input must be string")
return text.replace("͡", "")
# Initialize data managers
timit_manager = TimitDataManager(TIMIT_PATH)
|