Upload folder using huggingface_hub
Browse files- app.py +46 -12
- prompts/__init__.py +1 -0
- utils.py +38 -26
app.py
CHANGED
@@ -7,13 +7,26 @@ import pickle
|
|
7 |
import os
|
8 |
from dotenv import load_dotenv
|
9 |
from agent import get_agent, DEFAULT_TASK_SOLVING_TOOLBOX
|
|
|
|
|
|
|
|
|
|
|
10 |
from tools.text_to_image import TextToImageTool
|
|
|
|
|
|
|
|
|
11 |
|
12 |
load_dotenv()
|
13 |
|
|
|
|
|
14 |
sessions_path = "sessions.pkl"
|
15 |
sessions = (
|
16 |
-
pickle.load(open(sessions_path, "rb"))
|
|
|
|
|
17 |
)
|
18 |
|
19 |
# If currently hosted on HuggingFace Spaces, use the default model, otherwise use the local model
|
@@ -23,10 +36,24 @@ model_name = (
|
|
23 |
else "http://localhost:1234/v1"
|
24 |
)
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# Add image tools to the default task solving toolbox, for a more visually interactive experience
|
27 |
-
TASK_SOLVING_TOOLBOX = DEFAULT_TASK_SOLVING_TOOLBOX +
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
|
31 |
app = None
|
32 |
|
@@ -54,10 +81,14 @@ def interact_with_agent(messages, request: Request):
|
|
54 |
session_hash = request.session_hash
|
55 |
prompt = messages[-1]["content"]
|
56 |
agent.logs = sessions.get(session_hash + "_logs", [])
|
|
|
57 |
for msg in stream_from_transformers_agent(agent, prompt):
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
def persist(component):
|
@@ -74,16 +105,18 @@ def persist(component):
|
|
74 |
print(f"Updating persisted session state for {session_hash}")
|
75 |
sessions[session_hash] = value
|
76 |
sessions[session_hash + "_logs"] = agent.logs
|
77 |
-
|
78 |
-
|
79 |
|
80 |
Context.root_block.load(resume_session, inputs=[component], outputs=component)
|
81 |
-
component.change(update_session, inputs=[component], outputs=
|
82 |
|
83 |
return component
|
84 |
|
85 |
|
86 |
-
with gr.Blocks(fill_height=True) as demo:
|
|
|
|
|
87 |
chatbot = persist(
|
88 |
gr.Chatbot(
|
89 |
value=[],
|
@@ -99,6 +132,7 @@ with gr.Blocks(fill_height=True) as demo:
|
|
99 |
show_copy_button=True,
|
100 |
placeholder="""<h1>SQuAD Agent</h1>
|
101 |
<h2>I am your friendly guide to the Stanford Question and Answer Dataset (SQuAD).</h2>
|
|
|
102 |
""",
|
103 |
examples=[
|
104 |
{
|
@@ -115,10 +149,10 @@ with gr.Blocks(fill_height=True) as demo:
|
|
115 |
)
|
116 |
text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
|
117 |
chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
|
118 |
-
bot_msg = chat_msg.then(interact_with_agent, [chatbot], [chatbot])
|
119 |
text_input.submit(lambda: "", None, text_input)
|
120 |
chatbot.example_select(append_example_message, [chatbot], [chatbot]).then(
|
121 |
-
interact_with_agent, [chatbot], [chatbot]
|
122 |
)
|
123 |
|
124 |
if __name__ == "__main__":
|
|
|
7 |
import os
|
8 |
from dotenv import load_dotenv
|
9 |
from agent import get_agent, DEFAULT_TASK_SOLVING_TOOLBOX
|
10 |
+
from transformers.agents import (
|
11 |
+
DuckDuckGoSearchTool,
|
12 |
+
ImageQuestionAnsweringTool,
|
13 |
+
VisitWebpageTool,
|
14 |
+
)
|
15 |
from tools.text_to_image import TextToImageTool
|
16 |
+
from transformers import load_tool
|
17 |
+
from prompts import DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT
|
18 |
+
from pygments.formatters import HtmlFormatter
|
19 |
+
|
20 |
|
21 |
load_dotenv()
|
22 |
|
23 |
+
SESSION_PERSISTENCE_ENABLED = os.getenv("SESSION_PERSISTENCE_ENABLED", False)
|
24 |
+
|
25 |
sessions_path = "sessions.pkl"
|
26 |
sessions = (
|
27 |
+
pickle.load(open(sessions_path, "rb"))
|
28 |
+
if SESSION_PERSISTENCE_ENABLED and os.path.exists(sessions_path)
|
29 |
+
else {}
|
30 |
)
|
31 |
|
32 |
# If currently hosted on HuggingFace Spaces, use the default model, otherwise use the local model
|
|
|
36 |
else "http://localhost:1234/v1"
|
37 |
)
|
38 |
|
39 |
+
ADDITIONAL_TOOLS = [
|
40 |
+
DuckDuckGoSearchTool(),
|
41 |
+
VisitWebpageTool(),
|
42 |
+
ImageQuestionAnsweringTool(),
|
43 |
+
load_tool("speech_to_text"),
|
44 |
+
load_tool("text_to_speech"),
|
45 |
+
load_tool("translation"),
|
46 |
+
TextToImageTool(),
|
47 |
+
]
|
48 |
+
|
49 |
# Add image tools to the default task solving toolbox, for a more visually interactive experience
|
50 |
+
TASK_SOLVING_TOOLBOX = DEFAULT_TASK_SOLVING_TOOLBOX + ADDITIONAL_TOOLS
|
51 |
|
52 |
+
system_prompt = DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT
|
53 |
+
|
54 |
+
agent = get_agent(
|
55 |
+
model_name=model_name, toolbox=TASK_SOLVING_TOOLBOX, system_prompt=system_prompt, use_openai=False
|
56 |
+
)
|
57 |
|
58 |
app = None
|
59 |
|
|
|
81 |
session_hash = request.session_hash
|
82 |
prompt = messages[-1]["content"]
|
83 |
agent.logs = sessions.get(session_hash + "_logs", [])
|
84 |
+
yield messages, gr.update(value = "<center><h1>Thinking...</h1></center>", visible = True)
|
85 |
for msg in stream_from_transformers_agent(agent, prompt):
|
86 |
+
if isinstance(msg, ChatMessage):
|
87 |
+
messages.append(msg)
|
88 |
+
yield messages, gr.update(visible = True)
|
89 |
+
else:
|
90 |
+
yield messages, gr.update(value = f"<center><h1>{msg}</h1></center>", visible = True)
|
91 |
+
yield messages, gr.update(value = "<center><h1>Idle</h1></center>", visible = False)
|
92 |
|
93 |
|
94 |
def persist(component):
|
|
|
105 |
print(f"Updating persisted session state for {session_hash}")
|
106 |
sessions[session_hash] = value
|
107 |
sessions[session_hash + "_logs"] = agent.logs
|
108 |
+
if SESSION_PERSISTENCE_ENABLED:
|
109 |
+
pickle.dump(sessions, open(sessions_path, "wb"))
|
110 |
|
111 |
Context.root_block.load(resume_session, inputs=[component], outputs=component)
|
112 |
+
component.change(update_session, inputs=[component], outputs=None)
|
113 |
|
114 |
return component
|
115 |
|
116 |
|
117 |
+
with gr.Blocks(fill_height=True, css=".gradio-container .message .content {text-align: left;}" + HtmlFormatter().get_style_defs('.highlight')) as demo:
|
118 |
+
state = gr.State()
|
119 |
+
inner_monologue_component = gr.Markdown("""<h2>Inner Monologue</h2>""", visible = False)
|
120 |
chatbot = persist(
|
121 |
gr.Chatbot(
|
122 |
value=[],
|
|
|
132 |
show_copy_button=True,
|
133 |
placeholder="""<h1>SQuAD Agent</h1>
|
134 |
<h2>I am your friendly guide to the Stanford Question and Answer Dataset (SQuAD).</h2>
|
135 |
+
<h2>You can ask me questions about the dataset, or you can ask me to generate images based on your prompts.</h2>
|
136 |
""",
|
137 |
examples=[
|
138 |
{
|
|
|
149 |
)
|
150 |
text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
|
151 |
chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
|
152 |
+
bot_msg = chat_msg.then(interact_with_agent, [chatbot], [chatbot, inner_monologue_component])
|
153 |
text_input.submit(lambda: "", None, text_input)
|
154 |
chatbot.example_select(append_example_message, [chatbot], [chatbot]).then(
|
155 |
+
interact_with_agent, [chatbot], [chatbot, inner_monologue_component]
|
156 |
)
|
157 |
|
158 |
if __name__ == "__main__":
|
prompts/__init__.py
CHANGED
@@ -22,4 +22,5 @@ def load_constants(constants_dir):
|
|
22 |
PROMPTS = load_constants("prompts")
|
23 |
|
24 |
# Import all prompts locally as well, for code completion
|
|
|
25 |
from prompts.default import DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT
|
|
|
22 |
PROMPTS = load_constants("prompts")
|
23 |
|
24 |
# Import all prompts locally as well, for code completion
|
25 |
+
from transformers.agents.prompts import DEFAULT_REACT_CODE_SYSTEM_PROMPT
|
26 |
from prompts.default import DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT
|
utils.py
CHANGED
@@ -1,53 +1,65 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
|
|
3 |
from gradio import ChatMessage
|
4 |
from transformers.agents import ReactCodeAgent, agent_types
|
5 |
from typing import Generator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def pull_message(step_log: dict):
|
8 |
if step_log.get("rationale"):
|
9 |
-
yield
|
10 |
-
role="assistant",
|
11 |
-
metadata={"title": "π§ Rationale"},
|
12 |
-
content=step_log["rationale"]
|
13 |
-
)
|
14 |
if step_log.get("tool_call"):
|
15 |
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
|
16 |
content = step_log["tool_call"]["tool_arguments"]
|
17 |
-
|
18 |
-
content = f"```py\n{content}\n```"
|
19 |
-
yield ChatMessage(
|
20 |
-
role="assistant",
|
21 |
-
metadata={"title": f"π οΈ Used tool {step_log['tool_call']['tool_name']}"},
|
22 |
-
content=content,
|
23 |
-
)
|
24 |
if step_log.get("observation"):
|
25 |
-
yield
|
26 |
-
role="assistant",
|
27 |
-
metadata={"title": "π Observation"},
|
28 |
-
content=f"```\n{step_log['observation']}\n```"
|
29 |
-
)
|
30 |
if step_log.get("error"):
|
31 |
-
yield
|
32 |
-
role="assistant",
|
33 |
-
metadata={"title": "π₯ Error"},
|
34 |
-
content=str(step_log["error"]),
|
35 |
-
)
|
36 |
|
37 |
def stream_from_transformers_agent(
|
38 |
-
agent: ReactCodeAgent, prompt: str
|
39 |
) -> Generator[ChatMessage, None, ChatMessage | None]:
|
40 |
"""Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
|
41 |
|
42 |
class Output:
|
43 |
output: agent_types.AgentType | str = None
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
step_log = None
|
46 |
for step_log in agent.run(prompt, stream=True, reset=len(agent.logs) == 0): # Reset=False misbehaves if the agent has not yet been run
|
47 |
if isinstance(step_log, dict):
|
48 |
-
for message in pull_message(step_log):
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
Output.output = step_log
|
53 |
if isinstance(Output.output, agent_types.AgentText):
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
+
import gradio as gr
|
4 |
from gradio import ChatMessage
|
5 |
from transformers.agents import ReactCodeAgent, agent_types
|
6 |
from typing import Generator
|
7 |
+
from termcolor import colored
|
8 |
+
from pygments import highlight
|
9 |
+
from pygments.lexers import PythonLexer
|
10 |
+
from pygments.formatters import HtmlFormatter
|
11 |
+
from pygments.formatters import TerminalFormatter
|
12 |
+
|
13 |
+
def highlight_code_terminal(text):
|
14 |
+
return highlight(text, PythonLexer(), TerminalFormatter())
|
15 |
+
|
16 |
+
def highlight_code_html(code):
|
17 |
+
return highlight(code, PythonLexer(), HtmlFormatter())
|
18 |
+
|
19 |
|
20 |
def pull_message(step_log: dict):
|
21 |
if step_log.get("rationale"):
|
22 |
+
yield "π§ Thinking...", f"{step_log["rationale"]}"
|
|
|
|
|
|
|
|
|
23 |
if step_log.get("tool_call"):
|
24 |
used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
|
25 |
content = step_log["tool_call"]["tool_arguments"]
|
26 |
+
yield f"π οΈ Using tool {step_log['tool_call']['tool_name']}...", content
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
if step_log.get("observation"):
|
28 |
+
yield "π Observing...", step_log['observation']
|
|
|
|
|
|
|
|
|
29 |
if step_log.get("error"):
|
30 |
+
yield "π₯ Coping with an Error...", step_log['error'].message
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def stream_from_transformers_agent(
|
33 |
+
agent: ReactCodeAgent, prompt: str
|
34 |
) -> Generator[ChatMessage, None, ChatMessage | None]:
|
35 |
"""Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
|
36 |
|
37 |
class Output:
|
38 |
output: agent_types.AgentType | str = None
|
39 |
|
40 |
+
inner_monologue = ChatMessage(
|
41 |
+
role="assistant",
|
42 |
+
metadata={"title": "π§ Thinking..."},
|
43 |
+
content=""
|
44 |
+
)
|
45 |
+
|
46 |
step_log = None
|
47 |
for step_log in agent.run(prompt, stream=True, reset=len(agent.logs) == 0): # Reset=False misbehaves if the agent has not yet been run
|
48 |
if isinstance(step_log, dict):
|
49 |
+
for title, message in pull_message(step_log):
|
50 |
+
terminal_message = message
|
51 |
+
if ("Using tool" in title) or ("Error" in title):
|
52 |
+
terminal_message = highlight_code_terminal(message)
|
53 |
+
message = highlight_code_html(message)
|
54 |
+
if "Observing" in title:
|
55 |
+
message = f"<pre>{message}</pre>"
|
56 |
+
print(colored("=== Inner Monologue Message:\n", "blue", attrs=["bold"]), f"{title}\n{terminal_message}")
|
57 |
+
inner_monologue.content += f"<h2>{title}</h2><p>{message}</p>"
|
58 |
+
yield title
|
59 |
+
|
60 |
+
if inner_monologue is not None:
|
61 |
+
inner_monologue.metadata = {"title": "Inner Monologue (click to expand)"}
|
62 |
+
yield inner_monologue
|
63 |
|
64 |
Output.output = step_log
|
65 |
if isinstance(Output.output, agent_types.AgentText):
|