mohan007's picture
autism_plus_edynamix_demo_15
61350be
# Necessary imports
import sys
import gradio as gr
import spaces
from decord import VideoReader, cpu
# from PIL import Image
# Local imports
from src.config import (
device,
model_name,
sampling,
stream,
repetition_penalty,
)
from src.minicpm.model import load_model_tokenizer_and_processor
from src.logger import logging
from src.exception import CustomExceptionHandling
# Model, tokenizer and processor
model, tokenizer, processor = load_model_tokenizer_and_processor(model_name, device)
MAX_NUM_FRAMES=64
# def encode_video(video_path):
# MAX_NUM_FRAMES=64
# def uniform_sample(l, n):
# gap = len(l) / n
# idxs = [int(i * gap + gap / 2) for i in range(n)]
# return [l[i] for i in idxs]
# vr = VideoReader(video_path, ctx=cpu(0))
# sample_fps = round(vr.get_avg_fps() / 1) # FPS
# frame_idx = [i for i in range(0, len(vr), sample_fps)]
# if len(frame_idx) > MAX_NUM_FRAMES:
# frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
# frames = vr.get_batch(frame_idx).asnumpy()
# frames = [Image.fromarray(v.astype('uint8')) for v in frames]
# print('num frames:', len(frames))
# return frames
@spaces.GPU(duration=120)
def describe_image(
image: str,
question: str,
temperature: float,
top_p: float,
top_k: int,
max_new_tokens: int,
) -> str:
"""
Generates an answer to a given question based on the provided image and question.
Args:
- image (str): The path to the image file.
- question (str): The question text.
- temperature (float): The temperature parameter for the model.
- top_p (float): The top_p parameter for the model.
- top_k (int): The top_k parameter for the model.
- max_new_tokens (int): The max tokens to be generated by the model.
Returns:
str: The generated answer to the question.
"""
try:
# Check if image or question is None
if not image or not question:
gr.Warning("Please provide an image and a question.")
# Message format for the model
msgs = [{"role": "user", "content": [image, question]}]
# Generate the answer
answer = model.chat(
image=None,
msgs=msgs,
tokenizer=tokenizer,
processor=processor,
sampling=sampling,
stream=stream,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
)
# Log the successful generation of the answer
logging.info("Answer generated successfully.")
# Return the answer
return "".join(answer)
# Handle exceptions that may occur during answer generation
except Exception as e:
# Custom exception handling
raise CustomExceptionHandling(e, sys) from e