Spaces:
Running
Running
Update core/model_registry.py
Browse files- core/model_registry.py +27 -0
core/model_registry.py
CHANGED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import pipeline
|
| 2 |
+
from diffusers import StableDiffusionPipeline
|
| 3 |
+
from core.config import device, cache_dir
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
model_loaders = {
|
| 7 |
+
'sentiment_analysis': lambda: pipeline("sentiment-analysis", model="cardiffnlp/twitter-roberta-base-sentiment-latest", device=device),
|
| 8 |
+
'text_classification': lambda: pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english", device=device),
|
| 9 |
+
'summarization': lambda: pipeline("summarization", model="facebook/bart-large-cnn", device=device, max_length=150, min_length=30),
|
| 10 |
+
'question_answering': lambda: pipeline("question-answering", model="deepset/roberta-base-squad2", device=device),
|
| 11 |
+
'translation': lambda: pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-pt", device=device),
|
| 12 |
+
'text_generation': lambda: pipeline("text-generation", model="gpt2", device=device, pad_token_id=50256),
|
| 13 |
+
'ner': lambda: pipeline("ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", device=device, aggregation_strategy="simple"),
|
| 14 |
+
'image_classification': lambda: pipeline("image-classification", model="google/vit-base-patch16-224", device=device),
|
| 15 |
+
'object_detection': lambda: pipeline("object-detection", model="facebook/detr-resnet-50", device=device),
|
| 16 |
+
'image_segmentation': lambda: pipeline("image-segmentation", model="facebook/detr-resnet-50-panoptic", device=device),
|
| 17 |
+
'facial_recognition': lambda: pipeline("image-classification", model="mo-thecreator/vit-Facial-Expression-Recognition", device=device),
|
| 18 |
+
'speech_to_text': lambda: pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device),
|
| 19 |
+
'audio_classification': lambda: pipeline("audio-classification", model="superb/hubert-base-superb-er", device=device),
|
| 20 |
+
'text_to_image': lambda: StableDiffusionPipeline.from_pretrained(
|
| 21 |
+
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32, use_safetensors=True, safety_checker=None, cache_dir=cache_dir
|
| 22 |
+
).to(device)
|
| 23 |
+
}
|
| 24 |
+
def load_model(key):
|
| 25 |
+
if key not in model_loaders:
|
| 26 |
+
raise ValueError(f"Model {key} not found in registry")
|
| 27 |
+
return model_loaders[key]()
|