akashmadisetty commited on
Commit
b820bc7
·
1 Parent(s): 45c882e
Files changed (1) hide show
  1. app.py +63 -14
app.py CHANGED
@@ -27,37 +27,66 @@ def load_model(hf_token):
27
  return "⚠️ Please enter your Hugging Face token to use the model."
28
 
29
  try:
30
- # Try both model versions
31
  model_options = [
32
- "google/gemma-3-4b-pt", # Try the quantized PT version first
33
- "google/gemma-2b", # Fallback to 2b model
 
 
 
 
 
 
34
  ]
35
 
 
 
36
  # Try to load models in order until one works
37
  for model_name in model_options:
38
  try:
39
  print(f"Attempting to load model: {model_name}")
40
 
41
  # Load tokenizer
 
42
  global_tokenizer = AutoTokenizer.from_pretrained(
43
  model_name,
44
  token=hf_token
45
  )
 
46
 
47
- # Load model with minimal configuration to avoid errors
 
48
  global_model = AutoModelForCausalLM.from_pretrained(
49
  model_name,
50
  torch_dtype=torch.float16,
51
  device_map="auto",
52
  token=hf_token
53
  )
 
54
 
55
  model_loaded = True
56
  return f"✅ Model {model_name} loaded successfully!"
57
  except Exception as specific_e:
58
  print(f"Failed to load {model_name}: {specific_e}")
 
 
59
  continue
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # If we get here, all model options failed
62
  model_loaded = False
63
  return "❌ Could not load any model version. Please check your token and try again."
@@ -65,6 +94,10 @@ def load_model(hf_token):
65
  except Exception as e:
66
  model_loaded = False
67
  error_msg = str(e)
 
 
 
 
68
  if "401 Client Error" in error_msg:
69
  return "❌ Authentication failed. Please check your token and make sure you've accepted the model license on Hugging Face."
70
  else:
@@ -152,7 +185,11 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
152
  """Generate text using the Gemma model"""
153
  global global_model, global_tokenizer, model_loaded
154
 
 
 
 
155
  if not model_loaded or global_model is None or global_tokenizer is None:
 
156
  return "⚠️ Model not loaded. Please authenticate with your Hugging Face token."
157
 
158
  if not prompt:
@@ -161,21 +198,35 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
161
  try:
162
  # Keep generation simple to avoid errors
163
  inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
 
164
 
165
- # Use simpler generation parameters that work reliably
166
- outputs = global_model.generate(
167
- inputs.input_ids,
168
- max_length=min(2048, max_length + len(inputs.input_ids[0])),
169
- temperature=max(0.3, temperature), # Prevent too low temperature
170
- do_sample=True
171
- )
 
 
 
 
 
 
 
 
172
 
173
  # Decode and return the generated text
174
  generated_text = global_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
175
  return generated_text
176
  except Exception as e:
177
  error_msg = str(e)
178
  print(f"Generation error: {error_msg}")
 
 
 
 
179
  if "probability tensor" in error_msg:
180
  return "Error: There was a problem with the generation parameters. Try using simpler parameters or a different prompt."
181
  else:
@@ -234,11 +285,9 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
234
  with gr.Column(scale=1):
235
  auth_button = gr.Button("Authenticate", variant="primary")
236
 
237
- with gr.Group(visible=True) as auth_message_group:
238
- auth_status = gr.Markdown("Please authenticate to use the model.")
239
 
240
  def authenticate(token):
241
- auth_message_group.visible = True
242
  return "Loading model... Please wait, this may take a minute."
243
 
244
  def auth_complete(token):
 
27
  return "⚠️ Please enter your Hugging Face token to use the model."
28
 
29
  try:
30
+ # Try different model versions from smallest to largest
31
  model_options = [
32
+ "google/gemma-2b-it", # Try an instruction-tuned 2B model first (smallest)
33
+ "google/gemma-2b", # Try base 2B model next
34
+ "google/gemma-7b-it", # Try 7B instruction-tuned model
35
+ "google/gemma-7b", # Try base 7B model
36
+ # Fallback to completely different models if all Gemma models fail
37
+ "meta-llama/Llama-2-7b-chat-hf",
38
+ "facebook/opt-1.3b",
39
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
40
  ]
41
 
42
+ print(f"Attempting to load models with token starting with: {hf_token[:5]}...")
43
+
44
  # Try to load models in order until one works
