Keeby-smilyai commited on
Commit
1a351ce
·
verified ·
1 Parent(s): f5f8831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -15
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import gradio as gr
2
  import torch
3
  from dataclasses import dataclass
4
- from transformers import AutoTokenizer, PretrainedConfig, pipeline
5
  from optimum.onnxruntime import ORTModelForCausalLM
6
  import onnx
 
 
 
7
 
8
  # -----------------------------------------------------------------------------
9
  # Configuration and Special Tokens
@@ -40,7 +43,6 @@ class Sam3Config(PretrainedConfig):
40
  _attn_implementation_internal: str = "eager"
41
  is_encoder_decoder: bool = False
42
 
43
- # These are the required attributes for ORTModelForCausalLM
44
  hidden_size: int = 384
45
  num_attention_heads: int = 6
46
 
@@ -55,8 +57,6 @@ class Sam3Config(PretrainedConfig):
55
  self.input_modality = input_modality
56
  self.head_type = head_type
57
  self.version = version
58
-
59
- # Ensure hidden_size and num_attention_heads are set correctly
60
  self.hidden_size = self.d_model
61
  self.num_attention_heads = self.n_heads
62
 
@@ -64,28 +64,45 @@ class Sam3Config(PretrainedConfig):
64
  model_config = Sam3Config()
65
 
66
  # Load the ONNX model by providing the configuration
67
- model = ORTModelForCausalLM.from_pretrained(
68
- "Smilyai-labs/Sam-3.0-2-onnx",
69
- config=model_config,
70
- trust_remote_code=True
71
- )
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Define a function to generate text
74
  def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
 
 
 
 
 
 
 
 
 
 
75
  gen_pipeline = pipeline(
76
  "text-generation",
77
  model=model,
78
  tokenizer=tokenizer,
79
- device=device
 
80
  )
81
 
82
  generated_text = gen_pipeline(
83
  prompt,
84
- max_length=max_length,
85
- temperature=temperature,
86
- top_k=top_k,
87
- top_p=top_p,
88
- do_sample=True,
89
  )
90
  return generated_text[0]["generated_text"]
91
 
 
1
  import gradio as gr
2
  import torch
3
  from dataclasses import dataclass
4
+ from transformers import AutoTokenizer, PretrainedConfig, pipeline, GenerationConfig
5
  from optimum.onnxruntime import ORTModelForCausalLM
6
  import onnx
7
+ import logging
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
 
11
  # -----------------------------------------------------------------------------
12
  # Configuration and Special Tokens
 
43
  _attn_implementation_internal: str = "eager"
44
  is_encoder_decoder: bool = False
45
 
 
46
  hidden_size: int = 384
47
  num_attention_heads: int = 6
48
 
 
57
  self.input_modality = input_modality
58
  self.head_type = head_type
59
  self.version = version
 
 
60
  self.hidden_size = self.d_model
61
  self.num_attention_heads = self.n_heads
62
 
 
64
  model_config = Sam3Config()
65
 
66
  # Load the ONNX model by providing the configuration
67
+ try:
68
+ model = ORTModelForCausalLM.from_pretrained(
69
+ "Smilyai-labs/Sam-3.0-2-onnx",
70
+ config=model_config,
71
+ trust_remote_code=True,
72
+ )
73
+ logging.info("ONNX model loaded successfully.")
74
+
75
+ # Fix the use_cache issue by setting it to False if the model doesn't support it
76
+ if not getattr(model, "_is_stateful", True):
77
+ logging.warning("Model does not support `_is_stateful`, setting `use_cache=False` for generation.")
78
+ model.generation_config.use_cache = False
79
+
80
+ except Exception as e:
81
+ logging.error(f"Failed to load ONNX model: {e}")
82
+ raise e
83
 
84
  # Define a function to generate text
85
  def generate_text(prompt, max_length=128, temperature=0.8, top_k=60, top_p=0.9):
86
+ # Set generation parameters
87
+ gen_config = GenerationConfig(
88
+ max_length=max_length,
89
+ temperature=temperature,
90
+ top_k=top_k,
91
+ top_p=top_p,
92
+ do_sample=True,
93
+ use_cache=False, # Explicitly disable cache to avoid the error
94
+ )
95
+
96
  gen_pipeline = pipeline(
97
  "text-generation",
98
  model=model,
99
  tokenizer=tokenizer,
100
+ device=device,
101
+ generation_config=gen_config
102
  )
103
 
104
  generated_text = gen_pipeline(
105
  prompt,
 
 
 
 
 
106
  )
107
  return generated_text[0]["generated_text"]
108