Reubencf commited on
Commit
f98a7e0
Β·
verified Β·
1 Parent(s): 16bcde3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -271
app.py CHANGED
@@ -1,299 +1,111 @@
1
- # app.py β€” Hugging Face Space ready (LoRA adapter, Gradio compat)
 
 
 
2
  # ---------------------------------------------------------------
3
- # What changed vs your script
4
- # - Removed ChatInterface args that broke on old Gradio (retry_btn, undo_btn)
5
- # - No interactive input() for merging (Spaces are non-interactive). Use MERGE_LORA env var.
6
- # - Secrets: read HF token from env (Settings β†’ Secrets β†’ HF_TOKEN), never hardcode.
7
- # - Token passing works across transformers/peft versions (token/use_auth_token fallback).
8
- # - Optional 8-bit via USE_8BIT=1 (GPU only). Safe CPU defaults.
9
- # - Robust theme/queue/launch for mixed Gradio versions.
10
-
11
  import os
12
- import gc
13
- import warnings
14
- from typing import List, Tuple
15
-
16
  import torch
17
  import gradio as gr
18
-
19
- warnings.filterwarnings("ignore")
20
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
21
-
22
- try:
23
- from peft import PeftConfig, PeftModel
24
- from transformers import (
25
- AutoTokenizer,
26
- AutoModelForCausalLM,
27
- BitsAndBytesConfig,
28
- )
29
- IMPORTS_OK = True
30
- except Exception as e:
31
- IMPORTS_OK = False
32
- print(f"Missing dependencies: {e}")
33
- print("Install: pip install --upgrade 'transformers>=4.41' peft accelerate gradio torch bitsandbytes")
34
 
35
  # ── Configuration ──────────────────────────────────────────────────────────────
36
- HF_TOKEN = os.getenv("HF_TOKEN") # set in Space Settings β†’ Secrets β†’ HF_TOKEN
37
-
38
- # LoRA adapter repo (must be compatible with BASE_MODEL_ID)
39
- ADAPTER_ID = os.getenv("ADAPTER_ID", "Reubencf/gemma3-goan-finetuned")
40
-
41
- # Base model used during fine-tuning (should match adapter's base)
42
- BASE_MODEL_ID_DEFAULT = os.getenv("BASE_MODEL_ID", "google/gemma-3-4b-it")
43
-
44
- # Quantization toggle (GPU only): set USE_8BIT=1 in Space variables
45
- USE_8BIT = os.getenv("USE_8BIT", "0").lower() in {"1", "true", "yes", "y"}
46
-
47
- # Merge LoRA into the base for faster inference: MERGE_LORA=1/0
48
- MERGE_LORA = os.getenv("MERGE_LORA", "1").lower() in {"1", "true", "yes", "y"}
49
 
50
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
51
-
52
  TITLE = "🌴 Gemma Goan Q&A Bot"
53
- DESCRIPTION_TMPL = (
54
- "Gemma base model + LoRA adapter fine-tuned on a Goan Q&A dataset.\n"
55
- "Ask about Goa, Konkani culture, or general topics!\n\n"
56
- "**Status**: {}"
57
  )
58
 
