Boning c commited on
Commit
7a55e85
Β·
verified Β·
1 Parent(s): d5529e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -89
app.py CHANGED
@@ -4,32 +4,27 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import re, time, json
5
  from html import escape
6
 
7
- # ─── Model Config ─────────────────────────────────────────────────────────────
8
- PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
9
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
10
- USAGE_LIMIT = 5 # max messages before fallback
11
- RESET_MS = 20 * 60 * 1000 # 20 minutes in milliseconds
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  primary_model = primary_tokenizer = None
15
  fallback_model = fallback_tokenizer = None
16
 
17
- # ─── Load Models ────────────────────────────────────────────────────────────────
18
  def load_models():
19
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
20
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
21
- primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL,
22
- torch_dtype=torch.float16
23
- ).to(device).eval()
24
- fallback_tokenizer= AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
25
- fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL,
26
- torch_dtype=torch.float16
27
- ).to(device).eval()
28
- return f"βœ… Loaded: {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
29
-
30
- # ─── Prompt Builder ────────────────────────────────────────────────────────────
31
  def build_chat_prompt(history, user_input, reasoning_enabled):
32
- # inject think/no_think as a system role
33
  system_flag = "/think" if reasoning_enabled else "/no_think"
34
  prompt = f"<|system|>\n{system_flag}\n"
35
  for u, a in history:
@@ -37,53 +32,42 @@ def build_chat_prompt(history, user_input, reasoning_enabled):
37
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
38
  return prompt
39
 
40
- # ─── Collapse <think> Blocks ───────────────────────────────────────────────────
41
  def format_thinking(text):
42
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
43
  if not match:
44
  return escape(text)
45
  reasoning = escape(match.group(1).strip())
46
- visible = re.sub(r"<think>.*?</think>", "[thinking...]", text,
47
- flags=re.DOTALL).strip()
48
- return (
49
- escape(visible)
50
- + "<br><details><summary>🧠 Show reasoning</summary>"
51
- + "<pre>" + reasoning + "</pre></details>"
52
- )
53
 
54
- # ─── Token-Stream Generator ───────────────────────────────────────────────────
55
- def generate_stream(prompt, use_fallback=False,
56
- max_length=100, temperature=0.2, top_p=0.9):
57
- model = fallback_model if use_fallback else primary_model
58
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
59
-
60
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
61
  generated = input_ids
62
  assistant_text = ""
63
 
64
  for _ in range(max_length):
65
  logits = model(generated).logits[:, -1, :] / temperature
66
- sorted_logits, indices = torch.sort(logits, descending=True)
67
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
68
 
69
- # top-p filtering
70
  mask = probs > top_p
71
  mask[..., 1:] = mask[..., :-1].clone()
72
- mask[..., 0] = 0
73
  filtered = logits.clone()
74
- filtered[:, indices[mask]] = -float("Inf")
75
 
76
- # sample next token
77
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
78
  generated = torch.cat([generated, next_token], dim=-1)
79
  new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
80
  assistant_text += new_text
81
 
82
- # strip opening assistant tag
83
  if assistant_text.startswith("<|assistant|>"):
84
  assistant_text = assistant_text[len("<|assistant|>"):]
85
 
86
- # stop if model begins a new user turn
87
  if "<|user|>" in new_text:
88
  break
89
 
