Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
from diffusers import StableDiffusionPipeline | |
from core.config import device, cache_dir | |
import torch | |
model_loaders = { | |
'sentiment_analysis': lambda: pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=device), | |
'text_classification': lambda: pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", device=device), | |
'summarization': lambda: pipeline("summarization", model="facebook/bart-large-cnn", device=device, max_length=150, min_length=30), | |
'question_answering': lambda: pipeline("question-answering", model="deepset/roberta-base-squad2", device=device), | |
'translation': lambda: pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-pt", device=device), | |
'text_generation': lambda: pipeline("text-generation", model="gpt2", device=device, pad_token_id=50256), | |
'ner': lambda: pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", device=device, aggregation_strategy="simple"), | |
'image_classification': lambda: pipeline("image-classification", model="google/vit-base-patch16-224", device=device), | |
'object_detection': lambda: pipeline("object-detection", model="facebook/detr-resnet-50", device=device), | |
'image_segmentation': lambda: pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic", device=device), | |
'facial_recognition': lambda: pipeline("image-classification", model="mo-thecreator/vit-Facial-Expression-Recognition", device=device), | |
'speech_to_text': lambda: pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device), | |
'audio_classification': lambda: pipeline("audio-classification", model="superb/hubert-base-superb-er", device=device), | |
'text_to_image': lambda: StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32, use_safetensors=True, safety_checker=None, cache_dir=cache_dir | |
).to(device) | |
} | |
def load_model(key): | |
if key not in model_loaders: | |
raise ValueError(f"Model {key} not found in registry") | |
return model_loaders[key]() | |