File size: 8,210 Bytes
a40a845
 
f738fa6
9fe966b
c75bb42
a40a845
4d45982
7a55e85
be7cce5
9fe966b
 
 
6d34d27
560050f
 
206f796
4d45982
206f796
 
abccee4
9fe966b
 
 
 
 
 
 
4d45982
9fe966b
4d45982
abccee4
d5529e8
 
 
 
d84b5b3
 
abccee4
4d45982
abccee4
 
d5529e8
 
 
4d45982
9fe966b
 
 
 
 
d5529e8
4d45982
9fe966b
 
 
206f796
3e1d6c1
de2cfc0
23339f6
9fe966b
f738fa6
4d45982
560050f
9fe966b
de2cfc0
 
abccee4
9fe966b
de2cfc0
9fe966b
4d45982
 
de2cfc0
 
d5529e8
23339f6
9fe966b
4d45982
23339f6
 
9fe966b
4d45982
5283875
 
 
d84b5b3
9fe966b
4d45982
23339f6
9fe966b
4d45982
f738fa6
 
6d34d27
4d45982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206f796
7a55e85
f738fa6
4d45982
206f796
9fe966b
b612b32
9fe966b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d45982
9fe966b
 
 
4d45982
9fe966b
 
b612b32
 
9fe966b
b612b32
206f796
 
9fe966b
b612b32
9fe966b
f738fa6
7a55e85
d5529e8
206f796
a40a845
4d45982
9fe966b
7a55e85
 
 
9fe966b
4d45982
9fe966b
 
4d45982
9fe966b
d5529e8
560050f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re, json
from html import escape

# ─── Configuration ─────────────────────────────────────────────────────────
PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A3"
FALLBACK_MODEL = "Smilyai-labs/Sam-reason-A1"
USAGE_LIMIT   = 5
RESET_MS      = 20 * 60 * 1000   # 20 minutes in milliseconds
device        = "cuda" if torch.cuda.is_available() else "cpu"

primary_model = primary_tokenizer = None
fallback_model = fallback_tokenizer = None

# ─── Model Loading ─────────────────────────────────────────────────────────
def load_models():
    global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
    primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL, trust_remote_code=True)
    primary_model     = AutoModelForCausalLM.from_pretrained(
        PRIMARY_MODEL, torch_dtype=torch.float16
    ).to(device).eval()
    fallback_tokenizer= AutoTokenizer.from_pretrained(FALLBACK_MODEL, trust_remote_code=True)
    fallback_model    = AutoModelForCausalLM.from_pretrained(
        FALLBACK_MODEL, torch_dtype=torch.float16
    ).to(device).eval()
    return f"βœ… Loaded {PRIMARY_MODEL} (fallback: {FALLBACK_MODEL})"

# ─── Build Chat Prompt ──────────────────────────────────────────────────────
def build_chat_prompt(history, user_input, reasoning_enabled):
    system_flag = "/think" if reasoning_enabled else "/no_think"
    prompt = f"<|system|>\n{system_flag}\n"
    for u, a in history:
        prompt += f"<|user|>\n{u}\n<|assistant|>\n{a}\n"
    prompt += f"<|user|>\n{user_input}\n<|assistant|>\n"
    return prompt

# ─── Collapse <think> Blocks ────────────────────────────────────────────────
def format_thinking(text):
    match = re.search(r"<think>(.*?)</think>", text, re.DOTALL)
    if not match:
        return escape(text)
    reasoning = escape(match.group(1).strip())
    visible   = re.sub(r"<think>.*?</think>", "[thinking...]", text, flags=re.DOTALL).strip()
    return (
        escape(visible)
        + "<br><details><summary>🧠 Show reasoning</summary>"
        + f"<pre>{reasoning}</pre></details>"
    )

