Boning c commited on
Commit
206f796
·
verified ·
1 Parent(s): 3e1d6c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -72
app.py CHANGED
@@ -2,105 +2,94 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- # List of available SmilyAI Sam models (adjust as needed)
6
- MODELS = [
7
- "Smilyai-labs/Sam-reason-A1",
8
- "Smilyai-labs/Sam-reason-S1",
9
- "Smilyai-labs/Sam-reason-S1.5",
10
- "Smilyai-labs/Sam-reason-S2",
11
- "Smilyai-labs/Sam-reason-S3",
12
- "Smilyai-labs/Sam-reason-v1",
13
- "Smilyai-labs/Sam-reason-v2",
14
- "Smilyai-labs/Sam-flash-mini-v1"
15
- ]
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- # Global vars to hold model and tokenizer
20
- model = None
21
- tokenizer = None
22
-
23
- def load_model(model_name):
24
- global model, tokenizer
25
- tokenizer = AutoTokenizer.from_pretrained(model_name)
26
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
27
- model.eval()
28
- return f"Loaded model: {model_name}"
29
-
30
- def generate_stream(prompt, max_length=100, temperature=0.7, top_p=0.9):
31
- global model, tokenizer
32
- if model is None or tokenizer is None:
33
- yield "Model not loaded. Please select a model first."
34
- return
35
-
 
 
36
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
37
-
38
  generated_ids = input_ids
39
  output_text = tokenizer.decode(input_ids[0])
40
 
41
- # Generate tokens one by one
42
  for _ in range(max_length):
43
  outputs = model(generated_ids)
44
- logits = outputs.logits
45
-
46
- # Get logits for last token
47
- next_token_logits = logits[:, -1, :] / temperature
48
-
49
- # Apply top_p filtering for nucleus sampling
50
- sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
51
  cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
52
-
53
- # Remove tokens with cumulative prob above top_p
54
  sorted_indices_to_remove = cumulative_probs > top_p
55
- # Shift mask right to keep at least one token
56
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
57
  sorted_indices_to_remove[..., 0] = 0
58
-
59
- filtered_logits = next_token_logits.clone()
60
- filtered_logits[:, sorted_indices[sorted_indices_to_remove]] = -float('Inf')
61
-
62
- # Sample from filtered distribution
63
  probabilities = torch.softmax(filtered_logits, dim=-1)
64
  next_token = torch.multinomial(probabilities, num_samples=1)
65
-
66
  generated_ids = torch.cat([generated_ids, next_token], dim=-1)
67
-
68
  new_token_text = tokenizer.decode(next_token[0])
69
  output_text += new_token_text
70
-
71
  yield output_text
72
-
73
- # Stop if EOS token generated
74
  if next_token.item() == tokenizer.eos_token_id:
75
  break
76
 
77
- def on_model_change(model_name):
78
- status = load_model(model_name)
79
- return status
 
 
80
 
81
- with gr.Blocks() as demo:
82
- gr.Markdown("# SmilyAI Sam Models Manual Token Streaming Generator")
 
83
 
84
- with gr.Row():
85
- model_selector = gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model")
86
- status = gr.Textbox(label="Status", interactive=False)
87
 
88
- prompt_input = gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt")
89
- output_box = gr.Textbox(label="Generated Text", lines=15, interactive=False)
 
90
 
91
- generate_btn = gr.Button("Generate")
 
92
 
93
- # Load default model
94
- status.value = load_model(MODELS[0])
 
 
 
 
 
 
 
 
 
 
95
 
96
- model_selector.change(on_model_change, inputs=model_selector, outputs=status)
97
 
98
- def generate_func(prompt):
99
- if not prompt.strip():
100
- yield "Please enter a prompt."
101
- return
102
- yield from generate_stream(prompt)
103
 
104
- generate_btn.click(generate_func, inputs=prompt_input, outputs=output_box)
 
 
 
 
105
 