59
- # ── Helpers ───────────────────────────────────────────────────────────────────
60
-
61
- def call_with_token(fn, *args, **kwargs):
62
- """Call HF/Transformers/PEFT functions with token OR use_auth_token for
63
- broad version compatibility."""
64
- if HF_TOKEN:
65
- try:
66
- return fn(*args, token=HF_TOKEN, **kwargs)
67
- except TypeError:
68
- return fn(*args, use_auth_token=HF_TOKEN, **kwargs)
69
- return fn(*args, **kwargs)
70
-
71
- # ── Load model + tokenizer ─────────────────────────────────────────────────────
72
-
73
- def load_model_and_tokenizer():
74
- if not IMPORTS_OK:
75
- raise ImportError("Required packages not installed.")
76
-
77
- print("[Init] Starting model load…")
78
- print(f"[Config] Device: {DEVICE}")
79
-
80
- # GC + VRAM cleanup
81
- gc.collect()
82
- if torch.cuda.is_available():
83
- torch.cuda.empty_cache()
84
-
85
- # Step 1: Confirm base model from the adapter's config if possible
86
- actual_base_model = BASE_MODEL_ID_DEFAULT
87
- try:
88
- print(f"[Load] Reading adapter config: {ADAPTER_ID}")
89
- peft_cfg = call_with_token(PeftConfig.from_pretrained, ADAPTER_ID)
90
- if getattr(peft_cfg, "base_model_name_or_path", None):
91
- actual_base_model = peft_cfg.base_model_name_or_path
92
- print(f"[Load] Adapter expects base model: {actual_base_model}")
93
- else:
94
- print("[Warn] Adapter did not expose base_model_name_or_path; using configured base.")
95
- except Exception as e:
96
- print(f"[Warn] Could not read adapter config ({e}); using configured base: {actual_base_model}")
97
-
98
- # Step 2: Load base model (optionally quantized on GPU)
99
- print(f"[Load] Loading base model: {actual_base_model}")
100
- quant_cfg = None
101
- if USE_8BIT and torch.cuda.is_available():
102
- print("[Load] Enabling 8-bit quantization (bitsandbytes)")
103
- quant_cfg = BitsAndBytesConfig(load_in_8bit=True, bnb_8bit_compute_dtype=torch.float16)
104
-
105
- base_model = call_with_token(
106
- AutoModelForCausalLM.from_pretrained,
107
- actual_base_model,
108
- trust_remote_code=True,
109
- quantization_config=quant_cfg,
110
- low_cpu_mem_usage=True,
111
- torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
112
- device_map="auto" if torch.cuda.is_available() else None,
113
- )
114
-
115
- if DEVICE == "cpu" and not torch.cuda.is_available():
116
- base_model = base_model.to("cpu")
117
- print("[Load] Model on CPU")
118
-
119
- print("[Load] Base model loaded βœ”")
120
-
121
- # Step 3: Tokenizer
122
- print("[Load] Loading tokenizer…")
123
- tokenizer = call_with_token(
124
- AutoTokenizer.from_pretrained,
125
- actual_base_model,
126
- use_fast=True,
127
- trust_remote_code=True,
128
- )
129
- if tokenizer.pad_token is None:
130
- tokenizer.pad_token = tokenizer.eos_token
131
- tokenizer.padding_side = "left"
132
-
133
- # Step 4: Apply LoRA adapter
134
- status = ""
135
- model = base_model
136
- try:
137
- print(f"[Load] Applying LoRA adapter: {ADAPTER_ID}")
138
- model = call_with_token(PeftModel.from_pretrained, base_model, ADAPTER_ID)
139
-
140
- if MERGE_LORA:
141
- print("[Load] Merging adapter into base (merge_and_unload)…")
142
- model = model.merge_and_unload()
143
- status = f"βœ… Using fine-tuned model (merged): {ADAPTER_ID}"
144
- else:
145
- status = f"βœ… Using fine-tuned model via adapter: {ADAPTER_ID}"
146
- except FileNotFoundError as e:
147
- print(f"[Error] Adapter files not found: {e}")
148
- status = f"⚠️ Adapter not found. Using base only: {actual_base_model}"
149
- except Exception as e:
150
- print(f"[Error] Failed to load adapter: {e}")
151
- status = f"⚠️ Could not load adapter. Using base only: {actual_base_model}"
152
-
153
- model.eval()
154
- print(f"[Load] Model ready on {DEVICE} βœ”")
155
-
156
- gc.collect()
157
- if torch.cuda.is_available():
158
- torch.cuda.empty_cache()
159
-
160
- return model, tokenizer, status
161
-
162
- # Global load at import time (Space-friendly)
163
  try:
164
- model, tokenizer, STATUS_MSG = load_model_and_tokenizer()
 
 
 
 
 
 
165
  MODEL_LOADED = True
166
- DESCRIPTION = DESCRIPTION_TMPL.format(STATUS_MSG)
167
  except Exception as e:
168
- print(f"[Fatal] Could not load model: {e}")
169
  MODEL_LOADED = False
170
- model = tokenizer = None
171
- DESCRIPTION = DESCRIPTION_TMPL.format(f"❌ Model failed to load: {str(e)[:140]}")
172
 
173
- # ── Generation ────────────────────────────────────────────────────────────────
174
 
175
- def generate_response(
176
- message: str,
177
- history: List[Tuple[str, str]],
178
- temperature: float = 0.7,
179
- max_new_tokens: int = 256,
180
- top_p: float = 0.95,
181
- repetition_penalty: float = 1.1,
182
- ) -> str:
183
  if not MODEL_LOADED:
