Spaces:
Running
Running
import os | |
import sqlite3 | |
import operator | |
import time | |
from typing import TypedDict, Annotated, List | |
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.sqlite import SqliteSaver | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_openai import ChatOpenAI | |
from langchain_core.pydantic_v1 import BaseModel | |
from dotenv import load_dotenv | |
_ = load_dotenv() | |
class AgentState(TypedDict): | |
task: str | |
lnode: str | |
plan: str | |
draft: str | |
critique: str | |
content: List[str] | |
queries: List[str] | |
revision_number: int | |
max_revisions: int | |
count: Annotated[int, operator.add] | |
class Queries(BaseModel): | |
queries: List[str] | |
class ewriter(): | |
def __init__(self): | |
self.model = ChatOpenAI(model="gpt-4o-mini", temperature=0.5) | |
self.PLAN_PROMPT = ( | |
"You are a visionary startup strategist tasked with outlining comprehensive business plans " | |
"for emerging companies. Identify the unique selling points (USPs) and core competencies " | |
"of startups in various sectors such as tech (e.g., innovative tech stacks), healthcare " | |
"(e.g., compliance with medical regulations), fintech (e.g., secure payment systems), " | |
"and e-commerce (e.g., efficient supply chain management). Provide detailed strategies" | |
) | |
self.WRITER_PROMPT = ( | |
"You are a seasoned startup consultant responsible for crafting compelling pitches " | |
"that highlight what sets each company apart. Ensure clarity and persuasive storytelling" | |
) | |
self.RESEARCH_PLAN_PROMPT = ( | |
"You are a cutting-edge market research expert specializing in identifying trends " | |
"and analyzing competitors within specific industries. Offer actionable insights based on" | |
"the latest market data regarding regulatory environments, technological innovations," | |
) | |
self.REFLECTION_PROMPT = ( | |
"You are a top-tier investor and mentor providing strategic advice on fundraising," | |
"team building, operational efficiency tailored to each industry's challenges." | |
) | |
self.RESEARCH_CRITIQUE_PROMPT = ( | |
"You are an expert research analyst tasked with evaluating the feasibility of new business ideas," | |
"assessing potential risks/rewards based on industry-specific factors like regulatory hurdles," | |
) | |
try: | |
from tavily import TavilyClient | |
self.tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"]) | |
except ImportError: | |
class DummyTavilyClient: | |
def __init__(self, api_key): | |
self.api_key = api_key | |
def search(self, query: str, max_results: int = 2): | |
return {"results": [{"content": f"Dummy result for '{query}'"}]} | |
self.tavily = DummyTavilyClient(api_key="dummy") | |
builder = StateGraph(AgentState) | |
builder.add_node("strategy", self.strategy_node) | |
builder.add_node("market_research", self.market_research_node) | |
builder.add_node("refinement", self.refinement_node) | |
builder.add_node("review", self.review_node) | |
builder.add_node("followup_research", self.followup_research_node) | |
builder.set_entry_point("strategy") | |
builder.add_conditional_edges( | |
"refinement", | |
self.should_continue, | |
{END: END, "review": "review"} | |
) | |
builder.add_edge("strategy", "market_research") | |
builder.add_edge("market_research", "refinement") | |
builder.add_edge("review", "followup_research") | |
builder.add_edge("followup_research", "refinement") | |
memory = SqliteSaver(conn=sqlite3.connect(":memory:", check_same_thread=False)) | |
self.graph = builder.compile( | |
checkpointer=memory, | |
interrupt_after=[ | |
'strategy', | |
'refinement', | |
'review', | |
'market_research', | |
'followup_research' | |
] | |
) | |
def strategy_node(self, state: AgentState): | |
messages = [ | |
SystemMessage(content=self.PLAN_PROMPT), | |
HumanMessage(content=state['task']) | |
] | |
response = self.model.invoke(messages) | |
return {"plan": response.content, "lnode": "strategy", "count": 1} | |
def market_research_node(self, state: AgentState): | |
queries = self.model.with_structured_output(Queries).invoke([ | |
SystemMessage(content=self.RESEARCH_PLAN_PROMPT), | |
HumanMessage(content=state['task']) | |
]) | |
content = state['content'] or [] | |
for q in queries.queries: | |
resp = self.tavily.search(query=q, max_results=2) | |
for r in resp['results']: | |
content.append(r['content']) | |
return {"content": content, "queries": queries.queries, "lnode": "market_research", "count": 1} | |
def refinement_node(self, state: AgentState): | |
content = "\n\n".join(state['content'] or []) | |
user_message = HumanMessage(content=f"{state['task']}\n\nHere is the strategic outline:\n\n{state['plan']}") | |
messages = [ | |
SystemMessage(content=self.WRITER_PROMPT.format(content=content)), | |
user_message | |
] | |
response = self.model.invoke(messages) | |
return { | |
"draft": response.content, | |
"revision_number": state.get("revision_number", 1) + 1, | |
"lnode": "refinement", | |
"count": 1 | |
} | |
def review_node(self, state: AgentState): | |
messages = [ | |
SystemMessage(content=self.REFLECTION_PROMPT), | |
HumanMessage(content=state['draft']) | |
] | |
response = self.model.invoke(messages) | |
return {"critique": response.content, "lnode": "review", "count": 1} | |
def followup_research_node(self, state: AgentState): | |
queries = self.model.with_structured_output(Queries).invoke([ | |
SystemMessage(content=self.RESEARCH_CRITIQUE_PROMPT), | |
HumanMessage(content=state['critique']) | |
]) | |
content = state['content'] or [] | |
for q in queries.queries: | |
resp = self.tavily.search(query=q, max_results=2) | |
for r in resp['results']: | |
content.append(r['content']) | |
return {"content": content, "lnode": "followup_research", "count": 1} | |
def should_continue(self, state): | |
if state["revision_number"] > state["max_revisions"]: | |
return END | |
return "review" | |
import gradio as gr | |
class writer_gui(): | |
def __init__(self, graph, share=False): | |
self.graph = graph | |
self.share = share | |
self.partial_message = "" | |
self.response = {} | |
self.max_iterations = 10 | |
# Initialize with a default thread so dropdowns are populated. | |
self.iterations = [0] | |
self.threads = [0] | |
self.thread_id = 0 | |
self.thread = {"configurable": {"thread_id": "0"}} | |
self.demo = self.create_interface() | |
def run_agent(self, start, topic, stop_after): | |
if start: | |
self.iterations.append(0) | |
config = { | |
'task': topic, | |
"max_revisions": 2, | |
"revision_number": 0, | |
'lnode': "", | |
'strategy': "no strategy", | |
'draft': "no plan", | |
'critique': "no review", | |
'content': ["no research content"], | |
'queries': "no queries", | |
'count': 0 | |
} | |
self.thread_id += 1 | |
self.threads.append(self.thread_id) | |
else: | |
config = None | |
self.thread = {"configurable": {"thread_id": str(self.thread_id)}} | |
while self.iterations[self.thread_id] < self.max_iterations: | |
self.response = self.graph.invoke(config, self.thread) | |
self.iterations[self.thread_id] += 1 | |
self.partial_message += str(self.response) | |
self.partial_message += "\n------------------\n\n" | |
lnode, nnode, _, rev, acount = self.get_disp_state() | |
yield (self.partial_message, lnode, nnode, self.thread_id, rev, acount) | |
config = None | |
if not nnode or lnode in stop_after: | |
return | |
def get_disp_state(self): | |
current_state = self.graph.get_state(self.thread) | |
lnode = current_state.values["lnode"] | |
acount = current_state.values["count"] | |
rev = current_state.values["revision_number"] | |
nnode = current_state.next | |
return lnode, nnode, self.thread_id, rev, acount | |
def get_state(self, key): | |
current_values = self.graph.get_state(self.thread) | |
if key in current_values.values: | |
lnode, nnode, _, rev, astep = self.get_disp_state() | |
new_label = f"last_node: {lnode}, thread: {self.thread_id}, rev: {rev}, step: {astep}" | |
return gr.update(label=new_label, value=current_values.values[key]) | |
else: | |
return "" | |
def get_content(self): | |
current_values = self.graph.get_state(self.thread) | |
if "content" in current_values.values: | |
content = current_values.values["content"] | |
lnode, nnode, _, rev, astep = self.get_disp_state() | |
new_label = f"last_node: {lnode}, thread: {self.thread_id}, rev: {rev}, step: {astep}" | |
return gr.update(label=new_label, value="\n\n".join(item for item in content) + "\n\n") | |
else: | |
return "" | |
def create_interface(self): | |
with gr.Blocks(theme=gr.themes.Default(spacing_size='sm', text_size="sm")) as demo: | |
def updt_disp(): | |
current_state = self.graph.get_state(self.thread) | |
hist = [] | |
for st in self.graph.get_state_history(self.thread): | |
if st.metadata['step'] < 1: | |
continue | |
s_thread_ts = st.config['configurable']['thread_ts'] | |
s_tid = st.config['configurable']['thread_id'] | |
s_count = st.values['count'] | |
s_lnode = st.values['lnode'] | |
s_rev = st.values['revision_number'] | |
s_nnode = st.next | |
hist.append(f"{s_tid}:{s_count}:{s_lnode}:{s_nnode}:{s_rev}:{s_thread_ts}") | |
if not current_state.metadata or not hist: | |
return {} | |
return { | |
topic_bx: gr.update(value=str(current_state.values["task"])), | |
lnode_bx: gr.update(value=str(current_state.values["lnode"])), | |
count_bx: gr.update(value=str(current_state.values["count"])), | |
revision_bx: gr.update(value=str(current_state.values["revision_number"])), | |
nnode_bx: gr.update(value=str(current_state.next or "")), | |
threadid_bx: gr.update(value=str(self.thread_id)), | |
thread_pd: gr.update( | |
label="choose thread", | |
choices=self.threads, | |
value=self.thread_id, | |
interactive=True | |
), | |
step_pd: gr.update( | |
label="update state from: thread:count:last_node:next_node:rev:thread_ts", | |
choices=hist, | |
value=hist[0], | |
interactive=True | |
), | |
} | |
with gr.Tab("Agent"): | |
with gr.Row(): | |
topic_bx = gr.Textbox(label="Startup Idea", value="Revolutionary AI-driven Healthtech Platform") | |
gen_btn = gr.Button("Generate Plan", scale=0, min_width=80, variant='primary') | |
cont_btn = gr.Button("Continue Refinement", scale=0, min_width=80) | |
with gr.Row(): | |
lnode_bx = gr.Textbox(label="Last Node", min_width=100) | |
nnode_bx = gr.Textbox(label="Next Node", min_width=100) | |
threadid_bx = gr.Textbox(label="Thread", scale=0, min_width=80) | |
revision_bx = gr.Textbox(label="Plan Rev", scale=0, min_width=80) | |
count_bx = gr.Textbox(label="Count", scale=0, min_width=80) | |
with gr.Accordion("Manage Agent", open=False): | |
checks = list(self.graph.nodes.keys()) | |
checks.remove('__start__') | |
stop_after = gr.CheckboxGroup( | |
checks, | |
label="Interrupt After State", | |
value=checks, | |
scale=0, | |
min_width=400 | |
) | |
with gr.Row(): | |
thread_pd = gr.Dropdown( | |
choices=self.threads, | |
interactive=True, | |
label="Select Thread", | |
min_width=120, | |
scale=0 | |
) | |
step_pd = gr.Dropdown( | |
choices=['N/A'], | |
interactive=True, | |
label="Select Step", | |
min_width=160, | |
scale=1 | |
) | |
live = gr.Textbox(label="Live Agent Output", lines=5, max_lines=5) | |
sdisps = [topic_bx, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx, step_pd, thread_pd] | |
thread_pd.change( | |
self.switch_thread, [thread_pd], None | |
).then( | |
fn=updt_disp, inputs=None, outputs=None | |
) | |
step_pd.change( | |
self.copy_state, [step_pd], None | |
).then( | |
fn=updt_disp, inputs=None, outputs=None | |
) | |
gen_btn.click( | |
fn=lambda stat: gr.update(variant=stat), | |
inputs=gr.Textbox(value="secondary", visible=False), | |
outputs=gen_btn | |
).then( | |
fn=self.run_agent, | |
inputs=[gr.State(value=True), topic_bx, stop_after], | |
outputs=[live, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx], | |
show_progress=True | |
).then( | |
fn=updt_disp, inputs=None, outputs=None | |
).then( | |
fn=lambda stat: gr.update(variant=stat), | |
inputs=gr.Textbox(value="primary", visible=False), | |
outputs=gen_btn | |
).then( | |
fn=lambda stat: gr.update(variant=stat), | |
inputs=gr.Textbox(value="primary", visible=False), | |
outputs=cont_btn | |
) | |
cont_btn.click( | |
fn=lambda stat: gr.update(variant=stat), | |
inputs=gr.Textbox(value="secondary", visible=False), | |
outputs=cont_btn | |
).then( | |
fn=self.run_agent, | |
inputs=[gr.State(value=False), topic_bx, stop_after], | |
outputs=[live, lnode_bx, nnode_bx, threadid_bx, revision_bx, count_bx] | |
).then( | |
fn=updt_disp, inputs=None, outputs=None | |
).then( | |
fn=lambda stat: gr.update(variant=stat), | |
inputs=gr.Textbox(value="primary", visible=False), | |
outputs=cont_btn | |
) | |
with gr.Tab("Strategic Outline"): | |
with gr.Row(): | |
refresh_btn = gr.Button("Refresh") | |
modify_btn = gr.Button("Modify") | |
plan = gr.Textbox(label="Strategic Outline", lines=10, interactive=True) | |
refresh_btn.click( | |
fn=self.get_state, | |
inputs=gr.Textbox(value="plan", visible=False), | |
outputs=plan | |
) | |
modify_btn.click( | |
fn=self.modify_state, | |
inputs=[ | |
gr.Textbox(value="plan", visible=False), | |
gr.Textbox(value="strategy", visible=False), | |
plan | |
], | |
outputs=None | |
).then( | |
fn=updt_disp, inputs=None, outputs=None | |
) | |
with gr.Tab("Research Content"): | |
refresh_btn = gr.Button("Refresh") | |
content_bx = gr.Textbox(label="Research Content", lines=10) | |
refresh_btn.click(fn=self.get_content, inputs=None, outputs=content_bx) | |
with gr.Tab("Startup Plan"): | |
with gr.Row(): | |
refresh_btn = gr.Button("Refresh") | |
modify_btn = gr.Button("Modify") | |
draft_bx = gr.Textbox(label="Startup Plan Draft", lines=10, interactive=True) | |
refresh_btn.click( | |
fn=self.get_state, | |
inputs=gr.Textbox(value="draft", visible=False), | |
outputs=draft_bx | |
) | |
modify_btn.click( | |
fn=self.modify_state, | |
inputs=[ | |
gr.Textbox(value="draft", visible=False), | |
gr.Textbox(value="refinement", visible=False), | |
draft_bx | |
], | |
outputs=None | |
).then( | |
fn=updt_disp, inputs=None, outputs=None | |
) | |
with gr.Tab("Review"): | |
with gr.Row(): | |
refresh_btn = gr.Button("Refresh") | |
modify_btn = gr.Button("Modify") | |
critique_bx = gr.Textbox(label="Review / Critique", lines=10, interactive=True) | |
refresh_btn.click( | |
fn=self.get_state, | |
inputs=gr.Textbox(value="critique", visible=False), | |
outputs=critique_bx | |
) | |
modify_btn.click( | |
fn=self.modify_state, | |
inputs=[ | |
gr.Textbox(value="critique", visible=False), | |
gr.Textbox(value="review", visible=False), | |
critique_bx | |
], | |
outputs=None | |
).then( | |
fn=updt_disp, inputs=None, outputs=None | |
) | |
with gr.Tab("State Snapshots"): | |
with gr.Row(): | |
refresh_btn = gr.Button("Refresh") | |
snapshots = gr.Textbox(label="State Snapshots", lines=10) | |
refresh_btn.click(fn=lambda: self.get_snapshots(), inputs=None, outputs=snapshots) | |
return demo | |
def launch(self, share=None): | |
if port := os.getenv("PORT1"): | |
self.demo.launch(share=True, server_port=int(port), server_name="0.0.0.0") | |
else: | |
self.demo.launch(share=self.share) | |
def copy_state(self, hist_str): | |
thread_ts = hist_str.split(":")[-1] | |
config = self.find_config(thread_ts) | |
state = self.graph.get_state(config) | |
self.graph.update_state(self.thread, state.values, as_node=state.values['lnode']) | |
new_state = self.graph.get_state(self.thread) | |
return ( | |
new_state.values['lnode'], | |
new_state.next, | |
new_state.config['configurable']['thread_ts'], | |
new_state.values['revision_number'], | |
new_state.values['count'] | |
) | |
def find_config(self, thread_ts): | |
for st in self.graph.get_state_history(self.thread): | |
cfg = st.config | |
if cfg['configurable']['thread_ts'] == thread_ts: | |
return cfg | |
return None | |
def switch_thread(self, new_thread_id): | |
self.thread = {"configurable": {"thread_id": str(new_thread_id)}} | |
self.thread_id = int(new_thread_id) | |
return | |
def modify_state(self, key, asnode, new_state): | |
current_values = self.graph.get_state(self.thread) | |
current_values.values[key] = new_state | |
self.graph.update_state(self.thread, current_values.values, as_node=asnode) | |
return | |
def get_snapshots(self): | |
s = "" | |
for st in self.graph.get_state_history(self.thread): | |
for key in ['plan', 'draft', 'critique']: | |
if key in st.values: | |
st.values[key] = st.values[key][:80] + "..." | |
if 'content' in st.values: | |
for i in range(len(st.values['content'])): | |
st.values['content'][i] = st.values['content'][i][:20] + '...' | |
s += str(st) + "\n\n" | |
return gr.update(label=f"Thread {self.thread_id} Snapshots", value=s) | |
# Finally, understand what is happening under the hood, launch the Gradio app | |
if __name__ == "__main__": | |
writer_instance = ewriter() | |
gui_instance = writer_gui(writer_instance.graph, share=True) | |
gui_instance.launch() | |