Spaces:
Paused
Paused
| from fastapi import APIRouter, HTTPException, Depends | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy.future import select | |
| from passlib.context import CryptContext | |
| from jose import jwt | |
| from pydantic import BaseModel, EmailStr | |
| from app.database import get_db # Updated: use the correct async session dependency | |
| from app.models import User | |
| import os | |
| import logging | |
| from dotenv import load_dotenv | |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
| from jose import JWTError | |
| router = APIRouter() | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| # Load secret key and JWT algorithm | |
| SECRET_KEY = os.getenv("SECRET_KEY", "secret") | |
| ALGORITHM = "HS256" | |
| # Password hashing config | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| security = HTTPBearer() | |
| async def get_current_user(token: HTTPAuthorizationCredentials = Depends(security), | |
| db: AsyncSession = Depends(get_db)): | |
| credentials_exception = HTTPException( | |
| status_code=401, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode(token.credentials, SECRET_KEY, algorithms=[ALGORITHM]) | |
| user_id: int = payload.get("user_id") | |
| if user_id is None: | |
| raise credentials_exception | |
| except JWTError: | |
| raise credentials_exception | |
| result = await db.execute(select(User).where(User.id == user_id)) | |
| user = result.scalar_one_or_none() | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| # Request Schemas | |
| class SignUp(BaseModel): | |
| email: EmailStr | |
| password: str | |
| mobile: str | None = None | |
| name: str | None = None | |
| dob: str | None = None | |
| preparing_for: str | None = None | |
| class Login(BaseModel): | |
| email: EmailStr | |
| password: str | |
| class UpdateProfile(BaseModel): | |
| mobile: str | None = None | |
| name: str | None = None | |
| dob: str | None = None | |
| preparing_for: str | None = None | |
| async def update_profile(data: UpdateProfile, | |
| current_user: User = Depends(get_current_user), | |
| db: AsyncSession = Depends(get_db)): | |
| # Update user fields if provided | |
| if data.mobile is not None: | |
| current_user.mobile = data.mobile | |
| if data.name is not None: | |
| current_user.name = data.name | |
| if data.dob is not None: | |
| current_user.dob = data.dob | |
| if data.preparing_for is not None: | |
| current_user.preparing_for = data.preparing_for | |
| try: | |
| await db.commit() | |
| await db.refresh(current_user) | |
| return {"message": "Profile updated successfully", | |
| "user": {"id": current_user.id, | |
| "email": current_user.email, | |
| "mobile": current_user.mobile, | |
| "name": current_user.name, | |
| "dob": current_user.dob, | |
| "preparing_for": current_user.preparing_for}} | |
| except Exception as e: | |
| await db.rollback() | |
| logger.error(f"Profile update error: {e}") | |
| raise HTTPException(status_code=500, detail="Internal Server Error") | |
| async def signup(data: SignUp, db: AsyncSession = Depends(get_db)): | |
| # Check if user already exists | |
| result = await db.execute(select(User).where(User.email == data.email)) | |
| existing_user = result.scalar_one_or_none() | |
| if existing_user: | |
| raise HTTPException(status_code=400, detail="Email already exists") | |
| hashed_password = pwd_context.hash(data.password) | |
| new_user = User(email=data.email, hashed_password=hashed_password, | |
| mobile=data.mobile, name=data.name, dob=data.dob, | |
| preparing_for=data.preparing_for) | |
| try: | |
| db.add(new_user) | |
| await db.commit() | |
| await db.refresh(new_user) | |
| return {"message": "User created", "user_id": new_user.id} | |
| except Exception as e: | |
| await db.rollback() | |
| logger.error(f"Signup error: {e}") | |
| raise HTTPException(status_code=500, detail="Internal Server Error") | |
| async def login(data: Login, db: AsyncSession = Depends(get_db)): | |
| result = await db.execute(select(User).where(User.email == data.email)) | |
| user = result.scalar_one_or_none() | |
| if not user or not pwd_context.verify(data.password, user.hashed_password): | |
| raise HTTPException(status_code=401, detail="Invalid credentials") | |
| token = jwt.encode({"user_id": user.id}, SECRET_KEY, algorithm=ALGORITHM) | |
| return { | |
| "access_token": token, | |
| "token_type": "bearer", | |
| "user": {"id": user.id, "email": user.email}, | |
| } | |