hebrew-math-tutor / src /streamlit_app.py
danf's picture
Update src/streamlit_app.py
41463fd verified
raw
history blame
11.2 kB
"""
Chat demo for local LLMs using Streamlit.
Run with:
```
streamlit run chat.py --server.address 0.0.0.0
```
"""
import logging
import os
import openai
import regex
import streamlit as st
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def convert_latex_brackets_to_dollars(text):
"""Convert LaTeX bracket notation to dollar notation for Streamlit."""
def replace_display_latex(match):
return f"\n<bdi> $$ {match.group(1).strip()} $$ </bdi>\n"
text = regex.sub(r"(?r)\\\[\s*([^\[\]]+?)\s*\\\]", replace_display_latex, text)
def replace_paren_latex(match):
return f" <bdi> $ {match.group(1).strip()} $ </bdi> "
text = regex.sub(r"(?r)\\\(\s*(.+?)\s*\\\)", replace_paren_latex, text)
return text
# Add RTL CSS styling for Hebrew support
st.markdown(
"""
<style>
/* RTL support for specific text elements - avoid global .stMarkdown RTL */
.stText, .stTextArea textarea, .stTextArea label, .stSelectbox select, .stSelectbox label, .stSelectbox div {
direction: rtl;
text-align: right;
}
/* Chat messages styling for RTL */
.stChatMessage {
direction: rtl;
text-align: right;
}
/* Title alignment - more specific selectors */
h1, .stTitle, [data-testid="stHeader"] h1 {
direction: rtl !important;
text-align: right !important;
}
/* Apply RTL only to text content, not math */
.stMarkdown p:not(:has(.MathJax)):not(:has(mjx-container)):not(:has(.katex)) {
direction: rtl;
text-align: right;
unicode-bidi: plaintext;
}
/* Code blocks should remain LTR */
.stMarkdown code, .stMarkdown pre {
direction: ltr !important;
text-align: left !important;
display: inline-block;
}
/* Details/summary styling for RTL */
details {
direction: rtl;
text-align: right;
}
/* Button alignment */
.stButton button {
direction: rtl;
}
/* Ensure LaTeX/Math rendering works normally - comprehensive selectors */
.MathJax,
.MathJax_Display,
mjx-container,
.katex,
.katex-display,
[data-testid="stMarkdownContainer"] .MathJax,
[data-testid="stMarkdownContainer"] .MathJax_Display,
[data-testid="stMarkdownContainer"] mjx-container,
[data-testid="stMarkdownContainer"] .katex,
[data-testid="stMarkdownContainer"] .katex-display,
.stMarkdown .MathJax,
.stMarkdown .MathJax_Display,
.stMarkdown mjx-container,
.stMarkdown .katex,
.stMarkdown .katex-display {
direction: ltr !important;
text-align: center !important;
unicode-bidi: normal !important;
}
/* Inline math should be LTR but inline */
mjx-container[display="false"],
.katex:not(.katex-display),
.MathJax:not(.MathJax_Display) {
direction: ltr !important;
text-align: left !important;
display: inline !important;
unicode-bidi: normal !important;
}
/* Block/display math should be centered */
mjx-container[display="true"],
.katex-display,
.MathJax_Display {
direction: ltr !important;
text-align: center !important;
display: block !important;
margin: 1em auto !important;
unicode-bidi: normal !important;
}
/* For custom RTL wrappers */
.rtl-text {
direction: rtl;
text-align: right;
unicode-bidi: plaintext;
}
</style>
""",
unsafe_allow_html=True,
)
@st.cache_resource
def openai_configured():
return {
"model": os.getenv("MY_MODEL", "Intel/hebrew-math-tutor-v1"),
"api_base": os.getenv("AWS_URL", "http://localhost:8111/v1"),
"api_key": os.getenv("MY_KEY"),
}
config = openai_configured()
@st.cache_resource
def get_client():
return openai.OpenAI(api_key=config["api_key"], base_url=config["api_base"])
client = get_client()
st.title("מתמטיבוט 🧮")
st.markdown("""
ברוכים הבאים לדמו! 💡 כאן תוכלו להתרשם **ממודל השפה החדש** שלנו; מודל בגודל 4 מיליארד פרמטרים שאומן לענות על שאלות מתמטיות בעברית, על המחשב שלכם, ללא חיבור לרשת.
קישור למודל, פרטים נוספים, יצירת קשר ותנאי שימוש:
https://huggingface.co/Intel/hebrew-math-tutor-v1
-----
""")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Predefined options
predefined_options = [
"שאלה חדשה...",
" מהו סכום הסדרה הבאה: 1 + 1/2 + 1/4 + 1/8 + ...",
"פתח את הביטוי: (a-b)^4",
"פתרו את המשוואה הבאה: sin(2x) = 0.5",
]
# Dropdown for predefined options
selected_option = st.selectbox("בחרו שאלה מוכנה או צרו שאלה חדשה:", predefined_options)
# Text area for input
if selected_option == "שאלה חדשה...":
user_input = st.text_area(
"שאלה:", height=100, key="user_input", placeholder="הזינו את השאלה כאן..."
)
else:
user_input = st.text_area("שאלה:", height=100, key="user_input", value=selected_option)
# Add reset button next to Send button
col1, col2 = st.columns([8, 4])
with col2:
send_clicked = st.button("שלח", type="primary", use_container_width=True) and user_input.strip()
with col1:
if st.button("שיחה חדשה", type="secondary", use_container_width=True):
st.session_state.chat_history = []
st.rerun()
if send_clicked:
st.session_state.chat_history.append(("user", user_input))
# Create a placeholder for streaming output
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
# System prompt - not visible in UI but guides the model
system_prompt = """\
You are a helpful AI assistant specialized in mathematics and problem-solving who can answer math questions with the correct answer.
Answer shortly, not more than 500 tokens, but outline the process step by step.
Answer ONLY in Hebrew!
"""
# Create messages in proper chat format
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input},
]
# Build a single string prompt for OpenAI-compatible chat API
# Keep the special thinking tokens (<think>...</think>) if the remote model supports them
prompt_messages = messages
# Stream from OpenAI-compatible API (vllm remote exposing openai-compatible endpoint)
# Use the chat completions streaming interface
in_thinking = True
thinking_content = "<think>"
final_answer = ""
try:
# openai.ChatCompletion.create with stream=True yields chunks with 'choices'
stream = client.chat.completions.create(
messages=prompt_messages,
model=config["model"],
temperature=0.6,
max_tokens=2000,
top_p=0.95,
stream=True,
extra_body={"top_k": 20},
)
for chunk in stream:
# Each chunk is a dict; text delta at chunk['choices'][0]['delta'] for newer APIs
delta = ""
try:
# compatible with OpenAI response structure
delta = chunk.choices[0].delta.content
except Exception:
# fallback for older/other shapes
delta = chunk.get("text", "HI ")
if not delta:
continue
full_response += delta
# Handle thinking markers
if "<think>" in delta:
in_thinking = True
if in_thinking:
thinking_content += delta
if "</think>" in delta:
in_thinking = False
thinking_text = (
thinking_content.replace("<think>", "").replace("</think>", "").strip()
)
display_content = f"""
<details dir="rtl" style="text-align: right;">
<summary>🤔 <em>לחץ כדי לראות את תהליך החשיבה</em></summary>
<div style="white-space: pre-wrap; margin: 10px 0; direction: rtl; text-align: right;">
{thinking_text}
</div>
</details>
"""
message_placeholder.markdown(display_content + "▌", unsafe_allow_html=True)
else:
dots = "." * ((len(thinking_content) // 10) % 6)
thinking_indicator = f"""
<div dir="rtl" style="padding: 10px; background-color: #f0f2f6; border-radius: 10px; border-right: 4px solid #1f77b4; text-align: right;">
<p style="margin: 0; color: #1f77b4; font-style: italic;">
🤔 חושב{dots}
</p>
</div>
"""
message_placeholder.markdown(thinking_indicator, unsafe_allow_html=True)
else:
# Final answer streaming
final_answer += delta
converted_answer = convert_latex_brackets_to_dollars(final_answer)
message_placeholder.markdown(
"🤔 *תהליך החשיבה הושלם, מכין תשובה...*\n\n**📝 תשובה סופית:**\n\n"
+ converted_answer
+ "▌",
unsafe_allow_html=True,
)
except Exception as e:
# Show an error to the user
message_placeholder.markdown(f"**Error contacting remote model:** {e}")
# Final rendering: if there was thinking content include it
if thinking_content and "</think>" in thinking_content:
thinking_text = thinking_content.replace("<think>", "").replace("</think>", "").strip()
message_placeholder.empty()
with message_placeholder.container():
thinking_html = f"""
<details dir="rtl" style="text-align: right;">
<summary>🤔 <em>לחץ כדי לראות את תהליך החשיבה</em></summary>
<div style="white-space: pre-wrap; margin: 10px 0; direction: rtl; text-align: right;">
{thinking_text}
</div>
</details>
"""
st.markdown(thinking_html, unsafe_allow_html=True)
st.markdown(
'<div dir="rtl" style="text-align: right; margin: 10px 0;"><strong>📝 תשובה סופית:</strong></div>',
unsafe_allow_html=True,
)
converted_answer = convert_latex_brackets_to_dollars(final_answer or full_response)
st.markdown(converted_answer, unsafe_allow_html=True)
else:
converted_response = convert_latex_brackets_to_dollars(final_answer or full_response)
message_placeholder.markdown(converted_response, unsafe_allow_html=True)
st.session_state.chat_history.append(("assistant", final_answer or full_response))