Boning c commited on
Commit
d5529e8
Β·
verified Β·
1 Parent(s): d84b5b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -68
app.py CHANGED
@@ -1,75 +1,89 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import re
5
- import time
6
  from html import escape
7
 
8
- # Model config
9
- PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
10
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
11
- USAGE_LIMIT = 5
12
- RESET_AFTER_SECONDS = 20 * 60
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  primary_model = primary_tokenizer = None
16
  fallback_model = fallback_tokenizer = None
17
- usage_info = {}
18
 
19
- # Load models
20
  def load_models():
21
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
22
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
23
- primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL, torch_dtype=torch.float16).to(device).eval()
24
- fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
25
- fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL, torch_dtype=torch.float16).to(device).eval()
 
 
 
 
26
  return f"βœ… Loaded: {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
27
 
28
- # Build prompt with full chat context
29
  def build_chat_prompt(history, user_input, reasoning_enabled):
30
- system = "/think" if reasoning_enabled else "/no_think"
31
- prompt = f"<|system|>\n{system}\n"
32
- for user_msg, bot_msg in history:
33
- prompt += f"<|user|>\n{user_msg}\n<|assistant|>\n{bot_msg}\n"
 
34
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
35
  return prompt
36
 
37
- # Collapse <think> reasoning blocks
38
  def format_thinking(text):
39
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
40
- if match:
41
- reasoning = escape(match.group(1).strip())
42
- visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
43
- return f"{escape(visible)}<br><details><summary>🧠 Show reasoning</summary><pre>{reasoning}</pre></details>"
44
- return escape(text)
45
-
46
- # Stream tokens and stop on <|user|> tag
47
- def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9):
48
- model = fallback_model if use_fallback else primary_model
 
 
 
 
 
 
49
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
 
50
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
51
  generated = input_ids
52
  assistant_text = ""
53
 
54
  for _ in range(max_length):
55
  logits = model(generated).logits[:, -1, :] / temperature
56
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
57
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
 
 
58
  mask = probs > top_p
59
  mask[..., 1:] = mask[..., :-1].clone()
60
- mask[..., 0] = 0
61
  filtered = logits.clone()
62
- filtered[:, sorted_indices[mask]] = -float("Inf")
63
 
 
64
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
65
  generated = torch.cat([generated, next_token], dim=-1)
66
- new_text = tokenizer.decode(next_token[0])
67
  assistant_text += new_text
68
 
 
69
  if assistant_text.startswith("<|assistant|>"):
70
  assistant_text = assistant_text[len("<|assistant|>"):]
71
 
72
- # ⛔️ Stop if model starts new user turn
73
  if "<|user|>" in new_text:
74
  break
75
 
@@ -78,55 +92,92 @@ def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2,
78
  if next_token.item() == tokenizer.eos_token_id:
79
  break
80
 
81
- # Respond to incoming message
82
- def respond(message, history, reasoning_enabled, request: gr.Request):
83
- ip = request.client.host if request else "unknown"
84
- now = time.time()
85
- info = usage_info.get(ip, {"count": 0, "last_seen": 0})
86
- if now - info["last_seen"] > RESET_AFTER_SECONDS:
87
- info["count"] = 0
88
- info["count"] += 1
89
- info["last_seen"] = now
90
- usage_info[ip] = info
91
-
92
- use_fallback = info["count"] > USAGE_LIMIT
93
- remaining = max(0, USAGE_LIMIT - info["count"])
94
- model_used = "A3" if not use_fallback else "Fallback A1"
95
-
96
- prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
97
- history = history + [[message, ""]]
98
-
99
- for partial in generate_stream(prompt, use_fallback=use_fallback):
100
- formatted = format_thinking(partial)
101
- history[-1][1] = f"{formatted}<br><sub style='color:gray'>({model_used})</sub>"
102
- yield history, history, f"🧠 A3 messages left: {remaining}"
 
 
 
103
 
104
  def clear_chat():
105
- return [], [], "🧠 A3 messages left: 5"
106
 
107
- # UI Layout
108
  with gr.Blocks() as demo:
