Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -20,6 +20,7 @@ from transformers import (
|
|
| 20 |
TextIteratorStreamer,
|
| 21 |
Qwen2VLForConditionalGeneration,
|
| 22 |
AutoProcessor,
|
|
|
|
| 23 |
)
|
| 24 |
from transformers.image_utils import load_image
|
| 25 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
|
@@ -208,6 +209,15 @@ def save_image(img: Image.Image) -> str:
|
|
| 208 |
img.save(unique_name)
|
| 209 |
return unique_name
|
| 210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
# -----------------------
|
| 212 |
# MAIN GENERATION FUNCTION
|
| 213 |
# -----------------------
|
|
@@ -225,7 +235,8 @@ def generate(
|
|
| 225 |
files = input_dict.get("files", [])
|
| 226 |
|
| 227 |
lower_text = text.lower().strip()
|
| 228 |
-
|
|
|
|
| 229 |
if (lower_text.startswith("@lightningv5") or
|
| 230 |
lower_text.startswith("@lightningv4") or
|
| 231 |
lower_text.startswith("@turbov3")):
|
|
@@ -277,6 +288,52 @@ def generate(
|
|
| 277 |
yield gr.Image(image_path)
|
| 278 |
return
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
# Otherwise, handle text/chat (and TTS) generation.
|
| 281 |
tts_prefix = "@tts"
|
| 282 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
|
@@ -391,7 +448,7 @@ demo = gr.ChatInterface(
|
|
| 391 |
description=DESCRIPTION,
|
| 392 |
css=css,
|
| 393 |
fill_height=True,
|
| 394 |
-
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 for
|
| 395 |
stop_btn="Stop Generation",
|
| 396 |
multimodal=True,
|
| 397 |
)
|
|
|
|
| 20 |
TextIteratorStreamer,
|
| 21 |
Qwen2VLForConditionalGeneration,
|
| 22 |
AutoProcessor,
|
| 23 |
+
Gemma3ForConditionalGeneration, # New import for Gemma3-4B
|
| 24 |
)
|
| 25 |
from transformers.image_utils import load_image
|
| 26 |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
|
|
|
|
| 209 |
img.save(unique_name)
|
| 210 |
return unique_name
|
| 211 |
|
| 212 |
+
# -----------------------
|
| 213 |
+
# GEMMA3-4B MULTIMODAL MODEL
|
| 214 |
+
# -----------------------
|
| 215 |
+
gemma3_model_id = "google/gemma-3-4b-it"
|
| 216 |
+
gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
|
| 217 |
+
gemma3_model_id, device_map="auto"
|
| 218 |
+
).eval()
|
| 219 |
+
gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
|
| 220 |
+
|
| 221 |
# -----------------------
|
| 222 |
# MAIN GENERATION FUNCTION
|
| 223 |
# -----------------------
|
|
|
|
| 235 |
files = input_dict.get("files", [])
|
| 236 |
|
| 237 |
lower_text = text.lower().strip()
|
| 238 |
+
|
| 239 |
+
# Image Generation Branch (Stable Diffusion models)
|
| 240 |
if (lower_text.startswith("@lightningv5") or
|
| 241 |
lower_text.startswith("@lightningv4") or
|
| 242 |
lower_text.startswith("@turbov3")):
|
|
|
|
| 288 |
yield gr.Image(image_path)
|
| 289 |
return
|
| 290 |
|
| 291 |
+
# GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
|
| 292 |
+
if lower_text.startswith("@gemma3-4b"):
|
| 293 |
+
# Remove the gemma3 flag from the prompt.
|
| 294 |
+
prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
|
| 295 |
+
if files:
|
| 296 |
+
# If image files are provided, load them.
|
| 297 |
+
images = [load_image(f) for f in files]
|
| 298 |
+
messages = [{
|
| 299 |
+
"role": "user",
|
| 300 |
+
"content": [
|
| 301 |
+
*[{"type": "image", "image": image} for image in images],
|
| 302 |
+
{"type": "text", "text": prompt_clean},
|
| 303 |
+
]
|
| 304 |
+
}]
|
| 305 |
+
else:
|
| 306 |
+
messages = [
|
| 307 |
+
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
|
| 308 |
+
{"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
|
| 309 |
+
]
|
| 310 |
+
inputs = gemma3_processor.apply_chat_template(
|
| 311 |
+
messages, add_generation_prompt=True, tokenize=True,
|
| 312 |
+
return_dict=True, return_tensors="pt"
|
| 313 |
+
).to(gemma3_model.device, dtype=torch.bfloat16)
|
| 314 |
+
streamer = TextIteratorStreamer(
|
| 315 |
+
gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
|
| 316 |
+
)
|
| 317 |
+
generation_kwargs = {
|
| 318 |
+
**inputs,
|
| 319 |
+
"streamer": streamer,
|
| 320 |
+
"max_new_tokens": max_new_tokens,
|
| 321 |
+
"do_sample": True,
|
| 322 |
+
"temperature": temperature,
|
| 323 |
+
"top_p": top_p,
|
| 324 |
+
"top_k": top_k,
|
| 325 |
+
"repetition_penalty": repetition_penalty,
|
| 326 |
+
}
|
| 327 |
+
thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
|
| 328 |
+
thread.start()
|
| 329 |
+
buffer = ""
|
| 330 |
+
yield progress_bar_html("Processing with Gemma3-4b")
|
| 331 |
+
for new_text in streamer:
|
| 332 |
+
buffer += new_text
|
| 333 |
+
time.sleep(0.01)
|
| 334 |
+
yield buffer
|
| 335 |
+
return
|
| 336 |
+
|
| 337 |
# Otherwise, handle text/chat (and TTS) generation.
|
| 338 |
tts_prefix = "@tts"
|
| 339 |
is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
|
|
|
|
| 448 |
description=DESCRIPTION,
|
| 449 |
css=css,
|
| 450 |
fill_height=True,
|
| 451 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 or @gemma3-4b for multimodal gen !"),
|
| 452 |
stop_btn="Stop Generation",
|
| 453 |
multimodal=True,
|
| 454 |
)
|