danf commited on
Commit
3af7433
·
verified ·
1 Parent(s): c6afa77
Files changed (1) hide show
  1. src/streamlit_app.py +323 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,330 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
4
  import streamlit as st
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
- # Welcome to Streamlit!
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
  """
 
 
 
 
 
 
 
 
 
 
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ """
2
+ Chat demo for local LLMs using Streamlit.
3
+
4
+
5
+ Run with:
6
+ ```
7
+ streamlit run chat.py --server.address 0.0.0.0
8
+ ```
9
+ """
10
+
11
+ import logging
12
+ import os
13
+
14
+ import openai
15
+ import regex
16
  import streamlit as st
17
 
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def convert_latex_brackets_to_dollars(text):
23
+ """Convert LaTeX bracket notation to dollar notation for Streamlit."""
24
+
25
+ def replace_display_latex(match):
26
+ return f"\n<bdi> $$ {match.group(1).strip()} $$ </bdi>\n"
27
+
28
+ text = regex.sub(r"(?r)\\\[\s*([^\[\]]+?)\s*\\\]", replace_display_latex, text)
29
+
30
+ def replace_paren_latex(match):
31
+ return f" <bdi> $ {match.group(1).strip()} $ </bdi> "
32
+
33
+ text = regex.sub(r"(?r)\\\(\s*(.+?)\s*\\\)", replace_paren_latex, text)
34
+
35
+ return text
36
+
37
+
38
+ # Add RTL CSS styling for Hebrew support
39
+ st.markdown(
40
+ """
41
+ <style>
42
+ /* RTL support for specific text elements - avoid global .stMarkdown RTL */
43
+ .stText, .stTextArea textarea, .stTextArea label, .stSelectbox select, .stSelectbox label, .stSelectbox div {
44
+ direction: rtl;
45
+ text-align: right;
46
+ }
47
+
48
+ /* Chat messages styling for RTL */
49
+ .stChatMessage {
50
+ direction: rtl;
51
+ text-align: right;
52
+ }
53
+
54
+ /* Title alignment - more specific selectors */
55
+ h1, .stTitle, [data-testid="stHeader"] h1 {
56
+ direction: rtl !important;
57
+ text-align: right !important;
58
+ }
59
+
60
+ /* Apply RTL only to text content, not math */
61
+ .stMarkdown p:not(:has(.MathJax)):not(:has(mjx-container)):not(:has(.katex)) {
62
+ direction: rtl;
63
+ text-align: right;
64
+ unicode-bidi: plaintext;
65
+ }
66
+
67
+ /* Code blocks should remain LTR */
68
+ .stMarkdown code, .stMarkdown pre {
69
+ direction: ltr !important;
70
+ text-align: left !important;
71
+ display: inline-block;
72
+ }
73
+
74
+ /* Details/summary styling for RTL */
75
+ details {
76
+ direction: rtl;
77
+ text-align: right;
78
+ }
79
+
80
+ /* Button alignment */
81
+ .stButton button {
82
+ direction: rtl;
83
+ }
84
+
85
+ /* Ensure LaTeX/Math rendering works normally - comprehensive selectors */
86
+ .MathJax,
87
+ .MathJax_Display,
88
+ mjx-container,
89
+ .katex,
90
+ .katex-display,
91
+ [data-testid="stMarkdownContainer"] .MathJax,
92
+ [data-testid="stMarkdownContainer"] .MathJax_Display,
93
+ [data-testid="stMarkdownContainer"] mjx-container,
94
+ [data-testid="stMarkdownContainer"] .katex,
95
+ [data-testid="stMarkdownContainer"] .katex-display,
96
+ .stMarkdown .MathJax,
97
+ .stMarkdown .MathJax_Display,
98
+ .stMarkdown mjx-container,
99
+ .stMarkdown .katex,
100
+ .stMarkdown .katex-display {
101
+ direction: ltr !important;
102
+ text-align: center !important;
103
+ unicode-bidi: normal !important;
104
+ }
105
+
106
+ /* Inline math should be LTR but inline */
107
+ mjx-container[display="false"],
108
+ .katex:not(.katex-display),
109
+ .MathJax:not(.MathJax_Display) {
110
+ direction: ltr !important;
111
+ text-align: left !important;
112
+ display: inline !important;
113
+ unicode-bidi: normal !important;
114
+ }
115
+
116
+ /* Block/display math should be centered */
117
+ mjx-container[display="true"],
118
+ .katex-display,
119
+ .MathJax_Display {
120
+ direction: ltr !important;
121
+ text-align: center !important;
122
+ display: block !important;
123
+ margin: 1em auto !important;
124
+ unicode-bidi: normal !important;
125
+ }
126
+
127
+ /* For custom RTL wrappers */
128
+ .rtl-text {
129
+ direction: rtl;
130
+ text-align: right;
131
+ unicode-bidi: plaintext;
132
+ }
133
+ </style>
134
+ """,
135
+ unsafe_allow_html=True,
136
+ )
137
+
138
+
139
+ @st.cache_resource
140
+ def openai_configured():
141
+ return {
142
+ "model": os.getenv("MY_MODEL", "Intel/hebrew-math-tutor-v1"),
143
+ "api_base": os.getenv("AWS_URL", "http://localhost:8111/v1"),
144
+ "api_key": os.getenv("MY_KEY"),
145
+ }
146
+
147
+
148
+ config = openai_configured()
149
+
150
+
151
+ @st.cache_resource
152
+ def get_client():
153
+ return openai.OpenAI(api_key=config["api_key"], base_url=config["api_base"])
154
+
155
+
156
+ client = get_client()
157
+
158
+ st.title("מתמטיבוט 🧮")
159
+
160
+ st.markdown("""
161
+
162
+ ברוכים הבאים לדמו! 💡 כאן תוכלו להתרשם **ממודל השפה החדש** שלנו; מודל בגודל 4 מיליארד פרמטרים שאומן לענות על שאלות מתמטיות בעברית, על המחשב שלכם, ללא חיבור לרשת.
163
+
164
+ קישור למודל, פרטים נוספים, יצירת קשר ותנאי שימוש:
165
+
166
+ https://huggingface.co/Intel/hebrew-math-tutor-v1
167
+
168
+ -----
169
+ """)
170
+
171
+ if "chat_history" not in st.session_state:
172
+ st.session_state.chat_history = []
173
+
174
+ # Predefined options
175
+ predefined_options = [
176
+ "שאלה חדשה...",
177
+ " מהו סכום הסדרה הבאה: 1 + 1/2 + 1/4 + 1/8 + ...",
178
+ "פתח את הביטוי: (a-b)^4",
179
+ "פתרו את המשוואה הבאה: sin(2x) = 0.5",
180
+ ]
181
+
182
+ # Dropdown for predefined options
183
+ selected_option = st.selectbox("בחרו שאלה מוכנה או צרו שאלה חדשה:", predefined_options)
184
+
185
+ # Text area for input
186
+ if selected_option == "שאלה חדשה...":
187
+ user_input = st.text_area(
188
+ "שאלה:", height=100, key="user_input", placeholder="הזינו את השאלה כאן..."
189
+ )
190
+ else:
191
+ user_input = st.text_area("שאלה:", height=100, key="user_input", value=selected_option)
192
+
193
+ # Add reset button next to Send button
194
+ col1, col2 = st.columns([8, 4])
195
+ with col2:
196
+ send_clicked = st.button("שלח", type="primary", use_container_width=True) and user_input.strip()
197
+ with col1:
198
+ if st.button("נקה שיחה", type="secondary", use_container_width=True):
199
+ st.session_state.chat_history = []
200
+ st.rerun()
201
+
202
+ if send_clicked:
203
+ st.session_state.chat_history.append(("user", user_input))
204
+
205
+ # Create a placeholder for streaming output
206
+ with st.chat_message("assistant"):
207
+ message_placeholder = st.empty()
208
+ full_response = ""
209
+
210
+ # System prompt - not visible in UI but guides the model
211
+ system_prompt = """\
212
+ You are a helpful AI assistant specialized in mathematics and problem-solving who can answer math questions with the correct answer.
213
+ Answer shortly, not more than 500 tokens, but outline the process step by step.
214
+ Answer ONLY in Hebrew!
215
+ """
216
+
217
+ # Create messages in proper chat format
218
+ messages = [
219
+ {"role": "system", "content": system_prompt},
220
+ {"role": "user", "content": user_input},
221
+ ]
222
+
223
+ # Build a single string prompt for OpenAI-compatible chat API
224
+ # Keep the special thinking tokens (<think>...</think>) if the remote model supports them
225
+ prompt_messages = messages
226
+
227
+ # Stream from OpenAI-compatible API (vllm remote exposing openai-compatible endpoint)
228
+ # Use the chat completions streaming interface
229
+ in_thinking = True
230
+ thinking_content = "<think>"
231
+ final_answer = ""
232
+
233
+ try:
234
+ # openai.ChatCompletion.create with stream=True yields chunks with 'choices'
235
+ stream = client.chat.completions.create(
236
+ messages=prompt_messages,
237
+ model=config["model"],
238
+ temperature=0.6,
239
+ max_tokens=2000,
240
+ top_p=0.95,
241
+ stream=True,
242
+ extra_body={"top_k": 20},
243
+ )
244
+
245
+ for chunk in stream:
246
+ # Each chunk is a dict; text delta at chunk['choices'][0]['delta'] for newer APIs
247
+ delta = ""
248
+ try:
249
+ # compatible with OpenAI response structure
250
+ delta = chunk.choices[0].delta.content
251
+ except Exception:
252
+ # fallback for older/other shapes
253
+ delta = chunk.get("text", "HI ")
254
+
255
+ if not delta:
256
+ continue
257
+
258
+ full_response += delta
259
+
260
+ # Handle thinking markers
261
+ if "<think>" in delta:
262
+ in_thinking = True
263
+
264
+ if in_thinking:
265
+ thinking_content += delta
266
+ if "</think>" in delta:
267
+ in_thinking = False
268
+ thinking_text = (
269
+ thinking_content.replace("<think>", "").replace("</think>", "").strip()
270
+ )
271
+ display_content = f"""
272
+ <details dir="rtl" style="text-align: right;">
273
+ <summary>🤔 <em>לחץ כדי לראות את תהליך החשיבה</em></summary>
274
+ <div style="white-space: pre-wrap; margin: 10px 0; direction: rtl; text-align: right;">
275
+ {thinking_text}
276
+ </div>
277
+ </details>
278
+
279
+ """
280
+ message_placeholder.markdown(display_content + "▌", unsafe_allow_html=True)
281
+ else:
282
+ dots = "." * ((len(thinking_content) // 10) % 6)
283
+ thinking_indicator = f"""
284
+ <div dir="rtl" style="padding: 10px; background-color: #f0f2f6; border-radius: 10px; border-right: 4px solid #1f77b4; text-align: right;">
285
+ <p style="margin: 0; color: #1f77b4; font-style: italic;">
286
+ 🤔 חושב{dots}
287
+ </p>
288
+ </div>
289
  """
290
+ message_placeholder.markdown(thinking_indicator, unsafe_allow_html=True)
291
+ else:
292
+ # Final answer streaming
293
+ final_answer += delta
294
+ converted_answer = convert_latex_brackets_to_dollars(final_answer)
295
+ message_placeholder.markdown(
296
+ "🤔 *תהליך החשיבה הושלם, מכין תשובה...*\n\n**📝 תשובה סופית:**\n\n"
297
+ + converted_answer
298
+ + "▌",
299
+ unsafe_allow_html=True,
300
+ )
301
+ except Exception as e:
302
+ # Show an error to the user
303
+ message_placeholder.markdown(f"**Error contacting remote model:** {e}")
304
 
305
+ # Final rendering: if there was thinking content include it
306
+ if thinking_content and "</think>" in thinking_content:
307
+ thinking_text = thinking_content.replace("<think>", "").replace("</think>", "").strip()
308
+ message_placeholder.empty()
309
+ with message_placeholder.container():
310
+ thinking_html = f"""
311
+ <details dir="rtl" style="text-align: right;">
312
+ <summary>🤔 <em>לחץ כדי לראות את תהליך החשיבה</em></summary>
313
+ <div style="white-space: pre-wrap; margin: 10px 0; direction: rtl; text-align: right;">
314
+ {thinking_text}
315
+ </div>
316
+ </details>
317
 
 
318
  """
319
+ st.markdown(thinking_html, unsafe_allow_html=True)
320
+ st.markdown(
321
+ '<div dir="rtl" style="text-align: right; margin: 10px 0;"><strong>📝 תשובה סופית:</strong></div>',
322
+ unsafe_allow_html=True,
323
+ )
324
+ converted_answer = convert_latex_brackets_to_dollars(final_answer or full_response)
325
+ st.markdown(converted_answer, unsafe_allow_html=True)
326
+ else:
327
+ converted_response = convert_latex_brackets_to_dollars(final_answer or full_response)
328
+ message_placeholder.markdown(converted_response, unsafe_allow_html=True)
329
 
330
+ st.session_state.chat_history.append(("assistant", final_answer or full_response))