Sam-chat-full / app.py
Boning c
Update app.py
f2a079c verified
raw
history blame
3.55 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Model definitions
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A1"
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-S2.1"
USAGE_LIMIT = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
# Globals for models and tokenizers
primary_model, primary_tokenizer = None, None
fallback_model, fallback_tokenizer = None, None
# IP-based usage tracking
usage_counts = {}
def load_models():
global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL)
primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL).to(device).eval()
fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL).to(device).eval()
return f"Models loaded: {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}"
def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.7, top_p=0.9):
model = fallback_model if use_fallback else primary_model
tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generated = input_ids
output_text = tokenizer.decode(input_ids[0])
for _ in range(max_length):
outputs = model(generated)
logits = outputs.logits[:, -1, :] / temperature
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
mask = probs > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = 0
filtered = logits.clone()
filtered[:, sorted_indices[mask]] = -float("Inf")
next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
generated = torch.cat([generated, next_token], dim=-1)
new_text = tokenizer.decode(next_token[0])
output_text += new_text
yield output_text
if next_token.item() == tokenizer.eos_token_id:
break
def respond(msg, history, reasoning_enabled, request: gr.Request):
ip = request.client.host if request else "unknown"
usage_counts[ip] = usage_counts.get(ip, 0) + 1
use_fallback = usage_counts[ip] > USAGE_LIMIT
model_used = "A1" if not use_fallback else "Fallback S2.1"
prefix = "/think " if reasoning_enabled else "/no_think "
prompt = prefix + msg.strip()
history = history + [[msg, ""]]
for output in generate_stream(prompt, use_fallback):
history[-1][1] = output + f" ({model_used})"
yield history, history
def clear_chat():
return [], []
with gr.Blocks() as demo:
gr.Markdown("# πŸ€– SmilyAI Reasoning Chat β€’ Token-by-Token + IP Usage Limits")
model_status = gr.Textbox(label="Model Load Status", interactive=False)
chat_box = gr.Chatbot(label="Chat", type="tuples")
chat_state = gr.State([])
with gr.Row():
user_input = gr.Textbox(placeholder="Your message here...", show_label=False, scale=6)
reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
send_btn = gr.Button("Send", scale=1)
clear_btn = gr.Button("Clear Chat")
model_status.value = load_models()
send_btn.click(
respond,
inputs=[user_input, chat_state, reason_toggle],
outputs=[chat_box, chat_state]
)
clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state])
demo.queue()
demo.launch()