Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,17 +12,97 @@ import spaces
|
|
| 12 |
import torch
|
| 13 |
from loguru import logger
|
| 14 |
from PIL import Image
|
| 15 |
-
from transformers import AutoProcessor,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
| 27 |
image_count = 0
|
| 28 |
video_count = 0
|
|
@@ -33,7 +113,6 @@ def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
|
| 33 |
image_count += 1
|
| 34 |
return image_count, video_count
|
| 35 |
|
| 36 |
-
|
| 37 |
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
| 38 |
image_count = 0
|
| 39 |
video_count = 0
|
|
@@ -46,7 +125,6 @@ def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
|
| 46 |
image_count += 1
|
| 47 |
return image_count, video_count
|
| 48 |
|
| 49 |
-
|
| 50 |
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
| 51 |
new_image_count, new_video_count = count_files_in_new_message(message["files"])
|
| 52 |
history_image_count, history_video_count = count_files_in_history(history)
|
|
@@ -70,19 +148,15 @@ def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
|
| 70 |
return False
|
| 71 |
return True
|
| 72 |
|
| 73 |
-
|
| 74 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
| 75 |
vidcap = cv2.VideoCapture(video_path)
|
| 76 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 77 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 78 |
-
|
| 79 |
frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
|
| 80 |
frames: list[tuple[Image.Image, float]] = []
|
| 81 |
-
|
| 82 |
for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
|
| 83 |
if len(frames) >= MAX_NUM_IMAGES:
|
| 84 |
break
|
| 85 |
-
|
| 86 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 87 |
success, image = vidcap.read()
|
| 88 |
if success:
|
|
@@ -90,16 +164,13 @@ def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
|
| 90 |
pil_image = Image.fromarray(image)
|
| 91 |
timestamp = round(i / fps, 2)
|
| 92 |
frames.append((pil_image, timestamp))
|
| 93 |
-
|
| 94 |
vidcap.release()
|
| 95 |
return frames
|
| 96 |
|
| 97 |
-
|
| 98 |
def process_video(video_path: str) -> list[dict]:
|
| 99 |
content = []
|
| 100 |
frames = downsample_video(video_path)
|
| 101 |
-
for
|
| 102 |
-
pil_image, timestamp = frame
|
| 103 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
| 104 |
pil_image.save(temp_file.name)
|
| 105 |
content.append({"type": "text", "text": f"Frame {timestamp}:"})
|
|
@@ -107,12 +178,10 @@ def process_video(video_path: str) -> list[dict]:
|
|
| 107 |
logger.debug(f"{content=}")
|
| 108 |
return content
|
| 109 |
|
| 110 |
-
|
| 111 |
def process_interleaved_images(message: dict) -> list[dict]:
|
| 112 |
logger.debug(f"{message['files']=}")
|
| 113 |
parts = re.split(r"(<image>)", message["text"])
|
| 114 |
logger.debug(f"{parts=}")
|
| 115 |
-
|
| 116 |
content = []
|
| 117 |
image_index = 0
|
| 118 |
for part in parts:
|
|
@@ -128,23 +197,18 @@ def process_interleaved_images(message: dict) -> list[dict]:
|
|
| 128 |
logger.debug(f"{content=}")
|
| 129 |
return content
|
| 130 |
|
| 131 |
-
|
| 132 |
def process_new_user_message(message: dict) -> list[dict]:
|
| 133 |
if not message["files"]:
|
| 134 |
return [{"type": "text", "text": message["text"]}]
|
| 135 |
-
|
| 136 |
if message["files"][0].endswith(".mp4"):
|
| 137 |
return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
|
| 138 |
-
|
| 139 |
if "<image>" in message["text"]:
|
| 140 |
return process_interleaved_images(message)
|
| 141 |
-
|
| 142 |
return [
|
| 143 |
{"type": "text", "text": message["text"]},
|
| 144 |
*[{"type": "image", "url": path} for path in message["files"]],
|
| 145 |
]
|
| 146 |
|
| 147 |
-
|
| 148 |
def process_history(history: list[dict]) -> list[dict]:
|
| 149 |
messages = []
|
| 150 |
current_user_content: list[dict] = []
|
|
@@ -162,16 +226,19 @@ def process_history(history: list[dict]) -> list[dict]:
|
|
| 162 |
current_user_content.append({"type": "image", "url": content[0]})
|
| 163 |
return messages
|
| 164 |
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
@spaces.GPU(duration=120)
|
| 167 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
| 168 |
if not validate_media_constraints(message, history):
|
| 169 |
yield ""
|
| 170 |
return
|
| 171 |
|
|
|
|
|
|
|
| 172 |
messages = []
|
| 173 |
-
|
| 174 |
-
messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]})
|
| 175 |
messages.extend(process_history(history))
|
| 176 |
messages.append({"role": "user", "content": process_new_user_message(message)})
|
| 177 |
|
|
@@ -183,22 +250,30 @@ def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tok
|
|
| 183 |
return_tensors="pt",
|
| 184 |
).to(device=model.device, dtype=torch.bfloat16)
|
| 185 |
|
| 186 |
-
streamer = TextIteratorStreamer(
|
|
|
|
|
|
|
|
|
|
| 187 |
generate_kwargs = dict(
|
| 188 |
inputs,
|
| 189 |
streamer=streamer,
|
| 190 |
max_new_tokens=max_new_tokens,
|
| 191 |
disable_compile=True,
|
| 192 |
)
|
|
|
|
|
|
|
|
|
|
| 193 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 194 |
t.start()
|
| 195 |
|
| 196 |
output = ""
|
| 197 |
for delta in streamer:
|
| 198 |
output += delta
|
| 199 |
-
yield output
|
| 200 |
-
|
| 201 |
|
|
|
|
|
|
|
|
|
|
| 202 |
examples = [
|
| 203 |
[
|
| 204 |
{
|
|
@@ -321,11 +396,10 @@ examples = [
|
|
| 321 |
],
|
| 322 |
]
|
| 323 |
|
| 324 |
-
|
| 325 |
DESCRIPTION = """\
|
| 326 |
<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
|
| 327 |
<div align='center'>
|
| 328 |
-
This is a demo of Kenanga 11B IT, a multimodal Large Vision-Language Model (LVLM) adapted for Sundanese and Javanese support
|
| 329 |
You can upload images, as well as interleaved images and videos. Video input is limited to single-turn conversations and must be in MP4 format.
|
| 330 |
</div>
|
| 331 |
"""
|
|
@@ -337,7 +411,7 @@ demo = gr.ChatInterface(
|
|
| 337 |
textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
|
| 338 |
multimodal=True,
|
| 339 |
additional_inputs=[
|
| 340 |
-
gr.Textbox(label="System Prompt", value=
|
| 341 |
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
|
| 342 |
],
|
| 343 |
stop_btn=False,
|
|
|
|
| 12 |
import torch
|
| 13 |
from loguru import logger
|
| 14 |
from PIL import Image
|
| 15 |
+
from transformers import AutoProcessor, TextIteratorStreamer
|
| 16 |
+
|
| 17 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
# Model & processor
|
| 19 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 20 |
+
MODEL_ID = os.getenv("MODEL_ID", "rmdhirr/Kenanga-11B-IT")
|
| 21 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID, padding_side="left")
|
| 22 |
+
|
| 23 |
+
# Try Gemma-3 vision first; if it fails, fall back to Llama 3.2 Vision (Mllama)
|
| 24 |
+
model = None
|
| 25 |
+
_last_load_error = None
|
| 26 |
+
try:
|
| 27 |
+
from transformers import Gemma3ForConditionalGeneration
|
| 28 |
+
model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 29 |
+
MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
| 30 |
+
)
|
| 31 |
+
except Exception as e:
|
| 32 |
+
_last_load_error = e
|
| 33 |
+
try:
|
| 34 |
+
from transformers import MllamaForConditionalGeneration
|
| 35 |
+
model = MllamaForConditionalGeneration.from_pretrained(
|
| 36 |
+
MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager"
|
| 37 |
+
)
|
| 38 |
+
except Exception as e2:
|
| 39 |
+
raise RuntimeError(
|
| 40 |
+
f"Failed to load model as Gemma3 and Mllama.\nGemma3 error: {type(_last_load_error).__name__}: {_last_load_error}\n"
|
| 41 |
+
f"Mllama error: {type(e2).__name__}: {e2}"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
|
| 45 |
|
| 46 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
+
# Identity controls (System Prompt + Stream Sanitizer + Optional Logit Ban)
|
| 48 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
IDENTITY_PROMPT = (
|
| 50 |
+
"You are Kenanga, an Indonesian multimodal LVLM adapted for Sundanese and Javanese.\n"
|
| 51 |
+
"Identity rules:\n"
|
| 52 |
+
"β’ When referring to yourself, always say βKenangaβ.\n"
|
| 53 |
+
"β’ Never claim to be Gemma/Llama or any base model. If asked about your base, reply briefly: "
|
| 54 |
+
"βIβm Kenanga (locally adapted); please refer to me as Kenanga.β\n"
|
| 55 |
+
"β’ Stay helpful, concise, and safe."
|
| 56 |
)
|
| 57 |
|
| 58 |
+
BAN_BASE_NAMES = os.getenv("BAN_BASE_NAMES", "0") == "1"
|
| 59 |
+
|
| 60 |
+
def _make_bad_words_ids(words):
|
| 61 |
+
toks = processor.tokenizer
|
| 62 |
+
ids = []
|
| 63 |
+
for w in words:
|
| 64 |
+
for variant in {w, w.lower(), w.upper(), w.title(), " " + w, " " + w.lower()}:
|
| 65 |
+
enc = toks(variant, add_special_tokens=False).input_ids
|
| 66 |
+
if enc:
|
| 67 |
+
ids.append(enc)
|
| 68 |
+
# dedupe
|
| 69 |
+
uniq, seen = [], set()
|
| 70 |
+
for seq in ids:
|
| 71 |
+
t = tuple(seq)
|
| 72 |
+
if t and t not in seen:
|
| 73 |
+
uniq.append(seq)
|
| 74 |
+
seen.add(t)
|
| 75 |
+
return uniq
|
| 76 |
+
|
| 77 |
+
BAD_WORDS_IDS = _make_bad_words_ids([
|
| 78 |
+
"Gemma", "Gemma-3", "Gemma 3", "Gemma3",
|
| 79 |
+
# Uncomment to ban base model family self-calls entirely:
|
| 80 |
+
# "Llama", "LLaMA", "Llama 3", "Llama 3.2", "Llama3", "Llama3.2",
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
# Only rewrite self-identity claims; allow legitimate mentions in analysis/comparison text
|
| 84 |
+
SELF_REF_PAT = re.compile(
|
| 85 |
+
r"\b(?:(?:I\s*am|I'm|This\s+is|You'?re\s+chatting\s+with)\s+)(Gemma(?:[-\s]?3)?|LLa?ma(?:\s*3(?:\.2)?)?)\b",
|
| 86 |
+
flags=re.IGNORECASE,
|
| 87 |
+
)
|
| 88 |
+
AS_MODEL_PAT = re.compile(
|
| 89 |
+
r"\bAs\s+(?:an?\s+)?(Gemma(?:[-\s]?3)?|LLa?ma(?:\s*3(?:\.2)?)?)\b",
|
| 90 |
+
flags=re.IGNORECASE,
|
| 91 |
+
)
|
| 92 |
+
THIS_MODEL_IS_PAT = re.compile(
|
| 93 |
+
r"\b(This\s+model\s+is)\s+(Gemma(?:[-\s]?3)?|LLa?ma(?:\s*3(?:\.2)?)?)\b",
|
| 94 |
+
flags=re.IGNORECASE,
|
| 95 |
+
)
|
| 96 |
|
| 97 |
+
def sanitize_identity(text: str) -> str:
|
| 98 |
+
text = SELF_REF_PAT.sub("I am Kenanga", text)
|
| 99 |
+
text = AS_MODEL_PAT.sub("As Kenanga", text)
|
| 100 |
+
text = THIS_MODEL_IS_PAT.sub(r"\1 Kenanga", text)
|
| 101 |
+
return text
|
| 102 |
|
| 103 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
# Media utilities
|
| 105 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
|
| 107 |
image_count = 0
|
| 108 |
video_count = 0
|
|
|
|
| 113 |
image_count += 1
|
| 114 |
return image_count, video_count
|
| 115 |
|
|
|
|
| 116 |
def count_files_in_history(history: list[dict]) -> tuple[int, int]:
|
| 117 |
image_count = 0
|
| 118 |
video_count = 0
|
|
|
|
| 125 |
image_count += 1
|
| 126 |
return image_count, video_count
|
| 127 |
|
|
|
|
| 128 |
def validate_media_constraints(message: dict, history: list[dict]) -> bool:
|
| 129 |
new_image_count, new_video_count = count_files_in_new_message(message["files"])
|
| 130 |
history_image_count, history_video_count = count_files_in_history(history)
|
|
|
|
| 148 |
return False
|
| 149 |
return True
|
| 150 |
|
|
|
|
| 151 |
def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
|
| 152 |
vidcap = cv2.VideoCapture(video_path)
|
| 153 |
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 154 |
total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
| 155 |
frame_interval = max(total_frames // MAX_NUM_IMAGES, 1)
|
| 156 |
frames: list[tuple[Image.Image, float]] = []
|
|
|
|
| 157 |
for i in range(0, min(total_frames, MAX_NUM_IMAGES * frame_interval), frame_interval):
|
| 158 |
if len(frames) >= MAX_NUM_IMAGES:
|
| 159 |
break
|
|
|
|
| 160 |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
|
| 161 |
success, image = vidcap.read()
|
| 162 |
if success:
|
|
|
|
| 164 |
pil_image = Image.fromarray(image)
|
| 165 |
timestamp = round(i / fps, 2)
|
| 166 |
frames.append((pil_image, timestamp))
|
|
|
|
| 167 |
vidcap.release()
|
| 168 |
return frames
|
| 169 |
|
|
|
|
| 170 |
def process_video(video_path: str) -> list[dict]:
|
| 171 |
content = []
|
| 172 |
frames = downsample_video(video_path)
|
| 173 |
+
for pil_image, timestamp in frames:
|
|
|
|
| 174 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
|
| 175 |
pil_image.save(temp_file.name)
|
| 176 |
content.append({"type": "text", "text": f"Frame {timestamp}:"})
|
|
|
|
| 178 |
logger.debug(f"{content=}")
|
| 179 |
return content
|
| 180 |
|
|
|
|
| 181 |
def process_interleaved_images(message: dict) -> list[dict]:
|
| 182 |
logger.debug(f"{message['files']=}")
|
| 183 |
parts = re.split(r"(<image>)", message["text"])
|
| 184 |
logger.debug(f"{parts=}")
|
|
|
|
| 185 |
content = []
|
| 186 |
image_index = 0
|
| 187 |
for part in parts:
|
|
|
|
| 197 |
logger.debug(f"{content=}")
|
| 198 |
return content
|
| 199 |
|
|
|
|
| 200 |
def process_new_user_message(message: dict) -> list[dict]:
|
| 201 |
if not message["files"]:
|
| 202 |
return [{"type": "text", "text": message["text"]}]
|
|
|
|
| 203 |
if message["files"][0].endswith(".mp4"):
|
| 204 |
return [{"type": "text", "text": message["text"]}, *process_video(message["files"][0])]
|
|
|
|
| 205 |
if "<image>" in message["text"]:
|
| 206 |
return process_interleaved_images(message)
|
|
|
|
| 207 |
return [
|
| 208 |
{"type": "text", "text": message["text"]},
|
| 209 |
*[{"type": "image", "url": path} for path in message["files"]],
|
| 210 |
]
|
| 211 |
|
|
|
|
| 212 |
def process_history(history: list[dict]) -> list[dict]:
|
| 213 |
messages = []
|
| 214 |
current_user_content: list[dict] = []
|
|
|
|
| 226 |
current_user_content.append({"type": "image", "url": content[0]})
|
| 227 |
return messages
|
| 228 |
|
| 229 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 230 |
+
# Generation
|
| 231 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 232 |
@spaces.GPU(duration=120)
|
| 233 |
def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]:
|
| 234 |
if not validate_media_constraints(message, history):
|
| 235 |
yield ""
|
| 236 |
return
|
| 237 |
|
| 238 |
+
effective_sys = IDENTITY_PROMPT if not system_prompt else (IDENTITY_PROMPT + "\n\n" + system_prompt)
|
| 239 |
+
|
| 240 |
messages = []
|
| 241 |
+
messages.append({"role": "system", "content": [{"type": "text", "text": effective_sys}]})
|
|
|
|
| 242 |
messages.extend(process_history(history))
|
| 243 |
messages.append({"role": "user", "content": process_new_user_message(message)})
|
| 244 |
|
|
|
|
| 250 |
return_tensors="pt",
|
| 251 |
).to(device=model.device, dtype=torch.bfloat16)
|
| 252 |
|
| 253 |
+
streamer = TextIteratorStreamer(
|
| 254 |
+
processor.tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
generate_kwargs = dict(
|
| 258 |
inputs,
|
| 259 |
streamer=streamer,
|
| 260 |
max_new_tokens=max_new_tokens,
|
| 261 |
disable_compile=True,
|
| 262 |
)
|
| 263 |
+
if BAN_BASE_NAMES and BAD_WORDS_IDS:
|
| 264 |
+
generate_kwargs["bad_words_ids"] = BAD_WORDS_IDS
|
| 265 |
+
|
| 266 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 267 |
t.start()
|
| 268 |
|
| 269 |
output = ""
|
| 270 |
for delta in streamer:
|
| 271 |
output += delta
|
| 272 |
+
yield sanitize_identity(output)
|
|
|
|
| 273 |
|
| 274 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 275 |
+
# Demo UI
|
| 276 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 277 |
examples = [
|
| 278 |
[
|
| 279 |
{
|
|
|
|
| 396 |
],
|
| 397 |
]
|
| 398 |
|
|
|
|
| 399 |
DESCRIPTION = """\
|
| 400 |
<img src='https://huggingface.co/spaces/huggingface-projects/gemma-3-12b-it/resolve/main/assets/logo.png' id='logo' />
|
| 401 |
<div align='center'>
|
| 402 |
+
This is a demo of Kenanga 11B IT, a multimodal Large Vision-Language Model (LVLM) adapted for Sundanese and Javanese support.<br/>
|
| 403 |
You can upload images, as well as interleaved images and videos. Video input is limited to single-turn conversations and must be in MP4 format.
|
| 404 |
</div>
|
| 405 |
"""
|
|
|
|
| 411 |
textbox=gr.MultimodalTextbox(file_types=["image", ".mp4"], file_count="multiple", autofocus=True),
|
| 412 |
multimodal=True,
|
| 413 |
additional_inputs=[
|
| 414 |
+
gr.Textbox(label="System Prompt", value=IDENTITY_PROMPT),
|
| 415 |
gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
|
| 416 |
],
|
| 417 |
stop_btn=False,
|