Boning c commited on
Commit
9fe966b
Β·
verified Β·
1 Parent(s): b612b32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -69
app.py CHANGED
@@ -1,26 +1,33 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import re, time, json
5
  from html import escape
6
 
 
7
  PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
8
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
9
- USAGE_LIMIT = 5
10
- RESET_MS = 20 * 60 * 1000
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  primary_model = primary_tokenizer = None
14
  fallback_model = fallback_tokenizer = None
15
 
 
16
  def load_models():
17
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
18
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
19
- primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL, torch_dtype=torch.float16).to(device).eval()
20
- fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
21
- fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL, torch_dtype=torch.float16).to(device).eval()
22
- return f"βœ… Loaded {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}"
23
-
 
 
 
 
 
24
  def build_chat_prompt(history, user_input, reasoning_enabled):
25
  system_flag = "/think" if reasoning_enabled else "/no_think"
26
  prompt = f"<|system|>\n{system_flag}\n"
@@ -29,124 +36,151 @@ def build_chat_prompt(history, user_input, reasoning_enabled):
29
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
30
  return prompt
31
 
 
32
  def format_thinking(text):
33
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
34
  if not match:
35
  return escape(text)
36
  reasoning = escape(match.group(1).strip())
37
  visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
38
- return escape(visible) + "<br><details><summary>🧠 Show reasoning</summary><pre>" + reasoning + "</pre></details>"
 
 
 
 
39
 
40
- def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9):
41
- model = fallback_model if use_fallback else primary_model
 
 
42
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
43
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
44
  generated = input_ids
45
  assistant_text = ""
 
46
  for _ in range(max_length):
47
  logits = model(generated).logits[:, -1, :] / temperature
48
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
49
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
 
50
  mask = probs > top_p
51
  mask[..., 1:] = mask[..., :-1].clone()
52
- mask[..., 0] = 0
53
  filtered = logits.clone()
54
- filtered[:, sorted_indices[mask]] = -float("Inf")
 
55
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
56
  generated = torch.cat([generated, next_token], dim=-1)
57
  new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
58
  assistant_text += new_text
 
 
59
  if assistant_text.startswith("<|assistant|>"):
60
  assistant_text = assistant_text[len("<|assistant|>"):]
 
 
61
  if "<|user|>" in new_text:
62
  break
 
63
  yield assistant_text
 
64
  if next_token.item() == tokenizer.eos_token_id:
65
  break
66
 
 
67
  def respond(message, history, reasoning_enabled, limit_json):
68
- info = json.loads(limit_json) if limit_json else {"count": 0}
69
  count = info.get("count", 0)
70
  use_fallback = count > USAGE_LIMIT
71
- remaining = max(0, USAGE_LIMIT - count)
72
- model_label = "A3" if not use_fallback else "Fallback A1"
 
 
73
  prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
74
  history = history + [[message, ""]]
75
  yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
76
- for chunk in generate_stream(prompt, use_fallback=use_fallback):
 
 
77
  formatted = format_thinking(chunk)
78
- history[-1][1] = f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
 
 
79
  yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
 
 
80
  yield history, history, f"🧠 A3 left: {remaining}", "Send"
81
 
82
  def clear_chat():
83
  return [], [], "🧠 A3 left: 5", "Send"
84
 
 
85
  with gr.Blocks() as demo:
 
86
  gr.HTML(f"""
87
- <script>
88
- function updateUsageLimit() {{
89
- let key = "samai_limit";
90
- let now = Date.now();
91
- let record = JSON.parse(localStorage.getItem(key) || "null");
92
- if (!record || (now - record.lastSeen) > {RESET_MS}) {{
93
- record = {{count: 0, lastSeen: now}};
94
- }}
95
- record.count += 1;
96
- record.lastSeen = now;
97
- localStorage.setItem(key, JSON.stringify(record));
98
- document.getElementById("limit_json").value = JSON.stringify(record);
99
- }}
100
-
101
- function setGeneratingText() {{
102
- document.getElementById("send_btn").innerText = "Generating…";
103
- }}
104
- function setIdleText() {{
105
- document.getElementById("send_btn").innerText = "Send";
106
- }}
107
- </script>
108
- <style>
109
- .send-circle {{
110
- border-radius: 50%;
111
- height: 40px;
112
- width: 40px;
113
- padding: 0;
114
- font-size: 12px;
115
- text-align: center;
116
- }}
117
- </style>
118
- """)
119
-
120
- gr.Markdown("# πŸ€– SamAI – Chat Reasoning (Gradio v3 Compatible)")
121
- limit_json = gr.Textbox(visible=False, elem_id="limit_json")
122
- model_status = gr.Textbox(interactive=False, label="Model Status")
 
 
 
 
123
  usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False)
