File size: 3,549 Bytes
a40a845
 
f738fa6
a40a845
de2cfc0
206f796
 
de2cfc0
a40a845
f8686bf
6d34d27
de2cfc0
 
 
206f796
de2cfc0
206f796
 
 
 
 
 
 
 
de2cfc0
206f796
 
 
 
3e1d6c1
de2cfc0
f738fa6
6d34d27
f738fa6
de2cfc0
206f796
 
de2cfc0
 
 
 
 
 
 
 
 
 
3e1d6c1
f738fa6
 
6d34d27
de2cfc0
206f796
 
 
de2cfc0
 
 
 
 
 
 
2c4e10a
206f796
 
f738fa6
206f796
de2cfc0
206f796
de2cfc0
f2a079c
de2cfc0
206f796
 
de2cfc0
206f796
 
f738fa6
206f796
f738fa6
206f796
a40a845
206f796
 
de2cfc0
 
206f796
a40a845
de2cfc0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Model definitions
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A1"
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-S2.1"
USAGE_LIMIT = 10

device = "cuda" if torch.cuda.is_available() else "cpu"

# Globals for models and tokenizers
primary_model, primary_tokenizer = None, None
fallback_model, fallback_tokenizer = None, None

# IP-based usage tracking
usage_counts = {}

def load_models():
    global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
    primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL)
    primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL).to(device).eval()
    fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
    fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL).to(device).eval()
    return f"Models loaded: {PRIMARY_MODEL} + fallback {FALLBACK_MODEL}"

def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.7, top_p=0.9):
    model = fallback_model if use_fallback else primary_model
    tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    generated = input_ids
    output_text = tokenizer.decode(input_ids[0])

    for _ in range(max_length):
        outputs = model(generated)
        logits = outputs.logits[:, -1, :] / temperature
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
        mask = probs > top_p
        mask[..., 1:] = mask[..., :-1].clone()
        mask[..., 0] = 0
        filtered = logits.clone()
        filtered[:, sorted_indices[mask]] = -float("Inf")
        next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
        generated = torch.cat([generated, next_token], dim=-1)
        new_text = tokenizer.decode(next_token[0])
        output_text += new_text
        yield output_text
        if next_token.item() == tokenizer.eos_token_id:
            break

def respond(msg, history, reasoning_enabled, request: gr.Request):
    ip = request.client.host if request else "unknown"
    usage_counts[ip] = usage_counts.get(ip, 0) + 1
    use_fallback = usage_counts[ip] > USAGE_LIMIT
    model_used = "A1" if not use_fallback else "Fallback S2.1"
    prefix = "/think " if reasoning_enabled else "/no_think "
    prompt = prefix + msg.strip()
    history = history + [[msg, ""]]
    for output in generate_stream(prompt, use_fallback):
        history[-1][1] = output + f" ({model_used})"
        yield history, history

def clear_chat():
    return [], []

with gr.Blocks() as demo:
    gr.Markdown("# 🤖 SmilyAI Reasoning Chat • Token-by-Token + IP Usage Limits")

    model_status = gr.Textbox(label="Model Load Status", interactive=False)
    chat_box = gr.Chatbot(label="Chat", type="tuples")
    chat_state = gr.State([])

    with gr.Row():
        user_input = gr.Textbox(placeholder="Your message here...", show_label=False, scale=6)
        reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
        send_btn = gr.Button("Send", scale=1)

    clear_btn = gr.Button("Clear Chat")

    model_status.value = load_models()

    send_btn.click(
        respond,
        inputs=[user_input, chat_state, reason_toggle],
        outputs=[chat_box, chat_state]
    )

    clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_state])

demo.queue()
demo.launch()