Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import re, time, json | |
from html import escape | |
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3" | |
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1" | |
USAGE_LIMIT = 5 | |
RESET_MS = 20 * 60 * 1000 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
primary_model = primary_tokenizer = None | |
fallback_model = fallback_tokenizer = None | |
def load_models(): | |
global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer | |
primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True) | |
primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL, torch_dtype=torch.float16).to(device).eval() | |
fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True) | |
fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL, torch_dtype=torch.float16).to(device).eval() | |
return f"✅ Loaded {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}" | |
def build_chat_prompt(history, user_input, reasoning_enabled): | |
system_flag = "/think" if reasoning_enabled else "/no_think" | |
prompt = f"<|system|>\n{system_flag}\n" | |
for u, a in history: | |
prompt += f"<|user|>\n{u}\n<|assistant|>\n{a}\n" | |
prompt += f"<|user|>\n{user_input}\n<|assistant|>\n" | |
return prompt | |
def format_thinking(text): | |
match = re.search(r"<think>(.*?)</think>", text, re.DOTALL) | |
if not match: | |
return escape(text) | |
reasoning = escape(match.group(1).strip()) | |
visible = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip() | |
return escape(visible) + "<br><details><summary>🧠 Show reasoning</summary><pre>" + reasoning + "</pre></details>" | |
def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.2, top_p=0.9): | |
model = fallback_model if use_fallback else primary_model | |
tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
generated = input_ids | |
assistant_text = "" | |
for _ in range(max_length): | |
logits = model(generated).logits[:, -1, :] / temperature | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) | |
mask = probs > top_p | |
mask[..., 1:] = mask[..., :-1].clone() | |
mask[..., 0] = 0 | |
filtered = logits.clone() | |
filtered[:, sorted_indices[mask]] = -float("Inf") | |
next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1) | |
generated = torch.cat([generated, next_token], dim=-1) | |
new_text = tokenizer.decode(next_token[0], skip_special_tokens=False) | |
assistant_text += new_text | |
if assistant_text.startswith("<|assistant|>"): | |
assistant_text = assistant_text[len("<|assistant|>"):] | |
if "<|user|>" in new_text: | |
break | |
yield assistant_text | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
def respond(message, history, reasoning_enabled, limit_json): | |
info = json.loads(limit_json) if limit_json else {"count": 0} | |
count = info.get("count", 0) | |
use_fallback = count > USAGE_LIMIT | |
remaining = max(0, USAGE_LIMIT - count) | |
model_label = "A3" if not use_fallback else "Fallback A1" | |
prompt = build_chat_prompt(history, message.strip(), reasoning_enabled) | |
history = history + [[message, ""]] | |
yield history, history, f"🧠 A3 left: {remaining}", "Generating…" | |
for chunk in generate_stream(prompt, use_fallback=use_fallback): | |
formatted = format_thinking(chunk) | |
history[-1][1] = f"{formatted}<br><sub style='color:gray'>({model_label})</sub>" | |
yield history, history, f"🧠 A3 left: {remaining}", "Generating…" | |
yield history, history, f"🧠 A3 left: {remaining}", "Send" | |
def clear_chat(): | |
return [], [], "🧠 A3 left: 5", "Send" | |
with gr.Blocks() as demo: | |
gr.HTML(f""" | |
<script> | |
function updateUsageLimit() {{ | |
let key = "samai_limit"; | |
let now = Date.now(); | |
let record = JSON.parse(localStorage.getItem(key) || "null"); | |
if (!record || (now - record.lastSeen) > {RESET_MS}) {{ | |
record = {{count: 0, lastSeen: now}}; | |
}} | |
record.count += 1; | |
record.lastSeen = now; | |
localStorage.setItem(key, JSON.stringify(record)); | |
document.getElementById("limit_json").value = JSON.stringify(record); | |
}} | |
function setGeneratingText() {{ | |
document.getElementById("send_btn").innerText = "Generating…"; | |
}} | |
function setIdleText() {{ | |
document.getElementById("send_btn").innerText = "Send"; | |
}} | |
</script> | |
<style> | |
.send-circle {{ | |
border-radius: 50%; | |
height: 40px; | |
width: 40px; | |
padding: 0; | |
font-size: 12px; | |
text-align: center; | |
}} | |
</style> | |
""") | |
gr.Markdown("# 🤖 SamAI – Chat Reasoning (Gradio v3 Compatible)") | |
limit_json = gr.Textbox(visible=False, elem_id="limit_json") | |
model_status = gr.Textbox(interactive=False, label="Model Status") | |
usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False) | |
chat_box = gr.Chatbot(type="tuples") | |
chat_state = gr.State([]) | |
with gr.Row(): | |
user_input = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6) | |
reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1) | |
send_btn = gr.Button("Send", elem_id="send_btn", elem_classes=["send-circle"], scale=1) | |
update_btn = gr.Button(visible=False) | |
clear_btn = gr.Button("Clear") | |
model_status.value = load_models() | |
update_btn.click(None, _js="updateUsageLimit") | |
send_btn.click(None, _js="setGeneratingText").then( | |
fn=respond, | |
inputs=[user_input, chat_state, reason_toggle, limit_json], | |
outputs=[chat_box, chat_state, usage_counter, send_btn] | |
).then(fn=None, _js="setIdleText") | |
clear_btn.click(fn=clear_chat, | |
inputs=[], | |
outputs=[chat_box, chat_state, usage_counter, send_btn] | |
) | |
demo.queue() | |
demo.launch() | |