109
- gr.Markdown("# πŸ€– SamAI – Chat Reasoning (Qwen-Style)")
110
- model_status = gr.Textbox(interactive=False, label="Model Status")
111
- usage_counter = gr.Textbox(value="🧠 A3 messages left: 5", interactive=False, show_label=False)
112
- chat_box = gr.Chatbot(type="tuples")
113
- chat_state = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  with gr.Row():
116
- user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6)
117
- reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
118
- send_btn = gr.Button("Send", scale=1)
119
 
120
  clear_btn = gr.Button("Clear Chat")
 
121
  model_status.value = load_models()
122
 
 
123
  send_btn.click(
124
- respond,
125
- inputs=[user_input, chat_state, reason_toggle],
126
- outputs=[chat_box, chat_state, usage_counter]
 
 
 
 
 
127
  )
128
 
129
- clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state, usage_counter])
 
 
 
130
 
131
  demo.queue()
132
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import re, time, json
 
5
  from html import escape
6
 
7
+ # ─── Model Config ─────────────────────────────────────────────────────────────
8
+ PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
9
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
10
+ USAGE_LIMIT = 5 # max messages before fallback
11
+ RESET_MS = 20 * 60 * 1000 # 20 minutes in milliseconds
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  primary_model = primary_tokenizer = None
15
  fallback_model = fallback_tokenizer = None
 
16
 
17
+ # ─── Load Models ────────────────────────────────────────────────────────────────
18
  def load_models():
19
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
20
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
21
+ primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL,
22
+ torch_dtype=torch.float16
23
+ ).to(device).eval()
24
+ fallback_tokenizer= AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
25
+ fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL,
26
+ torch_dtype=torch.float16
27
+ ).to(device).eval()
28
  return f"βœ… Loaded: {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
29
 
30
+ # ─── Prompt Builder ────────────────────────────────────────────────────────────
31
  def build_chat_prompt(history, user_input, reasoning_enabled):
32
+ # inject think/no_think as a system role
33
+ system_flag = "/think" if reasoning_enabled else "/no_think"
34
+ prompt = f"<|system|>\n{system_flag}\n"
35
+ for u, a in history:
36
+ prompt += f"<|user|>\n{u}\n<|assistant|>\n{a}\n"
37
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
38
  return prompt
39
 
40
+ # ─── Collapse <think> Blocks ───────────────────────────────────────────────────
41
  def format_thinking(text):
42
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
43
+ if not match:
44
+ return escape(text)
45
+ reasoning = escape(match.group(1).strip())
46
+ visible = re.sub(r"<think>.*?</think>", "[thinking...]", text,
47
+ flags=re.DOTALL).strip()
48
+ return (
49
+ escape(visible)
50
+ + "<br><details><summary>🧠 Show reasoning</summary>"
51
+ + "<pre>" + reasoning + "</pre></details>"
52
+ )
53
+
54
+ # ─── Token-Stream Generator ───────────────────────────────────────────────────
55
+ def generate_stream(prompt, use_fallback=False,
56
+ max_length=100, temperature=0.2, top_p=0.9):
57
+ model = fallback_model if use_fallback else primary_model
58
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
59
+
60
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
61
  generated = input_ids
62
  assistant_text = ""
63
 
64
  for _ in range(max_length):
65
  logits = model(generated).logits[:, -1, :] / temperature
66
+ sorted_logits, indices = torch.sort(logits, descending=True)
67
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
68
+
69
+ # top-p filtering
70
  mask = probs > top_p
71
  mask[..., 1:] = mask[..., :-1].clone()
72
+ mask[..., 0] = 0
73
  filtered = logits.clone()
74
+ filtered[:, indices[mask]] = -float("Inf")
75
 
76
+ # sample next token
77
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
78
  generated = torch.cat([generated, next_token], dim=-1)
79
+ new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
80
  assistant_text += new_text
81
 
82
+ # strip opening assistant tag
83
  if assistant_text.startswith("<|assistant|>"):
84
  assistant_text = assistant_text[len("<|assistant|>"):]
