Makhfi_AI / api /routers /chats.py
Aasher's picture
fix(chats): filter messages to include only relevant user and assistant content in chat retrieval
8909b6d
raw
history blame
6.23 kB
import uuid
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlmodel import Session
# --- LangChain Imports ---
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
# --- Local Project Imports ---
from api.dependencies.auth import get_current_user
from db.session import get_db
from db.crud import chat as chat_crud, message as message_crud
from db.models.chat import Chat
from db.models.message import Message as DBMessage
from db.schemas.chat import ChatReadSimple, ChatUpdate, ChatReadWithMessages
from db.schemas.message import MessageCreate, MessageRead
from workflow.agent import agent
from workflow.title_generator import generate_chat_title
router = APIRouter()
# --- Helper Functions ---
def get_chat_for_user(chat_id: uuid.UUID, user_id: uuid.UUID, db: Session) -> Chat:
"""
A helper dependency to get a chat and verify the current user owns it.
"""
chat = chat_crud.get_chat_by_id(db, chat_id)
if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
if chat.user_id != user_id:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to access this chat")
return chat
def _convert_db_messages_to_langchain(db_messages: List[DBMessage]) -> List[BaseMessage]:
"""
Converts a list of database message objects into a list of LangChain BaseMessage objects.
"""
langchain_messages = []
for msg in db_messages:
if msg.role == "user":
langchain_messages.append(HumanMessage(content=msg.content))
elif msg.role == "assistant":
# Handle both simple AI messages and those that made tool calls
if msg.tool_calls:
langchain_messages.append(AIMessage(content="", tool_calls=msg.tool_calls))
else:
langchain_messages.append(AIMessage(content=msg.content))
elif msg.role == "tool":
langchain_messages.append(
ToolMessage(content=msg.content, tool_call_id=msg.tool_call_id)
)
return langchain_messages
# --- Chat CRUD Routes ---
@router.post("/", response_model=ChatReadSimple, status_code=status.HTTP_201_CREATED)
def create_new_chat(*, db: Session = Depends(get_db), user_id: uuid.UUID = Depends(get_current_user)):
"""Creates a new, empty chat for the authenticated user."""
return chat_crud.create_chat(db=db, user_id=user_id)
@router.get("/", response_model=List[ChatReadSimple])
def get_user_chats(*, db: Session = Depends(get_db), user_id: uuid.UUID = Depends(get_current_user)):
"""Retrieves all chats for the authenticated user."""
return chat_crud.get_chats_by_user(db=db, user_id=user_id)
@router.get("/{chat_id}", response_model=ChatReadWithMessages)
def get_single_chat_with_messages(*, chat_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user), db: Session = Depends(get_db)):
"""Retrieves a specific chat with all its messages."""
chat = get_chat_for_user(chat_id, user_id, db)
filtered_messages = [
msg for msg in chat.messages
if msg.role == 'user' or (msg.role == 'assistant' and msg.content)
]
return ChatReadWithMessages(
id=chat.id,
title=chat.title,
created_at=chat.created_at,
updated_at=chat.updated_at,
messages=filtered_messages
)
@router.patch("/{chat_id}", response_model=ChatReadSimple)
def rename_chat(*, chat_id: uuid.UUID, chat_update: ChatUpdate, user_id: uuid.UUID = Depends(get_current_user), db: Session = Depends(get_db)):
"""Renames a specific chat."""
chat = get_chat_for_user(chat_id, user_id, db)
return chat_crud.update_chat_title(db=db, chat=chat, chat_update=chat_update)
@router.delete("/{chat_id}", status_code=status.HTTP_204_NO_CONTENT)
def remove_chat(*, chat_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user), db: Session = Depends(get_db)):
"""Deletes a specific chat and all its messages."""
chat = get_chat_for_user(chat_id, user_id, db)
chat_crud.delete_chat(db=db, chat=chat)
return
# --- Main Message Handling Endpoint ---
@router.post("/{chat_id}/messages", response_model=MessageRead)
async def post_message_and_get_response(
*,
chat_id: uuid.UUID,
message_in: MessageCreate,
user_id: uuid.UUID = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Handles a user's message by invoking the agent and persisting the full turn.
"""
# 1. Get and verify chat ownership
chat = get_chat_for_user(chat_id, user_id, db)
# 2. Load and convert history for the agent
db_messages = message_crud.get_messages_by_chat(db, chat_id)
is_first_user_message = not db_messages
messages_for_agent: list[BaseMessage] = _convert_db_messages_to_langchain(db_messages)
# 3. Add the new user message and mark the boundary for what's "new"
messages_for_agent.append(HumanMessage(content=message_in.content))
initial_message_count = len(messages_for_agent)
# 4. Invoke the agent
response = await agent.ainvoke(messages_for_agent, config={"configurable": {"user_id": str(user_id)}})
# 5. Extract results from the agent's structured output
final_answer = response["answer"]
links = response["links"]
updated_messages_from_agent = response["messages"]
# 6. Isolate the messages that are new to this turn
new_lc_messages = updated_messages_from_agent[initial_message_count - 1:]
# 7. Atomically persist the entire turn to the database
newly_created_db_messages = message_crud.create_messages_for_turn(
db=db,
chat_id=chat_id,
new_lc_messages=new_lc_messages,
final_answer=final_answer,
links=links,
)
# 8. If this was the first message, generate and set a title for the chat
if is_first_user_message:
new_title = await generate_chat_title(message_in.content)
chat_crud.update_chat_title(db=db, chat=chat, chat_update=ChatUpdate(title=new_title))
# 9. Find the final AI message from the list of newly saved messages.
final_ai_message = newly_created_db_messages[-1]
return final_ai_message