|
|
|
import os |
|
import torch |
|
import json |
|
from transformers import WhisperForConditionalGeneration, WhisperConfig |
|
|
|
class ModelLoader: |
|
@staticmethod |
|
def load_model(model_path=".", device="cpu"): |
|
|
|
native_model_path = os.path.join(model_path, "original_model.pt") |
|
if os.path.exists(native_model_path): |
|
return ModelLoader._load_native_model(native_model_path, device) |
|
else: |
|
|
|
return ModelLoader._load_transformers_model(model_path, device) |
|
|
|
@staticmethod |
|
def _load_native_model(model_path, device): |
|
try: |
|
|
|
from whisper_impl import WhisperModel as NativeWhisperModel |
|
from whisper_impl import WhisperConfig as NativeConfig |
|
from whisper_impl import SimpleTokenizer |
|
|
|
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
|
|
|
config = NativeConfig() |
|
for k, v in checkpoint['config'].items(): |
|
if not callable(v) and k != "tokenizer": |
|
setattr(config, k, v) |
|
|
|
|
|
tokenizer = SimpleTokenizer() |
|
vocab_path = os.path.join(os.path.dirname(model_path), "vocab.json") |
|
if os.path.exists(vocab_path): |
|
tokenizer.load_vocab(vocab_path) |
|
config.tokenizer = tokenizer |
|
|
|
|
|
model = NativeWhisperModel(config).to(device) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.eval() |
|
|
|
return model |
|
except ImportError: |
|
|
|
print("Native model implementation not found. Using Transformers wrapper.") |
|
return ModelLoader._load_transformers_model(os.path.dirname(model_path), device) |
|
|
|
@staticmethod |
|
def _load_transformers_model(model_path, device): |
|
|
|
|
|
|
|
class TransformersWrapper: |
|
def __init__(self, model_path, device): |
|
self.config = WhisperConfig.from_pretrained(model_path) |
|
self.model = WhisperForConditionalGeneration.from_pretrained(model_path).to(device) |
|
self.device = device |
|
|
|
def transcribe(self, audio, beam_size=5): |
|
|
|
from transformers import WhisperProcessor |
|
import numpy as np |
|
|
|
processor = WhisperProcessor.from_pretrained(model_path) |
|
|
|
|
|
input_features = processor(audio, sampling_rate=16000, return_tensors="pt").input_features.to(self.device) |
|
|
|
|
|
predicted_ids = self.model.generate(input_features, num_beams=beam_size) |
|
|
|
|
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
|
|
|
|
class Segment: |
|
def __init__(self, text): |
|
self.text = text |
|
|
|
segments = [Segment(transcription)] |
|
info = {"language": "mn"} |
|
|
|
return segments, info |
|
|
|
return TransformersWrapper(model_path, device) |
|
|
|
|
|
WhisperModel = ModelLoader.load_model |
|
|