124
 
125
- chat_box = gr.Chatbot(type="tuples")
126
  chat_state = gr.State([])
127
 
128
  with gr.Row():
129
- user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6)
130
  reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
131
- send_btn = gr.Button("Send", elem_id="send_btn", elem_classes=["send-circle"], scale=1)
132
-
133
- update_btn = gr.Button(visible=False)
134
 
135
  clear_btn = gr.Button("Clear")
136
 
137
  model_status.value = load_models()
138
 
139
- update_btn.click(None, _js="updateUsageLimit")
140
-
141
- send_btn.click(None, _js="setGeneratingText").then(
142
  fn=respond,
143
  inputs=[user_input, chat_state, reason_toggle, limit_json],
144
  outputs=[chat_box, chat_state, usage_counter, send_btn]
145
- ).then(fn=None, _js="setIdleText")
146
-
147
- clear_btn.click(fn=clear_chat,
148
- inputs=[],
149
- outputs=[chat_box, chat_state, usage_counter, send_btn]
150
  )
151
 
152
  demo.queue()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import re, json
5
  from html import escape
6
 
7
+ # ─── Config ─────────────────────────────────────────────────────────
8
  PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
9
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
10
+ USAGE_LIMIT = 5
11
+ RESET_MS = 20 * 60 * 1000 # 20 minutes in milliseconds
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  primary_model = primary_tokenizer = None
15
  fallback_model = fallback_tokenizer = None
16
 
17
+ # ─── Load Models ───────────────────────────────────────────────────────
18
  def load_models():
19
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
20
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
21
+ primary_model = AutoModelForCausalLM.from_pretrained(
22
+ PRIMARY_MODEL, torch_dtype=torch.float16
23
+ ).to(device).eval()
24
+ fallback_tokenizer= AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
25
+ fallback_model = AutoModelForCausalLM.from_pretrained(
26
+ FALLBACK_MODEL, torch_dtype=torch.float16
27
+ ).to(device).eval()
28
+ return f"βœ… Loaded {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
29
+
30
+ # ─── Build Qwen-style Prompt ──────────────────────────────────────────
31
  def build_chat_prompt(history, user_input, reasoning_enabled):
32
  system_flag = "/think" if reasoning_enabled else "/no_think"
33
  prompt = f"<|system|>\n{system_flag}\n"
 
36
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
37
  return prompt
38
 
39
+ # ─── Collapse <think> Blocks ──────────────────────────────────────────
40
  def format_thinking(text):
41
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
42
  if not match:
43
  return escape(text)
44
  reasoning = escape(match.group(1).strip())
45
  visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
46
+ return (
47
+ escape(visible)
48
+ + "<br><details><summary>🧠 Show reasoning</summary>"
49
+ + f"<pre>{reasoning}</pre></details>"
50
+ )
51
 
52
+ # ─── Token‐by‐Token Streaming ─────────────────────────────────────────
53
+ def generate_stream(prompt, use_fallback=False,
54
+ max_length=100, temperature=0.2, top_p=0.9):
55
+ model = fallback_model if use_fallback else primary_model
56
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
57
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
58
  generated = input_ids
59
  assistant_text = ""
60
+
61
  for _ in range(max_length):
62
  logits = model(generated).logits[:, -1, :] / temperature
63
+ sorted_logits, idxs = torch.sort(logits, descending=True)
64
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
65
+
66
  mask = probs > top_p
67
  mask[..., 1:] = mask[..., :-1].clone()
68
+ mask[..., 0] = 0
69
  filtered = logits.clone()
70
+ filtered[:, idxs[mask]] = -float("Inf")
71
+
72
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
73
  generated = torch.cat([generated, next_token], dim=-1)
74
  new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
