Keeby-smilyai commited on
Commit
01d2db3
·
verified ·
1 Parent(s): eb48590

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -22
app.py CHANGED
@@ -1,10 +1,11 @@
1
  import gradio as gr
2
  import torch
3
  from dataclasses import dataclass
4
- from transformers import AutoTokenizer, PretrainedConfig, pipeline, GenerationConfig
5
  from optimum.onnxruntime import ORTModelForCausalLM
6
  import onnx
7
  import logging
 
8
 
9
  logging.basicConfig(level=logging.INFO)
10
 
@@ -57,7 +58,6 @@ class Sam3Config(PretrainedConfig):
57
  self.input_modality = input_modality
58
  self.head_type = head_type
59
  self.version = version
60
-
61
  self.hidden_size = self.d_model
62
  self.num_attention_heads = self.n_heads
63
 
@@ -77,36 +77,49 @@ except Exception as e:
77
  logging.error(f"Failed to load ONNX model: {e}")
78
  raise e
79
 
80
- # Define a function to generate text
81
- def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
 
 
 
 
 
 
 
 
 
 
 
 
82
  # Set generation parameters within a GenerationConfig object
83
- # We set use_cache=False here to bypass the onnx export issue
84
  gen_config = GenerationConfig(
85
  max_length=max_length,
86
  temperature=temperature,
87
  top_k=top_k,
88
  top_p=top_p,
89
  do_sample=True,
90
- use_cache=False,
91
  )
92
 
93
- gen_pipeline = pipeline(
94
- "text-generation",
95
- model=model,
96
- tokenizer=tokenizer,
97
- device=device,
98
  )
99
-
100
- # Pass all generation parameters to the pipeline
101
- generated_text = gen_pipeline(
102
- prompt,
103
- **gen_config.to_dict()
104
- )
105
- return generated_text[0]["generated_text"]
106
 
107
- # Create and launch the Gradio interface
 
 
108
  demo = gr.Interface(
109
- fn=generate_text,
110
  inputs=[
111
  gr.Textbox(label="Prompt", lines=2),
112
  gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"),
@@ -115,8 +128,8 @@ demo = gr.Interface(
115
  gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
116
  ],
117
  outputs="text",
118
- title="SmilyAI Sam 3.0-2 ONNX Text Generation",
119
- description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2."
120
  )
121
 
122
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from dataclasses import dataclass
4
+ from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig, TextIteratorStreamer
5
  from optimum.onnxruntime import ORTModelForCausalLM
6
  import onnx
7
  import logging
8
+ from threading import Thread
9
 
10
  logging.basicConfig(level=logging.INFO)
11
 
 
58
  self.input_modality = input_modality
59
  self.head_type = head_type
60
  self.version = version
 
61
  self.hidden_size = self.d_model
62
  self.num_attention_heads = self.n_heads
63
 
 
77
  logging.error(f"Failed to load ONNX model: {e}")
78
  raise e
79
 
80
+ # -----------------------------------------------------------------------------
81
+ # Streaming Generation Function
82
+ # -----------------------------------------------------------------------------
83
+ def generate_text_stream(prompt, max_length, temperature, top_k, top_p):
84
+ """
85
+ This function acts as a generator to stream text.
86
+ It yields each new token as it's generated by the model.
87
+ """
88
+ # Create a streamer to iterate over the generated tokens
89
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
90
+
91
+ # Prepare the generation inputs
92
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
93
+
94
  # Set generation parameters within a GenerationConfig object
95
+ # We explicitly set use_cache=False to avoid the ONNX export bug
96
  gen_config = GenerationConfig(
97
  max_length=max_length,
98
  temperature=temperature,
99
  top_k=top_k,
100
  top_p=top_p,
101
  do_sample=True,
102
+ use_cache=False,
103
  )
104
 
105
+ # Create a thread to run the generation in the background
106
+ generation_kwargs = dict(
107
+ input_ids=input_ids,
108
+ streamer=streamer,
109
+ generation_config=gen_config,
110
  )
111
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
112
+ thread.start()
113
+
114
+ # Yield each token from the streamer as it is generated
115
+ for new_text in streamer:
116
+ yield new_text
 
117
 
118
+ # -----------------------------------------------------------------------------
119
+ # Gradio Interface
120
+ # -----------------------------------------------------------------------------
121
  demo = gr.Interface(
122
+ fn=generate_text_stream,
123
  inputs=[
124
  gr.Textbox(label="Prompt", lines=2),
125
  gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"),
 
128
  gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
129
  ],
130
  outputs="text",
131
+ title="SmilyAI Sam 3.0-2 ONNX Text Generation (Streaming)",
132
+ description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2, with streaming output.",
133
  )
134
 
135
  demo.launch()