106
- demo.launch()
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # Model identifiers
6
+ PRIMARY_MODEL = "Smilyai-labs/Sam-reason-A1"
7
+ FALLBACK_MODEL = "Smilyai-labs/Sam-reason-S2.1"
 
 
 
 
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Global model/tokenizer holders
12
+ primary_model = primary_tokenizer = None
13
+ fallback_model = fallback_tokenizer = None
14
+
15
+ # IP usage tracking
16
+ usage_counts = {}
17
+ USAGE_LIMIT = 10
18
+
19
+ def load_models():
20
+ global primary_model, primary_tokenizer, fallback_model, fallback_tokenizer
21
+ primary_tokenizer = AutoTokenizer.from_pretrained(PRIMARY_MODEL)
22
+ primary_model = AutoModelForCausalLM.from_pretrained(PRIMARY_MODEL).to(device).eval()
23
+ fallback_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL)
24
+ fallback_model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL).to(device).eval()
25
+ return f"Models loaded: {PRIMARY_MODEL} and fallback {FALLBACK_MODEL}"
26
+
27
+ def generate_stream(prompt, use_fallback=False, max_length=100, temperature=0.7, top_p=0.9):
28
+ model = fallback_model if use_fallback else primary_model
29
+ tokenizer = fallback_tokenizer if use_fallback else primary_tokenizer
30
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
31
  generated_ids = input_ids
32
  output_text = tokenizer.decode(input_ids[0])
33
 
 
34
  for _ in range(max_length):
35
  outputs = model(generated_ids)
36
+ logits = outputs.logits[:, -1, :] / temperature
37
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
 
 
 
 
 
38
  cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
 
 
39
  sorted_indices_to_remove = cumulative_probs > top_p
40
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1]
 
41
  sorted_indices_to_remove[..., 0] = 0
42
+ filtered_logits = logits.clone()
43
+ filtered_logits[:, sorted_indices[sorted_indices_to_remove]] = -float("Inf")
 
 
 
44
  probabilities = torch.softmax(filtered_logits, dim=-1)
45
  next_token = torch.multinomial(probabilities, num_samples=1)
 
46
  generated_ids = torch.cat([generated_ids, next_token], dim=-1)
 
47
  new_token_text = tokenizer.decode(next_token[0])
48
  output_text += new_token_text
 
49
  yield output_text
 
 
50
  if next_token.item() == tokenizer.eos_token_id:
51
  break
52
 
53
+ def respond(message, chat_history, reason_toggle, request: gr.Request):
54
+ ip = request.client.host if request else "unknown"
55
+ usage_counts[ip] = usage_counts.get(ip, 0) + 1
56
+ use_fallback = usage_counts[ip] > USAGE_LIMIT
57
+ model_label = "A1" if not use_fallback else "Fallback S2.1"
58
 
59
+ # Prefix prompt with reasoning mode
60
+ prefix = "/think " if reason_toggle else "/no_think "
61
+ processed_message = prefix + message.strip()
62
 
63
+ chat_history = chat_history + [[message, ""]]
 
 
64
 
65
+ for response in generate_stream(processed_message, use_fallback=use_fallback):
66
+ chat_history[-1][1] = response + f" ({model_label})"
67
+ yield chat_history, chat_history
68
 
69
+ def clear_chat():
70
+ return [], []
71
 
72
+ with gr.Blocks() as demo:
73
+ gr.Markdown("# 🧠 SmilyAI Chatbot with Reasoning Toggle & Usage Limits")
74
+
75
+ model_status = gr.Textbox(label="Model Status", interactive=False)
76
+
77
+ chat_box = gr.Chatbot()
78
+ chat_history_state = gr.State([])
79
+
80
+ with gr.Row():
81
+ user_input = gr.Textbox(placeholder="Type your message...", show_label=False, scale=6)
82
+ reason_toggle = gr.Checkbox(label="Reason", value=True, scale=1)
83
+ send_btn = gr.Button("Send", scale=1)
84
 
85
+ clear_btn = gr.Button("Clear Chat")
86
 
87
+ model_status.value = load_models()
 
 
 
 
88
 
89
+ send_btn.click(
90
+ respond,
91
+ inputs=[user_input, chat_history_state, reason_toggle],
92
+ outputs=[chat_box, chat_history_state]
93
+ )
94
 
95
+ clear_btn.click(fn=clear_chat, inputs=[], outputs=[chat_box, chat_history_state])