Boning c commited on
Commit
a40a845
·
verified ·
1 Parent(s): bc956d1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ # List of available SmilyAI Sam models (adjust as needed)
6
+ MODELS = [
7
+ "Smilyai-labs/Sam-reason-S1",
8
+ "Smilyai-labs/Sam-reason-S1.5",
9
+ "Smilyai-labs/Sam-reason-S2",
10
+ "Smilyai-labs/Sam-reason-S3",
11
+ "Smilyai-labs/Sam-reason-v1",
12
+ "Smilyai-labs/Sam-reason-v2",
13
+ "Smilyai-labs/Sam-reason-A1",
14
+ "Smilyai-labs/Sam-flash-mini-v1"
15
+ ]
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # Global vars to hold model and tokenizer
20
+ model = None
21
+ tokenizer = None
22
+
23
+ def load_model(model_name):
24
+ global model, tokenizer
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
27
+ model.eval()
28
+ return f"Loaded model: {model_name}"
29
+
30
+ def generate_stream(prompt, max_length=100, temperature=0.7, top_p=0.9):
31
+ global model, tokenizer
32
+ if model is None or tokenizer is None:
33
+ yield "Model not loaded. Please select a model first."
34
+ return
35
+
36
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
37
+
38
+ generated_ids = input_ids
39
+ output_text = tokenizer.decode(input_ids[0])
40
+
41
+ # Generate tokens one by one
42
+ for _ in range(max_length):
43
+ outputs = model(generated_ids)
44
+ logits = outputs.logits
45
+
46
+ # Get logits for last token
47
+ next_token_logits = logits[:, -1, :] / temperature
48
+
49
+ # Apply top_p filtering for nucleus sampling
50
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
51
+ cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1)
52
+
53
+ # Remove tokens with cumulative prob above top_p
54
+ sorted_indices_to_remove = cumulative_probs > top_p
55
+ # Shift mask right to keep at least one token
56
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
57
+ sorted_indices_to_remove[..., 0] = 0
58
+
59
+ filtered_logits = next_token_logits.clone()
60
+ filtered_logits[:, sorted_indices[sorted_indices_to_remove]] = -float('Inf')
61
+
62
+ # Sample from filtered distribution
63
+ probabilities = torch.softmax(filtered_logits, dim=-1)
64
+ next_token = torch.multinomial(probabilities, num_samples=1)
65
+
66
+ generated_ids = torch.cat([generated_ids, next_token], dim=-1)
67
+
68
+ new_token_text = tokenizer.decode(next_token[0])
69
+ output_text += new_token_text
70
+
71
+ yield output_text
72
+
73
+ # Stop if EOS token generated
74
+ if next_token.item() == tokenizer.eos_token_id:
75
+ break
76
+
77
+ def on_model_change(model_name):
78
+ status = load_model(model_name)
79
+ return status
80
+
81
+ with gr.Blocks() as demo:
82
+ gr.Markdown("# SmilyAI Sam Models — Manual Token Streaming Generator")
83
+
84
+ with gr.Row():
85
+ model_selector = gr.Dropdown(choices=MODELS, value=MODELS[0], label="Select Model")
86
+ status = gr.Textbox(label="Status", interactive=False)
87
+
88
+ prompt_input = gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt")
89
+ output_box = gr.Textbox(label="Generated Text", lines=15, interactive=False)
90
+
91
+ generate_btn = gr.Button("Generate")
92
+
93
+ # Load default model
94
+ status.value = load_model(MODELS[0])
95
+
96
+ model_selector.change(on_model_change, inputs=model_selector, outputs=status)
97
+
98
+ def generate_func(prompt):
99
+ if not prompt.strip():
100
+ yield "Please enter a prompt."
101
+ return
102
+ yield from generate_stream(prompt)
103
+
104
+ generate_btn.click(generate_func, inputs=prompt_input, outputs=output_box)
105
+
106
+ demo.launch()