|
import uuid |
|
from typing import List |
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status |
|
from sqlmodel import Session |
|
|
|
|
|
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage |
|
|
|
|
|
from api.dependencies.auth import get_current_user |
|
from db.session import get_db |
|
from db.crud import chat as chat_crud, message as message_crud |
|
from db.models.chat import Chat |
|
from db.models.message import Message as DBMessage |
|
from db.schemas.chat import ChatReadSimple, ChatUpdate, ChatReadWithMessages |
|
from db.schemas.message import MessageCreate, MessageRead |
|
from workflow.agent import agent |
|
from workflow.title_generator import generate_chat_title |
|
|
|
router = APIRouter() |
|
|
|
|
|
def get_chat_for_user(chat_id: uuid.UUID, user_id: uuid.UUID, db: Session) -> Chat: |
|
""" |
|
A helper dependency to get a chat and verify the current user owns it. |
|
""" |
|
chat = chat_crud.get_chat_by_id(db, chat_id) |
|
if not chat: |
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found") |
|
if chat.user_id != user_id: |
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to access this chat") |
|
return chat |
|
|
|
def _convert_db_messages_to_langchain(db_messages: List[DBMessage]) -> List[BaseMessage]: |
|
""" |
|
Converts a list of database message objects into a list of LangChain BaseMessage objects. |
|
""" |
|
langchain_messages = [] |
|
for msg in db_messages: |
|
if msg.role == "user": |
|
langchain_messages.append(HumanMessage(content=msg.content)) |
|
elif msg.role == "assistant": |
|
|
|
if msg.tool_calls: |
|
langchain_messages.append(AIMessage(content="", tool_calls=msg.tool_calls)) |
|
else: |
|
langchain_messages.append(AIMessage(content=msg.content)) |
|
elif msg.role == "tool": |
|
langchain_messages.append( |
|
ToolMessage(content=msg.content, tool_call_id=msg.tool_call_id) |
|
) |
|
return langchain_messages |
|
|
|
|
|
|
|
|
|
@router.post("/", response_model=ChatReadSimple, status_code=status.HTTP_201_CREATED) |
|
def create_new_chat(*, db: Session = Depends(get_db), user_id: uuid.UUID = Depends(get_current_user)): |
|
"""Creates a new, empty chat for the authenticated user.""" |
|
return chat_crud.create_chat(db=db, user_id=user_id) |
|
|
|
@router.get("/", response_model=List[ChatReadSimple]) |
|
def get_user_chats(*, db: Session = Depends(get_db), user_id: uuid.UUID = Depends(get_current_user)): |
|
"""Retrieves all chats for the authenticated user.""" |
|
return chat_crud.get_chats_by_user(db=db, user_id=user_id) |
|
|
|
@router.get("/{chat_id}", response_model=ChatReadWithMessages) |
|
def get_single_chat_with_messages(*, chat_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user), db: Session = Depends(get_db)): |
|
"""Retrieves a specific chat with all its messages.""" |
|
chat = get_chat_for_user(chat_id, user_id, db) |
|
filtered_messages = [ |
|
msg for msg in chat.messages |
|
if msg.role == 'user' or (msg.role == 'assistant' and msg.content) |
|
] |
|
|
|
return ChatReadWithMessages( |
|
id=chat.id, |
|
title=chat.title, |
|
created_at=chat.created_at, |
|
updated_at=chat.updated_at, |
|
messages=filtered_messages |
|
) |
|
|
|
@router.patch("/{chat_id}", response_model=ChatReadSimple) |
|
def rename_chat(*, chat_id: uuid.UUID, chat_update: ChatUpdate, user_id: uuid.UUID = Depends(get_current_user), db: Session = Depends(get_db)): |
|
"""Renames a specific chat.""" |
|
chat = get_chat_for_user(chat_id, user_id, db) |
|
return chat_crud.update_chat_title(db=db, chat=chat, chat_update=chat_update) |
|
|
|
@router.delete("/{chat_id}", status_code=status.HTTP_204_NO_CONTENT) |
|
def remove_chat(*, chat_id: uuid.UUID, user_id: uuid.UUID = Depends(get_current_user), db: Session = Depends(get_db)): |
|
"""Deletes a specific chat and all its messages.""" |
|
chat = get_chat_for_user(chat_id, user_id, db) |
|
chat_crud.delete_chat(db=db, chat=chat) |
|
return |
|
|
|
|
|
|
|
|
|
@router.post("/{chat_id}/messages", response_model=MessageRead) |
|
async def post_message_and_get_response( |
|
*, |
|
chat_id: uuid.UUID, |
|
message_in: MessageCreate, |
|
user_id: uuid.UUID = Depends(get_current_user), |
|
db: Session = Depends(get_db) |
|
): |
|
""" |
|
Handles a user's message by invoking the agent and persisting the full turn. |
|
""" |
|
|
|
chat = get_chat_for_user(chat_id, user_id, db) |
|
|
|
|
|
db_messages = message_crud.get_messages_by_chat(db, chat_id) |
|
is_first_user_message = not db_messages |
|
messages_for_agent: list[BaseMessage] = _convert_db_messages_to_langchain(db_messages) |
|
|
|
|
|
messages_for_agent.append(HumanMessage(content=message_in.content)) |
|
initial_message_count = len(messages_for_agent) |
|
|
|
|
|
response = await agent.ainvoke(messages_for_agent, config={"configurable": {"user_id": str(user_id)}}) |
|
|
|
|
|
final_answer = response["answer"] |
|
links = response["links"] |
|
updated_messages_from_agent = response["messages"] |
|
|
|
new_lc_messages = updated_messages_from_agent[initial_message_count - 1:] |
|
|
|
|
|
newly_created_db_messages = message_crud.create_messages_for_turn( |
|
db=db, |
|
chat_id=chat_id, |
|
new_lc_messages=new_lc_messages, |
|
final_answer=final_answer, |
|
links=links, |
|
) |
|
|
|
|
|
if is_first_user_message: |
|
new_title = await generate_chat_title(message_in.content) |
|
chat_crud.update_chat_title(db=db, chat=chat, chat_update=ChatUpdate(title=new_title)) |
|
|
|
|
|
final_ai_message = newly_created_db_messages[-1] |
|
return final_ai_message |