from fastapi import FastAPI, HTTPException from pydantic import BaseModel from dotenv import load_dotenv import os import google.generativeai as genai import joblib # Load environment variables load_dotenv() genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) # Load the machine learning model try: model = joblib.load('./movie_review_classifier.joblib') except Exception as e: raise ImportError(f"Failed to load model: {e}") app = FastAPI() # Define models for requests class QueryRequest(BaseModel): question: str class Review(BaseModel): text: str # Initialize the Gemini chat model gemini_model = genai.GenerativeModel("gemini-pro") chat = gemini_model.start_chat(history=[]) mental_health_prompt = """ You are an expert in providing mental health support. When a user describes their mental health issues, you should provide relevant articles or blog posts to assist them. """ # Gemini response function def get_gemini_response(question, prompt): response = chat.send_message(f"{prompt} {question}", stream=True) return [chunk.text for chunk in response] # Function to retrieve articles from a database or external source def get_articles(query): return [ {"title": "Understanding Anxiety", "url": "https://newsinhealth.nih.gov/2016/03/understanding-anxiety-disorders", "summary": "A comprehensive guide on anxiety disorders."}, {"title": "Coping with Depression", "url": "https://www.helpguide.org/articles/depression/coping-with-depression.htm", "summary": "Effective strategies for dealing with depression."} ] # Mental health support endpoint @app.post("/rag") async def mental_health_support(request: QueryRequest): try: responses = get_gemini_response(request.question, mental_health_prompt) articles = get_articles(request.question) result = {"responses": responses, "articles": articles} return result except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Classification endpoint @app.post("/classification") async def classify_review(review: Review): try: prediction = model.predict([review.text]) return {"predicted_sentiment": prediction[0]} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Main function to run the server if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)