vonliechti commited on
Commit
7b4ec76
·
verified ·
1 Parent(s): a3a2da3

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +42 -12
.gitignore CHANGED
@@ -4,6 +4,7 @@
4
  # Data
5
  chroma_db/
6
  data/
 
7
 
8
  # Byte-compiled / optimized / DLL files
9
  __pycache__/
 
4
  # Data
5
  chroma_db/
6
  data/
7
+ sessions.pkl
8
 
9
  # Byte-compiled / optimized / DLL files
10
  __pycache__/
app.py CHANGED
@@ -7,6 +7,8 @@ from tools.squad_tools import SquadRetrieverTool, SquadQueryTool
7
  from tools.text_to_image import TextToImageTool
8
  from gradio.context import Context
9
  from gradio import Request
 
 
10
  from dotenv import load_dotenv
11
 
12
  load_dotenv()
@@ -17,6 +19,11 @@ TASK_SOLVING_TOOLBOX = [
17
  TextToImageTool(),
18
  ]
19
 
 
 
 
 
 
20
  model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
21
  # model_name = "http://localhost:1234/v1"
22
 
@@ -46,11 +53,10 @@ def add_message(message, messages):
46
  messages.append(ChatMessage(role="user", content=message))
47
  return messages
48
 
49
- sessions = {}
50
-
51
  def interact_with_agent(messages, request: Request):
 
52
  prompt = messages[-1]['content']
53
- agent.logs = sessions.get(request.session_hash + "_logs", [])
54
  for msg in stream_from_transformers_agent(agent, prompt):
55
  messages.append(msg)
56
  yield messages
@@ -59,15 +65,18 @@ def interact_with_agent(messages, request: Request):
59
  def persist(component):
60
 
61
  def resume_session(value, request: Request):
62
- print(f"Resuming session for {request.session_hash}")
63
- state = sessions.get(request.session_hash, value)
64
- agent.logs = sessions.get(request.session_hash + "_logs", [])
 
65
  return state
66
 
67
  def update_session(value, request: Request):
68
- print(f"Updating persisted session state for {request.session_hash}")
69
- sessions[request.session_hash] = value
70
- sessions[request.session_hash + "_logs"] = agent.logs
 
 
71
  return
72
 
73
  Context.root_block.load(resume_session, inputs=[component], outputs=component)
@@ -75,7 +84,16 @@ def persist(component):
75
 
76
  return component
77
 
 
 
 
78
  with gr.Blocks(fill_height=True) as demo:
 
 
 
 
 
 
79
  chatbot = persist(gr.Chatbot(
80
  value=[],
81
  label="SQuAD Agent",
@@ -85,11 +103,12 @@ with gr.Blocks(fill_height=True) as demo:
85
  "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
86
  ),
87
  scale=1,
88
- # bubble_full_width=False,
89
  autoscroll=True,
90
  show_copy_all_button=True,
91
  show_copy_button=True,
92
- placeholder="Enter a message",
 
 
93
  examples=[
94
  {
95
  "text": "What is on top of the Notre Dame building?",
@@ -110,5 +129,16 @@ with gr.Blocks(fill_height=True) as demo:
110
  interact_with_agent, [chatbot], [chatbot]
111
  )
112
 
 
 
 
113
  if __name__ == "__main__":
114
- demo.launch()
 
 
 
 
 
 
 
 
 
7
  from tools.text_to_image import TextToImageTool
8
  from gradio.context import Context
9
  from gradio import Request
10
+ import pickle
11
+ import os
12
  from dotenv import load_dotenv
13
 
14
  load_dotenv()
 
19
  TextToImageTool(),
20
  ]
21
 
22
+ sessions_path = "sessions.pkl"
23
+ sessions = pickle.load(open(sessions_path, "rb")) if os.path.exists(sessions_path) else {}
24
+
25
+ app = None
26
+
27
  model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
28
  # model_name = "http://localhost:1234/v1"
29
 
 
53
  messages.append(ChatMessage(role="user", content=message))
54
  return messages
55
 
 
 
56
  def interact_with_agent(messages, request: Request):
57
+ username = request.username
58
  prompt = messages[-1]['content']
59
+ agent.logs = sessions.get(username + "_logs", [])
60
  for msg in stream_from_transformers_agent(agent, prompt):
61
  messages.append(msg)
62
  yield messages
 
65
  def persist(component):
66
 
67
  def resume_session(value, request: Request):
68
+ username = request.username
69
+ print(f"Resuming session for {username}")
70
+ state = sessions.get(username, value)
71
+ agent.logs = sessions.get(username + "_logs", [])
72
  return state
73
 
74
  def update_session(value, request: Request):
75
+ username = request.username
76
+ print(f"Updating persisted session state for {username}")
77
+ sessions[username] = value
78
+ sessions[username + "_logs"] = agent.logs
79
+ pickle.dump(sessions, open(sessions_path, "wb"))
80
  return
81
 
82
  Context.root_block.load(resume_session, inputs=[component], outputs=component)
 
84
 
85
  return component
86
 
87
+ def welcome_message(request: Request):
88
+ return f"<h2>Welcome, {request.username}</h2>"
89
+
90
  with gr.Blocks(fill_height=True) as demo:
91
+ # put the welcome message and logout button in a row
92
+ with gr.Row() as row:
93
+ welcome_msg = gr.Markdown(f"Welcome")
94
+ logout_button = gr.Button("Logout", link="/logout")
95
+ demo.load(welcome_message, None, welcome_msg)
96
+
97
  chatbot = persist(gr.Chatbot(
98
  value=[],
99
  label="SQuAD Agent",
 
103
  "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
104
  ),
105
  scale=1,
 
106
  autoscroll=True,
107
  show_copy_all_button=True,
108
  show_copy_button=True,
109
+ placeholder="""<h1>SQuAD Agent</h1>
110
+ <h2>I am your friendly guide to the Stanford Question and Answer Dataset (SQuAD).</h2>
111
+ """,
112
  examples=[
113
  {
114
  "text": "What is on top of the Notre Dame building?",
 
129
  interact_with_agent, [chatbot], [chatbot]
130
  )
131
 
132
+ def honor_system_auth(username, password):
133
+ return password == "happy"
134
+
135
  if __name__ == "__main__":
136
+ demo.launch(
137
+ auth=honor_system_auth,
138
+ auth_message="""<h3>Honor System Authentication:</h3>
139
+ <p>Log in with <strong>any username</strong> and the password <strong>"happy"</strong>.</p>
140
+ <br/><br/>
141
+ <i>Note:Your chat history will be saved to your username, but take care because others
142
+ may log in with the same username and see your chat history.</i>
143
+ """,
144
+ )