184
- return "⚠️ Model failed to load. Check Space logs."
185
-
186
- try:
187
- # Build short chat history
188
- conversation = []
189
- if history:
190
- for u, a in history[-3:]:
191
- if u:
192
- conversation.append({"role": "user", "content": u})
193
- if a:
194
- conversation.append({"role": "assistant", "content": a})
195
- conversation.append({"role": "user", "content": message})
196
-
197
- # Try the tokenizer's chat template first
198
- try:
199
- input_ids = tokenizer.apply_chat_template(
200
- conversation,
201
- add_generation_prompt=True,
202
- return_tensors="pt",
203
- )
204
- except Exception as e:
205
- print(f"[Warn] chat_template failed: {e}; using manual format")
206
- prompt_text = "".join(
207
- [
208
- ("User: " + m["content"] + "\n") if m["role"] == "user" else ("Assistant: " + m["content"] + "\n")
209
- for m in conversation
210
- ]
211
- ) + "Assistant: "
212
- input_ids = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=1024).input_ids
213
-
214
- input_ids = input_ids.to(model.device if hasattr(model, "device") else DEVICE)
215
-
216
- with torch.no_grad():
217
- out = model.generate(
218
- input_ids=input_ids,
219
- max_new_tokens=max(1, min(int(max_new_tokens), 512)),
220
- temperature=float(temperature),
221
- top_p=float(top_p),
222
- repetition_penalty=float(repetition_penalty),
223
- do_sample=True,
224
- pad_token_id=tokenizer.pad_token_id,
225
- eos_token_id=tokenizer.eos_token_id,
226
- use_cache=True,
227
- )
228
-
229
- gen = out[0][input_ids.shape[-1]:]
230
- text = tokenizer.decode(gen, skip_special_tokens=True).strip()
231
-
232
- # Cleanup
233
- del out, input_ids, gen
234
- gc.collect()
235
- if torch.cuda.is_available():
236
- torch.cuda.empty_cache()
237
 
238
- return text or "(no output)"
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- except Exception as e:
241
- gc.collect()
242
- if torch.cuda.is_available():
243
- torch.cuda.empty_cache()
244
- return f"⚠️ Error generating response: {e}"
245
 
246
  # ── UI ────────────────────────────────────────────────────────────────────────
 
247
  examples = [
248
- ["What is the capital of Goa?", 0.7, 256, 0.95, 1.1],
249
- ["Tell me about the Konkani language.", 0.7, 256, 0.95, 1.1],
250
- ["What are famous beaches in Goa?", 0.7, 256, 0.95, 1.1],
251
- ["Describe Goan fish curry.", 0.7, 256, 0.95, 1.1],
252
- ["What is the history of Old Goa?", 0.7, 256, 0.95, 1.1],
253
  ]
254
 
255
- # Best-effort theme across versions
256
- try:
257
- THEME = gr.themes.Soft()
258
- except Exception:
259
- THEME = None
260
-
261
- if MODEL_LOADED:
262
- demo = gr.ChatInterface(
263
- fn=generate_response,
264
- title=TITLE,
265
- description=DESCRIPTION,
266
- examples=examples,
267
- additional_inputs=[
268
- gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.05, label="Temperature"),
269
- gr.Slider(minimum=32, maximum=512, value=256, step=16, label="Max new tokens"),
270
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
271
- gr.Slider(minimum=1.0, maximum=2.0, value=1.1, step=0.05, label="Repetition penalty"),
272
- ],
273
- theme=THEME,
274
- )
275
- else:
276
- demo = gr.Interface(
277
- fn=lambda x: "Model failed to load. Check Space logs.",
278
- inputs=gr.Textbox(label="Message"),
279
- outputs=gr.Textbox(label="Response"),
280
- title=TITLE,
281
- description=DESCRIPTION,
282
- theme=THEME,
283
- )
284
-
285
- # Queue β€” keep params minimal for cross-version compat
286
- try:
287
- demo.queue()
288
- except Exception:
289
- pass
290
 
 
291
  if __name__ == "__main__":
292
- print("\n" + "=" * 60)
293
- print(f"πŸš€ Starting Gradio app on {DEVICE} …")
294
- print(f"πŸ“ Base model: {BASE_MODEL_ID_DEFAULT}")
295
- print(f"πŸ”§ LoRA adapter: {ADAPTER_ID}")
296
- print(f"🧩 Merge LoRA: {MERGE_LORA}")
297
- print("=" * 60 + "\n")
298
- # On Spaces, just calling launch() is fine.
299
  demo.launch()
 
