import os import sqlite3 import operator import time from typing import TypedDict, Annotated, List from langgraph.graph import StateGraph, END from langgraph.checkpoint.sqlite import SqliteSaver from langchain_core.messages import SystemMessage, HumanMessage from langchain_openai import ChatOpenAI from langchain_core.pydantic_v1 import BaseModel from dotenv import load_dotenv _ = load_dotenv() class AgentState(TypedDict): task: str lnode: str plan: str draft: str critique: str content: List[str] queries: List[str] revision_number: int max_revisions: int count: Annotated[int, operator.add] class Queries(BaseModel): queries: List[str] class ewriter(): def __init__(self): self.model = ChatOpenAI(model="gpt-4o-mini", temperature=0.5) self.PLAN_PROMPT = ( "You are a visionary startup strategist tasked with outlining comprehensive business plans " "for emerging companies. Identify the unique selling points (USPs) and core competencies " "of startups in various sectors such as tech (e.g., innovative tech stacks), healthcare " "(e.g., compliance with medical regulations), fintech (e.g., secure payment systems), " "and e-commerce (e.g., efficient supply chain management). Provide detailed strategies" ) self.WRITER_PROMPT = ( "You are a seasoned startup consultant responsible for crafting compelling pitches " "that highlight what sets each company apart. Ensure clarity and persuasive storytelling" ) self.RESEARCH_PLAN_PROMPT = ( "You are a cutting-edge market research expert specializing in identifying trends " "and analyzing competitors within specific industries. Offer actionable insights based on" "the latest market data regarding regulatory environments, technological innovations," ) self.REFLECTION_PROMPT = ( "You are a top-tier investor and mentor providing strategic advice on fundraising," "team building, operational efficiency tailored to each industry's challenges." ) self.RESEARCH_CRITIQUE_PROMPT = ( "You are an expert research analyst tasked with evaluating the feasibility of new business ideas," "assessing potential risks/rewards based on industry-specific factors like regulatory hurdles," ) try: from tavily import TavilyClient self.tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) except ImportError: class DummyTavilyClient: def __init__(self, api_key): self.api_key = api_key def search(self, query: str, max_results: int = 2): return {"results": [{"content": f"Dummy result for '{query}'"}]} self.tavily = DummyTavilyClient(api_key="dummy") builder = StateGraph(AgentState) builder.add_node("strategy", self.strategy_node) builder.add_node("market_research", self.market_research_node) builder.add_node("refinement", self.refinement_node) builder.add_node("review", self.review_node) builder.add_node("followup_research", self.followup_research_node) builder.set_entry_point("strategy") builder.add_conditional_edges( "refinement", self.should_continue, {END: END, "review": "review"} ) builder.add_edge("strategy", "market_research") builder.add_edge("market_research", "refinement") builder.add_edge("review", "followup_research") builder.add_edge("followup_research", "refinement") memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False)) self.graph = builder.compile( checkpointer=memory, interrupt_after=[ 'strategy', 'refinement', 'review', 'market_research', 'followup_research' ] ) def strategy_node(self, state: AgentState): messages = [ SystemMessage(content=self.PLAN_PROMPT), HumanMessage(content=state['task']) ] response = self.model.invoke(messages) return {"plan": response.content, "lnode": "strategy", "count": 1} def market_research_node(self, state: AgentState): queries = self.model.with_structured_output(Queries).invoke([ SystemMessage(content=self.RESEARCH_PLAN_PROMPT), HumanMessage(content=state['task']) ]) content = state['content'] or [] for q in queries.queries: resp = self.tavily.search(query=q, max_results=2) for r in resp['results']: content.append(r['content']) return {"content": content, "queries": queries.queries, "lnode": "market_research", "count": 1} def refinement_node(self, state: AgentState): content = "\n\n".join(state['content'] or []) user_message = HumanMessage(content=f"{state['task']}\n\nHere is the strategic outline:\n\n{state['plan']}") messages = [ SystemMessage(content=self.WRITER_PROMPT.format(content=content)), user_message ] response = self.model.invoke(messages) return { "draft": response.content, "revision_number": state.get("revision_number", 1) + 1, "lnode": "refinement", "count": 1 } def review_node(self, state: AgentState): messages = [ SystemMessage(content=self.REFLECTION_PROMPT), HumanMessage(content=state['draft']) ] response = self.model.invoke(messages) return {"critique": response.content, "lnode": "review", "count": 1} def followup_research_node(self, state: AgentState): queries = self.model.with_structured_output(Queries).invoke([ SystemMessage(content=self.RESEARCH_CRITIQUE_PROMPT), HumanMessage(content=state['critique']) ]) content = state['content'] or [] for q in queries.queries: resp = self.tavily.search(query=q, max_results=2) for r in resp['results']: content.append(r['content']) return {"content": content, "lnode": "followup_research", "count": 1} def should_continue(self, state): if state["revision_number"] > state["max_revisions"]: return END return "review" import gradio as gr class writer_gui(): def __init__(self, graph, share=False): self.graph = graph self.share = share self.partial_message = "" self.response = {} self.max_iterations = 10 # Initialize with a default thread so dropdowns are populated. self.iterations = [0] self.threads = [0] self.thread_id = 0 self.thread = {"configurable": {"thread_id": "0"}} self.demo = self.create_interface() def run_agent(self, start, topic, stop_after): if start: self.iterations.append(0) config = { 'task': topic, "max_revisions": 2, "revision_number": 0, 'lnode': "", 'strategy': "no strategy", 'draft': "no plan", 'critique': "no review", 'content': ["no research content"], 'queries': "no queries", 'count': 0 } self.thread_id += 1 self.threads.append(self.thread_id) else: config = None self.thread = {"configurable": {"thread_id": str(self.thread_id)}} while self.iterations[self.thread_id] < self.max_iterations: self.response = self.graph.invoke(config, self.thread) self.iterations[self.thread_id] += 1 self.partial_message += str(self.response) self.partial_message += "\n------------------\n\n" lnode, nnode, _, rev, acount = self.get_disp_state() yield (self.partial_message, lnode, nnode, self.thread_id, rev, acount) config = None if not nnode or lnode in stop_after: return def get_disp_state(self): current_state = self.graph.get_state(self.thread) lnode = current_state.values["lnode"] acount = current_state.values["count"] rev = current_state.values["revision_number"] nnode = current_state.next return lnode, nnode, self.thread_id, rev, acount def get_state(self, key): current_values = self.graph.get_state(self.thread) if key in current_values.values: lnode, nnode, _, rev, astep = self.get_disp_state() new_label = f"last_node: {lnode}, thread: {self.thread_id}, rev: {rev}, step: {astep}" return gr.update(label=new_label, value=current_values.values[key]) else: return "" def get_content(self): current_values = self.graph.get_state(self.thread) if "content" in current_values.values: content = current_values.values["content"] lnode, nnode, _, rev, astep = self.get_disp_state() new_label = f"last_node: {lnode}, thread: {self.thread_id}, rev: {rev}, step: {astep}" return gr.update(label=new_label, value="\n\n".join(item for item in content) + "\n\n") else: return "" def create_interface(self): with gr.Blocks(theme=gr.themes.Default(spacing_size='sm', text_size="sm")) as demo: def updt_disp(): current_state = self.graph.get_state(self.thread) hist = [] for st in self.graph.get_state_history(self.thread): if st.metadata['step'] < 1: continue s_thread_ts = st.config['configurable']['thread_ts'] s_tid = st.config['configurable']['thread_id'] s_count = st.values['count'] s_lnode = st.values['lnode'] s_rev = st.values['revision_number'] s_nnode = st.next hist.append(f"{s_tid}:{s_count}:{s_lnode}:{s_nnode}:{s_rev}:{s_thread_ts}") if not current_state.metadata or not hist: return {} return { topic_bx: gr.update(value=str(current_state.values["task"])), lnode_bx: gr.update(value=str(current_state.values["lnode"])), count_bx: gr.update(value=str(current_state.values["count"])), revision_bx: gr.update(value=str(current_state.values["revision_number"])), nnode_bx: gr.update(value=str(current_state.next or "")), threadid_bx: gr.update(value=str(self.thread_id)), thread_pd: gr.update( label="choose thread", choices=self.threads, value=self.thread_id, interactive=True ), step_pd: gr.update( label="update state from: thread:count:last_node:next_node:rev:thread_ts", choices=hist, value=hist[0], interactive=True ), } with gr.Tab("Agent"): with gr.Row(): topic_bx = gr.Textbox(label="Startup Idea", value="Revolutionary AI-driven Healthtech Platform") gen_btn = gr.Button("Generate Plan", scale=0, min_width=80, variant='primary') cont_btn = gr.Button("Continue Refinement", scale=0, min_width=80) with gr.Row(): lnode_bx = gr.Textbox(label="Last Node", min_width=100) nnode_bx = gr.Textbox(label="Next Node", min_width=100) threadid_bx = gr.Textbox(label="Thread", scale=0, min_width=80) revision_bx = gr.Textbox(label="Plan Rev", scale=0, min_width=80) count_bx = gr.Textbox(label="Count", scale=0, min_width=80) with gr.Accordion("Manage Agent", open=False): checks = list(self.graph.nodes.keys()) checks.remove('__start__') stop_after = gr.CheckboxGroup( checks, label="Interrupt After State", value=checks, scale=0, min_width=400 ) with gr.Row(): thread_pd = gr.Dropdown( choices=self.threads, interactive=True, label="Select Thread", min_width=120, scale=0 ) step_pd = gr.Dropdown( choices=['N/A'], interactive=True, label="Select Step", min_width=160, scale=1 ) live = gr.Textbox(label="Live Agent Output", lines=5, max_lines=5) sdisps = [topic_bx, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx, step_pd, thread_pd] thread_pd.change( self.switch_thread, [thread_pd], None ).then( fn=updt_disp, inputs=None, outputs=None ) step_pd.change( self.copy_state, [step_pd], None ).then( fn=updt_disp, inputs=None, outputs=None ) gen_btn.click( fn=lambda stat: gr.update(variant=stat), inputs=gr.Textbox(value="secondary", visible=False), outputs=gen_btn ).then( fn=self.run_agent, inputs=[gr.State(value=True), topic_bx, stop_after], outputs=[live, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx], show_progress=True ).then( fn=updt_disp, inputs=None, outputs=None ).then( fn=lambda stat: gr.update(variant=stat), inputs=gr.Textbox(value="primary", visible=False), outputs=gen_btn ).then( fn=lambda stat: gr.update(variant=stat), inputs=gr.Textbox(value="primary", visible=False), outputs=cont_btn ) cont_btn.click( fn=lambda stat: gr.update(variant=stat), inputs=gr.Textbox(value="secondary", visible=False), outputs=cont_btn ).then( fn=self.run_agent, inputs=[gr.State(value=False), topic_bx, stop_after], outputs=[live, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx] ).then( fn=updt_disp, inputs=None, outputs=None ).then( fn=lambda stat: gr.update(variant=stat), inputs=gr.Textbox(value="primary", visible=False), outputs=cont_btn ) with gr.Tab("Strategic Outline"): with gr.Row(): refresh_btn = gr.Button("Refresh") modify_btn = gr.Button("Modify") plan = gr.Textbox(label="Strategic Outline", lines=10, interactive=True) refresh_btn.click( fn=self.get_state, inputs=gr.Textbox(value="plan", visible=False), outputs=plan ) modify_btn.click( fn=self.modify_state, inputs=[ gr.Textbox(value="plan", visible=False), gr.Textbox(value="strategy", visible=False), plan ], outputs=None ).then( fn=updt_disp, inputs=None, outputs=None ) with gr.Tab("Research Content"): refresh_btn = gr.Button("Refresh") content_bx = gr.Textbox(label="Research Content", lines=10) refresh_btn.click(fn=self.get_content, inputs=None, outputs=content_bx) with gr.Tab("Startup Plan"): with gr.Row(): refresh_btn = gr.Button("Refresh") modify_btn = gr.Button("Modify") draft_bx = gr.Textbox(label="Startup Plan Draft", lines=10, interactive=True) refresh_btn.click( fn=self.get_state, inputs=gr.Textbox(value="draft", visible=False), outputs=draft_bx ) modify_btn.click( fn=self.modify_state, inputs=[ gr.Textbox(value="draft", visible=False), gr.Textbox(value="refinement", visible=False), draft_bx ], outputs=None ).then( fn=updt_disp, inputs=None, outputs=None ) with gr.Tab("Review"): with gr.Row(): refresh_btn = gr.Button("Refresh") modify_btn = gr.Button("Modify") critique_bx = gr.Textbox(label="Review / Critique", lines=10, interactive=True) refresh_btn.click( fn=self.get_state, inputs=gr.Textbox(value="critique", visible=False), outputs=critique_bx ) modify_btn.click( fn=self.modify_state, inputs=[ gr.Textbox(value="critique", visible=False), gr.Textbox(value="review", visible=False), critique_bx ], outputs=None ).then( fn=updt_disp, inputs=None, outputs=None ) with gr.Tab("State Snapshots"): with gr.Row(): refresh_btn = gr.Button("Refresh") snapshots = gr.Textbox(label="State Snapshots", lines=10) refresh_btn.click(fn=lambda: self.get_snapshots(), inputs=None, outputs=snapshots) return demo def launch(self, share=None): if port := os.getenv("PORT1"): self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0") else: self.demo.launch(share=self.share) def copy_state(self, hist_str): thread_ts = hist_str.split(":")[-1] config = self.find_config(thread_ts) state = self.graph.get_state(config) self.graph.update_state(self.thread, state.values, as_node=state.values['lnode']) new_state = self.graph.get_state(self.thread) return ( new_state.values['lnode'], new_state.next, new_state.config['configurable']['thread_ts'], new_state.values['revision_number'], new_state.values['count'] ) def find_config(self, thread_ts): for st in self.graph.get_state_history(self.thread): cfg = st.config if cfg['configurable']['thread_ts'] == thread_ts: return cfg return None def switch_thread(self, new_thread_id): self.thread = {"configurable": {"thread_id": str(new_thread_id)}} self.thread_id = int(new_thread_id) return def modify_state(self, key, asnode, new_state): current_values = self.graph.get_state(self.thread) current_values.values[key] = new_state self.graph.update_state(self.thread, current_values.values, as_node=asnode) return def get_snapshots(self): s = "" for st in self.graph.get_state_history(self.thread): for key in ['plan', 'draft', 'critique']: if key in st.values: st.values[key] = st.values[key][:80] + "..." if 'content' in st.values: for i in range(len(st.values['content'])): st.values['content'][i] = st.values['content'][i][:20] + '...' s += str(st) + "\n\n" return gr.update(label=f"Thread {self.thread_id} Snapshots", value=s) # Finally, understand what is happening under the hood, launch the Gradio app if __name__ == "__main__": writer_instance = ewriter() gui_instance = writer_gui(writer_instance.graph, share=True) gui_instance.launch()