Spaces:
Runtime error
Runtime error
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import torch | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import os | |
from fastapi.middleware.cors import CORSMiddleware | |
import logging | |
from contextlib import asynccontextmanager # Import asynccontextmanager | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Global variables for model and tokenizer | |
tokenizer = None | |
model = None | |
device = None | |
# --- Lifespan Event Handler --- | |
async def lifespan(app: FastAPI): | |
global tokenizer, model, device | |
logger.info("Starting model loading and quantization process (Lifespan event)...") | |
try: | |
# --- CHANGE STARTS HERE --- | |
# Using the smallest available Amharic BERT model for extreme resource efficiency. | |
# This model has only 4.18 Million parameters. | |
model_name_or_path = "rasyosef/bert-tiny-amharic" | |
# --- CHANGE ENDS HERE --- | |
logger.info(f"Attempting to load tokenizer from: {model_name_or_path}") | |
tokenizer = BertTokenizer.from_pretrained(model_name_or_path) | |
logger.info("Tokenizer loaded successfully.") | |
logger.info(f"Attempting to load model from: {model_name_or_path}") | |
# Load the model | |
model = BertForSequenceClassification.from_pretrained(model_name_or_path) | |
model.eval() # Set model to evaluation mode | |
logger.info("Model loaded successfully.") | |
# --- Apply Dynamic Quantization --- | |
logger.info("Applying dynamic quantization to the model...") | |
model = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
logger.info("Dynamic quantization applied successfully.") | |
# --- End Dynamic Quantization --- | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) # Move quantized model to device | |
logger.info(f"Model moved to device: {device}") | |
logger.info("All components loaded, quantized, and ready for serving.") | |
except Exception as e: | |
logger.error(f"Critical Error during model loading or quantization: {e}", exc_info=True) | |
# For a truly free deployment, if essential model loading fails, | |
# it might be better to let the app crash to avoid silent failures. | |
# However, for debugging purposes, we'll keep 'pass' for now. | |
pass | |
yield # This indicates that the startup logic is complete and the app can serve requests | |
# --- Shutdown logic (optional) --- | |
logger.info("Shutting down application (Lifespan event)...") | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Initialize FastAPI app with lifespan | |
app = FastAPI(lifespan=lifespan) | |
# --- CORS Configuration --- | |
origins = [ | |
"http://localhost", | |
"http://localhost:3000", # We Allow requests from our React development server | |
"https://natishanau-amharic-fake-news-frontend.vercel.app", # Your Vercel frontend domain | |
"https://natishanau-amharic-fake-news-backend.hf.space", # Allow self-access if needed | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# --- Health Check Endpoint --- | |
async def health_check(): | |
model_status = "loaded and quantized" if model is not None and tokenizer is not None else "loading, quantizing, or failed" | |
return {"status": "ok", "message": "FastAPI backend is running", "model_status": model_status} | |
# Define request body structure | |
class TextInput(BaseModel): | |
text: str | |
async def predict(input: TextInput): | |
""" | |
Predicts whether the given Amharic text is 'Real News' or 'Fake News'. | |
""" | |
if model is None or tokenizer is None: | |
logger.warning("Prediction request received but model is not yet loaded or failed to load.") | |
return {"error": "Model not loaded yet. Please try again in a moment.", "status_code": 503} | |
text = input.text | |
if not text.strip(): | |
return {"text": text, "prediction": "Invalid Input", "details": "Text input cannot be empty."} | |
try: | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=512 | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=-1) | |
predicted_class_idx = torch.argmax(logits, dim=-1).item() | |
confidence = probabilities[0][predicted_class_idx].item() | |
if predicted_class_idx == 0: | |
prediction_label = "Real News" | |
else: | |
prediction_label = "Fake News" | |
return { | |
"original_text": text, | |
"prediction": prediction_label, | |
"confidence": round(confidence, 4) | |
} | |
except Exception as e: | |
logger.error(f"Error during prediction: {e}", exc_info=True) | |
return {"error": f"An error occurred during prediction: {e}", "status_code": 500} | |