File size: 6,718 Bytes
15cc131
c7999f8
 
2e96320
 
15cc131
 
2e96320
6a8ebee
c7999f8
2e96320
15cc131
2e96320
 
 
 
 
 
 
 
 
ca57ee4
 
 
 
 
 
 
 
 
 
 
 
 
c7999f8
 
 
 
 
 
 
 
 
 
 
 
 
6a8ebee
 
c7999f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e96320
15cc131
c7999f8
2e96320
 
 
 
 
 
 
15cc131
2e96320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a8ebee
 
 
2e96320
 
 
15cc131
 
2e96320
 
 
 
 
 
 
 
 
 
 
 
 
 
15cc131
6a8ebee
 
 
15cc131
 
6a8ebee
2e96320
6a8ebee
 
 
 
 
 
 
 
 
 
 
 
 
d80f5ac
6a8ebee
 
 
2e96320
15cc131
 
2e96320
 
15cc131
 
2e96320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15cc131
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import chainlit as cl
from chainlit.input_widget import Select, TextInput

import logging
from typing import Optional

from src.axiom.agent import AxiomAgent
from src.axiom.config import settings, load_mcp_servers_from_config
from src.axiom.prompts import AXIOM_AGENT_PROMPT, AXIOM_ASSISTANT_PROMPT

from agents.mcp import MCPServer 

# Configure logging for the UI module
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Constants ---
SESSION_AGENT_KEY = "axiom_agent"
SESSION_HISTORY_KEY = "chat_history"
SESSION_MCP_SERVERS_KEY = "mcp_servers"

#################################
# User Authentication
#################################
@cl.oauth_callback
def oauth_callback(
  provider_id: str,
  token: str,
  raw_user_data: dict[str, str],
  default_user: cl.User,
) -> Optional[cl.User]:

  return default_user

#################################
# Quick Starter Questions
#################################
@cl.set_starters
async def set_starters():
    return [
        cl.Starter(
            label="LangGraph Agents Creation",
            message="Create a Multi-Agent customer support system using LangGraph swarm.",
            icon="/public/msg_icons/chatbot.png",
            ),

        cl.Starter(
            label="How to use Autogent",
            message="How to use Autogen? Create some agents using Autogen",
            icon="/public/msg_icons/usb.png",
            ),
        cl.Starter(
            label="JWT Auth Implementation",
            message="How to implement JWT authentication in FastAPi? Create a complete project.",
            icon="/public/msg_icons/tools.png",
            ),

        ]
#################################
# Response modes for Axiom
#################################
@cl.set_chat_profiles
async def chat_profile():
    return [
        cl.ChatProfile(
            name="Agent✨",
            markdown_description= "Ideal for complex tasks like complete projects building, and full-stack apps creation."
        ),
        cl.ChatProfile(
            name="Assistant",
            markdown_description="Suited for quick information retrieval and answering questions."

        ),
    ]

#################################
# Chat Settings
#################################
@cl.on_chat_start
async def on_chat_start():
    # Initialize chat history and MCP servers
    cl.user_session.set(SESSION_HISTORY_KEY, [])
    cl.user_session.set(SESSION_MCP_SERVERS_KEY, [])

    loaded_mcp_servers: list[MCPServer] = []
    started_mcp_servers: list[MCPServer] = []

    # 1. Load MCP Server configurations
    try:
        loaded_mcp_servers = load_mcp_servers_from_config()

    except FileNotFoundError:
        await cl.ErrorMessage(content=f"Fatal Error: MCP configuration file not found at '{settings.mcp_config_path}'. Agent cannot start.").send()
        return 
    except (ValueError, Exception) as e:
        logger.exception("Failed to load MCP server configurations.")
        await cl.ErrorMessage(content=f"Fatal Error: Could not load MCP configurations: {e}. Agent cannot start.").send()
        return 

    # 2. Start the loaded MCP servers sequentially
    if loaded_mcp_servers:
        for server in loaded_mcp_servers:
            try:
                logger.info(f"Starting MCP Server: {server.name}...")
                # The __aenter__ starts the server process
                await server.__aenter__() 
                started_mcp_servers.append(server)
                logger.info(f"MCP Server '{server.name}' started successfully.")
            except Exception as e:
                logger.error(f"Failed to start MCP Server '{server.name}': {e}")

        # Store only the successfully started servers for later cleanup
        cl.user_session.set(SESSION_MCP_SERVERS_KEY, started_mcp_servers) 

#################################
# MCP Server Cleanup
#################################
async def cleanup_mcp_servers():
    started_mcp_servers = cl.user_session.get(SESSION_MCP_SERVERS_KEY, [])
    if not started_mcp_servers:
        return

    logger.info(f"Cleaning up {len(started_mcp_servers)} MCP server(s)...")
    for server in started_mcp_servers:
        try:
            await server.__aexit__(None, None, None) 
            logger.info(f"MCP Server '{server.name}' stopped.")
        except Exception as e:
            logger.error(f"Error stopping MCP Server '{server.name}': {e}", exc_info=True) 
    
    cl.user_session.set(SESSION_MCP_SERVERS_KEY, []) # Clear the list

@cl.on_chat_end
async def on_chat_end():
    logger.info("Chat session ending.")
    await cleanup_mcp_servers()

#################################
# User Message Handler
#################################
@cl.on_message
async def on_message(message: cl.Message):
    started_mcp_servers = cl.user_session.get(SESSION_MCP_SERVERS_KEY)
    chat_history: list[dict] = cl.user_session.get(SESSION_HISTORY_KEY, []) 
    
    # Initialize the Axiom Agent
    axiom_mode = cl.user_session.get("chat_profile")
    
    # Filter MCP servers based on chat profile
    if axiom_mode == "Assistant":
        # For Assistant mode, only use context7 MCP server
        filtered_servers = [server for server in started_mcp_servers if server.name.lower() == "context7"]
    else:
        # For Agent mode, use all MCP servers
        filtered_servers = started_mcp_servers

    agent = AxiomAgent(
        model=settings.DEFAULT_AGENT_MODEL if axiom_mode == "Agent✨" else settings.DEFAULT_ASSISTANT_MODEL,
        mcp_servers=filtered_servers,
        prompt=AXIOM_AGENT_PROMPT if axiom_mode == "Agent✨" else AXIOM_ASSISTANT_PROMPT,
        )

    # Add user message to history
    chat_history.append({"role": "user", "content": message.content})
    
    msg = cl.Message(content="")
    full_response = ""

    try:
        response = agent.stream_agent(chat_history)
        async for token in response:
            full_response += token
            await msg.stream_token(token) # Stream tokens to the UI
        # Send the final message
        await msg.send()
    except Exception as e:
        logger.exception(f"Error during agent response streaming: {e}")
        # Update the message placeholder with an error
        await msg.update(content=f"Sorry, an error occurred while processing your request.") 
        chat_history.append({"role": "assistant", "content": f"[Agent Error Occurred]"}) 
    else:
        # Only append successful response if no exception occurred
        chat_history.append({"role": "assistant", "content": full_response})
    finally:
        # Update session history
        cl.user_session.set(SESSION_HISTORY_KEY, chat_history)