|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class WhisperConfig: |
|
def __init__(self): |
|
|
|
self.sampling_rate = 16000 |
|
self.n_fft = 400 |
|
self.hop_length = 160 |
|
self.n_mels = 80 |
|
self.d_model = 384 |
|
self.n_heads = 6 |
|
self.n_layers = 4 |
|
self.vocab_size = 1000 |
|
|
|
class SimpleTokenizer: |
|
def __init__(self): |
|
self.token_to_id = {} |
|
self.id_to_token = {} |
|
self.special_tokens = { |
|
"<pad>": 0, |
|
"<s>": 1, |
|
"</s>": 2, |
|
"<unk>": 3, |
|
} |
|
|
|
|
|
for token, idx in self.special_tokens.items(): |
|
self.token_to_id[token] = idx |
|
self.id_to_token[idx] = token |
|
|
|
self.next_id = len(self.special_tokens) |
|
|
|
def load_vocab(self, vocab_file): |
|
import json |
|
with open(vocab_file, 'r', encoding='utf-8') as f: |
|
self.token_to_id = json.load(f) |
|
|
|
|
|
self.id_to_token = {int(v): k for k, v in self.token_to_id.items()} |
|
self.next_id = max(map(int, self.id_to_token.keys())) + 1 |
|
|
|
def encode(self, text): |
|
if not isinstance(text, str): |
|
text = str(text) |
|
|
|
ids = [self.special_tokens["<s>"]] |
|
for char in text: |
|
if char in self.token_to_id: |
|
ids.append(self.token_to_id[char]) |
|
else: |
|
ids.append(self.special_tokens["<unk>"]) |
|
ids.append(self.special_tokens["</s>"]) |
|
return ids |
|
|
|
def decode(self, ids): |
|
text = "" |
|
for id in ids: |
|
|
|
if id in [self.special_tokens["<pad>"], self.special_tokens["<s>"], self.special_tokens["</s>"]]: |
|
continue |
|
|
|
id_int = int(id) if not isinstance(id, int) else id |
|
if id_int in self.id_to_token: |
|
text += self.id_to_token[id_int] |
|
else: |
|
text += self.id_to_token[self.special_tokens["<unk>"]] |
|
|
|
return text |
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, d_model, max_len=5000): |
|
super().__init__() |
|
import math |
|
pe = torch.zeros(max_len, d_model) |
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
|
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
pe = pe.unsqueeze(0) |
|
|
|
self.register_buffer('pe', pe) |
|
|
|
def forward(self, x): |
|
return x + self.pe[:, :x.size(1)] |
|
|
|
class EncoderBlock(nn.Module): |
|
def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.ff = nn.Sequential( |
|
nn.Linear(d_model, d_ff), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(d_ff, d_model) |
|
) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x, mask=None): |
|
attn_output, _ = self.self_attn(x, x, x, key_padding_mask=mask) |
|
x = x + self.dropout(attn_output) |
|
x = self.norm1(x) |
|
|
|
ff_output = self.ff(x) |
|
x = x + self.dropout(ff_output) |
|
x = self.norm2(x) |
|
|
|
return x |
|
|
|
class DecoderBlock(nn.Module): |
|
def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1): |
|
super().__init__() |
|
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
|
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
|
self.norm1 = nn.LayerNorm(d_model) |
|
self.norm2 = nn.LayerNorm(d_model) |
|
self.norm3 = nn.LayerNorm(d_model) |
|
self.ff = nn.Sequential( |
|
nn.Linear(d_model, d_ff), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(d_ff, d_model) |
|
) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x, enc_output, tgt_mask=None, src_mask=None): |
|
|
|
attn_output, _ = self.self_attn(x, x, x, attn_mask=tgt_mask) |
|
x = x + self.dropout(attn_output) |
|
x = self.norm1(x) |
|
|
|
|
|
attn_output, _ = self.cross_attn(x, enc_output, enc_output, key_padding_mask=src_mask) |
|
x = x + self.dropout(attn_output) |
|
x = self.norm2(x) |
|
|
|
|
|
ff_output = self.ff(x) |
|
x = x + self.dropout(ff_output) |
|
x = self.norm3(x) |
|
|
|
return x |
|
|
|
class AudioEncoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
d_model = config.d_model |
|
|
|
|
|
self.conv1 = nn.Conv1d(config.n_mels, d_model, kernel_size=3, stride=1, padding=1) |
|
self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1) |
|
self.conv3 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1) |
|
self.conv4 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1) |
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
self.pos_encoder = PositionalEncoding(d_model) |
|
|
|
self.layers = nn.ModuleList([ |
|
EncoderBlock(d_model, config.n_heads, d_model * 4) |
|
for _ in range(config.n_layers) |
|
]) |
|
|
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, x): |
|
|
|
x = F.gelu(self.conv1(x)) |
|
x = F.gelu(self.conv2(x)) |
|
x = F.gelu(self.conv3(x)) |
|
x = F.gelu(self.conv4(x)) |
|
|
|
x = x.transpose(1, 2) |
|
x = self.norm(x) |
|
x = self.pos_encoder(x) |
|
|
|
for layer in self.layers: |
|
x = layer(x) |
|
|
|
return x |
|
|
|
class TextDecoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
d_model = config.d_model |
|
vocab_size = config.vocab_size |
|
|
|
self.token_embedding = nn.Embedding(vocab_size, d_model) |
|
self.pos_encoder = PositionalEncoding(d_model) |
|
|
|
self.layers = nn.ModuleList([ |
|
DecoderBlock(d_model, config.n_heads, d_model * 4) |
|
for _ in range(config.n_layers) |
|
]) |
|
|
|
self.output_projection = nn.Linear(d_model, vocab_size) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, x, encoder_output, tgt_mask=None): |
|
x = self.token_embedding(x) |
|
x = self.pos_encoder(x) |
|
|
|
for layer in self.layers: |
|
x = layer(x, encoder_output, tgt_mask=tgt_mask) |
|
|
|
x = self.output_projection(x) |
|
return x |
|
|
|
class WhisperModel(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.encoder = AudioEncoder(config) |
|
self.decoder = TextDecoder(config) |
|
self.config = config |
|
|
|
def _create_causal_mask(self, size): |
|
mask = torch.triu(torch.ones(size, size), diagonal=1).bool() |
|
return mask.to(next(self.parameters()).device) |
|
|
|
def forward(self, audio_features, token_ids, attention_mask=None): |
|
|
|
encoder_output = self.encoder(audio_features) |
|
|
|
|
|
seq_len = token_ids.size(1) |
|
causal_mask = self._create_causal_mask(seq_len) |
|
|
|
|
|
output = self.decoder(token_ids, encoder_output, tgt_mask=causal_mask) |
|
|
|
return output |
|
|
|
def generate(self, audio_features, tokenizer, max_len=100): |
|
batch_size = audio_features.size(0) |
|
|
|
|
|
encoder_output = self.encoder(audio_features) |
|
|
|
|
|
curr_tokens = torch.ones(batch_size, 1).fill_(tokenizer.special_tokens["<s>"]).long().to(next(self.parameters()).device) |
|
|
|
|
|
for i in range(max_len - 1): |
|
|
|
causal_mask = self._create_causal_mask(curr_tokens.size(1)) |
|
|
|
|
|
with torch.no_grad(): |
|
output = self.decoder(curr_tokens, encoder_output, tgt_mask=causal_mask) |
|
next_token_logits = output[:, -1, :] |
|
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
|
|
curr_tokens = torch.cat([curr_tokens, next_token], dim=1) |
|
|
|
|
|
if (next_token == tokenizer.special_tokens["</s>"]).all(): |
|
break |
|
|
|
return curr_tokens |
|
|
|
|
|
def transcribe(self, audio, beam_size=5): |
|
import numpy as np |
|
import torch |
|
|
|
|
|
if isinstance(audio, str): |
|
try: |
|
from pydub import AudioSegment |
|
audio_seg = AudioSegment.from_file(audio) |
|
audio_seg = audio_seg.set_channels(1).set_frame_rate(16000) |
|
audio = np.array(audio_seg.get_array_of_samples()).astype(np.float32) / 32768.0 |
|
except: |
|
print("Error loading audio file. Using dummy audio.") |
|
audio = np.zeros(16000, dtype=np.float32) |
|
|
|
|
|
if not isinstance(audio, np.ndarray): |
|
audio = np.array(audio, dtype=np.float32) |
|
|
|
|
|
if len(audio.shape) == 1: |
|
audio = audio.reshape(1, -1) |
|
|
|
|
|
try: |
|
import torchaudio |
|
|
|
|
|
if not isinstance(audio, torch.Tensor): |
|
audio = torch.from_numpy(audio) |
|
|
|
|
|
mel_spec = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=self.config.sampling_rate, |
|
n_fft=self.config.n_fft, |
|
hop_length=self.config.hop_length, |
|
n_mels=self.config.n_mels |
|
)(audio) |
|
|
|
log_mel_spec = torch.log(mel_spec + 1e-9) |
|
|
|
|
|
mean = log_mel_spec.mean() |
|
std = log_mel_spec.std() |
|
log_mel_spec = (log_mel_spec - mean) / (std + 1e-9) |
|
|
|
except ImportError: |
|
|
|
print("torchaudio not available. Using dummy features.") |
|
log_mel_spec = torch.zeros(1, self.config.n_mels, 100) |
|
|
|
|
|
if log_mel_spec.dim() == 3: |
|
|
|
pass |
|
elif log_mel_spec.dim() == 2: |
|
|
|
log_mel_spec = log_mel_spec.unsqueeze(0) |
|
elif log_mel_spec.dim() == 4: |
|
|
|
log_mel_spec = log_mel_spec.squeeze(0) |
|
|
|
|
|
log_mel_spec = log_mel_spec.to(next(self.parameters()).device) |
|
|
|
|
|
with torch.no_grad(): |
|
generated = self.generate(log_mel_spec, self.config.tokenizer) |
|
|
|
|
|
transcription = self.config.tokenizer.decode(generated[0].cpu().numpy()) |
|
|
|
|
|
class Segment: |
|
def __init__(self, text): |
|
self.text = text |
|
|
|
segments = [Segment(transcription)] |
|
info = {"language": "mn"} |
|
|
|
return segments, info |
|
|