@@ -92,91 +76,91 @@ def generate_stream(prompt, use_fallback=False,
92
  if next_token.item() == tokenizer.eos_token_id:
93
  break
94
 
95
- # ─── Main Respond Handler ─────────────────────────────────────────────────────
96
- def respond(user_msg, history, reasoning_enabled, limit_json):
97
- # parse usage info from localStorage
98
  info = json.loads(limit_json) if limit_json else {"count": 0}
99
  count = info.get("count", 0)
100
 
101
  use_fallback = count > USAGE_LIMIT
102
- remaining = max(0, USAGE_LIMIT - count)
103
- model_label = "A3" if not use_fallback else "Fallback A1"
 
 
104
 
105
- # build prompt & init history
106
- prompt = build_chat_prompt(history, user_msg.strip(), reasoning_enabled)
107
- history = history + [[user_msg, ""]]
108
 
109
- # stream assistant reply
110
  for chunk in generate_stream(prompt, use_fallback=use_fallback):
111
  formatted = format_thinking(chunk)
112
- history[-1][1] = (
113
- f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
114
- )
115
- # during streaming, show Generating
116
- yield history, history, f"🧠 A3 left: {remaining}", "Generating..."
117
 
118
- # final update: set status back to Idle
119
- yield history, history, f"🧠 A3 left: {remaining}", "Idle"
120
 
121
  def clear_chat():
122
- return [], [], "🧠 A3 left: 5", "Idle"
123
 
124
- # ─── Gradio UI ────────────────────────────────────────────────────────────────
125
  with gr.Blocks() as demo:
126
- gr.HTML( # inject localStorage logic
127
- """
128
- <script>
129
- function updateUsageLimit() {
130
- const key = "samai_limit";
131
- let now = Date.now();
132
- let record = JSON.parse(localStorage.getItem(key) || "null");
133
- if (!record || (now - record.lastSeen) > {RESET_MS}) {{
134
- record = {{count: 0, lastSeen: now}};
135
- }}
136
- record.count += 1;
137
- record.lastSeen = now;
138
- localStorage.setItem(key, JSON.stringify(record));
139
- return record;
140
- }
141
- </script>
142
- """.replace("{RESET_MS}", str(RESET_MS))
 
 
 
 
 
 
 
 
 
 
143
  )
144
 
145
- gr.Markdown("# πŸ€– SamAI – Qwen Chat with Client-Side Limits")
146
 
147
- # hidden box to carry JSON string from JS β†’ Python
148
  limit_json = gr.Textbox(visible=False)
149
  model_status = gr.Textbox(interactive=False, label="Model Status")
150
- usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False)
151
- status_display = gr.Textbox("Idle", interactive=False, label="Status")
152
 
153
  chat_box = gr.Chatbot(type="tuples")
154
  chat_state= gr.State([])
155
 
156
  with gr.Row():
157
- user_input = gr.Textbox(placeholder="Ask me anything…", show_label=False, scale=6)
158
  reason_toggle= gr.Checkbox(label="Reason", value=True, scale=1)
159
- send_btn = gr.Button("Send", scale=1)
160
 
161
- clear_btn = gr.Button("Clear Chat")
162
 
163
  model_status.value = load_models()
164
 
165
- # first: JS updates localStorage β†’ limit_json
166
  send_btn.click(
167
- fn=None,
168
- _js="() => JSON.stringify(updateUsageLimit())",
169
- outputs=[limit_json]
170
  ).then(
171
- # then: call our Python respond() with that JSON
172
- fn=respond,
173
- inputs=[user_input, chat_state, reason_toggle, limit_json],
174
- outputs=[chat_box, chat_state, usage_counter, status_display]
175
  )
176
 
177
  clear_btn.click(fn=clear_chat,
178
- inputs=[],
179
- outputs=[chat_box, chat_state, usage_counter, status_display]
180
  )
181
 
182
  demo.queue()
 
4
  import re, time, json
5
  from html import escape
6
 
7
+ # ─── Config ───────────────────────────────────────────────────
8
+ PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
9
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
10
+ USAGE_LIMIT = 5
11
+ RESET_MS = 20 * 60 * 1000 # 20 min in ms
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  primary_model = primary_tokenizer = None
15
  fallback_model = fallback_tokenizer = None
16
 
17
+ # ─── Load Models ───────────────────────────────────────────────
18
  def load_models():
19
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
20
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
21
+ primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL, torch_dtype=torch.float16).to(device).eval()
22
+ fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
23
+ fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL, torch_dtype=torch.float16).to(device).eval()
24
+ return f"βœ… Loaded {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
25
+
26
+ # ─── Prompt Builder ────────────────────────────────────────────
 
 
 
 
27
  def build_chat_prompt(history, user_input, reasoning_enabled):
 
28
  system_flag = "/think" if reasoning_enabled else "/no_think"
29
  prompt = f"<|system|>\n{system_flag}\n"
30
  for u, a in history:
 
32
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
33
  return prompt
34
 
35
+ # ─── Collapse <think> blocks ──────────────────────────────────
36
  def format_thinking(text):
37
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
38
  if not match:
39
  return escape(text)
40
  reasoning = escape(match.group(1).strip())
41
+ visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
42
+ return escape(visible) + "<br><details><summary>🧠 Show reasoning</summary><pre>" + reasoning + "</pre></details>"
 
 
 
 
 
43
 
44
+ # ─── Stream Generator ─────────────────────────────────────────
45
+ def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9):
46
+ model = fallback_model if use_fallback else primary_model
 
