whisper-mongolian / whisper_impl.py
Nasanbuyan's picture
Upload Transformers-compatible Mongolian Whisper model
00af197 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class WhisperConfig:
def __init__(self):
# Default values - will be overridden from checkpoint
self.sampling_rate = 16000
self.n_fft = 400
self.hop_length = 160
self.n_mels = 80
self.d_model = 384
self.n_heads = 6
self.n_layers = 4
self.vocab_size = 1000
class SimpleTokenizer:
def __init__(self):
self.token_to_id = {}
self.id_to_token = {}
self.special_tokens = {
"<pad>": 0,
"<s>": 1,
"</s>": 2,
"<unk>": 3,
}
# Initialize with special tokens
for token, idx in self.special_tokens.items():
self.token_to_id[token] = idx
self.id_to_token[idx] = token
self.next_id = len(self.special_tokens)
def load_vocab(self, vocab_file):
import json
with open(vocab_file, 'r', encoding='utf-8') as f:
self.token_to_id = json.load(f)
# Rebuild id_to_token
self.id_to_token = {int(v): k for k, v in self.token_to_id.items()}
self.next_id = max(map(int, self.id_to_token.keys())) + 1
def encode(self, text):
if not isinstance(text, str):
text = str(text)
ids = [self.special_tokens["<s>"]]
for char in text:
if char in self.token_to_id:
ids.append(self.token_to_id[char])
else:
ids.append(self.special_tokens["<unk>"])
ids.append(self.special_tokens["</s>"])
return ids
def decode(self, ids):
text = ""
for id in ids:
# Skip special tokens
if id in [self.special_tokens["<pad>"], self.special_tokens["<s>"], self.special_tokens["</s>"]]:
continue
id_int = int(id) if not isinstance(id, int) else id
if id_int in self.id_to_token:
text += self.id_to_token[id_int]
else:
text += self.id_to_token[self.special_tokens["<unk>"]]
return text
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
import math
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class EncoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
attn_output, _ = self.self_attn(x, x, x, key_padding_mask=mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
ff_output = self.ff(x)
x = x + self.dropout(ff_output)
x = self.norm2(x)
return x
class DecoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, tgt_mask=None, src_mask=None):
# Self-attention
attn_output, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# Cross-attention
attn_output, _ = self.cross_attn(x, enc_output, enc_output, key_padding_mask=src_mask)
x = x + self.dropout(attn_output)
x = self.norm2(x)
# Feed forward
ff_output = self.ff(x)
x = x + self.dropout(ff_output)
x = self.norm3(x)
return x
class AudioEncoder(nn.Module):
def __init__(self, config):
super().__init__()
d_model = config.d_model
# Convolutional front-end
self.conv1 = nn.Conv1d(config.n_mels, d_model, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1)
self.conv4 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1)
self.norm = nn.LayerNorm(d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.layers = nn.ModuleList([
EncoderBlock(d_model, config.n_heads, d_model * 4)
for _ in range(config.n_layers)
])
self.dropout = nn.Dropout(0.1)
def forward(self, x):
# x shape: [batch_size, n_mels, time]
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = F.gelu(self.conv3(x))
x = F.gelu(self.conv4(x))
x = x.transpose(1, 2)
x = self.norm(x)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x)
return x
class TextDecoder(nn.Module):
def __init__(self, config):
super().__init__()
d_model = config.d_model
vocab_size = config.vocab_size
self.token_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model)
self.layers = nn.ModuleList([
DecoderBlock(d_model, config.n_heads, d_model * 4)
for _ in range(config.n_layers)
])
self.output_projection = nn.Linear(d_model, vocab_size)
self.dropout = nn.Dropout(0.1)
def forward(self, x, encoder_output, tgt_mask=None):
x = self.token_embedding(x)
x = self.pos_encoder(x)
for layer in self.layers:
x = layer(x, encoder_output, tgt_mask=tgt_mask)
x = self.output_projection(x)
return x
class WhisperModel(nn.Module):
def __init__(self, config):
super().__init__()
self.encoder = AudioEncoder(config)
self.decoder = TextDecoder(config)
self.config = config
def _create_causal_mask(self, size):
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
return mask.to(next(self.parameters()).device)
def forward(self, audio_features, token_ids, attention_mask=None):
# Encode audio
encoder_output = self.encoder(audio_features)
# Create causal mask for decoder
seq_len = token_ids.size(1)
causal_mask = self._create_causal_mask(seq_len)
# Decode text
output = self.decoder(token_ids, encoder_output, tgt_mask=causal_mask)
return output
def generate(self, audio_features, tokenizer, max_len=100):
batch_size = audio_features.size(0)
# Encode audio
encoder_output = self.encoder(audio_features)
# Initialize with start token
curr_tokens = torch.ones(batch_size, 1).fill_(tokenizer.special_tokens["<s>"]).long().to(next(self.parameters()).device)
# Generate tokens auto-regressively
for i in range(max_len - 1):
# Create causal mask
causal_mask = self._create_causal_mask(curr_tokens.size(1))
# Get next token probabilities
with torch.no_grad():
output = self.decoder(curr_tokens, encoder_output, tgt_mask=causal_mask)
next_token_logits = output[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Append to sequence
curr_tokens = torch.cat([curr_tokens, next_token], dim=1)
# Check if end token is generated
if (next_token == tokenizer.special_tokens["</s>"]).all():
break
return curr_tokens
# Add transcribe method for compatibility with test code
def transcribe(self, audio, beam_size=5):
import numpy as np
import torch
# Process audio if it's a file path
if isinstance(audio, str):
try:
from pydub import AudioSegment
audio_seg = AudioSegment.from_file(audio)
audio_seg = audio_seg.set_channels(1).set_frame_rate(16000)
audio = np.array(audio_seg.get_array_of_samples()).astype(np.float32) / 32768.0
except:
print("Error loading audio file. Using dummy audio.")
audio = np.zeros(16000, dtype=np.float32) # 1 second of silence
# Make sure audio is a numpy array
if not isinstance(audio, np.ndarray):
audio = np.array(audio, dtype=np.float32)
# Convert to torch tensor
if len(audio.shape) == 1:
audio = audio.reshape(1, -1) # Add batch dimension
# Check if we have torch audio to extract features
try:
import torchaudio
# Convert to torch tensor if needed
if not isinstance(audio, torch.Tensor):
audio = torch.from_numpy(audio)
# Extract mel spectrogram
mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=self.config.sampling_rate,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
n_mels=self.config.n_mels
)(audio)
log_mel_spec = torch.log(mel_spec + 1e-9)
# Normalize
mean = log_mel_spec.mean()
std = log_mel_spec.std()
log_mel_spec = (log_mel_spec - mean) / (std + 1e-9)
except ImportError:
# Fallback: create a dummy spectrogram
print("torchaudio not available. Using dummy features.")
log_mel_spec = torch.zeros(1, self.config.n_mels, 100)
# Make sure the spectrogram has the right shape
if log_mel_spec.dim() == 3:
# Already has batch dimension
pass
elif log_mel_spec.dim() == 2:
# Add batch dimension
log_mel_spec = log_mel_spec.unsqueeze(0)
elif log_mel_spec.dim() == 4:
# Remove first dimension
log_mel_spec = log_mel_spec.squeeze(0)
# Move to the same device as the model
log_mel_spec = log_mel_spec.to(next(self.parameters()).device)
# Generate transcription
with torch.no_grad():
generated = self.generate(log_mel_spec, self.config.tokenizer)
# Convert to text
transcription = self.config.tokenizer.decode(generated[0].cpu().numpy())
# Create segments object to match expected output format
class Segment:
def __init__(self, text):
self.text = text
segments = [Segment(transcription)]
info = {"language": "mn"}
return segments, info