|
|
|
""" |
|
Model Preloader for Multilingual Audio Intelligence System - Enhanced Version |
|
|
|
Key improvements: |
|
1. Smart local cache detection with corruption checking |
|
2. Fallback to download if local files don't exist or are corrupted |
|
3. Better error handling and retry mechanisms |
|
4. Consistent approach across all model types |
|
""" |
|
|
|
import os |
|
import sys |
|
import logging |
|
import time |
|
from pathlib import Path |
|
from typing import Dict, Any, Optional |
|
import json |
|
from datetime import datetime |
|
|
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import whisper |
|
from pyannote.audio import Pipeline |
|
from rich.console import Console |
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn |
|
from rich.panel import Panel |
|
from rich.text import Text |
|
import psutil |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
logger = logging.getLogger(__name__) |
|
|
|
console = Console() |
|
|
|
class ModelPreloader: |
|
"""Comprehensive model preloader with enhanced local cache detection.""" |
|
|
|
def __init__(self, cache_dir: str = "./model_cache", device: str = "auto"): |
|
self.cache_dir = Path(cache_dir) |
|
self.cache_dir.mkdir(exist_ok=True) |
|
|
|
|
|
if device == "auto": |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
else: |
|
self.device = device |
|
|
|
self.models = {} |
|
self.model_info = {} |
|
|
|
|
|
self.model_configs = { |
|
"speaker_diarization": { |
|
"name": "pyannote/speaker-diarization-3.1", |
|
"type": "pyannote", |
|
"description": "Speaker Diarization Pipeline", |
|
"size_mb": 32 |
|
}, |
|
"whisper_small": { |
|
"name": "openai/whisper-small", |
|
"type": "whisper", |
|
"description": "Whisper Speech Recognition (Small)", |
|
"size_mb": 484 |
|
}, |
|
"mbart_translation": { |
|
"name": "facebook/mbart-large-50-many-to-many-mmt", |
|
"type": "mbart", |
|
"description": "mBART Neural Machine Translation", |
|
"size_mb": 2440 |
|
}, |
|
|
|
"opus_mt_ja_en": { |
|
"name": "Helsinki-NLP/opus-mt-ja-en", |
|
"type": "opus_mt", |
|
"description": "Japanese to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_es_en": { |
|
"name": "Helsinki-NLP/opus-mt-es-en", |
|
"type": "opus_mt", |
|
"description": "Spanish to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_fr_en": { |
|
"name": "Helsinki-NLP/opus-mt-fr-en", |
|
"type": "opus_mt", |
|
"description": "French to English Translation", |
|
"size_mb": 303 |
|
}, |
|
|
|
"opus_mt_hi_en": { |
|
"name": "Helsinki-NLP/opus-mt-hi-en", |
|
"type": "opus_mt", |
|
"description": "Hindi to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_ta_en": { |
|
"name": "Helsinki-NLP/opus-mt-ta-en", |
|
"type": "opus_mt", |
|
"description": "Tamil to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_bn_en": { |
|
"name": "Helsinki-NLP/opus-mt-bn-en", |
|
"type": "opus_mt", |
|
"description": "Bengali to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_te_en": { |
|
"name": "Helsinki-NLP/opus-mt-te-en", |
|
"type": "opus_mt", |
|
"description": "Telugu to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_mr_en": { |
|
"name": "Helsinki-NLP/opus-mt-mr-en", |
|
"type": "opus_mt", |
|
"description": "Marathi to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_gu_en": { |
|
"name": "Helsinki-NLP/opus-mt-gu-en", |
|
"type": "opus_mt", |
|
"description": "Gujarati to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_kn_en": { |
|
"name": "Helsinki-NLP/opus-mt-kn-en", |
|
"type": "opus_mt", |
|
"description": "Kannada to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_pa_en": { |
|
"name": "Helsinki-NLP/opus-mt-pa-en", |
|
"type": "opus_mt", |
|
"description": "Punjabi to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_ml_en": { |
|
"name": "Helsinki-NLP/opus-mt-ml-en", |
|
"type": "opus_mt", |
|
"description": "Malayalam to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_ne_en": { |
|
"name": "Helsinki-NLP/opus-mt-ne-en", |
|
"type": "opus_mt", |
|
"description": "Nepali to English Translation", |
|
"size_mb": 303 |
|
}, |
|
"opus_mt_ur_en": { |
|
"name": "Helsinki-NLP/opus-mt-ur-en", |
|
"type": "opus_mt", |
|
"description": "Urdu to English Translation", |
|
"size_mb": 303 |
|
} |
|
} |
|
|
|
def check_local_model_files(self, model_name: str, model_type: str) -> bool: |
|
""" |
|
Check if model files exist locally and are not corrupted. |
|
Returns True if valid local files exist, False otherwise. |
|
""" |
|
try: |
|
if model_type == "whisper": |
|
|
|
whisper_cache = self.cache_dir / "whisper" / "models--Systran--faster-whisper-small" |
|
required_files = ["config.json", "model.bin", "tokenizer.json", "vocabulary.txt"] |
|
|
|
|
|
snapshots_dir = whisper_cache / "snapshots" |
|
if not snapshots_dir.exists(): |
|
return False |
|
|
|
|
|
snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()] |
|
if not snapshot_dirs: |
|
return False |
|
|
|
|
|
snapshot_path = snapshot_dirs[0] |
|
for file in required_files: |
|
file_path = snapshot_path / file |
|
if not file_path.exists() or file_path.stat().st_size == 0: |
|
return False |
|
|
|
return True |
|
|
|
elif model_type in ["mbart", "opus_mt"]: |
|
|
|
if model_type == "mbart": |
|
model_cache_path = self.cache_dir / "mbart" / f"models--{model_name.replace('/', '--')}" |
|
else: |
|
model_cache_path = self.cache_dir / "opus_mt" / f"{model_name.replace('/', '--')}" / f"models--{model_name.replace('/', '--')}" |
|
|
|
required_files = ["config.json", "tokenizer_config.json"] |
|
|
|
model_files = ["pytorch_model.bin", "model.safetensors"] |
|
|
|
|
|
snapshots_dir = model_cache_path / "snapshots" |
|
if not snapshots_dir.exists(): |
|
return False |
|
|
|
|
|
snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()] |
|
if not snapshot_dirs: |
|
return False |
|
|
|
|
|
snapshot_path = max(snapshot_dirs, key=lambda x: x.stat().st_mtime) |
|
|
|
|
|
for file in required_files: |
|
file_path = snapshot_path / file |
|
if not file_path.exists() or file_path.stat().st_size == 0: |
|
return False |
|
|
|
|
|
model_file_exists = any( |
|
(snapshot_path / model_file).exists() and (snapshot_path / model_file).stat().st_size > 0 |
|
for model_file in model_files |
|
) |
|
|
|
return model_file_exists |
|
|
|
elif model_type == "pyannote": |
|
|
|
|
|
return False |
|
|
|
except Exception as e: |
|
logger.warning(f"Error checking local files for {model_name}: {e}") |
|
return False |
|
|
|
return False |
|
|
|
def load_transformers_model_with_cache_check(self, model_name: str, cache_path: Path, model_type: str = "seq2seq") -> Optional[Dict[str, Any]]: |
|
""" |
|
Load transformers model with intelligent cache checking and fallback. |
|
""" |
|
try: |
|
|
|
has_local_files = self.check_local_model_files(model_name, "mbart" if "mbart" in model_name else "opus_mt") |
|
|
|
if has_local_files: |
|
console.print(f"[green]Found valid local cache for {model_name}, loading from cache...[/green]") |
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=str(cache_path), |
|
local_files_only=True |
|
) |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
cache_dir=str(cache_path), |
|
local_files_only=True, |
|
torch_dtype=torch.float32 if self.device == "cpu" else torch.float16 |
|
) |
|
|
|
console.print(f"[green]SUCCESS: Successfully loaded {model_name} from local cache[/green]") |
|
|
|
except Exception as e: |
|
console.print(f"[yellow]Local cache load failed for {model_name}, will download: {e}[/yellow]") |
|
has_local_files = False |
|
|
|
if not has_local_files: |
|
console.print(f"[yellow]No valid local cache for {model_name}, downloading...[/yellow]") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=str(cache_path) |
|
) |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_name, |
|
cache_dir=str(cache_path), |
|
torch_dtype=torch.float32 if self.device == "cpu" else torch.float16 |
|
) |
|
|
|
console.print(f"[green]SUCCESS: Successfully downloaded and loaded {model_name}[/green]") |
|
|
|
|
|
if self.device != "cpu": |
|
model = model.to(self.device) |
|
|
|
|
|
test_input = tokenizer("Hello world", return_tensors="pt") |
|
if self.device != "cpu": |
|
test_input = {k: v.to(self.device) for k, v in test_input.items()} |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**test_input, max_length=10) |
|
|
|
return { |
|
"model": model, |
|
"tokenizer": tokenizer |
|
} |
|
|
|
except Exception as e: |
|
console.print(f"[red]✗ Failed to load {model_name}: {e}[/red]") |
|
logger.error(f"Model loading failed for {model_name}: {e}") |
|
return None |
|
|
|
def get_system_info(self) -> Dict[str, Any]: |
|
"""Get system information for optimal model loading.""" |
|
return { |
|
"cpu_count": psutil.cpu_count(), |
|
"memory_gb": round(psutil.virtual_memory().total / (1024**3), 2), |
|
"available_memory_gb": round(psutil.virtual_memory().available / (1024**3), 2), |
|
"device": self.device, |
|
"torch_version": torch.__version__, |
|
"cuda_available": torch.cuda.is_available(), |
|
"gpu_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None |
|
} |
|
|
|
def check_model_cache(self, model_key: str) -> bool: |
|
"""Check if model is already cached and working.""" |
|
cache_file = self.cache_dir / f"{model_key}_info.json" |
|
if not cache_file.exists(): |
|
return False |
|
|
|
try: |
|
with open(cache_file, 'r') as f: |
|
cache_info = json.load(f) |
|
|
|
|
|
cache_time = datetime.fromisoformat(cache_info['timestamp']) |
|
days_old = (datetime.now() - cache_time).days |
|
|
|
if days_old > 7: |
|
logger.info(f"Cache for {model_key} is {days_old} days old, will refresh") |
|
return False |
|
|
|
return cache_info.get('status') == 'success' |
|
except Exception as e: |
|
logger.warning(f"Error reading cache for {model_key}: {e}") |
|
return False |
|
|
|
def save_model_cache(self, model_key: str, status: str, info: Dict[str, Any]): |
|
"""Save model loading information to cache.""" |
|
cache_file = self.cache_dir / f"{model_key}_info.json" |
|
cache_data = { |
|
"timestamp": datetime.now().isoformat(), |
|
"status": status, |
|
"device": self.device, |
|
"info": info |
|
} |
|
|
|
try: |
|
with open(cache_file, 'w') as f: |
|
json.dump(cache_data, f, indent=2) |
|
except Exception as e: |
|
logger.warning(f"Error saving cache for {model_key}: {e}") |
|
|
|
def load_pyannote_pipeline(self, task_id: str) -> Optional[Pipeline]: |
|
"""Load pyannote speaker diarization pipeline with container-safe settings.""" |
|
try: |
|
console.print(f"[yellow]Loading pyannote.audio pipeline...[/yellow]") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_token = os.getenv('HUGGINGFACE_TOKEN') or os.getenv('HF_TOKEN') |
|
if not hf_token: |
|
console.print("[red]Warning: HUGGINGFACE_TOKEN not found. Some models may not be accessible.[/red]") |
|
|
|
|
|
import warnings |
|
import logging |
|
|
|
|
|
old_warning_filters = warnings.filters[:] |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
os.environ['ORT_LOGGING_LEVEL'] = '3' |
|
|
|
|
|
|
|
logging.getLogger('transformers').setLevel(logging.ERROR) |
|
|
|
try: |
|
pipeline = Pipeline.from_pretrained( |
|
"pyannote/speaker-diarization-3.1", |
|
use_auth_token=hf_token, |
|
cache_dir=str(self.cache_dir / "pyannote") |
|
) |
|
|
|
|
|
if hasattr(pipeline, '_models'): |
|
for model_name, model in pipeline._models.items(): |
|
if hasattr(model, 'to'): |
|
model.to('cpu') |
|
|
|
console.print(f"[green]SUCCESS: pyannote.audio pipeline loaded successfully on CPU[/green]") |
|
return pipeline |
|
|
|
finally: |
|
|
|
warnings.filters[:] = old_warning_filters |
|
|
|
except Exception as e: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.error(f"Pyannote loading failed: {e}") |
|
return None |
|
|
|
def load_whisper_model(self, task_id: str) -> Optional[whisper.Whisper]: |
|
"""Load Whisper speech recognition model with enhanced cache checking.""" |
|
try: |
|
console.print(f"[yellow]Loading Whisper model (small)...[/yellow]") |
|
|
|
whisper_cache_dir = self.cache_dir / "whisper" |
|
|
|
|
|
has_local_files = self.check_local_model_files("small", "whisper") |
|
|
|
if has_local_files: |
|
console.print(f"[green]Found valid local Whisper cache, loading from cache...[/green]") |
|
else: |
|
console.print(f"[yellow]No valid local Whisper cache found, will download...[/yellow]") |
|
|
|
|
|
model = whisper.load_model("small", device=self.device) |
|
|
|
|
|
import numpy as np |
|
dummy_audio = np.zeros(16000, dtype=np.float32) |
|
result = model.transcribe(dummy_audio, language="en") |
|
|
|
console.print(f"[green]SUCCESS: Whisper model loaded successfully on {self.device}[/green]") |
|
|
|
return model |
|
|
|
except Exception as e: |
|
console.print(f"[red]ERROR: Failed to load Whisper model: {e}[/red]") |
|
logger.error(f"Whisper loading failed: {e}") |
|
return None |
|
|
|
def load_mbart_model(self, task_id: str) -> Optional[Dict[str, Any]]: |
|
"""Load mBART translation model with enhanced cache checking.""" |
|
console.print(f"[yellow]Loading mBART translation model...[/yellow]") |
|
|
|
model_name = "facebook/mbart-large-50-many-to-many-mmt" |
|
cache_path = self.cache_dir / "mbart" |
|
cache_path.mkdir(exist_ok=True) |
|
|
|
return self.load_transformers_model_with_cache_check(model_name, cache_path, "seq2seq") |
|
|
|
def load_opus_mt_model(self, task_id: str, model_name: str) -> Optional[Dict[str, Any]]: |
|
"""Load Opus-MT translation model with enhanced cache checking.""" |
|
console.print(f"[yellow]Loading Opus-MT model: {model_name}...[/yellow]") |
|
|
|
cache_path = self.cache_dir / "opus_mt" / model_name.replace("/", "--") |
|
cache_path.mkdir(parents=True, exist_ok=True) |
|
|
|
return self.load_transformers_model_with_cache_check(model_name, cache_path, "seq2seq") |
|
|
|
def preload_all_models(self) -> Dict[str, Any]: |
|
"""Preload all models with progress tracking.""" |
|
|
|
|
|
sys_info = self.get_system_info() |
|
|
|
info_panel = Panel.fit( |
|
f"""System Information |
|
|
|
• CPU Cores: {sys_info['cpu_count']} |
|
• Total Memory: {sys_info['memory_gb']} GB |
|
• Available Memory: {sys_info['available_memory_gb']} GB |
|
• Device: {sys_info['device'].upper()} |
|
• PyTorch: {sys_info['torch_version']} |
|
• CUDA Available: {sys_info['cuda_available']} |
|
{f"• GPU: {sys_info['gpu_name']}" if sys_info['gpu_name'] else ""}""", |
|
title="[bold blue]Audio Intelligence System[/bold blue]", |
|
border_style="blue" |
|
) |
|
console.print(info_panel) |
|
console.print() |
|
|
|
results = { |
|
"system_info": sys_info, |
|
"models": {}, |
|
"total_time": 0, |
|
"success_count": 0, |
|
"total_count": len(self.model_configs) |
|
} |
|
|
|
start_time = time.time() |
|
|
|
with Progress( |
|
SpinnerColumn(), |
|
TextColumn("[progress.description]{task.description}"), |
|
BarColumn(), |
|
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), |
|
TimeRemainingColumn(), |
|
console=console |
|
) as progress: |
|
|
|
|
|
main_task = progress.add_task("[cyan]Loading AI Models...", total=len(self.model_configs)) |
|
|
|
|
|
for model_key, config in self.model_configs.items(): |
|
task_id = progress.add_task(f"[yellow]{config['description']}", total=100) |
|
|
|
|
|
if self.check_model_cache(model_key): |
|
console.print(f"[green]SUCCESS: {config['description']} found in cache[/green]") |
|
progress.update(task_id, completed=100) |
|
progress.update(main_task, advance=1) |
|
results["models"][model_key] = {"status": "cached", "time": 0} |
|
results["success_count"] += 1 |
|
continue |
|
|
|
model_start_time = time.time() |
|
progress.update(task_id, completed=10) |
|
|
|
|
|
if config["type"] == "pyannote": |
|
model = self.load_pyannote_pipeline(task_id) |
|
elif config["type"] == "whisper": |
|
model = self.load_whisper_model(task_id) |
|
elif config["type"] == "mbart": |
|
model = self.load_mbart_model(task_id) |
|
elif config["type"] == "opus_mt": |
|
model = self.load_opus_mt_model(task_id, config["name"]) |
|
else: |
|
model = None |
|
|
|
model_time = time.time() - model_start_time |
|
|
|
if model is not None: |
|
self.models[model_key] = model |
|
progress.update(task_id, completed=100) |
|
results["models"][model_key] = {"status": "success", "time": model_time} |
|
results["success_count"] += 1 |
|
|
|
|
|
self.save_model_cache(model_key, "success", { |
|
"load_time": model_time, |
|
"device": self.device, |
|
"model_name": config["name"] |
|
}) |
|
else: |
|
progress.update(task_id, completed=100) |
|
results["models"][model_key] = {"status": "failed", "time": model_time} |
|
|
|
|
|
self.save_model_cache(model_key, "failed", { |
|
"load_time": model_time, |
|
"device": self.device, |
|
"error": "Model loading failed" |
|
}) |
|
|
|
progress.update(main_task, advance=1) |
|
|
|
results["total_time"] = time.time() - start_time |
|
|
|
|
|
console.print() |
|
if results["success_count"] == results["total_count"]: |
|
status_text = "[bold green]SUCCESS: All models loaded successfully![/bold green]" |
|
status_color = "green" |
|
elif results["success_count"] > 0: |
|
status_text = f"[bold yellow]WARNING: {results['success_count']}/{results['total_count']} models loaded[/bold yellow]" |
|
status_color = "yellow" |
|
else: |
|
status_text = "[bold red]ERROR: No models loaded successfully[/bold red]" |
|
status_color = "red" |
|
|
|
summary_panel = Panel.fit( |
|
f"""{status_text} |
|
|
|
• Loading Time: {results['total_time']:.1f} seconds |
|
• Device: {self.device.upper()} |
|
• Memory Usage: {psutil.virtual_memory().percent:.1f}% |
|
• Models Ready: {results['success_count']}/{results['total_count']}""", |
|
title="[bold]Model Loading Summary[/bold]", |
|
border_style=status_color |
|
) |
|
console.print(summary_panel) |
|
|
|
return results |
|
|
|
def get_models(self) -> Dict[str, Any]: |
|
"""Get loaded models.""" |
|
return self.models |
|
|
|
def cleanup(self): |
|
"""Cleanup resources.""" |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def main(): |
|
"""Main function to run model preloading.""" |
|
|
|
console.print(Panel.fit( |
|
"[bold blue]Multilingual Audio Intelligence System[/bold blue]\n[yellow]Model Preloader[/yellow]", |
|
border_style="blue" |
|
)) |
|
console.print() |
|
|
|
|
|
preloader = ModelPreloader() |
|
|
|
|
|
try: |
|
results = preloader.preload_all_models() |
|
|
|
if results["success_count"] > 0: |
|
console.print("\n[bold green]SUCCESS: Model preloading completed![/bold green]") |
|
console.print(f"[dim]Models cached in: {preloader.cache_dir}[/dim]") |
|
return True |
|
else: |
|
console.print("\n[bold red]ERROR: Model preloading failed![/bold red]") |
|
return False |
|
|
|
except KeyboardInterrupt: |
|
console.print("\n[yellow]Model preloading interrupted by user[/yellow]") |
|
return False |
|
except Exception as e: |
|
console.print(f"\n[bold red]✗ Model preloading failed: {e}[/bold red]") |
|
logger.error(f"Preloading failed: {e}") |
|
return False |
|
finally: |
|
preloader.cleanup() |
|
|
|
|
|
if __name__ == "__main__": |
|
success = main() |
|
sys.exit(0 if success else 1) |