Boning c commited on
Commit
2c4e10a
·
verified ·
1 Parent(s): baf1963

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -31
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
 
5
- # List of available SmilyAI Sam models (adjust as needed)
6
  MODELS = [
7
  "Smilyai-labs/Sam-reason-S1",
8
  "Smilyai-labs/Sam-reason-S1.5",
@@ -16,7 +16,6 @@ MODELS = [
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- # Global vars to hold model and tokenizer
20
  model = None
21
  tokenizer = None
22
 
@@ -27,80 +26,104 @@ def load_model(model_name):
27
  model.eval()
28
  return f"Loaded model: {model_name}"
29
 
30
- def generate_stream(prompt, max_length=100, temperature=0.7, top_p=0.9):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  global model, tokenizer
32
  if model is None or tokenizer is None:
33
  yield "Model not loaded. Please select a model first."
34
  return
35
 
 
36
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
37
-
38
  generated_ids = input_ids
39
- output_text = tokenizer.decode(input_ids[0])
40
 
41
- # Generate tokens one by one
42
  for _ in range(max_length):
43
  outputs = model(generated_ids)
44
  logits = outputs.logits
45
-
46
- # Get logits for last token
47
  next_token_logits = logits[:, -1, :] / temperature
48
 
49
- # Apply top_p filtering for nucleus sampling
50
  sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
51
  cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
52
 
53
- # Remove tokens with cumulative prob above top_p
54
  sorted_indices_to_remove = cumulative_probs > top_p
55
- # Shift mask right to keep at least one token
56
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
57
  sorted_indices_to_remove[..., 0] = 0
58
 
59
  filtered_logits = next_token_logits.clone()
60
  filtered_logits[:, sorted_indices[sorted_indices_to_remove]] = -float('Inf')
61
 
62
- # Sample from filtered distribution
63
  probabilities = torch.softmax(filtered_logits, dim=-1)
64
  next_token = torch.multinomial(probabilities, num_samples=1)
65
-
66
  generated_ids = torch.cat([generated_ids, next_token], dim=-1)
67
 
68
  new_token_text = tokenizer.decode(next_token[0])
69
  output_text += new_token_text
70
 
71
- yield output_text
 
 
 
72
 
73
- # Stop if EOS token generated
74
  if next_token.item() == tokenizer.eos_token_id:
75
  break
76
 
77
- def on_model_change(model_name):
78
- status = load_model(model_name)
79
- return status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  with gr.Blocks() as demo:
82
- gr.Markdown("# SmilyAI Sam Models Manual Token Streaming Generator")
83
 
84
  with gr.Row():
85
  model_selector = gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model")
86
  status = gr.Textbox(label="Status", interactive=False)
87
 
88
- prompt_input = gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt")
89
- output_box = gr.Textbox(label="Generated Text", lines=15, interactive=False)
 
90
 
91
- generate_btn = gr.Button("Generate")
92
-
93
- # Load default model
94
  status.value = load_model(MODELS[0])
95
 
96
- model_selector.change(on_model_change, inputs=model_selector, outputs=status)
 
 
 
97
 
98
- def generate_func(prompt):
99
- if not prompt.strip():
100
- yield "Please enter a prompt."
101
- return
102
- yield from generate_stream(prompt)
103
 
104
- generate_btn.click(generate_func, inputs=prompt_input, outputs=output_box)
 
105
 
106
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import time
5
 
 
6
  MODELS = [
7
  "Smilyai-labs/Sam-reason-S1",
8
  "Smilyai-labs/Sam-reason-S1.5",
 
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
 
19
  model = None
20
  tokenizer = None
21
 
 
26
  model.eval()
27
  return f"Loaded model: {model_name}"
28
 
29
+ def build_prompt(chat_history):
30
+ """
31
+ Build the prompt string for the model from chat history.
32
+ Adjust this format to match your model's expected input style.
33
+ Example format:
34
+ User: ...
35
+ Assistant: ...
36
+ User: ...
37
+ """
38
+ prompt = ""
39
+ for entry in chat_history:
40
+ role, text = entry
41
+ prompt += f"{role}: {text}\n"
42
+ prompt += "Assistant: " # Model is expected to continue here
43
+ return prompt
44
+
45
+ def generate_stream(chat_history, max_length=100, temperature=0.7, top_p=0.9):
46
  global model, tokenizer
47
  if model is None or tokenizer is None:
48
  yield "Model not loaded. Please select a model first."
49
  return
50
 
51
+ prompt = build_prompt(chat_history)
52
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
 
53
  generated_ids = input_ids
54
+ output_text = prompt
55
 
 
56
  for _ in range(max_length):
57
  outputs = model(generated_ids)
58
  logits = outputs.logits
 
 
59
  next_token_logits = logits[:, -1, :] / temperature
60
 
 
61
  sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
62
  cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
63
 
 
64
  sorted_indices_to_remove = cumulative_probs > top_p
 
65
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
66
  sorted_indices_to_remove[..., 0] = 0
67
 
68
  filtered_logits = next_token_logits.clone()
69
  filtered_logits[:, sorted_indices[sorted_indices_to_remove]] = -float('Inf')
70
 
 
71
  probabilities = torch.softmax(filtered_logits, dim=-1)
72
  next_token = torch.multinomial(probabilities, num_samples=1)
 
73
  generated_ids = torch.cat([generated_ids, next_token], dim=-1)
74
 
75
  new_token_text = tokenizer.decode(next_token[0])
76
  output_text += new_token_text
77
 
78
+ # Extract only assistant's reply (after "Assistant: ")
79
+ assistant_reply = output_text.split("Assistant:")[-1].strip()
80
+
81
+ yield assistant_reply
82
 
 
83
  if next_token.item() == tokenizer.eos_token_id:
84
  break
85
 
86
+ def chatbot_step(user_input, chat_history):
87
+ if not user_input.strip():
88
+ return chat_history, "Please type something."
89
+
90
+ # Append user input to chat history
91
+ chat_history = chat_history + [("User", user_input)]
92
+
93
+ # We will collect the assistant's streaming reply here
94
+ assistant_response = ""
95
+
96
+ # Generator to stream tokens
97
+ def response_generator():
98
+ nonlocal assistant_response
99
+ for partial_reply in generate_stream(chat_history):
100
+ assistant_response = partial_reply
101
+ yield chat_history + [("Assistant", assistant_response)]
102
+
103
+ return response_generator()
104
 
105
  with gr.Blocks() as demo:
106
+ gr.Markdown("# SmilyAI Sam Multi-turn Chatbot with Token Streaming")
107
 
108
  with gr.Row():
109
  model_selector = gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model")
110
  status = gr.Textbox(label="Status", interactive=False)
111
 
112
+ chatbot = gr.Chatbot()
113
+ msg = gr.Textbox(label="Your message")
114
+ send_btn = gr.Button("Send")
115
 
 
 
 
116
  status.value = load_model(MODELS[0])
117
 
118
+ model_selector.change(lambda m: load_model(m), inputs=model_selector, outputs=status)
119
+
120
+ # Keep chat history in state
121
+ state = gr.State([])
122
 
123
+ def update_chat(user_message, chat_history):
124
+ return chatbot_step(user_message, chat_history)
 
 
 
125
 
126
+ send_btn.click(update_chat, inputs=[msg, state], outputs=[chatbot, state])
127
+ msg.submit(update_chat, inputs=[msg, state], outputs=[chatbot, state])
128
 
129
  demo.launch()