yrobel-lima commited on
Commit
1dc1222
Β·
verified Β·
1 Parent(s): 5c4df90

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -173
app.py CHANGED
@@ -1,173 +1,186 @@
1
- import openai
2
- import streamlit as st
3
- from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
4
- from langchain_core.tracers.context import collect_runs
5
- from langsmith import Client
6
- from streamlit_feedback import streamlit_feedback
7
-
8
- from rag_chain.chain import get_rag_chain
9
-
10
- # Langsmith client for the feedback system
11
- client = Client()
12
-
13
- # Streamlit page configuration
14
- st.set_page_config(page_title="Tall Tree Health",
15
- page_icon="πŸ’¬",
16
- layout="centered",
17
- initial_sidebar_state="expanded")
18
-
19
- # Streamlit CSS configuration
20
-
21
- with open("styles/styles.css") as css:
22
- st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
23
-
24
- # Error message template
25
- base_error_message = (
26
- "Something went wrong while processing your request. "
27
- "Please refresh the page or try again later.\n\n"
28
- "If the error persists, please contact us at "
29
- "[Tall Tree Health](https://www.talltreehealth.ca/contact-us)."
30
- )
31
-
32
- # Get chain and memory
33
-
34
-
35
- @st.cache_resource(ttl="5d", show_spinner=False)
36
- def get_chain_and_memory():
37
- try:
38
- # gpt-4 points to gpt-4-0613
39
- # gpt-4-turbo-preview points to gpt-4-0125-preview
40
- # Fine-tuned: ft:gpt-3.5-turbo-1106:tall-tree::8mAkOSED
41
- # gpt-4-1106-preview
42
- return get_rag_chain(model_name="gpt-4-turbo", temperature=0.2)
43
-
44
- except Exception as e:
45
- st.warning(base_error_message, icon="πŸ™")
46
- st.stop()
47
-
48
-
49
- chain, memory = get_chain_and_memory()
50
-
51
- # Set up session state and clean memory (important to clean the memory at the end of each session)
52
- if "history" not in st.session_state:
53
- st.session_state["history"] = []
54
- memory.clear()
55
-
56
- if "messages" not in st.session_state:
57
- st.session_state["messages"] = []
58
-
59
- # Select locations element into a container
60
- with st.container(border=False):
61
- # Set the welcome message
62
- st.markdown(
63
- "\n\nHello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI-powered assistant give you a hand.\n\n"
64
- "To get started, please select your preferred location and share details about your symptoms or needs. "
65
- )
66
- location = st.radio(
67
- "**Our Locations**:",
68
- ["Cordova Bay - Victoria", "James Bay - Victoria",
69
- "Commercial Drive - Vancouver"],
70
- index=None, horizontal=False,
71
- )
72
- st.write("\n")
73
-
74
- # Get user input only if a location is selected
75
- prompt = ""
76
- if location:
77
- user_input = st.chat_input("Enter your message...")
78
- if user_input:
79
- st.session_state["messages"].append(
80
- ChatMessage(role="user", content=user_input))
81
- prompt = f"{user_input}\nLocation: {location}"
82
-
83
-
84
- # Display previous messages
85
-
86
- user_avatar = "images/user.png"
87
- ai_avatar = "images/tall-tree-logo.png"
88
- for msg in st.session_state["messages"]:
89
- avatar = user_avatar if msg.role == 'user' else ai_avatar
90
- with st.chat_message(msg.role, avatar=avatar):
91
- st.markdown(msg.content)
92
-
93
- # Chat interface
94
- if prompt:
95
-
96
- # Add all previous messages to memory
97
- for human, ai in st.session_state["history"]:
98
- memory.chat_memory.add_user_message(HumanMessage(content=human))
99
- memory.chat_memory.add_ai_message(AIMessage(content=ai))
100
-
101
- # render the assistant's response
102
- with st.chat_message("assistant", avatar=ai_avatar):
103
- message_placeholder = st.empty()
104
-
105
- try:
106
- partial_message = ""
107
- # Collect runs for feedback using Langsmith
108
- with st.spinner(" "), collect_runs() as cb:
109
- for chunk in chain.stream({"message": prompt}):
110
- partial_message += chunk
111
- message_placeholder.markdown(partial_message + "|")
112
- st.session_state.run_id = cb.traced_runs[0].id
113
- message_placeholder.markdown(partial_message)
114
- except openai.BadRequestError:
115
- st.warning(base_error_message, icon="πŸ™")
116
- st.stop()
117
- except Exception as e:
118
- st.warning(base_error_message, icon="πŸ™")
119
- st.stop()
120
-
121
- # Add the full response to the history
122
- st.session_state["history"].append((prompt, partial_message))
123
-
124
- # Add AI message to memory after the response is generated
125
- memory.chat_memory.add_ai_message(AIMessage(content=partial_message))
126
-
127
- # Add the full response to the message history
128
- st.session_state["messages"].append(ChatMessage(
129
- role="assistant", content=partial_message))
130
-
131
-
132
- # Feedback system using streamlit feedback and Langsmith
133
-
134
- # Get the feedback option
135
- feedback_option = "thumbs"
136
-
137
- if st.session_state.get("run_id"):
138
- run_id = st.session_state.run_id
139
- feedback = streamlit_feedback(
140
- feedback_type=feedback_option,
141
- optional_text_label="[Optional] Please provide an explanation",
142
- key=f"feedback_{run_id}",
143
- )
144
- score_mappings = {
145
- "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
146
- "faces": {"πŸ˜€": 1, "οΏ½οΏ½οΏ½": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
147
- }
148
-
149
- # Get the score mapping based on the selected feedback option
150
- scores = score_mappings[feedback_option]
151
-
152
- if feedback:
153
- # Get the score from the selected feedback option's score mapping
154
- score = scores.get(feedback["score"])
155
-
156
- if score is not None:
157
- # Formulate feedback type string incorporating the feedback option
158
- # and score value
159
- feedback_type_str = f"{feedback_option} {feedback['score']}"
160
-
161
- # Record the feedback with the formulated feedback type string
162
- feedback_record = client.create_feedback(
163
- run_id,
164
- feedback_type_str,
165
- score=score,
166
- comment=feedback.get("text"),
167
- )
168
- st.session_state.feedback = {
169
- "feedback_id": str(feedback_record.id),
170
- "score": score,
171
- }
172
- else:
173
- st.warning("Invalid feedback score.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import streamlit as st
3
+ from langchain_core.messages import AIMessage, ChatMessage, HumanMessage
4
+ from langchain_core.tracers.context import collect_runs
5
+ from langsmith import Client
6
+ from streamlit_feedback import streamlit_feedback
7
+
8
+ from rag.runnable import get_runnable
9
+ from utils.error_message_template import ERROR_MESSAGE
10
+
11
+ # Streamlit page configuration
12
+ st.set_page_config(
13
+ page_title="ELLA AI Assistant",
14
+ page_icon="πŸ’¬",
15
+ layout="centered",
16
+ initial_sidebar_state="collapsed",
17
+ )
18
+
19
+ # Streamlit CSS configuration
20
+ with open("styles/styles.css") as css:
21
+ st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
22
+
23
+
24
+ # Get runnable and memory
25
+ @st.cache_resource(show_spinner=False)
26
+ def get_runnable_and_memory():
27
+ try:
28
+ return get_runnable(model="gpt-4o", temperature=0)
29
+ except Exception:
30
+ st.warning(ERROR_MESSAGE, icon="πŸ™")
31
+ st.stop()
32
+
33
+
34
+ chain, memory = get_runnable_and_memory()
35
+
36
+
37
+ # Set up session state variables
38
+ # Clean memory (important! to clean the memory at the end of each session)
39
+ if "history" not in st.session_state:
40
+ st.session_state["history"] = []
41
+ memory.clear()
42
+
43
+ if "messages" not in st.session_state:
44
+ st.session_state["messages"] = []
45
+
46
+ if "selected_location" not in st.session_state:
47
+ st.session_state["selected_location"] = None
48
+
49
+ if "disable_chat_input" not in st.session_state:
50
+ st.session_state["disable_chat_input"] = True
51
+
52
+
53
+ # Welcome message and Selectbox for location preferences
54
+ def welcome_message():
55
+ st.markdown(
56
+ "Hello there! πŸ‘‹ Need help finding the right service or practitioner? Let our AI assistant give you a hand.\n\n"
57
+ "To get started, please select your preferred location and share details about your symptoms or needs. "
58
+ )
59
+
60
+
61
+ def on_change_location():
62
+ st.session_state["disable_chat_input"] = (
63
+ False if st.session_state["selected_location"] else True
64
+ )
65
+
66
+
67
+ with st.container():
68
+ welcome_message()
69
+ location = st.radio(
70
+ "**Our Locations**:",
71
+ (
72
+ "Cordova Bay - Victoria",
73
+ "James Bay - Victoria",
74
+ "Commercial Drive - Vancouver",
75
+ ),
76
+ index=None,
77
+ label_visibility="visible",
78
+ key="selected_location",
79
+ on_change=on_change_location,
80
+ )
81
+ st.markdown("<br>", unsafe_allow_html=True)
82
+
83
+ # Get user input only if a location is selected
84
+ user_input = st.chat_input(
85
+ "Ask ELLA...", disabled=st.session_state["disable_chat_input"]
86
+ )
87
+
88
+ if user_input:
89
+ st.session_state["messages"].append(ChatMessage(role="user", content=user_input))
90
+ prompt = f"{user_input}\nLocation preference: {st.session_state.selected_location}."
91
+
92
+ else:
93
+ prompt = None
94
+
95
+ # Display previous messages
96
+ user_avatar = "images/user.png"
97
+ ai_avatar = "images/tall-tree-logo.png"
98
+ for msg in st.session_state["messages"]:
99
+ avatar = user_avatar if msg.role == "user" else ai_avatar
100
+ with st.chat_message(msg.role, avatar=avatar):
101
+ st.markdown(msg.content)
102
+
103
+ # Chat interface
104
+ if prompt:
105
+ # Add all previous messages to memory
106
+ for human, ai in st.session_state["history"]:
107
+ memory.chat_memory.add_user_message(HumanMessage(content=human))
108
+ memory.chat_memory.add_ai_message(AIMessage(content=ai))
109
+
110
+ # render the assistant's response
111
+ with st.chat_message("assistant", avatar=ai_avatar):
112
+ message_placeholder = st.empty()
113
+
114
+ try:
115
+ partial_message = ""
116
+ # Collect runs for feedback using Langsmith.
117
+ with st.spinner(" "), collect_runs() as cb:
118
+ for chunk in chain.stream({"message": prompt}):
119
+ partial_message += chunk
120
+ message_placeholder.markdown(partial_message + "|")
121
+ st.session_state.run_id = cb.traced_runs[0].id
122
+ message_placeholder.markdown(partial_message)
123
+ except openai.BadRequestError:
124
+ st.warning(ERROR_MESSAGE, icon="πŸ™")
125
+ st.stop()
126
+ except Exception:
127
+ st.warning(ERROR_MESSAGE, icon="πŸ™")
128
+ st.stop()
129
+
130
+ # Add the full response to the history
131
+ st.session_state["history"].append((prompt, partial_message))
132
+
133
+ # Add AI message to memory after the response is generated
134
+ memory.chat_memory.add_ai_message(AIMessage(content=partial_message))
135
+
136
+ # Add the full response to the message history
137
+ st.session_state["messages"].append(
138
+ ChatMessage(role="assistant", content=partial_message)
139
+ )
140
+
141
+
142
+ # Feedback system using streamlit-feedback and Langsmith
143
+
144
+ # Langsmith client for the feedback system
145
+ ls_client = Client()
146
+
147
+ # Feedback option
148
+ feedback_option = "thumbs"
149
+
150
+ if st.session_state.get("run_id"):
151
+ run_id = st.session_state.run_id
152
+ feedback = streamlit_feedback(
153
+ feedback_type=feedback_option,
154
+ optional_text_label="[Optional] Please provide an explanation",
155
+ key=f"feedback_{run_id}",
156
+ )
157
+ score_mappings = {
158
+ "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
159
+ "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
160
+ }
161
+
162
+ # Get the score mapping based on the selected feedback option
163
+ scores = score_mappings[feedback_option]
164
+
165
+ if feedback:
166
+ # Get the score from the selected feedback option's score mapping
167
+ score = scores.get(feedback["score"])
168
+
169
+ if score is not None:
170
+ # Formulate feedback type string incorporating the feedback option
171
+ # and score value
172
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
173
+
174
+ # Record the feedback with the formulated feedback type string
175
+ feedback_record = ls_client.create_feedback(
176
+ run_id,
177
+ feedback_type_str,
178
+ score=score,
179
+ comment=feedback.get("text"),
180
+ )
181
+ st.session_state.feedback = {
182
+ "feedback_id": str(feedback_record.id),
183
+ "score": score,
184
+ }
185
+ else:
186
+ st.warning("Invalid feedback score.")