Boning c commited on
Commit
23339f6
·
verified ·
1 Parent(s): abccee4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -16,7 +16,7 @@ primary_model = primary_tokenizer = None
16
  fallback_model = fallback_tokenizer = None
17
  usage_info = {}
18
 
19
- # Load models
20
  def load_models():
21
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
22
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
@@ -25,31 +25,39 @@ def load_models():
25
  fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL, torch_dtype=torch.float16).to(device).eval()
26
  return f"✅ Loaded: {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
27
 
28
- # Format multi-turn history
29
  def build_chat_prompt(history, user_input, reasoning_enabled):
30
  prefix = "/think " if reasoning_enabled else "/no_think "
31
  prompt = ""
32
  for user_msg, bot_msg in history:
33
- prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{bot_msg}\n"
34
- prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
35
  return prefix + prompt
36
 
37
- # Collapse <think> block
38
  def format_thinking(text):
39
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
40
  if match:
41
  reasoning = escape(match.group(1).strip())
42
  visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
43
- return f"{escape(visible)}<br><details><summary>🧠 Show reasoning</summary><pre>{reasoning}</pre></details>"
 
 
 
 
 
44
  return escape(text)
45
 
46
- # Token stream generator
47
  def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9):
48
  model = fallback_model if use_fallback else primary_model
49
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
 
50
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
51
  generated = input_ids
52
- output = tokenizer.decode(input_ids[0])
 
 
53
 
54
  for _ in range(max_length):
55
  logits = model(generated).logits[:, -1, :] / temperature
@@ -60,20 +68,29 @@ def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2,
60
  mask[..., 0] = 0
61
  filtered = logits.clone()
62
  filtered[:, sorted_indices[mask]] = -float("Inf")
 
63
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
64
  generated = torch.cat([generated, next_token], dim=-1)
 
65
  new_text = tokenizer.decode(next_token[0])
66
- output += new_text
67
- yield output
 
 
 
 
 
 
68
  if next_token.item() == tokenizer.eos_token_id:
69
  break
70
 
71
- # Response pipeline
72
  def respond(message, history, reasoning_enabled, request: gr.Request):
73
  ip = request.client.host if request else "unknown"
74
  now = time.time()
75
  info = usage_info.get(ip, {"count": 0, "last_seen": 0})
76
 
 
77
  if now - info["last_seen"] > RESET_AFTER_SECONDS:
78
  info["count"] = 0
79
 
@@ -88,19 +105,21 @@ def respond(message, history, reasoning_enabled, request: gr.Request):
88
  prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
89
  history = history + [[message, ""]]
90
 
91
- for output in generate_stream(prompt, use_fallback=use_fallback):
92
- formatted = format_thinking(output)
 
93
  history[-1][1] = f"{formatted}<br><sub style='color:gray'>({model_used})</sub>"
94
  yield history, history, f"🧠 A3 messages left: {remaining}"
95
 
96
  def clear_chat():
97
  return [], [], "🧠 A3 messages left: 5"
98
 
99
- # UI
100
  with gr.Blocks() as demo:
101
- gr.Markdown("# 🤖 SamAI – Reasoning Chat (Chat Mode Enabled)")
102
  model_status = gr.Textbox(interactive=False, label="Model Status")
103
  usage_counter = gr.Textbox(value="🧠 A3 messages left: 5", interactive=False, show_label=False)
 
104
  chat_box = gr.Chatbot(type="tuples")
105
  chat_state = gr.State([])
106
 
@@ -115,9 +134,8 @@ with gr.Blocks() as demo:
115
  send_btn.click(
116
  respond,
117
  inputs=[user_input, chat_state, reason_toggle],
118
- outputs=[chat_box, chat_state, usage_counter]
119
  )
120
-
121
  clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state, usage_counter])
122
 
123
  demo.queue()
 
16
  fallback_model = fallback_tokenizer = None
17
  usage_info = {}
18
 
