Spaces:
Sleeping
Sleeping
File size: 8,210 Bytes
a40a845 f738fa6 9fe966b c75bb42 a40a845 4d45982 7a55e85 be7cce5 9fe966b 6d34d27 560050f 206f796 4d45982 206f796 abccee4 9fe966b 4d45982 9fe966b 4d45982 abccee4 d5529e8 d84b5b3 abccee4 4d45982 abccee4 d5529e8 4d45982 9fe966b d5529e8 4d45982 9fe966b 206f796 3e1d6c1 de2cfc0 23339f6 9fe966b f738fa6 4d45982 560050f 9fe966b de2cfc0 abccee4 9fe966b de2cfc0 9fe966b 4d45982 de2cfc0 d5529e8 23339f6 9fe966b 4d45982 23339f6 9fe966b 4d45982 5283875 d84b5b3 9fe966b 4d45982 23339f6 9fe966b 4d45982 f738fa6 6d34d27 4d45982 206f796 7a55e85 f738fa6 4d45982 206f796 9fe966b b612b32 9fe966b 4d45982 9fe966b 4d45982 9fe966b b612b32 9fe966b b612b32 206f796 9fe966b b612b32 9fe966b f738fa6 7a55e85 d5529e8 206f796 a40a845 4d45982 9fe966b 7a55e85 9fe966b 4d45982 9fe966b 4d45982 9fe966b d5529e8 560050f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re, json
from html import escape
# βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
USAGE_LIMIT = 5
RESET_MS = 20 * 60 * 1000 # 20 minutes in milliseconds
device = "cuda" if torch.cuda.is_available() else "cpu"
primary_model = primary_tokenizer = None
fallback_model = fallback_tokenizer = None
# βββ Model Loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_models():
global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
primary_model = AutoModelForCausalLM.from_pretrained(
PRIMARY_MODEL, torch_dtype=torch.float16
).to(device).eval()
fallback_tokenizer= AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
fallback_model = AutoModelForCausalLM.from_pretrained(
FALLBACK_MODEL, torch_dtype=torch.float16
).to(device).eval()
return f"β
Loaded {PRIMARY_MODEL} (fallback: {FALLBACK_MODEL})"
# βββ Build Chat Prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def build_chat_prompt(history, user_input, reasoning_enabled):
system_flag = "/think" if reasoning_enabled else "/no_think"
prompt = f"<|system|>\n{system_flag}\n"
for u, a in history:
prompt += f"<|user|>\n{u}\n<|assistant|>\n{a}\n"
prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
return prompt
# βββ Collapse <think> Blocks ββββββββββββββββββββββββββββββββββββββββββββββββ
def format_thinking(text):
match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
if not match:
return escape(text)
reasoning = escape(match.group(1).strip())
visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
return (
escape(visible)
+ "<br><details><summary>π§ Show reasoning</summary>"
+ f"<pre>{reasoning}</pre></details>"
)
# βββ Token-by-Token Streaming (Stops on <|user|>) βββββββββββββββββββββββββ
def generate_stream(prompt, use_fallback=False,
max_length=100, temperature=0.2, top_p=0.9):
model = fallback_model if use_fallback else primary_model
tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generated = input_ids
assistant_text = ""
for _ in range(max_length):
# 1) Get next-token logits and apply top-p
logits = model(generated).logits[:, -1, :] / temperature
sorted_logits, idxs = torch.sort(logits, descending=True)
probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
mask = probs > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = 0
filtered = logits.clone()
filtered[:, idxs[mask]] = -float("Inf")
# 2) Sample and append
next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
generated = torch.cat([generated, next_token], dim=-1)
new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
assistant_text += new_text
# 3) Remove any leading assistant tag
if assistant_text.startswith("<|assistant|>"):
assistant_text = assistant_text[len("<|assistant|>"):]
# 4) If we see a userβturn tag, truncate and bail
if "<|user|>" in assistant_text:
assistant_text = assistant_text.split("<|user|>")[0]
yield assistant_text
break
# 5) Otherwise stream clean assistant text
yield assistant_text
# 6) End if EOS
if next_token.item() == tokenizer.eos_token_id:
break
# βββ Main Chat Handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def respond(message, history, reasoning_enabled, limit_json):
# parse client-side usage info
info = json.loads(limit_json) if limit_json else {"count": 0}
count = info.get("count", 0)
use_fallback = count > USAGE_LIMIT
remaining = max(0, USAGE_LIMIT - count)
model_label = "A3" if not use_fallback else "Fallback A1"
# initial yield to set "Generatingβ¦"
prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
history = history + [[message, ""]]
yield history, history, f"π§ A3 left: {remaining}", "Generatingβ¦"
# stream assistant reply
for chunk in generate_stream(prompt, use_fallback):
formatted = format_thinking(chunk)
history[-1][1] = (
f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
)
yield history, history, f"π§ A3 left: {remaining}", "Generatingβ¦"
# final yield resets button text
yield history, history, f"π§ A3 left: {remaining}", "Send"
# βββ Clear Chat βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def clear_chat():
return [], [], "π§ A3 left: 5", "Send"
# βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with gr.Blocks() as demo:
# Inject client-side JS + CSS
gr.HTML(f"""
<script>
function updateUsageLimit() {{
const key = "samai_limit";
const now = Date.now();
let rec = JSON.parse(localStorage.getItem(key) || "null");
if (!rec || now - rec.lastSeen > {RESET_MS}) {{
rec = {{count:0, lastSeen: now}};
}}
rec.count += 1;
rec.lastSeen = now;
localStorage.setItem(key, JSON.stringify(rec));
document.getElementById("limit_json").value = JSON.stringify(rec);
}}
document.addEventListener("DOMContentLoaded", () => {{
const btn = document.getElementById("send_btn");
btn.addEventListener("click", () => {{
updateUsageLimit();
btn.innerText = "Generatingβ¦";
}});
}});
</script>
<style>
.send-circle {{
border-radius: 50%;
height: 40px;
width: 40px;
padding: 0;
font-size: 12px;
text-align: center;
}}
</style>
""")
gr.Markdown("# π€ SamAI β Chat Reasoning (Final)")
# Hidden textbox ferrying usage JSON from JS β Python
limit_json = gr.Textbox(visible=False, elem_id="limit_json")
model_status = gr.Textbox(interactive=False, label="Model Status")
usage_counter = gr.Textbox("π§ A3 left: 5", interactive=False, show_label=False)
chat_box = gr.Chatbot(type="tuples")
chat_state = gr.State([])
with gr.Row():
user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6)
reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
send_btn = gr.Button("Send", elem_id="send_btn", elem_classes=["send-circle"], scale=1)
clear_btn = gr.Button("Clear")
model_status.value = load_models()
# Bind Send button -> respond()
send_btn.click(
fn=respond,
inputs=[user_input, chat_state, reason_toggle, limit_json],
outputs=[chat_box, chat_state, usage_counter, send_btn]
)
clear_btn.click(
fn=clear_chat,
inputs=[],
outputs=[chat_box, chat_state, usage_counter, send_btn]
)
demo.queue()
demo.launch()
|