File size: 2,217 Bytes
cedd1c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from transformers import pipeline
from diffusers import StableDiffusionPipeline
from core.config import device, cache_dir
import torch

model_loaders = {
        'sentiment_analysis': lambda: pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=device),
        'text_classification': lambda: pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", device=device),
        'summarization': lambda: pipeline("summarization", model="facebook/bart-large-cnn", device=device, max_length=150, min_length=30),
        'question_answering': lambda: pipeline("question-answering", model="deepset/roberta-base-squad2", device=device),
        'translation': lambda: pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-pt", device=device),
        'text_generation': lambda: pipeline("text-generation", model="gpt2", device=device, pad_token_id=50256),
        'ner': lambda: pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", device=device, aggregation_strategy="simple"),
        'image_classification': lambda: pipeline("image-classification", model="google/vit-base-patch16-224", device=device),
        'object_detection': lambda: pipeline("object-detection", model="facebook/detr-resnet-50", device=device),
        'image_segmentation': lambda: pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic", device=device),
        'facial_recognition': lambda: pipeline("image-classification", model="mo-thecreator/vit-Facial-Expression-Recognition", device=device),
        'speech_to_text': lambda: pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device),
        'audio_classification': lambda: pipeline("audio-classification", model="superb/hubert-base-superb-er", device=device),
        'text_to_image': lambda: StableDiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32, use_safetensors=True, safety_checker=None, cache_dir=cache_dir
        ).to(device)
    }
def load_model(key):
    if key not in model_loaders:
        raise ValueError(f"Model {key} not found in registry")
    return model_loaders[key]()