from fastapi import FastAPI from pydantic import BaseModel import torch from transformers import BertTokenizer, BertForSequenceClassification import os from fastapi.middleware.cors import CORSMiddleware import logging from contextlib import asynccontextmanager # Import asynccontextmanager # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables for model and tokenizer tokenizer = None model = None device = None # --- Lifespan Event Handler --- @asynccontextmanager async def lifespan(app: FastAPI): global tokenizer, model, device logger.info("Starting model loading and quantization process (Lifespan event)...") try: # --- CHANGE STARTS HERE --- # Using the smallest available Amharic BERT model for extreme resource efficiency. # This model has only 4.18 Million parameters. model_name_or_path = "rasyosef/bert-tiny-amharic" # --- CHANGE ENDS HERE --- logger.info(f"Attempting to load tokenizer from: {model_name_or_path}") tokenizer = BertTokenizer.from_pretrained(model_name_or_path) logger.info("Tokenizer loaded successfully.") logger.info(f"Attempting to load model from: {model_name_or_path}") # Load the model model = BertForSequenceClassification.from_pretrained(model_name_or_path) model.eval() # Set model to evaluation mode logger.info("Model loaded successfully.") # --- Apply Dynamic Quantization --- logger.info("Applying dynamic quantization to the model...") model = torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtype=torch.qint8 ) logger.info("Dynamic quantization applied successfully.") # --- End Dynamic Quantization --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Move quantized model to device logger.info(f"Model moved to device: {device}") logger.info("All components loaded, quantized, and ready for serving.") except Exception as e: logger.error(f"Critical Error during model loading or quantization: {e}", exc_info=True) # For a truly free deployment, if essential model loading fails, # it might be better to let the app crash to avoid silent failures. # However, for debugging purposes, we'll keep 'pass' for now. pass yield # This indicates that the startup logic is complete and the app can serve requests # --- Shutdown logic (optional) --- logger.info("Shutting down application (Lifespan event)...") if torch.cuda.is_available(): torch.cuda.empty_cache() # Initialize FastAPI app with lifespan app = FastAPI(lifespan=lifespan) # --- CORS Configuration --- origins = [ "http://localhost", "http://localhost:3000", # We Allow requests from our React development server "https://natishanau-amharic-fake-news-frontend.vercel.app", # Your Vercel frontend domain "https://natishanau-amharic-fake-news-backend.hf.space", # Allow self-access if needed ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Health Check Endpoint --- @app.get("/") async def health_check(): model_status = "loaded and quantized" if model is not None and tokenizer is not None else "loading, quantizing, or failed" return {"status": "ok", "message": "FastAPI backend is running", "model_status": model_status} # Define request body structure class TextInput(BaseModel): text: str @app.post("/predict/") async def predict(input: TextInput): """ Predicts whether the given Amharic text is 'Real News' or 'Fake News'. """ if model is None or tokenizer is None: logger.warning("Prediction request received but model is not yet loaded or failed to load.") return {"error": "Model not loaded yet. Please try again in a moment.", "status_code": 503} text = input.text if not text.strip(): return {"text": text, "prediction": "Invalid Input", "details": "Text input cannot be empty."} try: inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=512 ).to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) predicted_class_idx = torch.argmax(logits, dim=-1).item() confidence = probabilities[0][predicted_class_idx].item() if predicted_class_idx == 0: prediction_label = "Real News" else: prediction_label = "Fake News" return { "original_text": text, "prediction": prediction_label, "confidence": round(confidence, 4) } except Exception as e: logger.error(f"Error during prediction: {e}", exc_info=True) return {"error": f"An error occurred during prediction: {e}", "status_code": 500}