ThinkFlow-llama / app.py
openfree's picture
Update app.py
f5c4480 verified
import re
import threading
import gradio as gr
import spaces
import transformers
from transformers import pipeline
# Loading model and tokenizer
model_name = "meta-llama/Llama-3.1-8B-Instruct"
if gr.NO_RELOAD:
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype="auto",
)
# Marker for detecting final answer
ANSWER_MARKER = "**Answer**"
# Sentences to start step-by-step reasoning
rethink_prepends = [
"Now, I need to understand the following ",
"In my opinion ",
"Let me verify if the following is correct ",
"Also, I should remember that ",
"Another point to note is ",
"And I also remember the following fact ",
"Now I think I understand sufficiently ",
]
# Prompt addition for generating final answer
final_answer_prompt = """
Based on my reasoning process so far, I will answer the original question in the language it was asked:
{question}
Here is the conclusion I've reasoned:
{reasoning_conclusion}
Based on the above reasoning, my final answer:
{ANSWER_MARKER}
"""
# Settings for displaying formulas
latex_delimiters = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
]
def reformat_math(text):
"""Modify MathJax delimiters to use Gradio syntax (Katex).
This is a temporary fix for displaying math formulas in Gradio. Currently,
I haven't found a way to make it work as expected with other latex_delimiters...
"""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
def user_input(message, history_original, history_thinking):
"""Add user input to history and clear input text box"""
return "", history_original + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
], history_thinking + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
]
def rebuild_messages(history: list):
"""Reconstruct messages from history for model use without intermediate thinking process"""
messages = []
for h in history:
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False):
messages.append(h)
elif (
isinstance(h, gr.ChatMessage)
and h.metadata.get("title", None) is None
and isinstance(h.content, str)
):
messages.append({"role": h.role, "content": h.content})
return messages
@spaces.GPU
def bot_original(
history: list,
max_num_tokens: int,
do_sample: bool,
temperature: float,
):
"""Make the original model answer questions (without reasoning process)"""
# For streaming tokens from thread later
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer, # pyright: ignore
skip_special_tokens=True,
skip_prompt=True,
)
# Prepare assistant message
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
)
)
# Messages to be displayed in current chat
messages = rebuild_messages(history[:-1]) # Excluding last empty message
# Original model answers directly without reasoning
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
yield history
@spaces.GPU
def bot_thinking(
history: list,
max_num_tokens: int,
final_num_tokens: int,
do_sample: bool,
temperature: float,
):
"""Make the model answer questions with reasoning process"""
# For streaming tokens from thread later
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer, # pyright: ignore
skip_special_tokens=True,
skip_prompt=True,
)
# For reinserting the question into reasoning if needed
question = history[-1]["content"]
# Prepare assistant message
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
metadata={"title": "🧠 Thinking...", "status": "pending"},
)
)
# Reasoning process to be displayed in current chat
messages = rebuild_messages(history)
# Variable to store the entire reasoning process
full_reasoning = ""
# Run reasoning steps
for i, prepend in enumerate(rethink_prepends):
if i > 0:
messages[-1]["content"] += "\n\n"
messages[-1]["content"] += prepend.format(question=question)
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
# Reconstruct history with new content
history[-1].content += prepend.format(question=question)
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
# Save the result of each reasoning step to full_reasoning
full_reasoning = history[-1].content
# Reasoning complete, now generate final answer
history[-1].metadata = {"title": "💭 Thought Process", "status": "done"}
# Extract conclusion part from reasoning process (approximately last 1-2 paragraphs)
reasoning_parts = full_reasoning.split("\n\n")
reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning
# Add final answer message
history.append(gr.ChatMessage(role="assistant", content=""))
# Construct message for final answer
final_messages = rebuild_messages(history[:-1]) # Excluding last empty message
final_prompt = final_answer_prompt.format(
question=question,
reasoning_conclusion=reasoning_conclusion,
ANSWER_MARKER=ANSWER_MARKER
)
final_messages[-1]["content"] += final_prompt
# Generate final answer
t = threading.Thread(
target=pipe,
args=(final_messages,),
kwargs=dict(
max_new_tokens=final_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
# Stream final answer
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
yield history
with gr.Blocks(fill_height=True, title="ThinkFlow") as demo:
# Title and description
gr.Markdown("# ThinkFlow")
gr.Markdown("### An LLM reasoning generation platform that automatically applies reasoning capabilities to LLM models without modification")
# Features and benefits section
with gr.Accordion("✨ Features & Benefits", open=True):
gr.Markdown("""
- **Enhanced Reasoning**: Transform any LLM into a step-by-step reasoning engine without model modifications
- **Transparency**: Visualize the model's thought process alongside direct answers
- **Improved Accuracy**: See how guided reasoning leads to more accurate solutions for complex problems
- **Educational Tool**: Perfect for teaching critical thinking and problem-solving approaches
- **Versatile Application**: Works with mathematical problems, logical puzzles, and complex questions
- **Side-by-Side Comparison**: Compare standard model responses with reasoning-enhanced outputs
""")
with gr.Row(scale=1):
with gr.Column(scale=2):
gr.Markdown("## Before (Original)")
chatbot_original = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
label="Original Model (No Reasoning)"
)
with gr.Column(scale=2):
gr.Markdown("## After (Thinking)")
chatbot_thinking = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
label="Model with Reasoning"
)
with gr.Row():
# Define msg textbox first
msg = gr.Textbox(
submit_btn=True,
label="",
show_label=False,
placeholder="Enter your question here.",
autofocus=True,
)
# Examples section - placed after msg variable definition
with gr.Accordion("EXAMPLES", open=False):
examples = gr.Examples(
examples=[
"[Source: MATH-500)] How many numbers among the first 100 positive integers are divisible by 3, 4, and 5?",
"[Source: MATH-500)] In the land of Ink, the money system is unique. 1 trinket equals 4 blinkets, and 3 blinkets equal 7 drinkits. What is the value of 56 drinkits in trinkets?",
"[Source: MATH-500)] The average age of Amy, Ben, and Chris is 6 years. Four years ago, Chris was the same age as Amy is now. Four years from now, Ben's age will be $\\frac{3}{5}$ of Amy's age at that time. How old is Chris now?",
"[Source: MATH-500)] A bag contains yellow and blue marbles. Currently, the ratio of blue marbles to yellow marbles is 4:3. After adding 5 blue marbles and removing 3 yellow marbles, the ratio becomes 7:3. How many blue marbles were in the bag before any were added?"
],
inputs=msg
)
with gr.Row():
with gr.Column():
gr.Markdown("""## Parameter Adjustment""")
num_tokens = gr.Slider(
50,
4000,
2000,
step=1,
label="Maximum tokens per reasoning step",
interactive=True,
)
final_num_tokens = gr.Slider(
50,
4000,
2000,
step=1,
label="Maximum tokens for final answer",
interactive=True,
)
do_sample = gr.Checkbox(True, label="Use sampling")
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="Temperature")
# Community link at the bottom
gr.Markdown("<p style='font-size: 12px;'>Community: <a href='https://discord.gg/openfreeai' target='_blank'>https://discord.gg/openfreeai</a></p>")
# When user submits a message, both bots respond simultaneously
msg.submit(
user_input,
[msg, chatbot_original, chatbot_thinking], # inputs
[msg, chatbot_original, chatbot_thinking], # outputs
).then(
bot_original,
[
chatbot_original,
num_tokens,
do_sample,
temperature,
],
chatbot_original, # save new history in outputs
).then(
bot_thinking,
[
chatbot_thinking,
num_tokens,
final_num_tokens,
do_sample,
temperature,
],
chatbot_thinking, # save new history in outputs
)
if __name__ == "__main__":
demo.queue().launch()