Spaces:
Sleeping
Sleeping
File size: 3,549 Bytes
a40a845 f738fa6 a40a845 de2cfc0 206f796 de2cfc0 a40a845 f8686bf 6d34d27 de2cfc0 206f796 de2cfc0 206f796 de2cfc0 206f796 3e1d6c1 de2cfc0 f738fa6 6d34d27 f738fa6 de2cfc0 206f796 de2cfc0 3e1d6c1 f738fa6 6d34d27 de2cfc0 206f796 de2cfc0 2c4e10a 206f796 f738fa6 206f796 de2cfc0 206f796 de2cfc0 f2a079c de2cfc0 206f796 de2cfc0 206f796 f738fa6 206f796 f738fa6 206f796 a40a845 206f796 de2cfc0 206f796 a40a845 de2cfc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Model definitions
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A1"
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-S2.1"
USAGE_LIMIT = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
# Globals for models and tokenizers
primary_model, primary_tokenizer = None, None
fallback_model, fallback_tokenizer = None, None
# IP-based usage tracking
usage_counts = {}
def load_models():
global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL)
primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL).to(device).eval()
fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL).to(device).eval()
return f"Models loaded: {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}"
def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.7, top_p=0.9):
model = fallback_model if use_fallback else primary_model
tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
generated = input_ids
output_text = tokenizer.decode(input_ids[0])
for _ in range(max_length):
outputs = model(generated)
logits = outputs.logits[:, -1, :] / temperature
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
mask = probs > top_p
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = 0
filtered = logits.clone()
filtered[:, sorted_indices[mask]] = -float("Inf")
next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
generated = torch.cat([generated, next_token], dim=-1)
new_text = tokenizer.decode(next_token[0])
output_text += new_text
yield output_text
if next_token.item() == tokenizer.eos_token_id:
break
def respond(msg, history, reasoning_enabled, request: gr.Request):
ip = request.client.host if request else "unknown"
usage_counts[ip] = usage_counts.get(ip, 0) + 1
use_fallback = usage_counts[ip] > USAGE_LIMIT
model_used = "A1" if not use_fallback else "Fallback S2.1"
prefix = "/think " if reasoning_enabled else "/no_think "
prompt = prefix + msg.strip()
history = history + [[msg, ""]]
for output in generate_stream(prompt, use_fallback):
history[-1][1] = output + f" ({model_used})"
yield history, history
def clear_chat():
return [], []
with gr.Blocks() as demo:
gr.Markdown("# 🤖 SmilyAI Reasoning Chat • Token-by-Token + IP Usage Limits")
model_status = gr.Textbox(label="Model Load Status", interactive=False)
chat_box = gr.Chatbot(label="Chat", type="tuples")
chat_state = gr.State([])
with gr.Row():
user_input = gr.Textbox(placeholder="Your message here...", show_label=False, scale=6)
reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
send_btn = gr.Button("Send", scale=1)
clear_btn = gr.Button("Clear Chat")
model_status.value = load_models()
send_btn.click(
respond,
inputs=[user_input, chat_state, reason_toggle],
outputs=[chat_box, chat_state]
)
clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state])
demo.queue()
demo.launch()
|