my-gradio-app / app.py
ayushroy's picture
Update app.py
d0673bc verified
import os
import sqlite3
import operator
import time
from typing import TypedDict, Annotated, List
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel
from dotenv import load_dotenv
_ = load_dotenv()
class AgentState(TypedDict):
task: str
lnode: str
plan: str
draft: str
critique: str
content: List[str]
queries: List[str]
revision_number: int
max_revisions: int
count: Annotated[int, operator.add]
class Queries(BaseModel):
queries: List[str]
class ewriter():
def __init__(self):
self.model = ChatOpenAI(model="gpt-4o-mini", temperature=0.5)
self.PLAN_PROMPT = (
"You are a visionary startup strategist tasked with outlining comprehensive business plans "
"for emerging companies. Identify the unique selling points (USPs) and core competencies "
"of startups in various sectors such as tech (e.g., innovative tech stacks), healthcare "
"(e.g., compliance with medical regulations), fintech (e.g., secure payment systems), "
"and e-commerce (e.g., efficient supply chain management). Provide detailed strategies"
)
self.WRITER_PROMPT = (
"You are a seasoned startup consultant responsible for crafting compelling pitches "
"that highlight what sets each company apart. Ensure clarity and persuasive storytelling"
)
self.RESEARCH_PLAN_PROMPT = (
"You are a cutting-edge market research expert specializing in identifying trends "
"and analyzing competitors within specific industries. Offer actionable insights based on"
"the latest market data regarding regulatory environments, technological innovations,"
)
self.REFLECTION_PROMPT = (
"You are a top-tier investor and mentor providing strategic advice on fundraising,"
"team building, operational efficiency tailored to each industry's challenges."
)
self.RESEARCH_CRITIQUE_PROMPT = (
"You are an expert research analyst tasked with evaluating the feasibility of new business ideas,"
"assessing potential risks/rewards based on industry-specific factors like regulatory hurdles,"
)
try:
from tavily import TavilyClient
self.tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])
except ImportError:
class DummyTavilyClient:
def __init__(self, api_key):
self.api_key = api_key
def search(self, query: str, max_results: int = 2):
return {"results": [{"content": f"Dummy result for '{query}'"}]}
self.tavily = DummyTavilyClient(api_key="dummy")
builder = StateGraph(AgentState)
builder.add_node("strategy", self.strategy_node)
builder.add_node("market_research", self.market_research_node)
builder.add_node("refinement", self.refinement_node)
builder.add_node("review", self.review_node)
builder.add_node("followup_research", self.followup_research_node)
builder.set_entry_point("strategy")
builder.add_conditional_edges(
"refinement",
self.should_continue,
{END: END, "review": "review"}
)
builder.add_edge("strategy", "market_research")
builder.add_edge("market_research", "refinement")
builder.add_edge("review", "followup_research")
builder.add_edge("followup_research", "refinement")
memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False))
self.graph = builder.compile(
checkpointer=memory,
interrupt_after=[
'strategy',
'refinement',
'review',
'market_research',
'followup_research'
]
)
def strategy_node(self, state: AgentState):
messages = [
SystemMessage(content=self.PLAN_PROMPT),
HumanMessage(content=state['task'])
]
response = self.model.invoke(messages)
return {"plan": response.content, "lnode": "strategy", "count": 1}
def market_research_node(self, state: AgentState):
queries = self.model.with_structured_output(Queries).invoke([
SystemMessage(content=self.RESEARCH_PLAN_PROMPT),
HumanMessage(content=state['task'])
])
content = state['content'] or []
for q in queries.queries:
resp = self.tavily.search(query=q, max_results=2)
for r in resp['results']:
content.append(r['content'])
return {"content": content, "queries": queries.queries, "lnode": "market_research", "count": 1}
def refinement_node(self, state: AgentState):
content = "\n\n".join(state['content'] or [])
user_message = HumanMessage(content=f"{state['task']}\n\nHere is the strategic outline:\n\n{state['plan']}")
messages = [
SystemMessage(content=self.WRITER_PROMPT.format(content=content)),
user_message
]
response = self.model.invoke(messages)
return {
"draft": response.content,
"revision_number": state.get("revision_number", 1) + 1,
"lnode": "refinement",
"count": 1
}
def review_node(self, state: AgentState):
messages = [
SystemMessage(content=self.REFLECTION_PROMPT),
HumanMessage(content=state['draft'])
]
response = self.model.invoke(messages)
return {"critique": response.content, "lnode": "review", "count": 1}
def followup_research_node(self, state: AgentState):
queries = self.model.with_structured_output(Queries).invoke([
SystemMessage(content=self.RESEARCH_CRITIQUE_PROMPT),
HumanMessage(content=state['critique'])
])
content = state['content'] or []
for q in queries.queries:
resp = self.tavily.search(query=q, max_results=2)
for r in resp['results']:
content.append(r['content'])
return {"content": content, "lnode": "followup_research", "count": 1}
def should_continue(self, state):
if state["revision_number"] > state["max_revisions"]:
return END
return "review"
import gradio as gr
class writer_gui():
def __init__(self, graph, share=False):
self.graph = graph
self.share = share
self.partial_message = ""
self.response = {}
self.max_iterations = 10
# Initialize with a default thread so dropdowns are populated.
self.iterations = [0]
self.threads = [0]
self.thread_id = 0
self.thread = {"configurable": {"thread_id": "0"}}
self.demo = self.create_interface()
def run_agent(self, start, topic, stop_after):
if start:
self.iterations.append(0)
config = {
'task': topic,
"max_revisions": 2,
"revision_number": 0,
'lnode': "",
'strategy': "no strategy",
'draft': "no plan",
'critique': "no review",
'content': ["no research content"],
'queries': "no queries",
'count': 0
}
self.thread_id += 1
self.threads.append(self.thread_id)
else:
config = None
self.thread = {"configurable": {"thread_id": str(self.thread_id)}}
while self.iterations[self.thread_id] < self.max_iterations:
self.response = self.graph.invoke(config, self.thread)
self.iterations[self.thread_id] += 1
self.partial_message += str(self.response)
self.partial_message += "\n------------------\n\n"
lnode, nnode, _, rev, acount = self.get_disp_state()
yield (self.partial_message, lnode, nnode, self.thread_id, rev, acount)
config = None
if not nnode or lnode in stop_after:
return
def get_disp_state(self):
current_state = self.graph.get_state(self.thread)
lnode = current_state.values["lnode"]
acount = current_state.values["count"]
rev = current_state.values["revision_number"]
nnode = current_state.next
return lnode, nnode, self.thread_id, rev, acount
def get_state(self, key):
current_values = self.graph.get_state(self.thread)
if key in current_values.values:
lnode, nnode, _, rev, astep = self.get_disp_state()
new_label = f"last_node: {lnode}, thread: {self.thread_id}, rev: {rev}, step: {astep}"
return gr.update(label=new_label, value=current_values.values[key])
else:
return ""
def get_content(self):
current_values = self.graph.get_state(self.thread)
if "content" in current_values.values:
content = current_values.values["content"]
lnode, nnode, _, rev, astep = self.get_disp_state()
new_label = f"last_node: {lnode}, thread: {self.thread_id}, rev: {rev}, step: {astep}"
return gr.update(label=new_label, value="\n\n".join(item for item in content) + "\n\n")
else:
return ""
def create_interface(self):
with gr.Blocks(theme=gr.themes.Default(spacing_size='sm', text_size="sm")) as demo:
def updt_disp():
current_state = self.graph.get_state(self.thread)
hist = []
for st in self.graph.get_state_history(self.thread):
if st.metadata['step'] < 1:
continue
s_thread_ts = st.config['configurable']['thread_ts']
s_tid = st.config['configurable']['thread_id']
s_count = st.values['count']
s_lnode = st.values['lnode']
s_rev = st.values['revision_number']
s_nnode = st.next
hist.append(f"{s_tid}:{s_count}:{s_lnode}:{s_nnode}:{s_rev}:{s_thread_ts}")
if not current_state.metadata or not hist:
return {}
return {
topic_bx: gr.update(value=str(current_state.values["task"])),
lnode_bx: gr.update(value=str(current_state.values["lnode"])),
count_bx: gr.update(value=str(current_state.values["count"])),
revision_bx: gr.update(value=str(current_state.values["revision_number"])),
nnode_bx: gr.update(value=str(current_state.next or "")),
threadid_bx: gr.update(value=str(self.thread_id)),
thread_pd: gr.update(
label="choose thread",
choices=self.threads,
value=self.thread_id,
interactive=True
),
step_pd: gr.update(
label="update state from: thread:count:last_node:next_node:rev:thread_ts",
choices=hist,
value=hist[0],
interactive=True
),
}
with gr.Tab("Agent"):
with gr.Row():
topic_bx = gr.Textbox(label="Startup Idea", value="Revolutionary AI-driven Healthtech Platform")
gen_btn = gr.Button("Generate Plan", scale=0, min_width=80, variant='primary')
cont_btn = gr.Button("Continue Refinement", scale=0, min_width=80)
with gr.Row():
lnode_bx = gr.Textbox(label="Last Node", min_width=100)
nnode_bx = gr.Textbox(label="Next Node", min_width=100)
threadid_bx = gr.Textbox(label="Thread", scale=0, min_width=80)
revision_bx = gr.Textbox(label="Plan Rev", scale=0, min_width=80)
count_bx = gr.Textbox(label="Count", scale=0, min_width=80)
with gr.Accordion("Manage Agent", open=False):
checks = list(self.graph.nodes.keys())
checks.remove('__start__')
stop_after = gr.CheckboxGroup(
checks,
label="Interrupt After State",
value=checks,
scale=0,
min_width=400
)
with gr.Row():
thread_pd = gr.Dropdown(
choices=self.threads,
interactive=True,
label="Select Thread",
min_width=120,
scale=0
)
step_pd = gr.Dropdown(
choices=['N/A'],
interactive=True,
label="Select Step",
min_width=160,
scale=1
)
live = gr.Textbox(label="Live Agent Output", lines=5, max_lines=5)
sdisps = [topic_bx, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx, step_pd, thread_pd]
thread_pd.change(
self.switch_thread, [thread_pd], None
).then(
fn=updt_disp, inputs=None, outputs=None
)
step_pd.change(
self.copy_state, [step_pd], None
).then(
fn=updt_disp, inputs=None, outputs=None
)
gen_btn.click(
fn=lambda stat: gr.update(variant=stat),
inputs=gr.Textbox(value="secondary", visible=False),
outputs=gen_btn
).then(
fn=self.run_agent,
inputs=[gr.State(value=True), topic_bx, stop_after],
outputs=[live, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx],
show_progress=True
).then(
fn=updt_disp, inputs=None, outputs=None
).then(
fn=lambda stat: gr.update(variant=stat),
inputs=gr.Textbox(value="primary", visible=False),
outputs=gen_btn
).then(
fn=lambda stat: gr.update(variant=stat),
inputs=gr.Textbox(value="primary", visible=False),
outputs=cont_btn
)
cont_btn.click(
fn=lambda stat: gr.update(variant=stat),
inputs=gr.Textbox(value="secondary", visible=False),
outputs=cont_btn
).then(
fn=self.run_agent,
inputs=[gr.State(value=False), topic_bx, stop_after],
outputs=[live, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx]
).then(
fn=updt_disp, inputs=None, outputs=None
).then(
fn=lambda stat: gr.update(variant=stat),
inputs=gr.Textbox(value="primary", visible=False),
outputs=cont_btn
)
with gr.Tab("Strategic Outline"):
with gr.Row():
refresh_btn = gr.Button("Refresh")
modify_btn = gr.Button("Modify")
plan = gr.Textbox(label="Strategic Outline", lines=10, interactive=True)
refresh_btn.click(
fn=self.get_state,
inputs=gr.Textbox(value="plan", visible=False),
outputs=plan
)
modify_btn.click(
fn=self.modify_state,
inputs=[
gr.Textbox(value="plan", visible=False),
gr.Textbox(value="strategy", visible=False),
plan
],
outputs=None
).then(
fn=updt_disp, inputs=None, outputs=None
)
with gr.Tab("Research Content"):
refresh_btn = gr.Button("Refresh")
content_bx = gr.Textbox(label="Research Content", lines=10)
refresh_btn.click(fn=self.get_content, inputs=None, outputs=content_bx)
with gr.Tab("Startup Plan"):
with gr.Row():
refresh_btn = gr.Button("Refresh")
modify_btn = gr.Button("Modify")
draft_bx = gr.Textbox(label="Startup Plan Draft", lines=10, interactive=True)
refresh_btn.click(
fn=self.get_state,
inputs=gr.Textbox(value="draft", visible=False),
outputs=draft_bx
)
modify_btn.click(
fn=self.modify_state,
inputs=[
gr.Textbox(value="draft", visible=False),
gr.Textbox(value="refinement", visible=False),
draft_bx
],
outputs=None
).then(
fn=updt_disp, inputs=None, outputs=None
)
with gr.Tab("Review"):
with gr.Row():
refresh_btn = gr.Button("Refresh")
modify_btn = gr.Button("Modify")
critique_bx = gr.Textbox(label="Review / Critique", lines=10, interactive=True)
refresh_btn.click(
fn=self.get_state,
inputs=gr.Textbox(value="critique", visible=False),
outputs=critique_bx
)
modify_btn.click(
fn=self.modify_state,
inputs=[
gr.Textbox(value="critique", visible=False),
gr.Textbox(value="review", visible=False),
critique_bx
],
outputs=None
).then(
fn=updt_disp, inputs=None, outputs=None
)
with gr.Tab("State Snapshots"):
with gr.Row():
refresh_btn = gr.Button("Refresh")
snapshots = gr.Textbox(label="State Snapshots", lines=10)
refresh_btn.click(fn=lambda: self.get_snapshots(), inputs=None, outputs=snapshots)
return demo
def launch(self, share=None):
if port := os.getenv("PORT1"):
self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0")
else:
self.demo.launch(share=self.share)
def copy_state(self, hist_str):
thread_ts = hist_str.split(":")[-1]
config = self.find_config(thread_ts)
state = self.graph.get_state(config)
self.graph.update_state(self.thread, state.values, as_node=state.values['lnode'])
new_state = self.graph.get_state(self.thread)
return (
new_state.values['lnode'],
new_state.next,
new_state.config['configurable']['thread_ts'],
new_state.values['revision_number'],
new_state.values['count']
)
def find_config(self, thread_ts):
for st in self.graph.get_state_history(self.thread):
cfg = st.config
if cfg['configurable']['thread_ts'] == thread_ts:
return cfg
return None
def switch_thread(self, new_thread_id):
self.thread = {"configurable": {"thread_id": str(new_thread_id)}}
self.thread_id = int(new_thread_id)
return
def modify_state(self, key, asnode, new_state):
current_values = self.graph.get_state(self.thread)
current_values.values[key] = new_state
self.graph.update_state(self.thread, current_values.values, as_node=asnode)
return
def get_snapshots(self):
s = ""
for st in self.graph.get_state_history(self.thread):
for key in ['plan', 'draft', 'critique']:
if key in st.values:
st.values[key] = st.values[key][:80] + "..."
if 'content' in st.values:
for i in range(len(st.values['content'])):
st.values['content'][i] = st.values['content'][i][:20] + '...'
s += str(st) + "\n\n"
return gr.update(label=f"Thread {self.thread_id} Snapshots", value=s)
# Finally, understand what is happening under the hood, launch the Gradio app
if __name__ == "__main__":
writer_instance = ewriter()
gui_instance = writer_gui(writer_instance.graph, share=True)
gui_instance.launch()