Boning c commited on
Commit
b612b32
Β·
verified Β·
1 Parent(s): 7a55e85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -66
app.py CHANGED
@@ -4,26 +4,23 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import re, time, json
5
  from html import escape
6
 
7
- # ─── Config ───────────────────────────────────────────────────
8
  PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
9
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
10
  USAGE_LIMIT = 5
11
- RESET_MS = 20 * 60 * 1000 # 20 min in ms
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  primary_model = primary_tokenizer = None
15
  fallback_model = fallback_tokenizer = None
16
 
17
- # ─── Load Models ───────────────────────────────────────────────
18
  def load_models():
19
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
20
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
21
  primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL, torch_dtype=torch.float16).to(device).eval()
22
  fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
23
  fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL, torch_dtype=torch.float16).to(device).eval()
24
- return f"βœ… Loaded {PRIMARY_MODEL} with fallback {FALLBACK_MODEL}"
25
 
26
- # ─── Prompt Builder ────────────────────────────────────────────
27
  def build_chat_prompt(history, user_input, reasoning_enabled):
28
  system_flag = "/think" if reasoning_enabled else "/no_think"
29
  prompt = f"<|system|>\n{system_flag}\n"
@@ -32,7 +29,6 @@ def build_chat_prompt(history, user_input, reasoning_enabled):
32
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
33
  return prompt
34
 
35
- # ─── Collapse <think> blocks ──────────────────────────────────
36
  def format_thinking(text):
37
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
38
  if not match:
@@ -41,122 +37,112 @@ def format_thinking(text):
41
  visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
42
  return escape(visible) + "<br><details><summary>🧠 Show reasoning</summary><pre>" + reasoning + "</pre></details>"
43
 
44
- # ─── Stream Generator ─────────────────────────────────────────
45
  def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9):
46
  model = fallback_model if use_fallback else primary_model
47
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
48
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
49
  generated = input_ids
50
  assistant_text = ""
51
-
52
  for _ in range(max_length):
53
  logits = model(generated).logits[:, -1, :] / temperature
54
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
55
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
56
-
57
  mask = probs > top_p
58
  mask[..., 1:] = mask[..., :-1].clone()
59
  mask[..., 0] = 0
60
  filtered = logits.clone()
61
  filtered[:, sorted_indices[mask]] = -float("Inf")
62
-
63
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
64
  generated = torch.cat([generated, next_token], dim=-1)
65
  new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
66
  assistant_text += new_text
67
-
68
  if assistant_text.startswith("<|assistant|>"):
69
  assistant_text = assistant_text[len("<|assistant|>"):]
70
-
71
  if "<|user|>" in new_text:
72
  break
73
-
74
  yield assistant_text
75
-
76
  if next_token.item() == tokenizer.eos_token_id:
77
  break
78
 
79
- # ─── Respond Handler ──────────────────────────────────────────
80
  def respond(message, history, reasoning_enabled, limit_json):
81
  info = json.loads(limit_json) if limit_json else {"count": 0}
82
  count = info.get("count", 0)
83
-
84
  use_fallback = count > USAGE_LIMIT
85
  remaining = max(0, USAGE_LIMIT - count)
86
  model_label = "A3" if not use_fallback else "Fallback A1"
87
  prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
88
  history = history + [[message, ""]]
89
-
90
  yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
91
-
92
  for chunk in generate_stream(prompt, use_fallback=use_fallback):
93
  formatted = format_thinking(chunk)
94
  history[-1][1] = f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
95
  yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
96
-
97
  yield history, history, f"🧠 A3 left: {remaining}", "Send"
98
 
99
  def clear_chat():
100
  return [], [], "🧠 A3 left: 5", "Send"
101
 
102
- # ─── Gradio UI ───��────────────────────────────────────────────
103
  with gr.Blocks() as demo:
104
- gr.HTML(
105
- """
106
- <script>
107
- function updateUsageLimit() {
108
- let key = "samai_limit";
109
- let now = Date.now();
110
- let record = JSON.parse(localStorage.getItem(key) || "null");
111
- if (!record || (now - record.lastSeen) > """ + str(RESET_MS) + """) {
112
- record = {count: 0, lastSeen: now};
113
- }
114
- record.count += 1;
115
- record.lastSeen = now;
116
- localStorage.setItem(key, JSON.stringify(record));
117
- return record;
118
- }
119
- </script>
120
- <style>
121
- .send-circle {
122
- border-radius: 50%;
123
- height: 40px;
124
- width: 40px;
125
- padding: 0;
126
- font-size: 12px;
127
- text-align: center;
128
- }
129
- </style>
130
- """
131
- )
132
-
133
- gr.Markdown("# πŸ€– SamAI – Chat Reasoning UI")
134
-
135
- limit_json = gr.Textbox(visible=False)
136
- model_status = gr.Textbox(interactive=False, label="Model Status")
137
- usage_counter = gr.Textbox(value="🧠 A3 left: 5", interactive=False, show_label=False)
138
-
139
- chat_box = gr.Chatbot(type="tuples")
140
- chat_state= gr.State([])
 
 
 
 
141
 
142
  with gr.Row():
143
- user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6)
144
- reason_toggle= gr.Checkbox(label="Reason", value=True, scale=1)
145
- send_btn = gr.Button("Send", elem_classes=["send-circle"], scale=1)
 
 
146
 
147
  clear_btn = gr.Button("Clear")
148
 
149
  model_status.value = load_models()
150
 
151
- send_btn.click(
152
- None,
153
- _js="() => JSON.stringify(updateUsageLimit())",
154
- outputs=[limit_json]
155
- ).then(
156
  fn=respond,
157
  inputs=[user_input, chat_state, reason_toggle, limit_json],
158
  outputs=[chat_box, chat_state, usage_counter, send_btn]
159
- )
160
 
