Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Model definitions | |
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A1" | |
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-S2.1" | |
USAGE_LIMIT = 10 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Globals for models and tokenizers | |
primary_model, primary_tokenizer = None, None | |
fallback_model, fallback_tokenizer = None, None | |
# IP-based usage tracking | |
usage_counts = {} | |
def load_models(): | |
global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer | |
primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL) | |
primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL).to(device).eval() | |
fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL) | |
fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL).to(device).eval() | |
return f"Models loaded: {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}" | |
def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.7, top_p=0.9): | |
model = fallback_model if use_fallback else primary_model | |
tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) | |
generated = input_ids | |
output_text = tokenizer.decode(input_ids[0]) | |
for _ in range(max_length): | |
outputs = model(generated) | |
logits = outputs.logits[:, -1, :] / temperature | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) | |
mask = probs > top_p | |
mask[..., 1:] = mask[..., :-1].clone() | |
mask[..., 0] = 0 | |
filtered = logits.clone() | |
filtered[:, sorted_indices[mask]] = -float("Inf") | |
next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1) | |
generated = torch.cat([generated, next_token], dim=-1) | |
new_text = tokenizer.decode(next_token[0]) | |
output_text += new_text | |
yield output_text | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
def respond(msg, history, reasoning_enabled, request: gr.Request): | |
ip = request.client.host if request else "unknown" | |
usage_counts[ip] = usage_counts.get(ip, 0) + 1 | |
use_fallback = usage_counts[ip] > USAGE_LIMIT | |
model_used = "A1" if not use_fallback else "Fallback S2.1" | |
prefix = "/think " if reasoning_enabled else "/no_think " | |
prompt = prefix + msg.strip() | |
history = history + [[msg, ""]] | |
for output in generate_stream(prompt, use_fallback): | |
history[-1][1] = output + f" ({model_used})" | |
yield history, history | |
def clear_chat(): | |
return [], [] | |
with gr.Blocks() as demo: | |
gr.Markdown("# π€ SmilyAI Reasoning Chat β’ Token-by-Token + IP Usage Limits") | |
model_status = gr.Textbox(label="Model Load Status", interactive=False) | |
chat_box = gr.Chatbot(label="Chat", type="tuples") | |
chat_state = gr.State([]) | |
with gr.Row(): | |
user_input = gr.Textbox(placeholder="Your message here...", show_label=False, scale=6) | |
reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1) | |
send_btn = gr.Button("Send", scale=1) | |
clear_btn = gr.Button("Clear Chat") | |
model_status.value = load_models() | |
send_btn.click( | |
respond, | |
inputs=[user_input, chat_state, reason_toggle], | |
outputs=[chat_box, chat_state] | |
) | |
clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state]) | |
demo.queue() | |
demo.launch() | |