Kimi-VL-A3B / app.py
teowu's picture
Update app.py
72746ee verified
import argparse
import gradio as gr
import os
import re
from PIL import Image, ImageDraw, ImageFont, ImageColor
import traceback
from PIL import Image
import spaces
import copy
from kimi_vl.serve.frontend import reload_javascript
from kimi_vl.serve.utils import (
configure_logger,
pil_to_base64,
parse_ref_bbox,
strip_stop_words,
is_variable_assigned,
)
from kimi_vl.serve.gradio_utils import (
cancel_outputing,
delete_last_conversation,
reset_state,
reset_textbox,
transfer_input,
wrap_gen_fn,
)
from kimi_vl.serve.chat_utils import (
generate_prompt_with_history,
convert_conversation_to_prompts,
to_gradio_chatbot,
to_gradio_history,
)
from kimi_vl.serve.inference import kimi_vl_generate, load_model
from kimi_vl.serve.examples import get_examples
TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-VL-A3B-Thinking-2506🤔 </h1>"""
DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-VL" target="_blank">Kimi-VL-A3B-Thinking</a> is a multi-modal LLM that can understand text and images, and generate text with thinking processes. This demo has been updated to its new [2506](https://huggingface.co/moonshotai/Kimi-VL-A3B-Thinking-2506) version."""
DESCRIPTION = """"""
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
DEPLOY_MODELS = dict()
logger = configure_logger()
def draw_click_action(
action: str,
point: tuple[float, float],
image: Image.Image,
color: str = "green",
):
"""
Draw a click action on the image with a bounding box, circle, and text.
Args:
point (tuple[float, float]): The point to click (x, y) in normalized coordinates (0-1)
image (Image.Image): The image to draw on
font_path (str): Path to the font file to use for text
Returns:
Image.Image: The image with the click action drawn on it
"""
image = image.copy()
image_w, image_h = image.size
# Convert normalized coordinates to pixel coordinates and clip to image bounds
x = int(min(max(point[0], 0), 1) * image_w)
y = int(min(max(point[1], 0), 1) * image_h)
if isinstance(color, str):
try:
color = ImageColor.getrgb(color)
color = color + (128,)
except ValueError:
color = (255, 0, 0, 128)
else:
color = (255, 0, 0, 128)
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
overlay_draw = ImageDraw.Draw(overlay)
radius = min(image.size) * 0.06
overlay_draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], fill=color)
center_radius = radius * 0.12
overlay_draw.ellipse(
[(x - center_radius, y - center_radius), (x + center_radius, y + center_radius)], fill=(0, 255, 0, 255)
)
image = image.convert('RGBA')
combined = Image.alpha_composite(image, overlay)
return combined.convert('RGB')
def draw_agent_response(codes, image: Image.Image):
"""
Draw the agent response on the image.
Example:
prompt: Please observe the screenshot, please locate the following elements with action and point.
<instruction>Half-sectional view
the response format is like:
<action>Half-sectional view
<point>
```python
pyautogui.click(x=0.150, y=0.081)
```
Args:
actions (list[str]): the action list
codes (list[str]): the code list
image (Image.Image): the image to draw on
Returns:
Image.Image: the image with the action and point drawn on it
"""
image = image.copy()
for code in codes:
if "pyautogui.click" in code:
# code: 'pyautogui.click(x=0.075, y=0.087)' -> x=0.075, y=0.087
pattern = r'x=([0-9.]+),\s*y=([0-9.]+)'
match = re.search(pattern, code)
x = float(match.group(1)) # x = 0.075
y = float(match.group(2)) # y = 0.087
# normalize the x and y to the image size and clip the value to the image size
x = min(max(x, 0), 1)
y = min(max(y, 0), 1)
image = draw_click_action("", (x, y), image)
return image
def parse_and_draw_response(response, image: Image.Image):
"""
Parse the response and draw the response on the image.
"""
try:
plotted = False
# draw agent response, with relaxed judgement
if 'pyautogui.click(' in response:
# action is between <action> and <point>
action = ""
# code is between ```python and ```
code = re.findall(r'```python(.*?)```', response, flags=re.DOTALL)
image = draw_agent_response(code, image)
plotted = True
if not plotted:
logger.warning("No response to draw")
return None
return image
except Exception as e:
traceback.print_exc()
logger.error(f"Error parsing reference bounding boxes: {e}")
return None
def pdf_to_multi_image(local_pdf):
import fitz, io
doc = fitz.open(local_pdf)
all_input_images = []
for page_num in range(len(doc)):
page = doc.load_page(page_num)
pix = page.get_pixmap(dpi=96)
image = Image.open(io.BytesIO(pix.tobytes("png")))
all_input_images.append(image)
doc.close()
return all_input_images
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="Kimi-VL-A3B-Thinking-2506")
parser.add_argument(
"--local-path",
type=str,
default="",
help="huggingface ckpt, optional",
)
parser.add_argument("--ip", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7890)
return parser.parse_args()
def fetch_model(model_name: str):
global args, DEPLOY_MODELS
if args.local_path:
model_path = args.local_path
else:
model_path = f"moonshotai/{args.model}"
if model_name in DEPLOY_MODELS:
model_info = DEPLOY_MODELS[model_name]
print(f"{model_name} has been loaded.")
else:
print(f"{model_name} is loading...")
DEPLOY_MODELS[model_name] = load_model(model_path)
print(f"Load {model_name} successfully...")
model_info = DEPLOY_MODELS[model_name]
return model_info
def preview_images(files) -> list[str]:
if files is None:
return []
image_paths = []
for file in files:
image_paths.append(file.name)
return image_paths
def get_prompt(conversation) -> str:
"""
Get the prompt for the conversation.
"""
system_prompt = conversation.system_template.format(system_message=conversation.system_message)
return system_prompt
def highlight_thinking(msg: str) -> str:
msg = copy.deepcopy(msg)
if "◁think▷" in msg:
msg = msg.replace("◁think▷", "<b style='color:blue;'>🤔Thinking...</b>\n")
if "◁/think▷" in msg:
msg = msg.replace("◁/think▷", "\n<b style='color:purple;'>💡Summary</b>\n")
return msg
def resize_image(image: Image.Image, max_size: int = 640, min_size: int = 28):
width, height = image.size
if width < min_size or height < min_size:
# Double both dimensions while maintaining aspect ratio
scale = min_size / min(width, height)
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
elif max_size > 0 and (width > max_size or height > max_size):
# Double both dimensions while maintaining aspect ratio
scale = max_size / max(width, height)
new_width = int(width * scale)
new_height = int(height * scale)
image = image.resize((new_width, new_height))
return image
def load_frames(video_file, max_num_frames=64, long_edge=448):
from decord import VideoReader
vr = VideoReader(video_file)
duration = len(vr)
fps = vr.get_avg_fps()
length = int(duration / fps)
num_frames = min(max_num_frames, length)
frame_timestamps = [int(duration / num_frames * (i+0.5)) / fps for i in range(num_frames)]
frame_indices = [int(duration / num_frames * (i+0.5)) for i in range(num_frames)]
frames_data = vr.get_batch(frame_indices).asnumpy()
imgs = []
for idx in range(num_frames):
img = resize_image(Image.fromarray(frames_data[idx]).convert("RGB"), long_edge)
imgs.append(img)
return imgs, frame_timestamps
@wrap_gen_fn
@spaces.GPU(duration=30)
def predict(
text,
images,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
video_num_frames,
video_long_edge,
system_prompt,
chunk_size: int = 512,
):
"""
Predict the response for the input text and images.
Args:
text (str): The input text.
images (list[PIL.Image.Image]): The input images.
chatbot (list): The chatbot.
history (list): The history.
top_p (float): The top-p value.
temperature (float): The temperature value.
repetition_penalty (float): The repetition penalty value.
max_length_tokens (int): The max length tokens.
max_context_length_tokens (int): The max context length tokens.
chunk_size (int): The chunk size.
system_prompt (str): Default
"""
print("running the prediction function")
print("system prompt overrided by user as:", system_prompt)
try:
model, processor = fetch_model(args.model)
if text == "":
yield chatbot, history, "Empty context."
return
except KeyError:
yield [[text, "No Model Found"]], [], "No Model Found"
return
if images is None:
images = []
# load images
pil_images = []
timestamps = None
for img_or_file in images:
if img_or_file.endswith(".pdf") or img_or_file.endswith(".PDF"):
pil_images = pdf_to_multi_image(img_or_file)
continue
try:
# load as pil image
if isinstance(images, Image.Image):
pil_images.append(img_or_file)
else:
image = Image.open(img_or_file.name).convert("RGB")
pil_images.append(image)
except:
try:
pil_images, timestamps = load_frames(img_or_file, video_num_frames, video_long_edge)
## Only allow one video as input
break
except Exception as e:
print(f"Error loading image or video: {e}")
# generate prompt
conversation = generate_prompt_with_history(
text,
pil_images,
timestamps,
history,
processor,
max_length=max_context_length_tokens,
)
all_conv, last_image = convert_conversation_to_prompts(conversation)
stop_words = conversation.stop_str
gradio_chatbot_output = to_gradio_chatbot(conversation)
full_response = ""
for x in kimi_vl_generate(
conversations=all_conv,
override_system_prompt=system_prompt,
model=model,
processor=processor,
stop_words=stop_words,
max_length=max_length_tokens,
temperature=temperature,
top_p=top_p,
):
full_response += x
response = strip_stop_words(full_response, stop_words)
conversation.update_last_message(response)
gradio_chatbot_output[-1][1] = highlight_thinking(response)
yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
if last_image is not None:
vg_image = parse_and_draw_response(response, last_image)
if vg_image is not None:
vg_base64 = pil_to_base64(vg_image, "vg", max_size=2048, min_size=400)
# the end of the last message will be ```python ```
gradio_chatbot_output[-1][1] += '\n\n' + vg_base64
yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..."
logger.info("flushed result to gradio")
if is_variable_assigned("x"):
print(
f"temperature: {temperature}, "
f"top_p: {top_p}, "
f"max_length_tokens: {max_length_tokens}"
)
yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success"
def retry(
text,
images,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
video_num_frames,
video_long_edge,
system_prompt,
chunk_size: int = 512,
):
"""
Retry the response for the input text and images.
"""
if len(history) == 0:
yield (chatbot, history, "Empty context")
return
chatbot.pop()
history.pop()
text = history.pop()[-1]
if type(text) is tuple:
text, _ = text
yield from predict(
text,
images,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
video_num_frames,
video_long_edge,
system_prompt,
chunk_size,
)
def build_demo(args: argparse.Namespace) -> gr.Blocks:
with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo:
history = gr.State([])
input_text = gr.State()
input_images = gr.State()
with gr.Row():
gr.HTML(TITLE)
status_display = gr.Markdown("Success", elem_id="status_display")
gr.Markdown(DESCRIPTION_TOP)
with gr.Row(equal_height=True):
with gr.Column(scale=4):
with gr.Row():
chatbot = gr.Chatbot(
elem_id="Kimi-VL-A3B-Thinking-chatbot",
show_share_button=True,
bubble_full_width=False,
height=600,
)
with gr.Row():
system_prompt = gr.Textbox(show_label=False, placeholder="Customize system prompt", container=False)
with gr.Row():
with gr.Column(scale=4):
text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False)
with gr.Column(min_width=70):
submit_btn = gr.Button("Send")
with gr.Column(min_width=70):
cancel_btn = gr.Button("Stop")
with gr.Row():
empty_btn = gr.Button("🧹 New Conversation")
retry_btn = gr.Button("🔄 Regenerate")
del_last_btn = gr.Button("🗑️ Remove Last Turn")
with gr.Column():
# add note no more than 2 images once
upload_images = gr.Files(file_types=["image", "video", ".pdf"], show_label=True)
gallery = gr.Gallery(columns=[3], height="200px", show_label=True)
upload_images.change(preview_images, inputs=upload_images, outputs=gallery)
# Parameter Setting Tab for control the generation parameters
with gr.Tab(label="Parameter Setting"):
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p")
temperature = gr.Slider(
minimum=0, maximum=1.0, value=0.8, step=0.1, interactive=True, label="Temperature"
)
max_length_tokens = gr.Slider(
minimum=512, maximum=16384, value=2048, step=64, interactive=True, label="Max Length Tokens"
)
max_context_length_tokens = gr.Slider(
minimum=512, maximum=16384, value=4096, step=64, interactive=True, label="Max Context Length Tokens"
)
video_num_frames = gr.Slider(
minimum=1, maximum=64, value=16, step=1, interactive=True, label="Max Number of Frames for Video"
)
video_long_edge = gr.Slider(
minimum=28, maximum=896, value=448, step=28, interactive=True, label="Long Edge of Video"
)
show_images = gr.HTML(visible=False)
gr.Examples(
examples=get_examples(ROOT_DIR),
inputs=[upload_images, show_images, system_prompt, text_box],
)
gr.Markdown()
input_widgets = [
input_text,
input_images,
chatbot,
history,
top_p,
temperature,
max_length_tokens,
max_context_length_tokens,
video_num_frames,
video_long_edge,
system_prompt
]
output_widgets = [chatbot, history, status_display]
transfer_input_args = dict(
fn=transfer_input,
inputs=[text_box, upload_images],
outputs=[input_text, input_images, text_box, upload_images, submit_btn],
show_progress=True,
)
predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True)
retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True)
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display])
predict_events = [
text_box.submit(**transfer_input_args).then(**predict_args),
submit_btn.click(**transfer_input_args).then(**predict_args),
]
empty_btn.click(reset_state, outputs=output_widgets, show_progress=True)
empty_btn.click(**reset_args)
retry_btn.click(**retry_args)
del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True)
cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events)
demo.title = "Kimi-VL-A3B-Thinking-2506 Chatbot"
return demo
def main(args: argparse.Namespace):
demo = build_demo(args)
reload_javascript()
# concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS
favicon_path = os.path.join("kimi_vl/serve/assets/favicon.ico")
demo.queue().launch(
favicon_path=favicon_path,
server_name=args.ip,
server_port=args.port,
)
if __name__ == "__main__":
args = parse_args()
main(args)