85
 
86
+ # stop if model begins a new user turn
87
  if "<|user|>" in new_text:
88
  break
89
 
 
92
  if next_token.item() == tokenizer.eos_token_id:
93
  break
94
 
95
+ # ─── Main Respond Handler ─────────────────────────────────────────────────────
96
+ def respond(user_msg, history, reasoning_enabled, limit_json):
97
+ # parse usage info from localStorage
98
+ info = json.loads(limit_json) if limit_json else {"count": 0}
99
+ count = info.get("count", 0)
100
+
101
+ use_fallback = count > USAGE_LIMIT
102
+ remaining = max(0, USAGE_LIMIT - count)
103
+ model_label = "A3" if not use_fallback else "Fallback A1"
104
+
105
+ # build prompt & init history
106
+ prompt = build_chat_prompt(history, user_msg.strip(), reasoning_enabled)
107
+ history = history + [[user_msg, ""]]
108
+
109
+ # stream assistant reply
110
+ for chunk in generate_stream(prompt, use_fallback=use_fallback):
111
+ formatted = format_thinking(chunk)
112
+ history[-1][1] = (
113
+ f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
114
+ )
115
+ # during streaming, show Generating
116
+ yield history, history, f"🧠 A3 left: {remaining}", "Generating..."
117
+
118
+ # final update: set status back to Idle
119
+ yield history, history, f"🧠 A3 left: {remaining}", "Idle"
120
 
121
  def clear_chat():
122
+ return [], [], "🧠 A3 left: 5", "Idle"
123
 
124
+ # ─── Gradio UI ────────────────────────────────────────────────────────────────
125
  with gr.Blocks() as demo:
126
+ gr.HTML( # inject localStorage logic
127
+ """
128
+ <script>
129
+ function updateUsageLimit() {
130
+ const key = "samai_limit";
131
+ let now = Date.now();
132
+ let record = JSON.parse(localStorage.getItem(key) || "null");
133
+ if (!record || (now - record.lastSeen) > {RESET_MS}) {{
134
+ record = {{count: 0, lastSeen: now}};
135
+ }}
136
+ record.count += 1;
137
+ record.lastSeen = now;
138
+ localStorage.setItem(key, JSON.stringify(record));
139
+ return record;
140
+ }
141
+ </script>
142
+ """.replace("{RESET_MS}", str(RESET_MS))
143
+ )
144
+
145
+ gr.Markdown("# πŸ€– SamAI – Qwen Chat with Client-Side Limits")
146
+
147
+ # hidden box to carry JSON string from JS β†’ Python
148
+ limit_json = gr.Textbox(visible=False)
149
+ model_status = gr.Textbox(interactive=False, label="Model Status")
150
+ usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False)
151
+ status_display = gr.Textbox("Idle", interactive=False, label="Status")
152
+
153
+ chat_box = gr.Chatbot(type="tuples")
154
+ chat_state= gr.State([])
155
 
156
  with gr.Row():
157
+ user_input = gr.Textbox(placeholder="Ask me anything…", show_label=False, scale=6)
158
+ reason_toggle= gr.Checkbox(label="Reason", value=True, scale=1)
159
+ send_btn = gr.Button("Send", scale=1)
160
 
161
  clear_btn = gr.Button("Clear Chat")
162
+
163
  model_status.value = load_models()
164
 
165
+ # first: JS updates localStorage β†’ limit_json
166
  send_btn.click(
167
+ fn=None,
168
+ _js="() => JSON.stringify(updateUsageLimit())",
169
+ outputs=[limit_json]
170
+ ).then(
171
+ # then: call our Python respond() with that JSON
172
+ fn=respond,
173
+ inputs=[user_input, chat_state, reason_toggle, limit_json],
174
+ outputs=[chat_box, chat_state, usage_counter, status_display]
175
  )
176
 
177
+ clear_btn.click(fn=clear_chat,
178
+ inputs=[],
179
+ outputs=[chat_box, chat_state, usage_counter, status_display]
180
+ )
181
 
182
  demo.queue()
183
  demo.launch()