161
  clear_btn.click(fn=clear_chat,
162
  inputs=[],
 
4
  import re, time, json
5
  from html import escape
6
 
 
7
  PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
8
  FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
9
  USAGE_LIMIT = 5
10
+ RESET_MS = 20 * 60 * 1000
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
  primary_model = primary_tokenizer = None
14
  fallback_model = fallback_tokenizer = None
15
 
 
16
  def load_models():
17
  global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
18
  primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
19
  primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL, torch_dtype=torch.float16).to(device).eval()
20
  fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
21
  fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL, torch_dtype=torch.float16).to(device).eval()
22
+ return f"βœ… Loaded {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}"
23
 
 
24
  def build_chat_prompt(history, user_input, reasoning_enabled):
25
  system_flag = "/think" if reasoning_enabled else "/no_think"
26
  prompt = f"<|system|>\n{system_flag}\n"
 
29
  prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
30
  return prompt
31
 
 
32
  def format_thinking(text):
33
  match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
34
  if not match:
 
37
  visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
38
  return escape(visible) + "<br><details><summary>🧠 Show reasoning</summary><pre>" + reasoning + "</pre></details>"
39
 
 
40
  def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9):
41
  model = fallback_model if use_fallback else primary_model
42
  tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
43
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
44
  generated = input_ids
45
  assistant_text = ""
 
46
  for _ in range(max_length):
47
  logits = model(generated).logits[:, -1, :] / temperature
48
  sorted_logits, sorted_indices = torch.sort(logits, descending=True)
49
  probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
 
50
  mask = probs > top_p
51
  mask[..., 1:] = mask[..., :-1].clone()
52
  mask[..., 0] = 0
53
  filtered = logits.clone()
54
  filtered[:, sorted_indices[mask]] = -float("Inf")
 
55
  next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
56
  generated = torch.cat([generated, next_token], dim=-1)
57
  new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
58
  assistant_text += new_text
 
59
  if assistant_text.startswith("<|assistant|>"):
60
  assistant_text = assistant_text[len("<|assistant|>"):]
 
61
  if "<|user|>" in new_text:
62
  break
 
63
  yield assistant_text
 
64
  if next_token.item() == tokenizer.eos_token_id:
65
  break
66
 
 
67
  def respond(message, history, reasoning_enabled, limit_json):
68
  info = json.loads(limit_json) if limit_json else {"count": 0}
69
  count = info.get("count", 0)
 
70
  use_fallback = count > USAGE_LIMIT
71
  remaining = max(0, USAGE_LIMIT - count)
72
  model_label = "A3" if not use_fallback else "Fallback A1"
73
  prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
74
  history = history + [[message, ""]]
 
75
  yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
 
76
  for chunk in generate_stream(prompt, use_fallback=use_fallback):
77
  formatted = format_thinking(chunk)
78
  history[-1][1] = f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
79
  yield history, history, f"🧠 A3 left: {remaining}", "Generating…"
 
80
  yield history, history, f"🧠 A3 left: {remaining}", "Send"
81
 
82
  def clear_chat():
83
  return [], [], "🧠 A3 left: 5", "Send"
84
 
 
85
  with gr.Blocks() as demo:
86
+ gr.HTML(f"""
87
+ <script>
88
+ function updateUsageLimit() {{
89
+ let key = "samai_limit";
90
+ let now = Date.now();
91
+ let record = JSON.parse(localStorage.getItem(key) || "null");
92
+ if (!record || (now - record.lastSeen) > {RESET_MS}) {{
93
+ record = {{count: 0, lastSeen: now}};
94
+ }}
95
+ record.count += 1;
96
+ record.lastSeen = now;
97
+ localStorage.setItem(key, JSON.stringify(record));
98
+ document.getElementById("limit_json").value = JSON.stringify(record);
99
+ }}
100
+
101
+ function setGeneratingText() {{
102
+ document.getElementById("send_btn").innerText = "Generating…";
103
+ }}
104
+ function setIdleText() {{
105
+ document.getElementById("send_btn").innerText = "Send";
106
+ }}
107
+ </script>
108
+ <style>
109
+ .send-circle {{
110
+ border-radius: 50%;
111
+ height: 40px;
112
+ width: 40px;
113
+ padding: 0;
114
+ font-size: 12px;
115
+ text-align: center;
116
+ }}
117
+ </style>
118
+ """)
119
+
120
+ gr.Markdown("# πŸ€– SamAI – Chat Reasoning (Gradio v3 Compatible)")
121
+ limit_json = gr.Textbox(visible=False, elem_id="limit_json")
122
+ model_status = gr.Textbox(interactive=False, label="Model Status")
123
+ usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False)
124
+
125
+ chat_box = gr.Chatbot(type="tuples")
126
+ chat_state = gr.State([])
127
 
128
  with gr.Row():
129
+ user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6)
130
+ reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
131
+ send_btn = gr.Button("Send", elem_id="send_btn", elem_classes=["send-circle"], scale=1)
132
+
133
+ update_btn = gr.Button(visible=False)
134
 
135
  clear_btn = gr.Button("Clear")
136
 
137
  model_status.value = load_models()
138
 
139
+ update_btn.click(None, _js="updateUsageLimit")
140
+
141
+ send_btn.click(None, _js="setGeneratingText").then(
 
 
142
  fn=respond,
143
  inputs=[user_input, chat_state, reason_toggle, limit_json],
144
  outputs=[chat_box, chat_state, usage_counter, send_btn]
145
+ ).then(fn=None, _js="setIdleText")
146
 
147
  clear_btn.click(fn=clear_chat,
148
  inputs=[],