vonliechti commited on
Commit
be5ca66
Β·
verified Β·
1 Parent(s): 5f43612

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +46 -12
  2. prompts/__init__.py +1 -0
  3. utils.py +38 -26
app.py CHANGED
@@ -7,13 +7,26 @@ import pickle
7
  import os
8
  from dotenv import load_dotenv
9
  from agent import get_agent, DEFAULT_TASK_SOLVING_TOOLBOX
 
 
 
 
 
10
  from tools.text_to_image import TextToImageTool
 
 
 
 
11
 
12
  load_dotenv()
13
 
 
 
14
  sessions_path = "sessions.pkl"
15
  sessions = (
16
- pickle.load(open(sessions_path, "rb")) if os.path.exists(sessions_path) else {}
 
 
17
  )
18
 
19
  # If currently hosted on HuggingFace Spaces, use the default model, otherwise use the local model
@@ -23,10 +36,24 @@ model_name = (
23
  else "http://localhost:1234/v1"
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
26
  # Add image tools to the default task solving toolbox, for a more visually interactive experience
27
- TASK_SOLVING_TOOLBOX = DEFAULT_TASK_SOLVING_TOOLBOX + [TextToImageTool()]
28
 
29
- agent = get_agent(model_name=model_name, toolbox=TASK_SOLVING_TOOLBOX)
 
 
 
 
30
 
31
  app = None
32
 
@@ -54,10 +81,14 @@ def interact_with_agent(messages, request: Request):
54
  session_hash = request.session_hash
55
  prompt = messages[-1]["content"]
56
  agent.logs = sessions.get(session_hash + "_logs", [])
 
57
  for msg in stream_from_transformers_agent(agent, prompt):
58
- messages.append(msg)
59
- yield messages
60
- yield messages
 
 
 
61
 
62
 
63
  def persist(component):
@@ -74,16 +105,18 @@ def persist(component):
74
  print(f"Updating persisted session state for {session_hash}")
75
  sessions[session_hash] = value
76
  sessions[session_hash + "_logs"] = agent.logs
77
- pickle.dump(sessions, open(sessions_path, "wb"))
78
- return
79
 
80
  Context.root_block.load(resume_session, inputs=[component], outputs=component)
81
- component.change(update_session, inputs=[component], outputs=[])
82
 
83
  return component
84
 
85
 
86
- with gr.Blocks(fill_height=True) as demo:
 
 
87
  chatbot = persist(
88
  gr.Chatbot(
89
  value=[],
@@ -99,6 +132,7 @@ with gr.Blocks(fill_height=True) as demo:
99
  show_copy_button=True,
100
  placeholder="""<h1>SQuAD Agent</h1>
101
  <h2>I am your friendly guide to the Stanford Question and Answer Dataset (SQuAD).</h2>
 
102
  """,
103
  examples=[
104
  {
@@ -115,10 +149,10 @@ with gr.Blocks(fill_height=True) as demo:
115
  )
116
  text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
117
  chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
118
- bot_msg = chat_msg.then(interact_with_agent, [chatbot], [chatbot])
119
  text_input.submit(lambda: "", None, text_input)
120
  chatbot.example_select(append_example_message, [chatbot], [chatbot]).then(
121
- interact_with_agent, [chatbot], [chatbot]
122
  )
123
 
124
  if __name__ == "__main__":
 
7
  import os
8
  from dotenv import load_dotenv
9
  from agent import get_agent, DEFAULT_TASK_SOLVING_TOOLBOX
10
+ from transformers.agents import (
11
+ DuckDuckGoSearchTool,
12
+ ImageQuestionAnsweringTool,
13
+ VisitWebpageTool,
14
+ )
15
  from tools.text_to_image import TextToImageTool
16
+ from transformers import load_tool
17
+ from prompts import DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT
18
+ from pygments.formatters import HtmlFormatter
19
+
20
 
21
  load_dotenv()
22
 
23
+ SESSION_PERSISTENCE_ENABLED = os.getenv("SESSION_PERSISTENCE_ENABLED", False)
24
+
25
  sessions_path = "sessions.pkl"
26
  sessions = (
27
+ pickle.load(open(sessions_path, "rb"))
28
+ if SESSION_PERSISTENCE_ENABLED and os.path.exists(sessions_path)
29
+ else {}
30
  )
31
 
32
  # If currently hosted on HuggingFace Spaces, use the default model, otherwise use the local model
 
36
  else "http://localhost:1234/v1"
37
  )
38
 
39
+ ADDITIONAL_TOOLS = [
40
+ DuckDuckGoSearchTool(),
41
+ VisitWebpageTool(),
42
+ ImageQuestionAnsweringTool(),
43
+ load_tool("speech_to_text"),
44
+ load_tool("text_to_speech"),
45
+ load_tool("translation"),
46
+ TextToImageTool(),
47
+ ]
48
+
49
  # Add image tools to the default task solving toolbox, for a more visually interactive experience
50
+ TASK_SOLVING_TOOLBOX = DEFAULT_TASK_SOLVING_TOOLBOX + ADDITIONAL_TOOLS
51
 
52
+ system_prompt = DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT
53
+
54
+ agent = get_agent(
55
+ model_name=model_name, toolbox=TASK_SOLVING_TOOLBOX, system_prompt=system_prompt, use_openai=False
56
+ )
57
 
58
  app = None
59
 
 
81
  session_hash = request.session_hash
82
  prompt = messages[-1]["content"]
83
  agent.logs = sessions.get(session_hash + "_logs", [])
84
+ yield messages, gr.update(value = "<center><h1>Thinking...</h1></center>", visible = True)
85
  for msg in stream_from_transformers_agent(agent, prompt):
86
+ if isinstance(msg, ChatMessage):
87
+ messages.append(msg)
88
+ yield messages, gr.update(visible = True)
89
+ else:
90
+ yield messages, gr.update(value = f"<center><h1>{msg}</h1></center>", visible = True)
91
+ yield messages, gr.update(value = "<center><h1>Idle</h1></center>", visible = False)
92
 
93
 
94
  def persist(component):
 
105
  print(f"Updating persisted session state for {session_hash}")
106
  sessions[session_hash] = value
107
  sessions[session_hash + "_logs"] = agent.logs
108
+ if SESSION_PERSISTENCE_ENABLED:
109
+ pickle.dump(sessions, open(sessions_path, "wb"))
110
 
111
  Context.root_block.load(resume_session, inputs=[component], outputs=component)
112
+ component.change(update_session, inputs=[component], outputs=None)
113
 
114
  return component
115
 
116
 
117
+ with gr.Blocks(fill_height=True, css=".gradio-container .message .content {text-align: left;}" + HtmlFormatter().get_style_defs('.highlight')) as demo:
118
+ state = gr.State()
119
+ inner_monologue_component = gr.Markdown("""<h2>Inner Monologue</h2>""", visible = False)
120
  chatbot = persist(
121
  gr.Chatbot(
122
  value=[],
 
132
  show_copy_button=True,
133
  placeholder="""<h1>SQuAD Agent</h1>
134
  <h2>I am your friendly guide to the Stanford Question and Answer Dataset (SQuAD).</h2>
135
+ <h2>You can ask me questions about the dataset, or you can ask me to generate images based on your prompts.</h2>
136
  """,
137
  examples=[
138
  {
 
149
  )
150
  text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
151
  chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
152
+ bot_msg = chat_msg.then(interact_with_agent, [chatbot], [chatbot, inner_monologue_component])
153
  text_input.submit(lambda: "", None, text_input)
154
  chatbot.example_select(append_example_message, [chatbot], [chatbot]).then(
155
+ interact_with_agent, [chatbot], [chatbot, inner_monologue_component]
156
  )
157
 
158
  if __name__ == "__main__":
prompts/__init__.py CHANGED
@@ -22,4 +22,5 @@ def load_constants(constants_dir):
22
  PROMPTS = load_constants("prompts")
23
 
24
  # Import all prompts locally as well, for code completion
 
25
  from prompts.default import DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT
 
22
  PROMPTS = load_constants("prompts")
23
 
24
  # Import all prompts locally as well, for code completion
25
+ from transformers.agents.prompts import DEFAULT_REACT_CODE_SYSTEM_PROMPT
26
  from prompts.default import DEFAULT_SQUAD_REACT_CODE_SYSTEM_PROMPT
utils.py CHANGED
@@ -1,53 +1,65 @@
1
  from __future__ import annotations
2
 
 
3
  from gradio import ChatMessage
4
  from transformers.agents import ReactCodeAgent, agent_types
5
  from typing import Generator
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def pull_message(step_log: dict):
8
  if step_log.get("rationale"):
9
- yield ChatMessage(
10
- role="assistant",
11
- metadata={"title": "🧠 Rationale"},
12
- content=step_log["rationale"]
13
- )
14
  if step_log.get("tool_call"):
15
  used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
16
  content = step_log["tool_call"]["tool_arguments"]
17
- if used_code:
18
- content = f"```py\n{content}\n```"
19
- yield ChatMessage(
20
- role="assistant",
21
- metadata={"title": f"πŸ› οΈ Used tool {step_log['tool_call']['tool_name']}"},
22
- content=content,
23
- )
24
  if step_log.get("observation"):
25
- yield ChatMessage(
26
- role="assistant",
27
- metadata={"title": "πŸ‘€ Observation"},
28
- content=f"```\n{step_log['observation']}\n```"
29
- )
30
  if step_log.get("error"):
31
- yield ChatMessage(
32
- role="assistant",
33
- metadata={"title": "πŸ’₯ Error"},
34
- content=str(step_log["error"]),
35
- )
36
 
37
  def stream_from_transformers_agent(
38
- agent: ReactCodeAgent, prompt: str,
39
  ) -> Generator[ChatMessage, None, ChatMessage | None]:
40
  """Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
41
 
42
  class Output:
43
  output: agent_types.AgentType | str = None
44
 
 
 
 
 
 
 
45
  step_log = None
46
  for step_log in agent.run(prompt, stream=True, reset=len(agent.logs) == 0): # Reset=False misbehaves if the agent has not yet been run
47
  if isinstance(step_log, dict):
48
- for message in pull_message(step_log):
49
- print("message", message)
50
- yield message
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  Output.output = step_log
53
  if isinstance(Output.output, agent_types.AgentText):
 
1
  from __future__ import annotations
2
 
3
+ import gradio as gr
4
  from gradio import ChatMessage
5
  from transformers.agents import ReactCodeAgent, agent_types
6
  from typing import Generator
7
+ from termcolor import colored
8
+ from pygments import highlight
9
+ from pygments.lexers import PythonLexer
10
+ from pygments.formatters import HtmlFormatter
11
+ from pygments.formatters import TerminalFormatter
12
+
13
+ def highlight_code_terminal(text):
14
+ return highlight(text, PythonLexer(), TerminalFormatter())
15
+
16
+ def highlight_code_html(code):
17
+ return highlight(code, PythonLexer(), HtmlFormatter())
18
+
19
 
20
  def pull_message(step_log: dict):
21
  if step_log.get("rationale"):
22
+ yield "🧠 Thinking...", f"{step_log["rationale"]}"
 
 
 
 
23
  if step_log.get("tool_call"):
24
  used_code = step_log["tool_call"]["tool_name"] == "code interpreter"
25
  content = step_log["tool_call"]["tool_arguments"]
26
+ yield f"πŸ› οΈ Using tool {step_log['tool_call']['tool_name']}...", content
 
 
 
 
 
 
27
  if step_log.get("observation"):
28
+ yield "πŸ‘€ Observing...", step_log['observation']
 
 
 
 
29
  if step_log.get("error"):
30
+ yield "πŸ’₯ Coping with an Error...", step_log['error'].message
 
 
 
 
31
 
32
  def stream_from_transformers_agent(
33
+ agent: ReactCodeAgent, prompt: str
34
  ) -> Generator[ChatMessage, None, ChatMessage | None]:
35
  """Runs an agent with the given prompt and streams the messages from the agent as ChatMessages."""
36
 
37
  class Output:
38
  output: agent_types.AgentType | str = None
39
 
40
+ inner_monologue = ChatMessage(
41
+ role="assistant",
42
+ metadata={"title": "🧠 Thinking..."},
43
+ content=""
44
+ )
45
+
46
  step_log = None
47
  for step_log in agent.run(prompt, stream=True, reset=len(agent.logs) == 0): # Reset=False misbehaves if the agent has not yet been run
48
  if isinstance(step_log, dict):
49
+ for title, message in pull_message(step_log):
50
+ terminal_message = message
51
+ if ("Using tool" in title) or ("Error" in title):
52
+ terminal_message = highlight_code_terminal(message)
53
+ message = highlight_code_html(message)
54
+ if "Observing" in title:
55
+ message = f"<pre>{message}</pre>"
56
+ print(colored("=== Inner Monologue Message:\n", "blue", attrs=["bold"]), f"{title}\n{terminal_message}")
57
+ inner_monologue.content += f"<h2>{title}</h2><p>{message}</p>"
58
+ yield title
59
+
60
+ if inner_monologue is not None:
61
+ inner_monologue.metadata = {"title": "Inner Monologue (click to expand)"}
62
+ yield inner_monologue
63
 
64
  Output.output = step_log
65
  if isinstance(Output.output, agent_types.AgentText):