rmdhirr commited on
Commit
c29f849
Β·
verified Β·
1 Parent(s): 7a38948

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -32
app.py CHANGED
@@ -12,17 +12,97 @@ import spaces
12
  import torch
13
  from loguru import logger
14
  from PIL import Image
15
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- model_id = os.getenv("MODEL_ID", "rmdhirr/Kenanga-11B-IT")
18
- processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
19
- model = Gemma3ForConditionalGeneration.from_pretrained(
20
- model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
 
 
 
 
 
 
21
  )
22
 
23
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
25
 
 
 
 
26
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
27
  image_count = 0
28
  video_count = 0
@@ -33,7 +113,6 @@ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
33
  image_count += 1
34
  return image_count, video_count
35
 
36
-
37
  def count_files_in_history(history: list[dict]) -> tuple[int, int]:
38
  image_count = 0
39
  video_count = 0
@@ -46,7 +125,6 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
46
  image_count += 1
47
  return image_count, video_count
48
 
49
-
50
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
51
  new_image_count, new_video_count = count_files_in_new_message(message["files"])
52
  history_image_count, history_video_count = count_files_in_history(history)
@@ -70,19 +148,15 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
70
  return False
71
  return True
72
 
73
-
74
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
75
  vidcap = cv2.VideoCapture(video_path)
76
  fps = vidcap.get(cv2.CAP_PROP_FPS)
77
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
78
-
79
  frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
80
  frames: list[tuple[Image.Image, float]] = []
81
-
82
  for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
83
  if len(frames) >= MAX_NUM_IMAGES:
84
  break
85
-
86
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
87
  success, image = vidcap.read()
88
  if success:
@@ -90,16 +164,13 @@ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
90
  pil_image = Image.fromarray(image)
91
  timestamp = round(i / fps, 2)
92
  frames.append((pil_image, timestamp))
93
-
94
  vidcap.release()
95
  return frames
96
 
97
-
98
  def process_video(video_path: str) -> list[dict]:
99
  content = []
100
  frames = downsample_video(video_path)
101
- for frame in frames:
102
- pil_image, timestamp = frame
103
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
104
  pil_image.save(temp_file.name)
105
  content.append({"type": "text", "text": f"Frame {timestamp}:"})
@@ -107,12 +178,10 @@ def process_video(video_path: str) -> list[dict]:
107
  logger.debug(f"{content=}")
108
  return content
109
 
110
-
111
  def process_interleaved_images(message: dict) -> list[dict]:
112
  logger.debug(f"{message['files']=}")
113
  parts = re.split(r"(<image>)", message["text"])
114
  logger.debug(f"{parts=}")
115
-
116
  content = []
117
  image_index = 0
118
  for part in parts:
@@ -128,23 +197,18 @@ def process_interleaved_images(message: dict) -> list[dict]:
128
  logger.debug(f"{content=}")
129
  return content
130
 
131
-
132
  def process_new_user_message(message: dict) -> list[dict]:
133
  if not message["files"]:
134
  return [{"type": "text", "text": message["text"]}]
135
-
136
  if message["files"][0].endswith(".mp4"):
137
  return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
138
-
139
  if "<image>" in message["text"]:
140
  return process_interleaved_images(message)
141
-
142
  return [
143
  {"type": "text", "text": message["text"]},
144
  *[{"type": "image", "url": path} for path in message["files"]],
145
  ]
146
 
147
-
148
  def process_history(history: list[dict]) -> list[dict]:
149
  messages = []
150
  current_user_content: list[dict] = []
@@ -162,16 +226,19 @@ def process_history(history: list[dict]) -> list[dict]:
162
  current_user_content.append({"type": "image", "url": content[0]})
163
  return messages
164
 
165
-
 
 
166
  @spaces.GPU(duration=120)
167
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
168
  if not validate_media_constraints(message, history):
169
  yield ""
170
  return
171
 
 
 