75
  assistant_text += new_text
76
+
77
+ # strip the opening assistant tag
78
  if assistant_text.startswith("<|assistant|>"):
79
  assistant_text = assistant_text[len("<|assistant|>"):]
80
+
81
+ # stop if model tries to start a new user turn
82
  if "<|user|>" in new_text:
83
  break
84
+
85
  yield assistant_text
86
+
87
  if next_token.item() == tokenizer.eos_token_id:
88
  break
89
 
90
+ # ─── Main Respond Handler ─────────────────────────────────────────────
91
  def respond(message, history, reasoning_enabled, limit_json):
92
+ info = json.loads(limit_json) if limit_json else {"count": 0}
93
  count = info.get("count", 0)
94
  use_fallback = count > USAGE_LIMIT
95
+ remaining = max(0, USAGE_LIMIT - count)
96
+ model_label = "A3" if not use_fallback else "Fallback A1"
97
+
98
+ # show "Generating…" immediately
99
  prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
100
  history = history + [[message, ""]]
101
  yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
102
+
103
+ # stream assistant reply
104
+ for chunk in generate_stream(prompt, use_fallback):
105
  formatted = format_thinking(chunk)
106
+ history[-1][1] = (
107
+ f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
108
+ )
109
  yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
110
+
111
+ # final: reset to Send
112
  yield history, history, f"🧠 A3 left: {remaining}", "Send"
113
 
114
  def clear_chat():
115
  return [], [], "🧠 A3 left: 5", "Send"
116
 
117
+ # ─── Gradio UI ─────────────────────────────────────────────────────────
118
  with gr.Blocks() as demo:
119
+ # Inject client-side JS + CSS
120
  gr.HTML(f"""
121
+ <script>
122
+ // bump/reset usage in localStorage and write to hidden textbox
123
+ function updateUsageLimit() {{
124
+ const key = "samai_limit";
125
+ const now = Date.now();
126
+ let rec = JSON.parse(localStorage.getItem(key) || "null");
127
+ if (!rec || now - rec.lastSeen > {RESET_MS}) {{
128
+ rec = {{count:0, lastSeen: now}};
129
+ }}
130
+ rec.count += 1;
131
+ rec.lastSeen = now;
132
+ localStorage.setItem(key, JSON.stringify(rec));
133
+ document.getElementById("limit_json").value = JSON.stringify(rec);
134
+ }}
135
+ // on Send click: update limit & flip button text
136
+ document.addEventListener("DOMContentLoaded", () => {{
137
+ const btn = document.getElementById("send_btn");
138
+ btn.addEventListener("click", () => {{
139
+ updateUsageLimit();
140
+ btn.innerText = "Generating…";
141
+ }});
142
+ }});
143
+ </script>
144
+ <style>
145
+ .send-circle {{
146
+ border-radius: 50%;
147
+ height: 40px;
148
+ width: 40px;
149
+ padding: 0;
150
+ font-size: 12px;
151
+ text-align: center;
152
+ }}
153
+ </style>
154
+ """)
155
+
156
+ gr.Markdown("# πŸ€– SamAI – Chat Reasoning (Final)")
157
+
158
+ # carry usage JSON from JS β†’ Python
159
+ limit_json = gr.Textbox(visible=False, elem_id="limit_json")
160
+ model_status = gr.Textbox(interactive=False, label="Model Status")
161
  usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False)
162
 
163
+ chat_box = gr.Chatbot(type="tuples")
164
  chat_state = gr.State([])
165
 
166
  with gr.Row():
167
+ user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6)
168
  reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
169
+ send_btn = gr.Button("Send", elem_id="send_btn", elem_classes=["send-circle"], scale=1)
 
 
170
 
171
  clear_btn = gr.Button("Clear")
172
 
173
  model_status.value = load_models()
174
 
175
+ send_btn.click(
 
 
176
  fn=respond,
177
  inputs=[user_input, chat_state, reason_toggle, limit_json],
178
  outputs=[chat_box, chat_state, usage_counter, send_btn]
179
+ )
180
+ clear_btn.click(
181
+ fn=clear_chat,
182
+ inputs=[],
183
+ outputs=[chat_box, chat_state, usage_counter, send_btn]
184
  )
185
 
186
  demo.queue()