from transformers import pipeline from diffusers import StableDiffusionPipeline from core.config import device, cache_dir import torch model_loaders = { 'sentiment_analysis': lambda: pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=device), 'text_classification': lambda: pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", device=device), 'summarization': lambda: pipeline("summarization", model="facebook/bart-large-cnn", device=device, max_length=150, min_length=30), 'question_answering': lambda: pipeline("question-answering", model="deepset/roberta-base-squad2", device=device), 'translation': lambda: pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-pt", device=device), 'text_generation': lambda: pipeline("text-generation", model="gpt2", device=device, pad_token_id=50256), 'ner': lambda: pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", device=device, aggregation_strategy="simple"), 'image_classification': lambda: pipeline("image-classification", model="google/vit-base-patch16-224", device=device), 'object_detection': lambda: pipeline("object-detection", model="facebook/detr-resnet-50", device=device), 'image_segmentation': lambda: pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic", device=device), 'facial_recognition': lambda: pipeline("image-classification", model="mo-thecreator/vit-Facial-Expression-Recognition", device=device), 'speech_to_text': lambda: pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device), 'audio_classification': lambda: pipeline("audio-classification", model="superb/hubert-base-superb-er", device=device), 'text_to_image': lambda: StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32, use_safetensors=True, safety_checker=None, cache_dir=cache_dir ).to(device) } def load_model(key): if key not in model_loaders: raise ValueError(f"Model {key} not found in registry") return model_loaders[key]()