Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from gradio_client import Client | |
| from langgraph.graph import StateGraph, START, END | |
| from typing import TypedDict, Optional | |
| import io | |
| from PIL import Image | |
| #OPEN QUESTION: SHOULD WE PASS ALL PARAMS FROM THE ORCHESTRATOR TO THE NODES INSTEAD OF SETTING IN EACH MODULE? | |
| # Define the state schema | |
| class GraphState(TypedDict): | |
| query: str | |
| context: str | |
| result: str | |
| # Add orchestrator-level parameters (addressing your open question) | |
| reports_filter: str | |
| sources_filter: str | |
| subtype_filter: str | |
| year_filter: str | |
| # node 2: retriever | |
| def retrieve_node(state: GraphState) -> GraphState: | |
| client = Client("giz/chatfed_retriever") # HF repo name | |
| context = client.predict( | |
| query=state["query"], | |
| reports_filter=state.get("reports_filter", ""), | |
| sources_filter=state.get("sources_filter", ""), | |
| subtype_filter=state.get("subtype_filter", ""), | |
| year_filter=state.get("year_filter", ""), | |
| api_name="/retrieve" | |
| ) | |
| return {"context": context} | |
| # node 3: generator | |
| def generate_node(state: GraphState) -> GraphState: | |
| client = Client("giz/chatfed_generator") | |
| result = client.predict( | |
| query=state["query"], | |
| context=state["context"], | |
| api_name="/generate" | |
| ) | |
| return {"result": result} | |
| # build the graph | |
| workflow = StateGraph(GraphState) | |
| # Add nodes | |
| workflow.add_node("retrieve", retrieve_node) | |
| workflow.add_node("generate", generate_node) | |
| # Add edges | |
| workflow.add_edge(START, "retrieve") | |
| workflow.add_edge("retrieve", "generate") | |
| workflow.add_edge("generate", END) | |
| # Compile the graph | |
| graph = workflow.compile() | |
| # Single tool for processing queries | |
| def process_query( | |
| query: str, | |
| reports_filter: str = "", | |
| sources_filter: str = "", | |
| subtype_filter: str = "", | |
| year_filter: str = "" | |
| ) -> str: | |
| """ | |
| Execute the ChatFed orchestration pipeline to process a user query. | |
| This function orchestrates a two-step workflow: | |
| 1. Retrieve relevant context using the ChatFed retriever service with optional filters | |
| 2. Generate a response using the ChatFed generator service with the retrieved context | |
| Args: | |
| query (str): The user's input query/question to be processed | |
| reports_filter (str, optional): Filter for specific report types. Defaults to "". | |
| sources_filter (str, optional): Filter for specific data sources. Defaults to "". | |
| subtype_filter (str, optional): Filter for document subtypes. Defaults to "". | |
| year_filter (str, optional): Filter for specific years. Defaults to "". | |
| Returns: | |
| str: The generated response from the ChatFed generator service | |
| """ | |
| initial_state = { | |
| "query": query, | |
| "context": "", | |
| "result": "", | |
| "reports_filter": reports_filter or "", | |
| "sources_filter": sources_filter or "", | |
| "subtype_filter": subtype_filter or "", | |
| "year_filter": year_filter or "" | |
| } | |
| final_state = graph.invoke(initial_state) | |
| return final_state["result"] | |
| # Simple testing interface | |
| ui = gr.Interface( | |
| fn=process_query, | |
| inputs=gr.Textbox(lines=2, placeholder="Enter query here"), | |
| outputs="text", | |
| flagging_mode="never" | |
| ) | |
| # Add a function to generate the graph visualization | |
| def get_graph_visualization(): | |
| """Generate and return the LangGraph workflow visualization as a PIL Image.""" | |
| # Generate the graph as PNG bytes | |
| graph_png_bytes = graph.get_graph().draw_mermaid_png() | |
| # Convert bytes to PIL Image for Gradio display | |
| graph_image = Image.open(io.BytesIO(graph_png_bytes)) | |
| return graph_image | |
| # Guidance for ChatUI - can be removed later. Questionable whether front end even necessary. Maybe nice to show the graph. | |
| with gr.Blocks(title="ChatFed Orchestrator") as demo: | |
| gr.Markdown("# ChatFed Orchestrator") | |
| gr.Markdown("This LangGraph server exposes MCP endpoints for the ChatUI module to call (which triggers the graph).") | |
| with gr.Row(): | |
| # Left column - Graph visualization | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Workflow Visualization**") | |
| graph_display = gr.Image( | |
| value=get_graph_visualization(), | |
| label="LangGraph Workflow", | |
| interactive=False, | |
| height=300 | |
| ) | |
| # Add a refresh button for the graph | |
| refresh_graph_btn = gr.Button("🔄 Refresh Graph", size="sm") | |
| refresh_graph_btn.click( | |
| fn=get_graph_visualization, | |
| outputs=graph_display | |
| ) | |
| # Right column - Interface and documentation | |
| with gr.Column(scale=2): | |
| gr.Markdown("**Available MCP Tools:**") | |
| with gr.Accordion("MCP Endpoint Information", open=True): | |
| gr.Markdown(f""" | |
| **MCP Server Endpoint:** https://giz-chatfed-orchestrator.hf.space/gradio_api/mcp/sse | |
| **For ChatUI Integration:** | |
| ```python | |
| from gradio_client import Client | |
| # Connect to orchestrator | |
| orchestrator_client = Client("https://giz-chatfed-orchestrator.hf.space") | |
| # Basic usage (no filters) | |
| response = orchestrator_client.predict( | |
| query="query", | |
| api_name="/process_query" | |
| ) | |
| # Advanced usage with any combination of filters | |
| response = orchestrator_client.predict( | |
| query="query", | |
| reports_filter="annual_reports", | |
| sources_filter="internal", | |
| year_filter="2024", | |
| api_name="/process_query" | |
| ) | |
| ``` | |
| """) | |
| with gr.Accordion("Quick Testing Interface", open=True): | |
| ui.render() | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| mcp_server=True, | |
| show_error=True | |
| ) |