19
+ # Load both models
20
  def load_models():
21
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
22
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
 
25
  fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL, torch_dtype=torch.float16).to(device).eval()
26
  return f"✅ Loaded: {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
27
 
28
+ # Build a Qwen-style chat prompt
29
  def build_chat_prompt(history, user_input, reasoning_enabled):
30
  prefix = "/think " if reasoning_enabled else "/no_think "
31
  prompt = ""
32
  for user_msg, bot_msg in history:
33
+ prompt += "<|user|>\n" + user_msg + "\n<|assistant|>\n" + bot_msg + "\n"
34
+ prompt += "<|user|>\n" + user_input + "\n<|assistant|>\n"
35
  return prefix + prompt
36
 
37
+ # Collapse <think> blocks into hidden details
38
  def format_thinking(text):
39
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
40
  if match:
41
  reasoning = escape(match.group(1).strip())
42
  visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
43
+ return (
44
+ escape(visible)
45
+ + "<br><details><summary>🧠 Show reasoning</summary><pre>"
46
+ + reasoning
47
+ + "</pre></details>"
48
+ )
49
  return escape(text)
50
 
51
+ # Stream only the new assistant tokens (no prompt echo)
52
  def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9):
53
  model = fallback_model if use_fallback else primary_model
54
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
55
+
56
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
57
  generated = input_ids
58
+
59
+ # We’ll accumulate only the new assistant text:
60
+ assistant_text = ""
61
 
62
  for _ in range(max_length):
63
  logits = model(generated).logits[:, -1, :] / temperature
 
68
  mask[..., 0] = 0
69
  filtered = logits.clone()
70
  filtered[:, sorted_indices[mask]] = -float("Inf")
71
+
72
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
73
  generated = torch.cat([generated, next_token], dim=-1)
74
+
75
  new_text = tokenizer.decode(next_token[0])
76
+ assistant_text += new_text
77
+
78
+ # Strip leading assistant tag if it shows up
79
+ if assistant_text.startswith("<|assistant|>"):
80
+ assistant_text = assistant_text[len("<|assistant|>"):]
81
+
82
+ yield assistant_text
83
+
84
  if next_token.item() == tokenizer.eos_token_id:
85
  break
86
 
87
+ # Main respond handler
88
  def respond(message, history, reasoning_enabled, request: gr.Request):
89
  ip = request.client.host if request else "unknown"
90
  now = time.time()
91
  info = usage_info.get(ip, {"count": 0, "last_seen": 0})
92
 
93
+ # Reset count if idle
94
  if now - info["last_seen"] > RESET_AFTER_SECONDS:
95
  info["count"] = 0
96
 
 
105
  prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
106
  history = history + [[message, ""]]
107
 
108
+ # Stream only the assistant’s new text
109
+ for partial in generate_stream(prompt, use_fallback=use_fallback):
110
+ formatted = format_thinking(partial)
111
  history[-1][1] = f"{formatted}<br><sub style='color:gray'>({model_used})</sub>"
112
  yield history, history, f"🧠 A3 messages left: {remaining}"
113
 
114
  def clear_chat():
115
  return [], [], "🧠 A3 messages left: 5"
116
 
117
+ # Build Gradio UI
118
  with gr.Blocks() as demo:
119
+ gr.Markdown("# 🤖 SamAI – Qwen-Chat Mode")
120
  model_status = gr.Textbox(interactive=False, label="Model Status")
121
  usage_counter = gr.Textbox(value="🧠 A3 messages left: 5", interactive=False, show_label=False)
122
+
123
  chat_box = gr.Chatbot(type="tuples")
124
  chat_state = gr.State([])
125
 
 
134
  send_btn.click(
135
  respond,
136
  inputs=[user_input, chat_state, reason_toggle],
137
+ outputs=[chat_box, chat_state, usage_counter],
138
  )
 
139
  clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state, usage_counter])
140
 
141
  demo.queue()