47
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
 
48
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
49
  generated = input_ids
50
  assistant_text = ""
51
 
52
  for _ in range(max_length):
53
  logits = model(generated).logits[:, -1, :] / temperature
54
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
55
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
56
 
 
57
  mask = probs > top_p
58
  mask[..., 1:] = mask[..., :-1].clone()
59
+ mask[..., 0] = 0
60
  filtered = logits.clone()
61
+ filtered[:, sorted_indices[mask]] = -float("Inf")
62
 
 
63
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
64
  generated = torch.cat([generated, next_token], dim=-1)
65
  new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
66
  assistant_text += new_text
67
 
 
68
  if assistant_text.startswith("<|assistant|>"):
69
  assistant_text = assistant_text[len("<|assistant|>"):]
70
 
 
71
  if "<|user|>" in new_text:
72
  break
73
 
 
76
  if next_token.item() == tokenizer.eos_token_id:
77
  break
78
 
79
+ # ─── Respond Handler ──────────────────────────────────────────
80
+ def respond(message, history, reasoning_enabled, limit_json):
 
81
  info = json.loads(limit_json) if limit_json else {"count": 0}
82
  count = info.get("count", 0)
83
 
84
  use_fallback = count > USAGE_LIMIT
85
+ remaining = max(0, USAGE_LIMIT - count)
86
+ model_label = "A3" if not use_fallback else "Fallback A1"
87
+ prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
88
+ history = history + [[message, ""]]
89
 
90
+ yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
 
 
91
 
 
92
  for chunk in generate_stream(prompt, use_fallback=use_fallback):
93
  formatted = format_thinking(chunk)
94
+ history[-1][1] = f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
95
+ yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
 
 
 
96
 
97
+ yield history, history, f"🧠 A3 left: {remaining}", "Send"
 
98
 
99
  def clear_chat():
100
+ return [], [], "🧠 A3 left: 5", "Send"
101
 
102
+ # ─── Gradio UI ────────────────────────────────────────────────
103
  with gr.Blocks() as demo:
104
+ gr.HTML(
105
+ """
106
+ <script>
107
+ function updateUsageLimit() {
108
+ let key = "samai_limit";
109
+ let now = Date.now();
110
+ let record = JSON.parse(localStorage.getItem(key) || "null");
111
+ if (!record || (now - record.lastSeen) > """ + str(RESET_MS) + """) {
112
+ record = {count: 0, lastSeen: now};
113
+ }
114
+ record.count += 1;
115
+ record.lastSeen = now;
116
+ localStorage.setItem(key, JSON.stringify(record));
117
+ return record;
118
+ }
119
+ </script>
120
+ <style>
121
+ .send-circle {
122
+ border-radius: 50%;
123
+ height: 40px;
124
+ width: 40px;
125
+ padding: 0;
126
+ font-size: 12px;
127
+ text-align: center;
128
+ }
129
+ </style>
130
+ """
131
  )
132
 
133
+ gr.Markdown("# πŸ€– SamAI – Chat Reasoning UI")
134
 
 
135
  limit_json = gr.Textbox(visible=False)
136
  model_status = gr.Textbox(interactive=False, label="Model Status")
137
+ usage_counter = gr.Textbox(value="🧠 A3 left: 5", interactive=False, show_label=False)
 
138
 
139
  chat_box = gr.Chatbot(type="tuples")
140
  chat_state= gr.State([])
141
 
142
  with gr.Row():
143
+ user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6)
144
  reason_toggle= gr.Checkbox(label="Reason", value=True, scale=1)
145
+ send_btn = gr.Button("Send", elem_classes=["send-circle"], scale=1)
146
 
147
+ clear_btn = gr.Button("Clear")
148
 
149
  model_status.value = load_models()
150
 
 
151
  send_btn.click(
152
+ None,
153
+ _js="() => JSON.stringify(updateUsageLimit())",
154
+ outputs=[limit_json]
155
  ).then(
156
+ fn=respond,
157
+ inputs=[user_input, chat_state, reason_toggle, limit_json],
158
+ outputs=[chat_box, chat_state, usage_counter, send_btn]
 
159
  )
160
 
161
  clear_btn.click(fn=clear_chat,
162
+ inputs=[],
163
+ outputs=[chat_box, chat_state, usage_counter, send_btn]
164
  )
165
 
166
  demo.queue()