Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| import torch | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app = FastAPI(title="Model Inference API") | |
| # Allow CORS for external frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_MAP = { | |
| "tinny-llama": "Lyon28/Tinny-Llama", | |
| "pythia": "Lyon28/Pythia", | |
| "bert-tinny": "Lyon28/Bert-Tinny", | |
| "albert-base-v2": "Lyon28/Albert-Base-V2", | |
| "t5-small": "Lyon28/T5-Small", | |
| "gpt-2": "Lyon28/GPT-2", | |
| "gpt-neo": "Lyon28/GPT-Neo", | |
| "distilbert-base-uncased": "Lyon28/Distilbert-Base-Uncased", | |
| "distil-gpt-2": "Lyon28/Distil_GPT-2", | |
| "gpt-2-tinny": "Lyon28/GPT-2-Tinny", | |
| "electra-small": "Lyon28/Electra-Small" | |
| } | |
| TASK_MAP = { | |
| "text-generation": ["gpt-2", "gpt-neo", "distil-gpt-2", "gpt-2-tinny", "tinny-llama", "pythia"], | |
| "text-classification": ["bert-tinny", "albert-base-v2", "distilbert-base-uncased", "electra-small"], | |
| "text2text-generation": ["t5-small"] | |
| } | |
| class InferenceRequest(BaseModel): | |
| text: str | |
| max_length: int = 100 | |
| temperature: float = 0.9 | |
| def get_task(model_id: str): | |
| for task, models in TASK_MAP.items(): | |
| if model_id in models: | |
| return task | |
| return "text-generation" | |
| async def load_models(): | |
| # Initialize models (optional: pre-load critical models) | |
| app.state.pipelines = {} | |
| print("Models initialized in memory") | |
| async def model_inference(model_id: str, request: InferenceRequest): | |
| try: | |
| if model_id not in MODEL_MAP: | |
| raise HTTPException(status_code=404, detail="Model not found") | |
| task = get_task(model_id) | |
| # Load pipeline with caching | |
| if model_id not in app.state.pipelines: | |
| app.state.pipelines[model_id] = pipeline( | |
| task=task, | |
| model=MODEL_MAP[model_id], | |
| device_map="auto", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 | |
| ) | |
| pipe = app.state.pipelines[model_id] | |
| # Process based on task | |
| if task == "text-generation": | |
| result = pipe( | |
| request.text, | |
| max_length=request.max_length, | |
| temperature=request.temperature | |
| )[0]['generated_text'] | |
| elif task == "text-classification": | |
| output = pipe(request.text)[0] | |
| result = { | |
| "label": output['label'], | |
| "confidence": round(output['score'], 4) | |
| } | |
| elif task == "text2text-generation": | |
| result = pipe(request.text)[0]['generated_text'] | |
| return {"result": result} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_models(): | |
| return {"available_models": list(MODEL_MAP.keys())} | |
| async def health_check(): | |
| return {"status": "healthy"} |