whisper-mongolian / whisper_model.py
Nasanbuyan's picture
Upload Transformers-compatible Mongolian Whisper model
00af197 verified
import os
import torch
import json
from transformers import WhisperForConditionalGeneration, WhisperConfig
class ModelLoader:
@staticmethod
def load_model(model_path=".", device="cpu"):
# First try to load as native checkpoint
native_model_path = os.path.join(model_path, "original_model.pt")
if os.path.exists(native_model_path):
return ModelLoader._load_native_model(native_model_path, device)
else:
# Fall back to the transformers API
return ModelLoader._load_transformers_model(model_path, device)
@staticmethod
def _load_native_model(model_path, device):
try:
# Import the necessary modules for the native model
from whisper_impl import WhisperModel as NativeWhisperModel
from whisper_impl import WhisperConfig as NativeConfig
from whisper_impl import SimpleTokenizer
# Load the checkpoint
checkpoint = torch.load(model_path, map_location=device)
# Create config
config = NativeConfig()
for k, v in checkpoint['config'].items():
if not callable(v) and k != "tokenizer":
setattr(config, k, v)
# Create tokenizer
tokenizer = SimpleTokenizer()
vocab_path = os.path.join(os.path.dirname(model_path), "vocab.json")
if os.path.exists(vocab_path):
tokenizer.load_vocab(vocab_path)
config.tokenizer = tokenizer
# Create model
model = NativeWhisperModel(config).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
except ImportError:
# If whisper_impl is not available, fall back to transformers
print("Native model implementation not found. Using Transformers wrapper.")
return ModelLoader._load_transformers_model(os.path.dirname(model_path), device)
@staticmethod
def _load_transformers_model(model_path, device):
# This is a compatibility wrapper for the Transformers API
# It creates a class that mimics the WhisperModel API but uses the transformers model
class TransformersWrapper:
def __init__(self, model_path, device):
self.config = WhisperConfig.from_pretrained(model_path)
self.model = WhisperForConditionalGeneration.from_pretrained(model_path).to(device)
self.device = device
def transcribe(self, audio, beam_size=5):
# This is a simplified implementation - it doesn't handle all the parameters
from transformers import WhisperProcessor
import numpy as np
processor = WhisperProcessor.from_pretrained(model_path)
# Process audio
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(self.device)
# Generate
predicted_ids = self.model.generate(input_features, num_beams=beam_size)
# Decode
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
# Create a segments object that mimics the native API
class Segment:
def __init__(self, text):
self.text = text
segments = [Segment(transcription)]
info = {"language": "mn"}
return segments, info
return TransformersWrapper(model_path, device)
# For compatibility with the test code
WhisperModel = ModelLoader.load_model