AiiTServices / core /model_registry.py
HenriqueBraz's picture
Update core/model_registry.py
cedd1c0 verified
from transformers import pipeline
from diffusers import StableDiffusionPipeline
from core.config import device, cache_dir
import torch
model_loaders = {
'sentiment_analysis': lambda: pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=device),
'text_classification': lambda: pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", device=device),
'summarization': lambda: pipeline("summarization", model="facebook/bart-large-cnn", device=device, max_length=150, min_length=30),
'question_answering': lambda: pipeline("question-answering", model="deepset/roberta-base-squad2", device=device),
'translation': lambda: pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-pt", device=device),
'text_generation': lambda: pipeline("text-generation", model="gpt2", device=device, pad_token_id=50256),
'ner': lambda: pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", device=device, aggregation_strategy="simple"),
'image_classification': lambda: pipeline("image-classification", model="google/vit-base-patch16-224", device=device),
'object_detection': lambda: pipeline("object-detection", model="facebook/detr-resnet-50", device=device),
'image_segmentation': lambda: pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic", device=device),
'facial_recognition': lambda: pipeline("image-classification", model="mo-thecreator/vit-Facial-Expression-Recognition", device=device),
'speech_to_text': lambda: pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device),
'audio_classification': lambda: pipeline("audio-classification", model="superb/hubert-base-superb-er", device=device),
'text_to_image': lambda: StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32, use_safetensors=True, safety_checker=None, cache_dir=cache_dir
).to(device)
}
def load_model(key):
if key not in model_loaders:
raise ValueError(f"Model {key} not found in registry")
return model_loaders[key]()