File size: 5,481 Bytes
38024bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# This module handles the data loading and preprocessing for various phoneme transcription datasets.

import torch
import torchaudio

import zipfile
from pathlib import Path

# Get absolute path
CURRENT_DIR = Path(__file__).parent.absolute()

# Constants
DATA_DIR = CURRENT_DIR / "data"
TIMIT_PATH = DATA_DIR / "TIMIT.zip"


# Abstract data manager class
class DataManager:
    """Abstract class for handling dataset operations"""

    def get_file_list(self, subset: str) -> list[str]:
        """Get list of files for given subset"""
        raise NotImplementedError

    def load_audio(self, filename: str) -> torch.Tensor:
        """Load and preprocess audio file"""
        raise NotImplementedError

    def get_phonemes(self, filename: str) -> str:
        """Get phoneme sequence from file"""
        raise NotImplementedError


# Implement datasets
class TimitDataManager(DataManager):
    """Handles all TIMIT dataset operations"""

    # TIMIT to IPA mapping with direct simplifications
    _TIMIT_TO_IPA = {
        # Vowels (simplified)
        "aa": "ɑ",
        "ae": "æ",
        "ah": "ʌ",
        "ao": "ɔ",
        "aw": "aʊ",
        "ay": "aɪ",
        "eh": "ɛ",
        "er": "ɹ",  # Simplified from 'ɝ'
        "ey": "eɪ",
        "ih": "ɪ",
        "ix": "i",  # Simplified from 'ɨ'
        "iy": "i",
        "ow": "oʊ",
        "oy": "ɔɪ",
        "uh": "ʊ",
        "uw": "u",
        "ux": "u",  # Simplified from 'ʉ'
        "ax": "ə",
        "ax-h": "ə",  # Simplified from 'ə̥'
        "axr": "ɹ",  # Simplified from 'ɚ'
        # Consonants
        "b": "",
        "bcl": "b",
        "d": "",
        "dcl": "d",
        "g": "",
        "gcl": "g",
        "p": "",
        "pcl": "p",
        "t": "",
        "tcl": "t",
        "k": "",
        "kcl": "k",
        "dx": "ɾ",
        "q": "ʔ",
        # Fricatives
        "jh": "dʒ",
        "ch": "tʃ",
        "s": "s",
        "sh": "ʃ",
        "z": "z",
        "zh": "ʒ",
        "f": "f",
        "th": "θ",
        "v": "v",
        "dh": "ð",
        "hh": "h",
        "hv": "h",  # Simplified from 'ɦ'
        # Nasals (simplified)
        "m": "m",
        "n": "n",
        "ng": "ŋ",
        "em": "m",  # Simplified from 'm̩'
        "en": "n",  # Simplified from 'n̩'
        "eng": "ŋ",  # Simplified from 'ŋ̍'
        "nx": "ɾ",  # Simplified from 'ɾ̃'
        # Semivowels and Glides
        "l": "l",
        "r": "ɹ",
        "w": "w",
        "wh": "ʍ",
        "y": "j",
        "el": "l",  # Simplified from 'l̩'
        # Special
        "epi": "",  # Remove epenthetic silence
        "h#": "",  # Remove start/end silence
        "pau": "",  # Remove pause
    }

    def __init__(self, timit_path: Path):
        self.timit_path = timit_path
        self._zip_ = None
        print(f"TimitDataManager initialized with path: {self.timit_path.absolute()}")
        if not self.timit_path.exists():
            raise FileNotFoundError(
                f"TIMIT dataset not found at {self.timit_path.absolute()}. Try running ./scripts/download_data_lfs.sh again."
            )
        else:
            print("TIMIT dataset file exists!")

    @property
    def _zip(self):
        if not self._zip_:
            self._zip_ = zipfile.ZipFile(self.timit_path, "r")
        return self._zip_

    def get_file_list(self, subset: str) -> list[str]:
        """Get list of WAV files for given subset"""
        files = [
            f
            for f in self._zip.namelist()
            if f.endswith(".WAV") and subset.lower() in f.lower()
        ]
        print(f"Found {len(files)} WAV files in {subset} subset")
        if files:
            print("First 3 files:", files[:3])
        return files

    def load_audio(self, filename: str) -> torch.Tensor:
        """Load and preprocess audio file"""
        with self._zip.open(filename) as wav_file:
            waveform, sample_rate = torchaudio.load(wav_file)  # type: ignore

            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            if sample_rate != 16000:
                waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)

            waveform = (waveform - waveform.mean()) / (waveform.std() + 1e-7)

            if waveform.dim() == 1:
                waveform = waveform.unsqueeze(0)

            return waveform

    def get_phonemes(self, filename: str) -> str:
        """Get cleaned phoneme sequence from PHN file and convert to IPA"""
        phn_file = filename.replace(".WAV", ".PHN")
        with self._zip.open(phn_file) as f:
            phonemes = []
            for line in f.read().decode("utf-8").splitlines():
                if line.strip():
                    _, _, phone = line.split()
                    phone = self._remove_stress_mark(phone)
                    # Convert to IPA instead of using simplify_timit
                    ipa = self._TIMIT_TO_IPA.get(phone.lower(), "")
                    if ipa:
                        phonemes.append(ipa)
            return "".join(phonemes)  # Join without spaces for IPA

    def _remove_stress_mark(self, text: str) -> str:
        """Removes the combining double inverted breve (͡) from text"""
        if not isinstance(text, str):
            raise TypeError("Input must be string")
        return text.replace("͡", "")


# Initialize data managers
timit_manager = TimitDataManager(TIMIT_PATH)