# ─── Token-by-Token Streaming (Stops on <|user|>) ─────────────────────────
def generate_stream(prompt, use_fallback=False,
                    max_length=100, temperature=0.2, top_p=0.9):
    model     = fallback_model if use_fallback else primary_model
    tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    generated = input_ids
    assistant_text = ""

    for _ in range(max_length):
        # 1) Get next-token logits and apply top-p
        logits = model(generated).logits[:, -1, :] / temperature
        sorted_logits, idxs = torch.sort(logits, descending=True)
        probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
        mask = probs > top_p
        mask[..., 1:] = mask[..., :-1].clone()
        mask[..., 0]  = 0
        filtered = logits.clone()
        filtered[:, idxs[mask]] = -float("Inf")

        # 2) Sample and append
        next_token = torch.multinomial(torch.softmax(filtered, dim=-1), 1)
        generated = torch.cat([generated, next_token], dim=-1)
        new_text = tokenizer.decode(next_token[0], skip_special_tokens=False)
        assistant_text += new_text

        # 3) Remove any leading assistant tag
        if assistant_text.startswith("<|assistant|>"):
            assistant_text = assistant_text[len("<|assistant|>"):]

        # 4) If we see a user‐turn tag, truncate and bail
        if "<|user|>" in assistant_text:
            assistant_text = assistant_text.split("<|user|>")[0]
            yield assistant_text
            break

        # 5) Otherwise stream clean assistant text
        yield assistant_text

        # 6) End if EOS
        if next_token.item() == tokenizer.eos_token_id:
            break

# ─── Main Chat Handler ──────────────────────────────────────────────────────
def respond(message, history, reasoning_enabled, limit_json):
    # parse client-side usage info
    info  = json.loads(limit_json) if limit_json else {"count": 0}
    count = info.get("count", 0)
    use_fallback = count > USAGE_LIMIT
    remaining    = max(0, USAGE_LIMIT - count)
    model_label  = "A3" if not use_fallback else "Fallback A1"

    # initial yield to set "Generating…"
    prompt = build_chat_prompt(history, message.strip(), reasoning_enabled)
    history = history + [[message, ""]]
    yield history, history, f"🧠 A3 left: {remaining}", "Generating…"

    # stream assistant reply
    for chunk in generate_stream(prompt, use_fallback):
        formatted = format_thinking(chunk)
        history[-1][1] = (
            f"{formatted}<br><sub style='color:gray'>({model_label})</sub>"
        )
        yield history, history, f"🧠 A3 left: {remaining}", "Generating…"

    # final yield resets button text
    yield history, history, f"🧠 A3 left: {remaining}", "Send"

# ─── Clear Chat ─────────────────────────────────────────────────────────────
def clear_chat():
    return [], [], "🧠 A3 left: 5", "Send"

# ─── Gradio UI ──────────────────────────────────────────────────────────────
with gr.Blocks() as demo:
    # Inject client-side JS + CSS
    gr.HTML(f"""
<script>
  function updateUsageLimit() {{
    const key = "samai_limit";
    const now = Date.now();
    let rec = JSON.parse(localStorage.getItem(key) || "null");
    if (!rec || now - rec.lastSeen > {RESET_MS}) {{
      rec = {{count:0, lastSeen: now}};
    }}
    rec.count += 1;
    rec.lastSeen = now;
    localStorage.setItem(key, JSON.stringify(rec));
    document.getElementById("limit_json").value = JSON.stringify(rec);
  }}
  document.addEventListener("DOMContentLoaded", () => {{
    const btn = document.getElementById("send_btn");
    btn.addEventListener("click", () => {{
      updateUsageLimit();
      btn.innerText = "Generating…";
    }});
  }});
</script>
<style>
  .send-circle {{
    border-radius: 50%;
    height: 40px;
    width: 40px;
    padding: 0;
    font-size: 12px;
    text-align: center;
  }}
</style>
    """)

    gr.Markdown("# πŸ€– SamAI – Chat Reasoning (Final)")

    # Hidden textbox ferrying usage JSON from JS β†’ Python
    limit_json    = gr.Textbox(visible=False, elem_id="limit_json")
    model_status  = gr.Textbox(interactive=False, label="Model Status")
    usage_counter = gr.Textbox("🧠 A3 left: 5", interactive=False, show_label=False)

    chat_box   = gr.Chatbot(type="tuples")
    chat_state = gr.State([])

    with gr.Row():
        user_input    = gr.Textbox(placeholder="Ask anything...", show_label=False, scale=6)
        reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
        send_btn      = gr.Button("Send", elem_id="send_btn", elem_classes=["send-circle"], scale=1)

    clear_btn = gr.Button("Clear")

    model_status.value = load_models()

    # Bind Send button -> respond()
    send_btn.click(
        fn=respond,
        inputs=[user_input, chat_state, reason_toggle, limit_json],
        outputs=[chat_box, chat_state, usage_counter, send_btn]
    )

    clear_btn.click(
        fn=clear_chat,
        inputs=[], 
        outputs=[chat_box, chat_state, usage_counter, send_btn]
    )

demo.queue()
demo.launch()