Spaces:
Sleeping
Sleeping
Boning c
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
4 |
|
5 |
-
# List of available SmilyAI Sam models (adjust as needed)
|
6 |
MODELS = [
|
7 |
"Smilyai-labs/Sam-reason-S1",
|
8 |
"Smilyai-labs/Sam-reason-S1.5",
|
@@ -16,7 +16,6 @@ MODELS = [
|
|
16 |
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
|
19 |
-
# Global vars to hold model and tokenizer
|
20 |
model = None
|
21 |
tokenizer = None
|
22 |
|
@@ -27,80 +26,104 @@ def load_model(model_name):
|
|
27 |
model.eval()
|
28 |
return f"Loaded model: {model_name}"
|
29 |
|
30 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
global model, tokenizer
|
32 |
if model is None or tokenizer is None:
|
33 |
yield "Model not loaded. Please select a model first."
|
34 |
return
|
35 |
|
|
|
36 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
37 |
-
|
38 |
generated_ids = input_ids
|
39 |
-
output_text =
|
40 |
|
41 |
-
# Generate tokens one by one
|
42 |
for _ in range(max_length):
|
43 |
outputs = model(generated_ids)
|
44 |
logits = outputs.logits
|
45 |
-
|
46 |
-
# Get logits for last token
|
47 |
next_token_logits = logits[:, -1, :] / temperature
|
48 |
|
49 |
-
# Apply top_p filtering for nucleus sampling
|
50 |
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
51 |
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
|
52 |
|
53 |
-
# Remove tokens with cumulative prob above top_p
|
54 |
sorted_indices_to_remove = cumulative_probs > top_p
|
55 |
-
# Shift mask right to keep at least one token
|
56 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
57 |
sorted_indices_to_remove[..., 0] = 0
|
58 |
|
59 |
filtered_logits = next_token_logits.clone()
|
60 |
filtered_logits[:, sorted_indices[sorted_indices_to_remove]] = -float('Inf')
|
61 |
|
62 |
-
# Sample from filtered distribution
|
63 |
probabilities = torch.softmax(filtered_logits, dim=-1)
|
64 |
next_token = torch.multinomial(probabilities, num_samples=1)
|
65 |
-
|
66 |
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
67 |
|
68 |
new_token_text = tokenizer.decode(next_token[0])
|
69 |
output_text += new_token_text
|
70 |
|
71 |
-
|
|
|
|
|
|
|
72 |
|
73 |
-
# Stop if EOS token generated
|
74 |
if next_token.item() == tokenizer.eos_token_id:
|
75 |
break
|
76 |
|
77 |
-
def
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
with gr.Blocks() as demo:
|
82 |
-
gr.Markdown("# SmilyAI Sam
|
83 |
|
84 |
with gr.Row():
|
85 |
model_selector = gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model")
|
86 |
status = gr.Textbox(label="Status", interactive=False)
|
87 |
|
88 |
-
|
89 |
-
|
|
|
90 |
|
91 |
-
generate_btn = gr.Button("Generate")
|
92 |
-
|
93 |
-
# Load default model
|
94 |
status.value = load_model(MODELS[0])
|
95 |
|
96 |
-
model_selector.change(
|
|
|
|
|
|
|
97 |
|
98 |
-
def
|
99 |
-
|
100 |
-
yield "Please enter a prompt."
|
101 |
-
return
|
102 |
-
yield from generate_stream(prompt)
|
103 |
|
104 |
-
|
|
|
105 |
|
106 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
import time
|
5 |
|
|
|
6 |
MODELS = [
|
7 |
"Smilyai-labs/Sam-reason-S1",
|
8 |
"Smilyai-labs/Sam-reason-S1.5",
|
|
|
16 |
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
|
|
|
19 |
model = None
|
20 |
tokenizer = None
|
21 |
|
|
|
26 |
model.eval()
|
27 |
return f"Loaded model: {model_name}"
|
28 |
|
29 |
+
def build_prompt(chat_history):
|
30 |
+
"""
|
31 |
+
Build the prompt string for the model from chat history.
|
32 |
+
Adjust this format to match your model's expected input style.
|
33 |
+
Example format:
|
34 |
+
User: ...
|
35 |
+
Assistant: ...
|
36 |
+
User: ...
|
37 |
+
"""
|
38 |
+
prompt = ""
|
39 |
+
for entry in chat_history:
|
40 |
+
role, text = entry
|
41 |
+
prompt += f"{role}: {text}\n"
|
42 |
+
prompt += "Assistant: " # Model is expected to continue here
|
43 |
+
return prompt
|
44 |
+
|
45 |
+
def generate_stream(chat_history, max_length=100, temperature=0.7, top_p=0.9):
|
46 |
global model, tokenizer
|
47 |
if model is None or tokenizer is None:
|
48 |
yield "Model not loaded. Please select a model first."
|
49 |
return
|
50 |
|
51 |
+
prompt = build_prompt(chat_history)
|
52 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
|
|
53 |
generated_ids = input_ids
|
54 |
+
output_text = prompt
|
55 |
|
|
|
56 |
for _ in range(max_length):
|
57 |
outputs = model(generated_ids)
|
58 |
logits = outputs.logits
|
|
|
|
|
59 |
next_token_logits = logits[:, -1, :] / temperature
|
60 |
|
|
|
61 |
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
62 |
cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
|
63 |
|
|
|
64 |
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
65 |
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
66 |
sorted_indices_to_remove[..., 0] = 0
|
67 |
|
68 |
filtered_logits = next_token_logits.clone()
|
69 |
filtered_logits[:, sorted_indices[sorted_indices_to_remove]] = -float('Inf')
|
70 |
|
|
|
71 |
probabilities = torch.softmax(filtered_logits, dim=-1)
|
72 |
next_token = torch.multinomial(probabilities, num_samples=1)
|
|
|
73 |
generated_ids = torch.cat([generated_ids, next_token], dim=-1)
|
74 |
|
75 |
new_token_text = tokenizer.decode(next_token[0])
|
76 |
output_text += new_token_text
|
77 |
|
78 |
+
# Extract only assistant's reply (after "Assistant: ")
|
79 |
+
assistant_reply = output_text.split("Assistant:")[-1].strip()
|
80 |
+
|
81 |
+
yield assistant_reply
|
82 |
|
|
|
83 |
if next_token.item() == tokenizer.eos_token_id:
|
84 |
break
|
85 |
|
86 |
+
def chatbot_step(user_input, chat_history):
|
87 |
+
if not user_input.strip():
|
88 |
+
return chat_history, "Please type something."
|
89 |
+
|
90 |
+
# Append user input to chat history
|
91 |
+
chat_history = chat_history + [("User", user_input)]
|
92 |
+
|
93 |
+
# We will collect the assistant's streaming reply here
|
94 |
+
assistant_response = ""
|
95 |
+
|
96 |
+
# Generator to stream tokens
|
97 |
+
def response_generator():
|
98 |
+
nonlocal assistant_response
|
99 |
+
for partial_reply in generate_stream(chat_history):
|
100 |
+
assistant_response = partial_reply
|
101 |
+
yield chat_history + [("Assistant", assistant_response)]
|
102 |
+
|
103 |
+
return response_generator()
|
104 |
|
105 |
with gr.Blocks() as demo:
|
106 |
+
gr.Markdown("# SmilyAI Sam Multi-turn Chatbot with Token Streaming")
|
107 |
|
108 |
with gr.Row():
|
109 |
model_selector = gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model")
|
110 |
status = gr.Textbox(label="Status", interactive=False)
|
111 |
|
112 |
+
chatbot = gr.Chatbot()
|
113 |
+
msg = gr.Textbox(label="Your message")
|
114 |
+
send_btn = gr.Button("Send")
|
115 |
|
|
|
|
|
|
|
116 |
status.value = load_model(MODELS[0])
|
117 |
|
118 |
+
model_selector.change(lambda m: load_model(m), inputs=model_selector, outputs=status)
|
119 |
+
|
120 |
+
# Keep chat history in state
|
121 |
+
state = gr.State([])
|
122 |
|
123 |
+
def update_chat(user_message, chat_history):
|
124 |
+
return chatbot_step(user_message, chat_history)
|
|
|
|
|
|
|
125 |
|
126 |
+
send_btn.click(update_chat, inputs=[msg, state], outputs=[chatbot, state])
|
127 |
+
msg.submit(update_chat, inputs=[msg, state], outputs=[chatbot, state])
|
128 |
|
129 |
demo.launch()
|