Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import re, json | |
from html import escape | |
# βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3" | |
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1" | |
USAGE_LIMIT = 5 | |
RESET_MS = 20 * 60 * 1000 # 20 minutes in milliseconds | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
primary_model = primary_tokenizer = None | |
fallback_model = fallback_tokenizer = None | |
# βββ Model Loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def load_models(): | |
global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer | |
primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True) | |
primary_model = AutoModelForCausalLM.from_pretrained( | |
PRIMARY_MODEL, torch_dtype=torch.float16 | |
).to(device).eval() | |
fallback_tokenizer= AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True) | |
fallback_model = AutoModelForCausalLM.from_pretrained( | |
FALLBACK_MODEL, torch_dtype=torch.float16 | |
).to(device).eval() | |
return f"β Loaded {PRIMARY_MODEL} (fallback: {FALLBACK_MODEL})" | |
# βββ Build Chat Prompt ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def build_chat_prompt(history, user_input, reasoning_enabled): | |
system_flag = "/think" if reasoning_enabled else "/no_think" | |
prompt = f"<|system|>\n{system_flag}\n" | |
for u, a in history: | |
prompt += f"<|user|>\n{u}\n<|assistant|>\n{a}\n" | |
prompt += f"<|user|>\n{user_input}\n<|assistant|>\n" | |
return prompt | |
# βββ Collapse <think> Blocks ββββββββββββββββββββββββββββββββββββββββββββββββ | |
def format_thinking(text): | |
match = re.search(r"<think>(.*?)</think>", text, re.DOTALL) | |
if not match: | |
return escape(text) | |
reasoning = escape(match.group(1).strip()) | |
visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip() | |
return ( | |
escape(visible) | |
+ "<br><details><summary>π§ Show reasoning</summary>" | |
+ f"<pre>{reasoning}</pre></details>" | |
) | |
# βββ Token-by-Token Streaming (Stops on <|user|>) βββββββββββββββββββββββββ | |
def generate_stream(prompt, use_fallback=False, | |
max_length=100, temperature=0.2, top_p=0.9): | |
model = fallback_model if use_fallback else primary_model | |
tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
generated = input_ids | |
assistant_text = "" | |
for _ in range(max_length): | |
# 1) Get next-token logits and apply top-p | |
logits = model(generated).logits[:, -1, :] / temperature | |
sorted_logits, idxs = torch.sort(logits, descending=True) | |
probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) | |
mask = probs > top_p | |
mask[..., 1:] = mask[..., :-1].clone() | |
mask[..., 0] = 0 | |
filtered = logits.clone() | |
filtered[:, idxs[mask]] = -float("Inf") | |
# 2) Sample and append | |
next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1) | |
generated = torch.cat([generated, next_token], dim=-1) | |
new_text = tokenizer.decode(next_token[0], skip_special_tokens=False) | |
assistant_text += new_text | |
# 3) Remove any leading assistant tag | |
if assistant_text.startswith("<|assistant|>"): | |
assistant_text = assistant_text[len("<|assistant|>"):] | |
# 4) If we see a userβturn tag, truncate and bail | |
if "<|user|>" in assistant_text: | |
assistant_text = assistant_text.split("<|user|>")[0] | |
yield assistant_text | |
break | |
# 5) Otherwise stream clean assistant text | |
yield assistant_text | |
# 6) End if EOS | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
# βββ Main Chat Handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def respond(message, history, reasoning_enabled, limit_json): | |
# parse client-side usage info | |
info = json.loads(limit_json) if limit_json else {"count": 0} | |
count = info.get("count", 0) | |
use_fallback = count > USAGE_LIMIT | |
remaining = max(0, USAGE_LIMIT - count) | |
model_label = "A3" if not use_fallback else "Fallback A1" | |
# initial yield to set "Generatingβ¦" | |
prompt = build_chat_prompt(history, message.strip(), reasoning_enabled) | |
history = history + [[message, ""]] | |
yield history, history, f"π§ A3 left: {remaining}", "Generatingβ¦" | |
# stream assistant reply | |
for chunk in generate_stream(prompt, use_fallback): | |
formatted = format_thinking(chunk) | |
history[-1][1] = ( | |
f"{formatted}<br><sub style='color:gray'>({model_label})</sub>" | |
) | |
yield history, history, f"π§ A3 left: {remaining}", "Generatingβ¦" | |
# final yield resets button text | |
yield history, history, f"π§ A3 left: {remaining}", "Send" | |
# βββ Clear Chat βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def clear_chat(): | |
return [], [], "π§ A3 left: 5", "Send" | |
# βββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks() as demo: | |
# Inject client-side JS + CSS | |
gr.HTML(f""" | |
<script> | |
function updateUsageLimit() {{ | |
const key = "samai_limit"; | |
const now = Date.now(); | |
let rec = JSON.parse(localStorage.getItem(key) || "null"); | |
if (!rec || now - rec.lastSeen > {RESET_MS}) {{ | |
rec = {{count:0, lastSeen: now}}; | |
}} | |
rec.count += 1; | |
rec.lastSeen = now; | |
localStorage.setItem(key, JSON.stringify(rec)); | |
document.getElementById("limit_json").value = JSON.stringify(rec); | |
}} | |
document.addEventListener("DOMContentLoaded", () => {{ | |
const btn = document.getElementById("send_btn"); | |
btn.addEventListener("click", () => {{ | |
updateUsageLimit(); | |
btn.innerText = "Generatingβ¦"; | |
}}); | |
}}); | |
</script> | |
<style> | |
.send-circle {{ | |
border-radius: 50%; | |
height: 40px; | |
width: 40px; | |
padding: 0; | |
font-size: 12px; | |
text-align: center; | |
}} | |
</style> | |
""") | |
gr.Markdown("# π€ SamAI β Chat Reasoning (Final)") | |
# Hidden textbox ferrying usage JSON from JS β Python | |
limit_json = gr.Textbox(visible=False, elem_id="limit_json") | |
model_status = gr.Textbox(interactive=False, label="Model Status") | |
usage_counter = gr.Textbox("π§ A3 left: 5", interactive=False, show_label=False) | |
chat_box = gr.Chatbot(type="tuples") | |
chat_state = gr.State([]) | |
with gr.Row(): | |
user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6) | |
reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1) | |
send_btn = gr.Button("Send", elem_id="send_btn", elem_classes=["send-circle"], scale=1) | |
clear_btn = gr.Button("Clear") | |
model_status.value = load_models() | |
# Bind Send button -> respond() | |
send_btn.click( | |
fn=respond, | |
inputs=[user_input, chat_state, reason_toggle, limit_json], | |
outputs=[chat_box, chat_state, usage_counter, send_btn] | |
) | |
clear_btn.click( | |
fn=clear_chat, | |
inputs=[], | |
outputs=[chat_box, chat_state, usage_counter, send_btn] | |
) | |
demo.queue() | |
demo.launch() | |