vonliechti commited on
Commit
8e43a1d
·
verified ·
1 Parent(s): 77dce35

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -5,6 +5,8 @@ from utils import stream_from_transformers_agent
5
  from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
6
  from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
7
  from tools.text_to_image import TextToImageTool
 
 
8
  from dotenv import load_dotenv
9
 
10
  load_dotenv()
@@ -44,15 +46,38 @@ def add_message(message, messages):
44
  messages.append(ChatMessage(role="user", content=message))
45
  return messages
46
 
47
- def interact_with_agent(messages):
 
 
48
  prompt = messages[-1]['content']
 
49
  for msg in stream_from_transformers_agent(agent, prompt):
50
  messages.append(msg)
51
  yield messages
52
  yield messages
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  with gr.Blocks(fill_height=True) as demo:
55
- chatbot = gr.Chatbot(
 
56
  label="SQuAD Agent",
57
  type="messages",
58
  avatar_images=(
@@ -60,7 +85,7 @@ with gr.Blocks(fill_height=True) as demo:
60
  "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
61
  ),
62
  scale=1,
63
- bubble_full_width=False,
64
  autoscroll=True,
65
  show_copy_all_button=True,
66
  show_copy_button=True,
@@ -76,7 +101,7 @@ with gr.Blocks(fill_height=True) as demo:
76
  "text": "Draw a picture of whatever is on top of the Notre Dame building.",
77
  },
78
  ],
79
- )
80
  text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
81
  chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
82
  bot_msg = chat_msg.then(interact_with_agent, [chatbot], [chatbot])
 
5
  from prompts import SQUAD_REACT_CODE_SYSTEM_PROMPT
6
  from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
7
  from tools.text_to_image import TextToImageTool
8
+ from gradio.context import Context
9
+ from gradio import Request
10
  from dotenv import load_dotenv
11
 
12
  load_dotenv()
 
46
  messages.append(ChatMessage(role="user", content=message))
47
  return messages
48
 
49
+ sessions = {}
50
+
51
+ def interact_with_agent(messages, request: Request):
52
  prompt = messages[-1]['content']
53
+ agent.logs = sessions.get(request.session_hash + "_logs", [])
54
  for msg in stream_from_transformers_agent(agent, prompt):
55
  messages.append(msg)
56
  yield messages
57
  yield messages
58
 
59
+ def persist(component):
60
+
61
+ def resume_session(value, request: Request):
62
+ print(f"Resuming session for {request.session_hash}")
63
+ state = sessions.get(request.session_hash, value)
64
+ agent.logs = sessions.get(request.session_hash + "_logs", [])
65
+ return state
66
+
67
+ def update_session(value, request: Request):
68
+ print(f"Updating persisted session state for {request.session_hash}")
69
+ sessions[request.session_hash] = value
70
+ sessions[request.session_hash + "_logs"] = agent.logs
71
+ return
72
+
73
+ Context.root_block.load(resume_session, inputs=[component], outputs=component)
74
+ component.change(update_session, inputs=[component], outputs=[])
75
+
76
+ return component
77
+
78
  with gr.Blocks(fill_height=True) as demo:
79
+ chatbot = persist(gr.Chatbot(
80
+ value=[],
81
  label="SQuAD Agent",
82
  type="messages",
83
  avatar_images=(
 
85
  "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
86
  ),
87
  scale=1,
88
+ # bubble_full_width=False,
89
  autoscroll=True,
90
  show_copy_all_button=True,
91
  show_copy_button=True,
 
101
  "text": "Draw a picture of whatever is on top of the Notre Dame building.",
102
  },
103
  ],
104
+ ))
105
  text_input = gr.Textbox(lines=1, label="Chat Message", scale=0)
106
  chat_msg = text_input.submit(add_message, [text_input, chatbot], [chatbot])
107
  bot_msg = chat_msg.then(interact_with_agent, [chatbot], [chatbot])