Keeby-smilyai commited on
Commit
eb48590
·
verified ·
1 Parent(s): 1a351ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -57,6 +57,7 @@ class Sam3Config(PretrainedConfig):
57
  self.input_modality = input_modality
58
  self.head_type = head_type
59
  self.version = version
 
60
  self.hidden_size = self.d_model
61
  self.num_attention_heads = self.n_heads
62
 
@@ -71,11 +72,6 @@ try:
71
  trust_remote_code=True,
72
  )
73
  logging.info("ONNX model loaded successfully.")
74
-
75
- # Fix the use_cache issue by setting it to False if the model doesn't support it
76
- if not getattr(model, "_is_stateful", True):
77
- logging.warning("Model does not support `_is_stateful`, setting `use_cache=False` for generation.")
78
- model.generation_config.use_cache = False
79
 
80
  except Exception as e:
81
  logging.error(f"Failed to load ONNX model: {e}")
@@ -83,14 +79,15 @@ except Exception as e:
83
 
84
  # Define a function to generate text
85
  def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
86
- # Set generation parameters
 
87
  gen_config = GenerationConfig(
88
  max_length=max_length,
89
  temperature=temperature,
90
  top_k=top_k,
91
  top_p=top_p,
92
  do_sample=True,
93
- use_cache=False, # Explicitly disable cache to avoid the error
94
  )
95
 
96
  gen_pipeline = pipeline(
@@ -98,11 +95,12 @@ def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
98
  model=model,
99
  tokenizer=tokenizer,
100
  device=device,
101
- generation_config=gen_config
102
  )
103
 
 
104
  generated_text = gen_pipeline(
105
  prompt,
 
106
  )
107
  return generated_text[0]["generated_text"]
108
 
 
57
  self.input_modality = input_modality
58
  self.head_type = head_type
59
  self.version = version
60
+
61
  self.hidden_size = self.d_model
62
  self.num_attention_heads = self.n_heads
63
 
 
72
  trust_remote_code=True,
73
  )
74
  logging.info("ONNX model loaded successfully.")
 
 
 
 
 
75
 
76
  except Exception as e:
77
  logging.error(f"Failed to load ONNX model: {e}")
 
79
 
80
  # Define a function to generate text
81
  def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
82
+ # Set generation parameters within a GenerationConfig object
83
+ # We set use_cache=False here to bypass the onnx export issue
84
  gen_config = GenerationConfig(
85
  max_length=max_length,
86
  temperature=temperature,
87
  top_k=top_k,
88
  top_p=top_p,
89
  do_sample=True,
90
+ use_cache=False,
91
  )
92
 
93
  gen_pipeline = pipeline(
 
95
  model=model,
96
  tokenizer=tokenizer,
97
  device=device,
 
98
  )
99
 
100
+ # Pass all generation parameters to the pipeline
101
  generated_text = gen_pipeline(
102
  prompt,
103
+ **gen_config.to_dict()
104
  )
105
  return generated_text[0]["generated_text"]
106