Spaces:
Build error
Build error
| from fastapi import FastAPI, HTTPException, Request | |
| from pymongo import MongoClient | |
| from pydantic import BaseModel | |
| from passlib.context import CryptContext | |
| from bson import ObjectId | |
| from datetime import datetime, timedelta | |
| import jwt | |
| from collections import Counter | |
| from fastapi.responses import JSONResponse | |
| app = FastAPI() | |
| # MongoDB connection | |
| client = MongoClient( | |
| "mongodb+srv://sarmadsiddiqui29:[email protected]/?retryWrites=true&w=majority&appName=Cluster0", | |
| tls=True, | |
| tlsAllowInvalidCertificates=True # For testing only, disable for production | |
| ) | |
| db = client["annotations_db"] | |
| # Password hashing context | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| # Secret key for JWT | |
| SECRET_KEY = "your_secret_key" # Replace with a secure secret key | |
| ALGORITHM = "HS256" | |
| ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Token expiration time | |
| # In-memory variable to store the token | |
| current_token = None | |
| # MongoDB Collections | |
| users_collection = db["users"] | |
| stories_collection = db["stories"] | |
| prompts_collection = db["prompts"] | |
| summaries_collection = db["summaries"] | |
| # Models | |
| class User(BaseModel): | |
| email: str | |
| password: str | |
| class Story(BaseModel): | |
| story_id: str | |
| story: str | |
| # annotator_id is removed from the Story model | |
| class Prompt(BaseModel): | |
| story_id: str | |
| prompt: str | |
| annotator_id: int = None # Will be set automatically | |
| class Summary(BaseModel): | |
| story_id: str | |
| summary: str | |
| annotator_id: int =None # Add annotator_id to Summary model | |
| # Serialize document function | |
| def serialize_document(doc): | |
| """Convert a MongoDB document into a serializable dictionary.""" | |
| if isinstance(doc, ObjectId): | |
| return str(doc) | |
| if isinstance(doc, dict): | |
| return {k: serialize_document(v) for k, v in doc.items()} | |
| if isinstance(doc, list): | |
| return [serialize_document(i) for i in doc] | |
| return doc | |
| # Helper Functions | |
| def hash_password(password: str) -> str: | |
| return pwd_context.hash(password) | |
| def verify_password(plain_password: str, hashed_password: str) -> bool: | |
| return pwd_context.verify(plain_password, hashed_password) | |
| def create_access_token(data: dict, expires_delta: timedelta = None): | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(minutes=15) | |
| to_encode.update({"exp": expire}) | |
| return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| def get_annotator_id() -> int: | |
| if current_token is None: | |
| raise HTTPException(status_code=401, detail="User not logged in") | |
| try: | |
| payload = jwt.decode(current_token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| return payload["annotator_id"] | |
| except jwt.PyJWTError: | |
| raise HTTPException(status_code=401, detail="Invalid token") | |
| # Endpoints for user, story, prompt, and summary operations | |
| # Register User | |
| async def register_user(user: User): | |
| if db.users.find_one({"email": user.email}): | |
| raise HTTPException(status_code=400, detail="Email already registered") | |
| user_data = { | |
| "email": user.email, | |
| "password": hash_password(user.password), | |
| "annotator_id": db.users.count_documents({}) + 1 | |
| } | |
| db.users.insert_one(user_data) | |
| return {"message": "User registered successfully", "annotator_id": user_data["annotator_id"]} | |
| # Login User | |
| async def login_user(user: User): | |
| found_user = db.users.find_one({"email": user.email}) | |
| if not found_user or not verify_password(user.password, found_user["password"]): | |
| raise HTTPException(status_code=400, detail="Invalid email or password") | |
| # Create access token and store it | |
| global current_token | |
| access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| current_token = create_access_token(data={"email": found_user["email"], "annotator_id": found_user["annotator_id"]}, | |
| expires_delta=access_token_expires) | |
| return {"access_token": current_token, "token_type": "bearer"} | |
| # Add Story | |
| async def add_story(story: Story): | |
| # annotator_id is not needed when adding a story | |
| if db.stories.find_one({"story_id": story.story_id}): | |
| raise HTTPException(status_code=400, detail="Story already exists") | |
| db.stories.insert_one(story.dict()) | |
| return {"message": "Story added successfully"} | |
| # Add Prompt | |
| async def add_prompt(prompt: Prompt): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID | |
| prompt.annotator_id = annotator_id # Assign annotator ID to the prompt | |
| db.prompts.insert_one(prompt.dict()) | |
| return {"message": "Prompt added successfully"} | |
| # Add Summary | |
| async def add_summary(summary: Summary): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID | |
| summary.annotator_id = annotator_id # Assign annotator ID to the summary | |
| db.summaries.insert_one(summary.dict()) | |
| return {"message": "Summary added successfully"} | |
| # Delete All Users | |
| async def delete_all_users(): | |
| result = db.users.delete_many({}) | |
| return {"message": f"{result.deleted_count} users deleted"} | |
| # Delete All Stories | |
| async def delete_all_stories(): | |
| result = db.stories.delete_many({}) | |
| return {"message": f"{result.deleted_count} stories deleted"} | |
| # Delete All Prompts | |
| async def delete_all_prompts(): | |
| result = db.prompts.delete_many({}) | |
| return {"message": f"{result.deleted_count} prompts deleted"} | |
| # Delete All Summaries | |
| async def delete_all_summaries(): | |
| result = db.summaries.delete_many({}) | |
| return {"message": f"{result.deleted_count} summaries deleted"} | |
| # Test MongoDB Connection | |
| async def test_connection(): | |
| try: | |
| db.list_collection_names() | |
| return {"message": "Connected to MongoDB successfully"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Display Story by ID | |
| async def display_story(story_id: str): | |
| story = db.stories.find_one({"story_id": story_id}) | |
| if story: | |
| return serialize_document(story) # Serialize the story document | |
| raise HTTPException(status_code=404, detail="Story not found") | |
| # Display All for a Given Annotator ID | |
| from fastapi import Query | |
| from fastapi import Query, HTTPException | |
| async def display_all(story_id: str = Query(...)): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Fetch the specific prompt associated with the provided story_id for the current annotator | |
| prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
| if not prompt: | |
| raise HTTPException(status_code=404, detail="Prompt not found for this annotator and story ID") | |
| # Fetch the corresponding story | |
| story = db.stories.find_one({"story_id": story_id}) or {"story": ""} | |
| # Fetch the summary for the specific annotator | |
| summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id}) or {"summary": ""} | |
| # Prepare the result | |
| result = { | |
| "story_id": story_id, | |
| "story": story["story"], # Get the story text | |
| "annotator_id": prompt["annotator_id"], | |
| "summary": summary.get("summary", ""), # Use empty string if summary not found | |
| "prompt": prompt.get("prompt", "") # Use empty string if prompt not found | |
| } | |
| return serialize_document(result) # Serialize the story document | |
| async def delete_prompt(story_id: str): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Find and delete all prompts associated with the provided story_id for the current annotator | |
| result = db.prompts.delete_many({"story_id": story_id, "annotator_id": annotator_id}) | |
| if result.deleted_count > 0: | |
| return {"message": f"{result.deleted_count} prompt(s) deleted successfully"} | |
| else: | |
| raise HTTPException(status_code=404, detail="No prompts found for this annotator and story ID") | |
| async def delete_summary(story_id: str): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Find and delete all summaries associated with the provided story_id for the current annotator | |
| result = db.summaries.delete_many({"story_id": story_id, "annotator_id": annotator_id}) | |
| if result.deleted_count > 0: | |
| return {"message": f"{result.deleted_count} summary(ies) deleted successfully"} | |
| else: | |
| raise HTTPException(status_code=404, detail="No summaries found for this annotator and story ID") | |
| async def delete_story(story_id: str): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Find and delete the story associated with the provided story_id for the current annotator | |
| story_result = db.stories.delete_one({"story_id": story_id}) | |
| # Delete all prompts associated with the provided story_id for the current annotator | |
| prompts_result = db.prompts.delete_many({"story_id": story_id, "annotator_id": annotator_id}) | |
| # Delete all summaries associated with the provided story_id for the current annotator | |
| summaries_result = db.summaries.delete_many({"story_id": story_id, "annotator_id": annotator_id}) | |
| if story_result.deleted_count > 0: | |
| return { | |
| "message": f"Story deleted successfully", | |
| "deleted_prompts": prompts_result.deleted_count, | |
| "deleted_summaries": summaries_result.deleted_count, | |
| } | |
| else: | |
| raise HTTPException(status_code=404, detail="Story not found for this annotator") | |
| async def update_story(story_id: str, updated_story: Story): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Check if the story exists and belongs to the current annotator | |
| existing_story = db.stories.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
| if not existing_story: | |
| raise HTTPException(status_code=404, detail="Story not found or does not belong to this annotator") | |
| # Update the story | |
| db.stories.update_one({"story_id": story_id}, {"$set": {"story": updated_story.story}}) | |
| return {"message": "Story updated successfully"} | |
| async def update_prompt(story_id: str, updated_prompt: Prompt): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Check if the prompt exists and belongs to the current annotator | |
| existing_prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
| if not existing_prompt: | |
| raise HTTPException(status_code=404, detail="Prompt not found or does not belong to this annotator") | |
| # Update the prompt | |
| db.prompts.update_one({"story_id": story_id, "annotator_id": annotator_id}, {"$set": {"prompt": updated_prompt.prompt}}) | |
| return {"message": "Prompt updated successfully"} | |
| async def update_summary(story_id: str, updated_summary: Summary): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Check if the summary exists and belongs to the current annotator | |
| existing_summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
| if not existing_summary: | |
| raise HTTPException(status_code=404, detail="Summary not found or does not belong to this annotator") | |
| # Update the summary | |
| db.summaries.update_one({"story_id": story_id, "annotator_id": annotator_id}, {"$set": {"summary": updated_summary.summary}}) | |
| return {"message": "Summary updated successfully"} | |
| async def get_prompt(story_id: str): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Retrieve the prompt associated with the story_id for the current annotator | |
| prompt = db.prompts.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
| if prompt: | |
| return {"story_id": story_id, "prompt": prompt.get("prompt", "")} # Return prompt or empty string | |
| else: | |
| return {"story_id": story_id, "prompt": ""} # Return empty if no prompt found | |
| async def get_summary(story_id: str): | |
| annotator_id = get_annotator_id() # Automatically get the annotator ID from the token | |
| # Retrieve the summary associated with the story_id for the current annotator | |
| summary = db.summaries.find_one({"story_id": story_id, "annotator_id": annotator_id}) | |
| if summary: | |
| return {"story_id": story_id, "summary": summary.get("summary", "")} # Return summary or empty string | |
| else: | |
| return {"story_id": story_id, "summary": ""} # Return empty if no summary found | |
| async def get_story(story_id: str): | |
| # Retrieve the story associated with the story_id | |
| story = db.stories.find_one({"story_id": story_id}) | |
| if story: | |
| return {"story_id": story_id, "story": story.get("story", "")} # Return story text or empty string | |
| else: | |
| return {"story_id": story_id, "story": ""} # Return empty if no story found | |
| async def get_annotators(): | |
| # Fetch all prompts synchronously | |
| prompts = prompts_collection.find() # Get cursor | |
| # Count prompts by annotator_id | |
| annotator_counts = Counter(prompt['annotator_id'] for prompt in prompts if 'annotator_id' in prompt) | |
| # Convert the Counter to a list of dictionaries | |
| annotators = [{"annotator_id": annotator_id, "prompt_count": count} for annotator_id, count in annotator_counts.items()] | |
| return JSONResponse(content=annotators) |