testapp / app.py
soniakhamitkar's picture
Update app.py
b5b61f1 verified
pip install torch torchvision torchaudio
import io
import argparse
import numpy as np
import torch
from decord import cpu, VideoReader, bridge
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
parser = argparse.ArgumentParser(description="CogVLM2 Video to Text")
parser.add_argument('--video', type=str, required=True, help="Path to the video file")
parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=0)
args = parser.parse_args()
def load_video(video_path, strategy='chat'):
bridge.set_bridge('torch')
with open(video_path, 'rb') as f:
video_stream = f.read()
num_frames = 24
decord_vr = VideoReader(io.BytesIO(video_stream), ctx=cpu(0))
frame_id_list = None
total_frames = len(decord_vr)
if strategy == 'base':
clip_end_sec = 60
clip_start_sec = 0
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
end_frame = min(total_frames, int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
elif strategy == 'chat':
timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
timestamps = [i[0] for i in timestamps]
max_second = round(max(timestamps)) + 1
frame_id_list = []
for second in range(max_second):
closest_num = min(timestamps, key=lambda x: abs(x - second))
index = timestamps.index(closest_num)
frame_id_list.append(index)
if len(frame_id_list) >= num_frames:
break
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True
).eval().to(DEVICE)
def predict(video_path, temperature=0.1):
strategy = 'chat'
prompt = "Please describe this video in detail."
video_data = load_video(video_path, strategy=strategy)
history = []
inputs = model.build_conversation_input_ids(
tokenizer=tokenizer,
query=prompt,
images=[video_data],
history=history,
template_version=strategy
)
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
"top_k": 1,
"do_sample": False,
"top_p": 0.1,
"temperature": temperature,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
if __name__ == '__main__':
video_file = args.video
response_text = predict(video_file)
print("\nGenerated Text Description:\n")
print(response_text)