1
+ # app.py β€” Simplified for Hugging Face Spaces
2
+ # ---------------------------------------------------------------
3
+ # This version uses the high-level `pipeline` from transformers
4
+ # for a much simpler and cleaner implementation.
5
  # ---------------------------------------------------------------
 
 
 
 
 
 
 
 
6
  import os
 
 
 
 
7
  import torch
8
  import gradio as gr
9
+ from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # ── Configuration ──────────────────────────────────────────────────────────────
12
+ # Set the model repository ID
13
+ MODEL_ID = "Reubencf/gemma3-goan-finetuned"
14
+ HF_TOKEN = os.getenv("HF_TOKEN") # Optional: for private models
 
 
 
 
 
 
 
 
 
 
15
 
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
17
  TITLE = "🌴 Gemma Goan Q&A Bot"
18
+ DESCRIPTION = (
19
+ "This is a simple Gradio chat interface for the Gemma model fine-tuned on a Goan Q&A dataset.\n"
20
+ "Ask about Goa, Konkani culture, or general topics!"
 
21
  )
22
 
23
+ # ── Load Model Pipeline ─────────────────────────────────────────────────────
24
+ # We load the model and tokenizer into a pipeline object.
25
+ # This is done only once when the app starts.
26
+ # `device_map="auto"` ensures the model is placed on a GPU if available.
27
+ print(f"[Init] Loading model pipeline: {MODEL_ID} on {DEVICE}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  try:
29
+ pipe = pipeline(
30
+ "text-generation",
31
+ model=MODEL_ID,
32
+ torch_dtype=torch.bfloat16, # Use bfloat16 for better performance
33
+ device_map="auto",
34
+ token=HF_TOKEN,
35
+ )
36
  MODEL_LOADED = True
37
+ print("[Init] Model pipeline loaded successfully.")
38
  except Exception as e:
 
39
  MODEL_LOADED = False
40
+ DESCRIPTION = f"❌ Model failed to load: {e}"
41
+ print(f"[Fatal] Could not load model: {e}")
42
 
 
43
 
44
+ # ── Generation Function ──────────────────────────────────────────────────────
45
+ def generate_response(message, history):
46
+ """
47
+ This function is called for each user message.
48
+ It takes the user's message and the conversation history,
49
+ formats them for the model, and returns the model's response.
50
+ """
 
51
  if not MODEL_LOADED:
52
+ return "⚠️ Model is not available. Please check the Space logs for errors."
53
+
54
+ # Format the conversation history into the format expected by the model
55
+ # The model expects a list of dictionaries with "role" and "content" keys
56
+ conversation = []
57
+ for user_msg, assistant_msg in history:
58
+ conversation.append({"role": "user", "content": user_msg})
59
+ if assistant_msg:
60
+ conversation.append({"role": "assistant", "content": assistant_msg})
61
+
62
+ # Add the current user's message
63
+ conversation.append({"role": "user", "content": message})
64
+
65
+ # Use the pipeline's tokenizer to apply the chat template
66
+ # This correctly formats the input for the conversational model
67
+ prompt = pipe.tokenizer.apply_chat_template(
68
+ conversation,
69
+ tokenize=False,
70
+ add_generation_prompt=True
71
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Generate the response using the pipeline
74
+ outputs = pipe(
75
+ prompt,
76
+ do_sample=True,
77
+ temperature=0.7,
78
+ top_k=50,
79
+ top_p=0.95
80
+ )
81
+
82
+ # The pipeline output includes the entire conversation history (prompt).
83
+ # We need to extract only the newly generated text from the assistant.
84
+ response = outputs[0]["generated_text"]
85
+ # Slice the response to get only the new part
86
+ new_response = response[len(prompt):].strip()
87
 
88
+ return new_response
 
 
 
 
89
 
90
  # ── UI ────────────────────────────────────────────────────────────────────────
91
+ # Define some example questions to display in the UI
92
  examples = [
93
+ "What is bebinca?",
94
+ "Tell me about the history of Feni.",
95
+ "Suggest a good, quiet beach in South Goa.",
96
+ "Describe Goan fish curry.",
 
97
  ]
98
 
99
+ # Create the Gradio ChatInterface
100
+ demo = gr.ChatInterface(
101
+ fn=generate_response,
102
+ title=TITLE,
103
+ description=DESCRIPTION,
104
+ examples=examples,
105
+ theme="soft",
106
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # ── Launch ────────────────────────────────────────────────────────────────────
109
  if __name__ == "__main__":
110
+ print("πŸš€ Starting Gradio app...")
 
 
 
 
 
 
111
  demo.launch()