45
  for model_name in model_options:
46
  try:
47
  print(f"Attempting to load model: {model_name}")
48
 
49
  # Load tokenizer
50
+ print("Loading tokenizer...")
51
  global_tokenizer = AutoTokenizer.from_pretrained(
52
  model_name,
53
  token=hf_token
54
  )
55
+ print("Tokenizer loaded successfully")
56
 
57
+ # Load model with minimal configuration
58
+ print(f"Loading model {model_name}...")
59
  global_model = AutoModelForCausalLM.from_pretrained(
60
  model_name,
61
  torch_dtype=torch.float16,
62
  device_map="auto",
63
  token=hf_token
64
  )
65
+ print(f"Model {model_name} loaded successfully!")
66
 
67
  model_loaded = True
68
  return f"✅ Model {model_name} loaded successfully!"
69
  except Exception as specific_e:
70
  print(f"Failed to load {model_name}: {specific_e}")
71
+ import traceback
72
+ traceback.print_exc()
73
  continue
74
 
75
+ # If we get here, all model options failed - try one more option with no token
76
+ try:
77
+ print("Trying a public model with no token requirement...")
78
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
79
+ global_tokenizer = AutoTokenizer.from_pretrained(model_name)
80
+ global_model = AutoModelForCausalLM.from_pretrained(
81
+ model_name,
82
+ torch_dtype=torch.float16,
83
+ device_map="auto"
84
+ )
85
+ model_loaded = True
86
+ return f"✅ Fallback model {model_name} loaded successfully! Note: This is not Gemma but a fallback model."
87
+ except Exception as fallback_e:
88
+ print(f"Failed to load fallback model: {fallback_e}")
89
+
90
  # If we get here, all model options failed
91
  model_loaded = False
92
  return "❌ Could not load any model version. Please check your token and try again."
 
94
  except Exception as e:
95
  model_loaded = False
96
  error_msg = str(e)
97
+ print(f"Error in load_model: {error_msg}")
98
+ import traceback
99
+ traceback.print_exc()
100
+
101
  if "401 Client Error" in error_msg:
102
  return "❌ Authentication failed. Please check your token and make sure you've accepted the model license on Hugging Face."
103
  else:
 
185
  """Generate text using the Gemma model"""
186
  global global_model, global_tokenizer, model_loaded
187
 
188
+ print(f"Generating text with params: max_length={max_length}, temp={temperature}, top_p={top_p}")
189
+ print(f"Prompt: {prompt[:100]}...")
190
+
191
  if not model_loaded or global_model is None or global_tokenizer is None:
192
+ print("Model not loaded")
193
  return "⚠️ Model not loaded. Please authenticate with your Hugging Face token."
194
 
195
  if not prompt:
 
198
  try:
199
  # Keep generation simple to avoid errors
200
  inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
201
+ print(f"Input token length: {len(inputs.input_ids[0])}")
202
 
203
+ # Use even simpler generation parameters
204
+ generation_args = {
205
+ "input_ids": inputs.input_ids,
206
+ "max_length": min(2048, max_length + len(inputs.input_ids[0])),
207
+ "do_sample": True,
208
+ }
209
+
210
+ # Only add temperature if not too low (can cause issues)
211
+ if temperature >= 0.3:
212
+ generation_args["temperature"] = temperature
213
+
214
+ print(f"Generation args: {generation_args}")
215
+
216
+ # Generate text
217
+ outputs = global_model.generate(**generation_args)
218
 
219
  # Decode and return the generated text
220
  generated_text = global_tokenizer.decode(outputs[0], skip_special_tokens=True)
221
+ print(f"Generated text length: {len(generated_text)}")
222
  return generated_text
223
  except Exception as e:
224
  error_msg = str(e)
225
  print(f"Generation error: {error_msg}")
226
+ print(f"Error type: {type(e)}")
227
+ import traceback
228
+ traceback.print_exc()
229
+
230
  if "probability tensor" in error_msg:
231
  return "Error: There was a problem with the generation parameters. Try using simpler parameters or a different prompt."
232
  else:
 
285
  with gr.Column(scale=1):
286
  auth_button = gr.Button("Authenticate", variant="primary")
287
 
288
+ auth_status = gr.Markdown("Please authenticate to use the model.")
 
289
 
290
  def authenticate(token):
 
291
  return "Loading model... Please wait, this may take a minute."
292
 
293
  def auth_complete(token):