Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -57,6 +57,7 @@ class Sam3Config(PretrainedConfig):
|
|
57 |
self.input_modality = input_modality
|
58 |
self.head_type = head_type
|
59 |
self.version = version
|
|
|
60 |
self.hidden_size = self.d_model
|
61 |
self.num_attention_heads = self.n_heads
|
62 |
|
@@ -71,11 +72,6 @@ try:
|
|
71 |
trust_remote_code=True,
|
72 |
)
|
73 |
logging.info("ONNX model loaded successfully.")
|
74 |
-
|
75 |
-
# Fix the use_cache issue by setting it to False if the model doesn't support it
|
76 |
-
if not getattr(model, "_is_stateful", True):
|
77 |
-
logging.warning("Model does not support `_is_stateful`, setting `use_cache=False` for generation.")
|
78 |
-
model.generation_config.use_cache = False
|
79 |
|
80 |
except Exception as e:
|
81 |
logging.error(f"Failed to load ONNX model: {e}")
|
@@ -83,14 +79,15 @@ except Exception as e:
|
|
83 |
|
84 |
# Define a function to generate text
|
85 |
def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
|
86 |
-
# Set generation parameters
|
|
|
87 |
gen_config = GenerationConfig(
|
88 |
max_length=max_length,
|
89 |
temperature=temperature,
|
90 |
top_k=top_k,
|
91 |
top_p=top_p,
|
92 |
do_sample=True,
|
93 |
-
use_cache=False,
|
94 |
)
|
95 |
|
96 |
gen_pipeline = pipeline(
|
@@ -98,11 +95,12 @@ def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
|
|
98 |
model=model,
|
99 |
tokenizer=tokenizer,
|
100 |
device=device,
|
101 |
-
generation_config=gen_config
|
102 |
)
|
103 |
|
|
|
104 |
generated_text = gen_pipeline(
|
105 |
prompt,
|
|
|
106 |
)
|
107 |
return generated_text[0]["generated_text"]
|
108 |
|
|
|
57 |
self.input_modality = input_modality
|
58 |
self.head_type = head_type
|
59 |
self.version = version
|
60 |
+
|
61 |
self.hidden_size = self.d_model
|
62 |
self.num_attention_heads = self.n_heads
|
63 |
|
|
|
72 |
trust_remote_code=True,
|
73 |
)
|
74 |
logging.info("ONNX model loaded successfully.")
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
except Exception as e:
|
77 |
logging.error(f"Failed to load ONNX model: {e}")
|
|
|
79 |
|
80 |
# Define a function to generate text
|
81 |
def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
|
82 |
+
# Set generation parameters within a GenerationConfig object
|
83 |
+
# We set use_cache=False here to bypass the onnx export issue
|
84 |
gen_config = GenerationConfig(
|
85 |
max_length=max_length,
|
86 |
temperature=temperature,
|
87 |
top_k=top_k,
|
88 |
top_p=top_p,
|
89 |
do_sample=True,
|
90 |
+
use_cache=False,
|
91 |
)
|
92 |
|
93 |
gen_pipeline = pipeline(
|
|
|
95 |
model=model,
|
96 |
tokenizer=tokenizer,
|
97 |
device=device,
|
|
|
98 |
)
|
99 |
|
100 |
+
# Pass all generation parameters to the pipeline
|
101 |
generated_text = gen_pipeline(
|
102 |
prompt,
|
103 |
+
**gen_config.to_dict()
|
104 |
)
|
105 |
return generated_text[0]["generated_text"]
|
106 |
|