172
  messages = []
173
- if system_prompt:
174
- messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
175
  messages.extend(process_history(history))
176
  messages.append({"role": "user", "content": process_new_user_message(message)})
177
 
@@ -183,22 +250,30 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
183
  return_tensors="pt",
184
  ).to(device=model.device, dtype=torch.bfloat16)
185
 
186
- streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
187
  generate_kwargs = dict(
188
  inputs,
189
  streamer=streamer,
190
  max_new_tokens=max_new_tokens,
191
  disable_compile=True,
192
  )
 
 
 
193
  t = Thread(target=model.generate, kwargs=generate_kwargs)
194
  t.start()
195
 
196
  output = ""
197
  for delta in streamer:
198
  output += delta
199
- yield output
200
-
201
 
 
 
 
202
  examples = [
203
  [
204
  {
@@ -321,11 +396,10 @@ examples = [
321
  ],
322
  ]
323
 
324
-
325
  DESCRIPTION = """\
326
  <img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
327
  <div align='center'>
328
- This is a demo of Kenanga 11B IT, a multimodal Large Vision-Language Model (LVLM) adapted for Sundanese and Javanese support.
329
  You can upload images, as well as interleaved images and videos. Video input is limited to single-turn conversations and must be in MP4 format.
330
  </div>
331
  """
@@ -337,7 +411,7 @@ demo = gr.ChatInterface(
337
  textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
338
  multimodal=True,
339
  additional_inputs=[
340
- gr.Textbox(label="System Prompt", value="You are a helpful assistant."),
341
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
342
  ],
343
  stop_btn=False,
 
12
  import torch
13
  from loguru import logger
14
  from PIL import Image
15
+ from transformers import AutoProcessor, TextIteratorStreamer
16
+
17
+ # ─────────────────────────────────────────────────────────────────────
18
+ # Model & processor
19
+ # ─────────────────────────────────────────────────────────────────────
20
+ MODEL_ID = os.getenv("MODEL_ID", "rmdhirr/Kenanga-11B-IT")
21
+ processor = AutoProcessor.from_pretrained(MODEL_ID, padding_side="left")
22
+
23
+ # Try Gemma-3 vision first; if it fails, fall back to Llama 3.2 Vision (Mllama)
24
+ model = None
25
+ _last_load_error = None
26
+ try:
27
+ from transformers import Gemma3ForConditionalGeneration
28
+ model = Gemma3ForConditionalGeneration.from_pretrained(
29
+ MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
30
+ )
31
+ except Exception as e:
32
+ _last_load_error = e
33
+ try:
34
+ from transformers import MllamaForConditionalGeneration
35
+ model = MllamaForConditionalGeneration.from_pretrained(
36
+ MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
37
+ )
38
+ except Exception as e2:
39
+ raise RuntimeError(
40
+ f"Failed to load model as Gemma3 and Mllama.\nGemma3 error: {type(_last_load_error).__name__}: {_last_load_error}\n"
41
+ f"Mllama error: {type(e2).__name__}: {e2}"
42
+ )
43
+
44
+ MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
45
 
46
+ # ─────────────────────────────────────────────────────────────────────
47
+ # Identity controls (System Prompt + Stream Sanitizer + Optional Logit Ban)
48
+ # ─────────────────────────────────────────────────────────────────────
49
+ IDENTITY_PROMPT = (
50
+ "You are Kenanga, an Indonesian multimodal LVLM adapted for Sundanese and Javanese.\n"
51
+ "Identity rules:\n"
52
+ "β€’ When referring to yourself, always say β€œKenanga”.\n"
53
+ "β€’ Never claim to be Gemma/Llama or any base model. If asked about your base, reply briefly: "
54
+ "β€œI’m Kenanga (locally adapted); please refer to me as Kenanga.”\n"
55
+ "β€’ Stay helpful, concise, and safe."
56
  )
57
 
58
+ BAN_BASE_NAMES = os.getenv("BAN_BASE_NAMES", "0") == "1"
59
+
60
+ def _make_bad_words_ids(words):
61
+ toks = processor.tokenizer
62
+ ids = []
63
+ for w in words:
64
+ for variant in {w, w.lower(), w.upper(), w.title(), " " + w, " " + w.lower()}:
65
+ enc = toks(variant, add_special_tokens=False).input_ids
66
+ if enc:
67
+ ids.append(enc)
68
+ # dedupe
69
+ uniq, seen = [], set()
70
+ for seq in ids:
71
+ t = tuple(seq)
72
+ if t and t not in seen:
73
+ uniq.append(seq)
74
+ seen.add(t)
75
+ return uniq
76
+
77
+ BAD_WORDS_IDS = _make_bad_words_ids([
78
+ "Gemma", "Gemma-3", "Gemma 3", "Gemma3",
79
+ # Uncomment to ban base model family self-calls entirely:
80
+ # "Llama", "LLaMA", "Llama 3", "Llama 3.2", "Llama3", "Llama3.2",
81
+ ])
82
+
83
+ # Only rewrite self-identity claims; allow legitimate mentions in analysis/comparison text
84
+ SELF_REF_PAT = re.compile(
85
+ r"\b(?:(?:I\s*am|I'm|This\s+is|You'?re\s+chatting\s+with)\s+)(Gemma(?:[-\s]?3)?|LLa?ma(?:\s*3(?:\.2)?)?)\b",
86
+ flags=re.IGNORECASE,
87
+ )
88
+ AS_MODEL_PAT = re.compile(
89
+ r"\bAs\s+(?:an?\s+)?(Gemma(?:[-\s]?3)?|LLa?ma(?:\s*3(?:\.2)?)?)\b",
90
+ flags=re.IGNORECASE,
91
+ )
92
+ THIS_MODEL_IS_PAT = re.compile(
93
+ r"\b(This\s+model\s+is)\s+(Gemma(?:[-\s]?3)?|LLa?ma(?:\s*3(?:\.2)?)?)\b",
94
+ flags=re.IGNORECASE,
95
+ )
96
 
97
+ def sanitize_identity(text: str) -> str:
98
+ text = SELF_REF_PAT.sub("I am Kenanga", text)
99
+ text = AS_MODEL_PAT.sub("As Kenanga", text)
100
+ text = THIS_MODEL_IS_PAT.sub(r"\1 Kenanga", text)
101
+ return text
102
 
103
+ # ─────────────────────────────────────────────────────────────────────
104
+ # Media utilities
105
+ # ─────────────────────────────────────────────────────────────────────
106
  def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
107
  image_count = 0
108
  video_count = 0
 
113
  image_count += 1
114
  return image_count, video_count
115
 
 
116
  def count_files_in_history(history: list[dict]) -> tuple[int, int]:
117
  image_count = 0
118
  video_count = 0
 
125
  image_count += 1
126
  return image_count, video_count
127
 
 
128
  def validate_media_constraints(message: dict, history: list[dict]) -> bool:
129
  new_image_count, new_video_count = count_files_in_new_message(message["files"])
130
  history_image_count, history_video_count = count_files_in_history(history)
 
148
  return False
149
  return True
150
 
 
151
  def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
152
  vidcap = cv2.VideoCapture(video_path)
153
  fps = vidcap.get(cv2.CAP_PROP_FPS)
154
  total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
 
155
  frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
156
  frames: list[tuple[Image.Image, float]] = []
 
157
  for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
158
  if len(frames) >= MAX_NUM_IMAGES:
159
  break
 
160
  vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
161
  success, image = vidcap.read()
162
  if success:
 
164
  pil_image = Image.fromarray(image)
165
  timestamp = round(i / fps, 2)
166
  frames.append((pil_image, timestamp))
 
167
  vidcap.release()
168
  return frames
169
 
 
170
  def process_video(video_path: str) -> list[dict]:
171
  content = []
172
  frames = downsample_video(video_path)
173
+ for pil_image, timestamp in frames:
 
174
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
175
  pil_image.save(temp_file.name)
176
  content.append({"type": "text", "text": f"Frame {timestamp}:"})
 
178
  logger.debug(f"{content=}")
179
  return content
180
 
 
181
  def process_interleaved_images(message: dict) -> list[dict]:
182
  logger.debug(f"{message['files']=}")
183
  parts = re.split(r"(<image>)", message["text"])
184
  logger.debug(f"{parts=}")
 
185
  content = []
186
  image_index = 0
187
  for part in parts:
 
197
  logger.debug(f"{content=}")
198
  return content
199
 
 
200
  def process_new_user_message(message: dict) -> list[dict]:
201
  if not message["files"]:
202
  return [{"type": "text", "text": message["text"]}]
 
203
  if message["files"][0].endswith(".mp4"):
204
  return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
 
205
  if "<image>" in message["text"]:
206
  return process_interleaved_images(message)
 
207
  return [
208
  {"type": "text", "text": message["text"]},
209
  *[{"type": "image", "url": path} for path in message["files"]],
210
  ]
211
 
 
212
  def process_history(history: list[dict]) -> list[dict]:
213
  messages = []
214
  current_user_content: list[dict] = []
 
226
  current_user_content.append({"type": "image", "url": content[0]})
227
  return messages
228
 
229
+ # ─────────────────────────────────────────────────────────────────────
230
+ # Generation
231
+ # ─────────────────────────────────────────────────────────────────────
232
  @spaces.GPU(duration=120)
233
  def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
234
  if not validate_media_constraints(message, history):
235
  yield ""
236
  return
237
 
238
+ effective_sys = IDENTITY_PROMPT if not system_prompt else (IDENTITY_PROMPT + "\n\n" + system_prompt)
239
+
240
  messages = []
241
+ messages.append({"role": "system", "content": [{"type": "text", "text": effective_sys}]})
 
242
  messages.extend(process_history(history))
243
  messages.append({"role": "user", "content": process_new_user_message(message)})
244
 
 
250
  return_tensors="pt",
251
  ).to(device=model.device, dtype=torch.bfloat16)
252
 
253
+ streamer = TextIteratorStreamer(
254
+ processor.tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
255
+ )
256
+
257
  generate_kwargs = dict(
258
  inputs,
259
  streamer=streamer,
260
  max_new_tokens=max_new_tokens,
261
  disable_compile=True,
262
  )
263
+ if BAN_BASE_NAMES and BAD_WORDS_IDS:
264
+ generate_kwargs["bad_words_ids"] = BAD_WORDS_IDS
265
+
266
  t = Thread(target=model.generate, kwargs=generate_kwargs)
267
  t.start()
268
 
269
  output = ""
270
  for delta in streamer:
271
  output += delta
272
+ yield sanitize_identity(output)
 
273
 
274
+ # ─────────────────────────────────────────────────────────────────────
275
+ # Demo UI
276
+ # ─────────────────────────────────────────────────────────────────────
277
  examples = [
278
  [
279
  {
 
396
  ],
397
  ]
398
 
 
399
  DESCRIPTION = """\
400
  <img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
401
  <div align='center'>
402
+ This is a demo of Kenanga 11B IT, a multimodal Large Vision-Language Model (LVLM) adapted for Sundanese and Javanese support.<br/>
403
  You can upload images, as well as interleaved images and videos. Video input is limited to single-turn conversations and must be in MP4 format.
404
  </div>
405
  """
 
411
  textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
412
  multimodal=True,
413
  additional_inputs=[
414
+ gr.Textbox(label="System Prompt", value=IDENTITY_PROMPT),
415
  gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
416
  ],
417
  stop_btn=False,