Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from dataclasses import dataclass
|
4 |
-
from transformers import AutoTokenizer, PretrainedConfig,
|
5 |
from optimum.onnxruntime import ORTModelForCausalLM
|
6 |
import onnx
|
7 |
import logging
|
|
|
8 |
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
|
@@ -57,7 +58,6 @@ class Sam3Config(PretrainedConfig):
|
|
57 |
self.input_modality = input_modality
|
58 |
self.head_type = head_type
|
59 |
self.version = version
|
60 |
-
|
61 |
self.hidden_size = self.d_model
|
62 |
self.num_attention_heads = self.n_heads
|
63 |
|
@@ -77,36 +77,49 @@ except Exception as e:
|
|
77 |
logging.error(f"Failed to load ONNX model: {e}")
|
78 |
raise e
|
79 |
|
80 |
-
#
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
# Set generation parameters within a GenerationConfig object
|
83 |
-
# We set use_cache=False
|
84 |
gen_config = GenerationConfig(
|
85 |
max_length=max_length,
|
86 |
temperature=temperature,
|
87 |
top_k=top_k,
|
88 |
top_p=top_p,
|
89 |
do_sample=True,
|
90 |
-
use_cache=False,
|
91 |
)
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
return generated_text[0]["generated_text"]
|
106 |
|
107 |
-
#
|
|
|
|
|
108 |
demo = gr.Interface(
|
109 |
-
fn=
|
110 |
inputs=[
|
111 |
gr.Textbox(label="Prompt", lines=2),
|
112 |
gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"),
|
@@ -115,8 +128,8 @@ demo = gr.Interface(
|
|
115 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
|
116 |
],
|
117 |
outputs="text",
|
118 |
-
title="SmilyAI Sam 3.0-2 ONNX Text Generation",
|
119 |
-
description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2."
|
120 |
)
|
121 |
|
122 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from dataclasses import dataclass
|
4 |
+
from transformers import AutoTokenizer, PretrainedConfig, GenerationConfig, TextIteratorStreamer
|
5 |
from optimum.onnxruntime import ORTModelForCausalLM
|
6 |
import onnx
|
7 |
import logging
|
8 |
+
from threading import Thread
|
9 |
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
|
|
|
58 |
self.input_modality = input_modality
|
59 |
self.head_type = head_type
|
60 |
self.version = version
|
|
|
61 |
self.hidden_size = self.d_model
|
62 |
self.num_attention_heads = self.n_heads
|
63 |
|
|
|
77 |
logging.error(f"Failed to load ONNX model: {e}")
|
78 |
raise e
|
79 |
|
80 |
+
# -----------------------------------------------------------------------------
|
81 |
+
# Streaming Generation Function
|
82 |
+
# -----------------------------------------------------------------------------
|
83 |
+
def generate_text_stream(prompt, max_length, temperature, top_k, top_p):
|
84 |
+
"""
|
85 |
+
This function acts as a generator to stream text.
|
86 |
+
It yields each new token as it's generated by the model.
|
87 |
+
"""
|
88 |
+
# Create a streamer to iterate over the generated tokens
|
89 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
90 |
+
|
91 |
+
# Prepare the generation inputs
|
92 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
93 |
+
|
94 |
# Set generation parameters within a GenerationConfig object
|
95 |
+
# We explicitly set use_cache=False to avoid the ONNX export bug
|
96 |
gen_config = GenerationConfig(
|
97 |
max_length=max_length,
|
98 |
temperature=temperature,
|
99 |
top_k=top_k,
|
100 |
top_p=top_p,
|
101 |
do_sample=True,
|
102 |
+
use_cache=False,
|
103 |
)
|
104 |
|
105 |
+
# Create a thread to run the generation in the background
|
106 |
+
generation_kwargs = dict(
|
107 |
+
input_ids=input_ids,
|
108 |
+
streamer=streamer,
|
109 |
+
generation_config=gen_config,
|
110 |
)
|
111 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
112 |
+
thread.start()
|
113 |
+
|
114 |
+
# Yield each token from the streamer as it is generated
|
115 |
+
for new_text in streamer:
|
116 |
+
yield new_text
|
|
|
117 |
|
118 |
+
# -----------------------------------------------------------------------------
|
119 |
+
# Gradio Interface
|
120 |
+
# -----------------------------------------------------------------------------
|
121 |
demo = gr.Interface(
|
122 |
+
fn=generate_text_stream,
|
123 |
inputs=[
|
124 |
gr.Textbox(label="Prompt", lines=2),
|
125 |
gr.Slider(minimum=10, maximum=512, value=128, label="Max Length"),
|
|
|
128 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top P"),
|
129 |
],
|
130 |
outputs="text",
|
131 |
+
title="SmilyAI Sam 3.0-2 ONNX Text Generation (Streaming)",
|
132 |
+
description="A simple API and UI for text generation using the ONNX version of Sam 3.0-2, with streaming output.",
|
133 |
)
|
134 |
|
135 |
demo.launch()
|