natishanau's picture
Fix PermissionError by setting HF_HOME in Dockerfilev
a2ae8d3
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import os
from fastapi.middleware.cors import CORSMiddleware
import logging
from contextlib import asynccontextmanager # Import asynccontextmanager
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for model and tokenizer
tokenizer = None
model = None
device = None
# --- Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global tokenizer, model, device
logger.info("Starting model loading and quantization process (Lifespan event)...")
try:
# --- CHANGE STARTS HERE ---
# Using the smallest available Amharic BERT model for extreme resource efficiency.
# This model has only 4.18 Million parameters.
model_name_or_path = "rasyosef/bert-tiny-amharic"
# --- CHANGE ENDS HERE ---
logger.info(f"Attempting to load tokenizer from: {model_name_or_path}")
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
logger.info("Tokenizer loaded successfully.")
logger.info(f"Attempting to load model from: {model_name_or_path}")
# Load the model
model = BertForSequenceClassification.from_pretrained(model_name_or_path)
model.eval() # Set model to evaluation mode
logger.info("Model loaded successfully.")
# --- Apply Dynamic Quantization ---
logger.info("Applying dynamic quantization to the model...")
model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
logger.info("Dynamic quantization applied successfully.")
# --- End Dynamic Quantization ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) # Move quantized model to device
logger.info(f"Model moved to device: {device}")
logger.info("All components loaded, quantized, and ready for serving.")
except Exception as e:
logger.error(f"Critical Error during model loading or quantization: {e}", exc_info=True)
# For a truly free deployment, if essential model loading fails,
# it might be better to let the app crash to avoid silent failures.
# However, for debugging purposes, we'll keep 'pass' for now.
pass
yield # This indicates that the startup logic is complete and the app can serve requests
# --- Shutdown logic (optional) ---
logger.info("Shutting down application (Lifespan event)...")
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Initialize FastAPI app with lifespan
app = FastAPI(lifespan=lifespan)
# --- CORS Configuration ---
origins = [
"http://localhost",
"http://localhost:3000", # We Allow requests from our React development server
"https://natishanau-amharic-fake-news-frontend.vercel.app", # Your Vercel frontend domain
"https://natishanau-amharic-fake-news-backend.hf.space", # Allow self-access if needed
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Health Check Endpoint ---
@app.get("/")
async def health_check():
model_status = "loaded and quantized" if model is not None and tokenizer is not None else "loading, quantizing, or failed"
return {"status": "ok", "message": "FastAPI backend is running", "model_status": model_status}
# Define request body structure
class TextInput(BaseModel):
text: str
@app.post("/predict/")
async def predict(input: TextInput):
"""
Predicts whether the given Amharic text is 'Real News' or 'Fake News'.
"""
if model is None or tokenizer is None:
logger.warning("Prediction request received but model is not yet loaded or failed to load.")
return {"error": "Model not loaded yet. Please try again in a moment.", "status_code": 503}
text = input.text
if not text.strip():
return {"text": text, "prediction": "Invalid Input", "details": "Text input cannot be empty."}
try:
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
predicted_class_idx = torch.argmax(logits, dim=-1).item()
confidence = probabilities[0][predicted_class_idx].item()
if predicted_class_idx == 0:
prediction_label = "Real News"
else:
prediction_label = "Fake News"
return {
"original_text": text,
"prediction": prediction_label,
"confidence": round(confidence, 4)
}
except Exception as e:
logger.error(f"Error during prediction: {e}", exc_info=True)
return {"error": f"An error occurred during prediction: {e}", "status_code": 500}