lixin4ever's picture
Update app.py (#6)
bbed053 verified
raw
history blame
7.58 kB
import os
import os.path as osp
import gradio as gr
import spaces
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
HEADER = ("""
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;">
<img src="https://github.com/DAMO-NLP-SG/VideoLLaMA3/blob/main/assets/logo.png?raw=true" alt="VideoLLaMA 3 🔥🚀🔥" style="max-width: 120px; height: auto;">
</a>
<div>
<h1>VideoLLaMA 3: Frontier Multimodal Foundation Models for Video Understanding</h1>
<h5 style="margin: 0;">If this demo please you, please give us a star ⭐ on Github or 💖 on this space.</h5>
</div>
</div>
<div style="display: flex; justify-content: center; margin-top: 10px;">
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3"><img src='https://img.shields.io/badge/Github-VideoLLaMA3-9C276A' style="margin-right: 5px;"></a>
<a href="https://arxiv.org/pdf/2501.13106"><img src="https://img.shields.io/badge/Arxiv-2501.13106-AD1C18" style="margin-right: 5px;"></a>
<a href="https://huggingface.co/collections/DAMO-NLP-SG/videollama3-678cdda9281a0e32fe79af15"><img src="https://img.shields.io/badge/🤗-Checkpoints-ED5A22.svg" style="margin-right: 5px;"></a>
<a href="https://github.com/DAMO-NLP-SG/VideoLLaMA3/stargazers"><img src="https://img.shields.io/github/stars/DAMO-NLP-SG/VideoLLaMA3.svg?style=social"></a>
</div>
""")
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(
"DAMO-NLP-SG/VideoLLaMA3-7B",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.to(device)
processor = AutoProcessor.from_pretrained("DAMO-NLP-SG/VideoLLaMA3-7B", trust_remote_code=True)
example_dir = "./examples"
image_formats = ("png", "jpg", "jpeg")
video_formats = ("mp4",)
image_examples, video_examples = [], []
if example_dir is not None:
example_files = [
osp.join(example_dir, f) for f in os.listdir(example_dir)
]
for example_file in example_files:
if example_file.endswith(image_formats):
image_examples.append([example_file])
elif example_file.endswith(video_formats):
video_examples.append([example_file])
def _on_video_upload(messages, video):
if video is not None:
# messages.append({"role": "user", "content": gr.Video(video)})
messages.append({"role": "user", "content": {"path": video}})
return messages, None
def _on_image_upload(messages, image):
if image is not None:
# messages.append({"role": "user", "content": gr.Image(image)})
messages.append({"role": "user", "content": {"path": image}})
return messages, None
def _on_text_submit(messages, text):
messages.append({"role": "user", "content": text})
return messages, ""
@spaces.GPU(duration=120)
def _predict(messages, input_text, do_sample, temperature, top_p, max_new_tokens,
fps, max_frames):
if len(input_text) > 0:
messages.append({"role": "user", "content": input_text})
new_messages = []
contents = []
for message in messages:
if message["role"] == "assistant":
if len(contents):
new_messages.append({"role": "user", "content": contents})
contents = []
new_messages.append(message)
elif message["role"] == "user":
if isinstance(message["content"], str):
contents.append(message["content"])
else:
media_path = message["content"][0]
if media_path.endswith(video_formats):
contents.append({"type": "video", "video": {"video_path": media_path, "fps": fps, "max_frames": max_frames}})
elif media_path.endswith(image_formats):
contents.append({"type": "image", "image": {"image_path": media_path}})
else:
raise ValueError(f"Unsupported media type: {media_path}")
if len(contents):
new_messages.append({"role": "user", "content": contents})
if len(new_messages) == 0 or new_messages[-1]["role"] != "user":
return messages
generation_config = {
"do_sample": do_sample,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_new_tokens
}
inputs = processor(
conversation=new_messages,
add_system_prompt=True,
add_generation_prompt=True,
return_tensors="pt"
)
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
**inputs,
**generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
messages.append({"role": "assistant", "content": ""})
for token in streamer:
messages[-1]['content'] += token
yield messages
with gr.Blocks() as interface:
gr.HTML(HEADER)
with gr.Row():
chatbot = gr.Chatbot(type="messages", elem_id="chatbot", height=835)
with gr.Column():
with gr.Tab(label="Input"):
with gr.Row():
input_video = gr.Video(sources=["upload"], label="Upload Video")
input_image = gr.Image(sources=["upload"], type="filepath", label="Upload Image")
input_text = gr.Textbox(label="Input Text", placeholder="Type your message here and press enter to submit")
submit_button = gr.Button("Generate")
gr.Examples(examples=[
[f"examples/bear.mp4", "What is unusual in the video?"],
[f"examples/dog.mp4", "Please describe the video in detail."],
[f"examples/running.mp4", "Who won the competition?"],
], inputs=[input_video, input_text], label="Video examples")
with gr.Tab(label="Configure"):
with gr.Accordion("Generation Config", open=True):
do_sample = gr.Checkbox(value=True, label="Do Sample")
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
max_new_tokens = gr.Slider(minimum=0, maximum=4096, value=2048, step=1, label="Max New Tokens")
with gr.Accordion("Video Config", open=True):
fps = gr.Slider(minimum=0.0, maximum=10.0, value=1, label="FPS")
max_frames = gr.Slider(minimum=0, maximum=256, value=180, step=1, label="Max Frames")
input_video.change(_on_video_upload, [chatbot, input_video], [chatbot, input_video])
input_image.change(_on_image_upload, [chatbot, input_image], [chatbot, input_image])
input_text.submit(_on_text_submit, [chatbot, input_text], [chatbot, input_text])
submit_button.click(
_predict,
[
chatbot, input_text, do_sample, temperature, top_p, max_new_tokens,
fps, max_frames
],
[chatbot],
)
if __name__ == "__main__":
interface.launch()