Spaces:
Runtime error
Runtime error
import argparse | |
import gradio as gr | |
import os | |
import re | |
from PIL import Image, ImageDraw, ImageFont, ImageColor | |
import traceback | |
from PIL import Image | |
import spaces | |
import copy | |
from kimi_vl.serve.frontend import reload_javascript | |
from kimi_vl.serve.utils import ( | |
configure_logger, | |
pil_to_base64, | |
parse_ref_bbox, | |
strip_stop_words, | |
is_variable_assigned, | |
) | |
from kimi_vl.serve.gradio_utils import ( | |
cancel_outputing, | |
delete_last_conversation, | |
reset_state, | |
reset_textbox, | |
transfer_input, | |
wrap_gen_fn, | |
) | |
from kimi_vl.serve.chat_utils import ( | |
generate_prompt_with_history, | |
convert_conversation_to_prompts, | |
to_gradio_chatbot, | |
to_gradio_history, | |
) | |
from kimi_vl.serve.inference import kimi_vl_generate, load_model | |
from kimi_vl.serve.examples import get_examples | |
TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Kimi-VL-A3B-Thinking-2506🤔 </h1>""" | |
DESCRIPTION_TOP = """<a href="https://github.com/MoonshotAI/Kimi-VL" target="_blank">Kimi-VL-A3B-Thinking</a> is a multi-modal LLM that can understand text and images, and generate text with thinking processes. This demo has been updated to its new [2506](https://huggingface.co/moonshotai/Kimi-VL-A3B-Thinking-2506) version.""" | |
DESCRIPTION = """""" | |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
DEPLOY_MODELS = dict() | |
logger = configure_logger() | |
def draw_click_action( | |
action: str, | |
point: tuple[float, float], | |
image: Image.Image, | |
color: str = "green", | |
): | |
""" | |
Draw a click action on the image with a bounding box, circle, and text. | |
Args: | |
point (tuple[float, float]): The point to click (x, y) in normalized coordinates (0-1) | |
image (Image.Image): The image to draw on | |
font_path (str): Path to the font file to use for text | |
Returns: | |
Image.Image: The image with the click action drawn on it | |
""" | |
image = image.copy() | |
image_w, image_h = image.size | |
# Convert normalized coordinates to pixel coordinates and clip to image bounds | |
x = int(min(max(point[0], 0), 1) * image_w) | |
y = int(min(max(point[1], 0), 1) * image_h) | |
if isinstance(color, str): | |
try: | |
color = ImageColor.getrgb(color) | |
color = color + (128,) | |
except ValueError: | |
color = (255, 0, 0, 128) | |
else: | |
color = (255, 0, 0, 128) | |
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0)) | |
overlay_draw = ImageDraw.Draw(overlay) | |
radius = min(image.size) * 0.06 | |
overlay_draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], fill=color) | |
center_radius = radius * 0.12 | |
overlay_draw.ellipse( | |
[(x - center_radius, y - center_radius), (x + center_radius, y + center_radius)], fill=(0, 255, 0, 255) | |
) | |
image = image.convert('RGBA') | |
combined = Image.alpha_composite(image, overlay) | |
return combined.convert('RGB') | |
def draw_agent_response(codes, image: Image.Image): | |
""" | |
Draw the agent response on the image. | |
Example: | |
prompt: Please observe the screenshot, please locate the following elements with action and point. | |
<instruction>Half-sectional view | |
the response format is like: | |
<action>Half-sectional view | |
<point> | |
```python | |
pyautogui.click(x=0.150, y=0.081) | |
``` | |
Args: | |
actions (list[str]): the action list | |
codes (list[str]): the code list | |
image (Image.Image): the image to draw on | |
Returns: | |
Image.Image: the image with the action and point drawn on it | |
""" | |
image = image.copy() | |
for code in codes: | |
if "pyautogui.click" in code: | |
# code: 'pyautogui.click(x=0.075, y=0.087)' -> x=0.075, y=0.087 | |
pattern = r'x=([0-9.]+),\s*y=([0-9.]+)' | |
match = re.search(pattern, code) | |
x = float(match.group(1)) # x = 0.075 | |
y = float(match.group(2)) # y = 0.087 | |
# normalize the x and y to the image size and clip the value to the image size | |
x = min(max(x, 0), 1) | |
y = min(max(y, 0), 1) | |
image = draw_click_action("", (x, y), image) | |
return image | |
def parse_and_draw_response(response, image: Image.Image): | |
""" | |
Parse the response and draw the response on the image. | |
""" | |
try: | |
plotted = False | |
# draw agent response, with relaxed judgement | |
if 'pyautogui.click(' in response: | |
# action is between <action> and <point> | |
action = "" | |
# code is between ```python and ``` | |
code = re.findall(r'```python(.*?)```', response, flags=re.DOTALL) | |
image = draw_agent_response(code, image) | |
plotted = True | |
if not plotted: | |
logger.warning("No response to draw") | |
return None | |
return image | |
except Exception as e: | |
traceback.print_exc() | |
logger.error(f"Error parsing reference bounding boxes: {e}") | |
return None | |
def pdf_to_multi_image(local_pdf): | |
import fitz, io | |
doc = fitz.open(local_pdf) | |
all_input_images = [] | |
for page_num in range(len(doc)): | |
page = doc.load_page(page_num) | |
pix = page.get_pixmap(dpi=96) | |
image = Image.open(io.BytesIO(pix.tobytes("png"))) | |
all_input_images.append(image) | |
doc.close() | |
return all_input_images | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", type=str, default="Kimi-VL-A3B-Thinking-2506") | |
parser.add_argument( | |
"--local-path", | |
type=str, | |
default="", | |
help="huggingface ckpt, optional", | |
) | |
parser.add_argument("--ip", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7890) | |
return parser.parse_args() | |
def fetch_model(model_name: str): | |
global args, DEPLOY_MODELS | |
if args.local_path: | |
model_path = args.local_path | |
else: | |
model_path = f"moonshotai/{args.model}" | |
if model_name in DEPLOY_MODELS: | |
model_info = DEPLOY_MODELS[model_name] | |
print(f"{model_name} has been loaded.") | |
else: | |
print(f"{model_name} is loading...") | |
DEPLOY_MODELS[model_name] = load_model(model_path) | |
print(f"Load {model_name} successfully...") | |
model_info = DEPLOY_MODELS[model_name] | |
return model_info | |
def preview_images(files) -> list[str]: | |
if files is None: | |
return [] | |
image_paths = [] | |
for file in files: | |
image_paths.append(file.name) | |
return image_paths | |
def get_prompt(conversation) -> str: | |
""" | |
Get the prompt for the conversation. | |
""" | |
system_prompt = conversation.system_template.format(system_message=conversation.system_message) | |
return system_prompt | |
def highlight_thinking(msg: str) -> str: | |
msg = copy.deepcopy(msg) | |
if "◁think▷" in msg: | |
msg = msg.replace("◁think▷", "<b style='color:blue;'>🤔Thinking...</b>\n") | |
if "◁/think▷" in msg: | |
msg = msg.replace("◁/think▷", "\n<b style='color:purple;'>💡Summary</b>\n") | |
return msg | |
def resize_image(image: Image.Image, max_size: int = 640, min_size: int = 28): | |
width, height = image.size | |
if width < min_size or height < min_size: | |
# Double both dimensions while maintaining aspect ratio | |
scale = min_size / min(width, height) | |
new_width = int(width * scale) | |
new_height = int(height * scale) | |
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
elif max_size > 0 and (width > max_size or height > max_size): | |
# Double both dimensions while maintaining aspect ratio | |
scale = max_size / max(width, height) | |
new_width = int(width * scale) | |
new_height = int(height * scale) | |
image = image.resize((new_width, new_height)) | |
return image | |
def load_frames(video_file, max_num_frames=64, long_edge=448): | |
from decord import VideoReader | |
vr = VideoReader(video_file) | |
duration = len(vr) | |
fps = vr.get_avg_fps() | |
length = int(duration / fps) | |
num_frames = min(max_num_frames, length) | |
frame_timestamps = [int(duration / num_frames * (i+0.5)) / fps for i in range(num_frames)] | |
frame_indices = [int(duration / num_frames * (i+0.5)) for i in range(num_frames)] | |
frames_data = vr.get_batch(frame_indices).asnumpy() | |
imgs = [] | |
for idx in range(num_frames): | |
img = resize_image(Image.fromarray(frames_data[idx]).convert("RGB"), long_edge) | |
imgs.append(img) | |
return imgs, frame_timestamps | |
def predict( | |
text, | |
images, | |
chatbot, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
max_context_length_tokens, | |
video_num_frames, | |
video_long_edge, | |
system_prompt, | |
chunk_size: int = 512, | |
): | |
""" | |
Predict the response for the input text and images. | |
Args: | |
text (str): The input text. | |
images (list[PIL.Image.Image]): The input images. | |
chatbot (list): The chatbot. | |
history (list): The history. | |
top_p (float): The top-p value. | |
temperature (float): The temperature value. | |
repetition_penalty (float): The repetition penalty value. | |
max_length_tokens (int): The max length tokens. | |
max_context_length_tokens (int): The max context length tokens. | |
chunk_size (int): The chunk size. | |
system_prompt (str): Default | |
""" | |
print("running the prediction function") | |
print("system prompt overrided by user as:", system_prompt) | |
try: | |
model, processor = fetch_model(args.model) | |
if text == "": | |
yield chatbot, history, "Empty context." | |
return | |
except KeyError: | |
yield [[text, "No Model Found"]], [], "No Model Found" | |
return | |
if images is None: | |
images = [] | |
# load images | |
pil_images = [] | |
timestamps = None | |
for img_or_file in images: | |
if img_or_file.endswith(".pdf") or img_or_file.endswith(".PDF"): | |
pil_images = pdf_to_multi_image(img_or_file) | |
continue | |
try: | |
# load as pil image | |
if isinstance(images, Image.Image): | |
pil_images.append(img_or_file) | |
else: | |
image = Image.open(img_or_file.name).convert("RGB") | |
pil_images.append(image) | |
except: | |
try: | |
pil_images, timestamps = load_frames(img_or_file, video_num_frames, video_long_edge) | |
## Only allow one video as input | |
break | |
except Exception as e: | |
print(f"Error loading image or video: {e}") | |
# generate prompt | |
conversation = generate_prompt_with_history( | |
text, | |
pil_images, | |
timestamps, | |
history, | |
processor, | |
max_length=max_context_length_tokens, | |
) | |
all_conv, last_image = convert_conversation_to_prompts(conversation) | |
stop_words = conversation.stop_str | |
gradio_chatbot_output = to_gradio_chatbot(conversation) | |
full_response = "" | |
for x in kimi_vl_generate( | |
conversations=all_conv, | |
override_system_prompt=system_prompt, | |
model=model, | |
processor=processor, | |
stop_words=stop_words, | |
max_length=max_length_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
full_response += x | |
response = strip_stop_words(full_response, stop_words) | |
conversation.update_last_message(response) | |
gradio_chatbot_output[-1][1] = highlight_thinking(response) | |
yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..." | |
if last_image is not None: | |
vg_image = parse_and_draw_response(response, last_image) | |
if vg_image is not None: | |
vg_base64 = pil_to_base64(vg_image, "vg", max_size=2048, min_size=400) | |
# the end of the last message will be ```python ``` | |
gradio_chatbot_output[-1][1] += '\n\n' + vg_base64 | |
yield gradio_chatbot_output, to_gradio_history(conversation), "Generating..." | |
logger.info("flushed result to gradio") | |
if is_variable_assigned("x"): | |
print( | |
f"temperature: {temperature}, " | |
f"top_p: {top_p}, " | |
f"max_length_tokens: {max_length_tokens}" | |
) | |
yield gradio_chatbot_output, to_gradio_history(conversation), "Generate: Success" | |
def retry( | |
text, | |
images, | |
chatbot, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
max_context_length_tokens, | |
video_num_frames, | |
video_long_edge, | |
system_prompt, | |
chunk_size: int = 512, | |
): | |
""" | |
Retry the response for the input text and images. | |
""" | |
if len(history) == 0: | |
yield (chatbot, history, "Empty context") | |
return | |
chatbot.pop() | |
history.pop() | |
text = history.pop()[-1] | |
if type(text) is tuple: | |
text, _ = text | |
yield from predict( | |
text, | |
images, | |
chatbot, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
max_context_length_tokens, | |
video_num_frames, | |
video_long_edge, | |
system_prompt, | |
chunk_size, | |
) | |
def build_demo(args: argparse.Namespace) -> gr.Blocks: | |
with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo: | |
history = gr.State([]) | |
input_text = gr.State() | |
input_images = gr.State() | |
with gr.Row(): | |
gr.HTML(TITLE) | |
status_display = gr.Markdown("Success", elem_id="status_display") | |
gr.Markdown(DESCRIPTION_TOP) | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=4): | |
with gr.Row(): | |
chatbot = gr.Chatbot( | |
elem_id="Kimi-VL-A3B-Thinking-chatbot", | |
show_share_button=True, | |
bubble_full_width=False, | |
height=600, | |
) | |
with gr.Row(): | |
system_prompt = gr.Textbox(show_label=False, placeholder="Customize system prompt", container=False) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False) | |
with gr.Column(min_width=70): | |
submit_btn = gr.Button("Send") | |
with gr.Column(min_width=70): | |
cancel_btn = gr.Button("Stop") | |
with gr.Row(): | |
empty_btn = gr.Button("🧹 New Conversation") | |
retry_btn = gr.Button("🔄 Regenerate") | |
del_last_btn = gr.Button("🗑️ Remove Last Turn") | |
with gr.Column(): | |
# add note no more than 2 images once | |
upload_images = gr.Files(file_types=["image", "video", ".pdf"], show_label=True) | |
gallery = gr.Gallery(columns=[3], height="200px", show_label=True) | |
upload_images.change(preview_images, inputs=upload_images, outputs=gallery) | |
# Parameter Setting Tab for control the generation parameters | |
with gr.Tab(label="Parameter Setting"): | |
top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p") | |
temperature = gr.Slider( | |
minimum=0, maximum=1.0, value=0.8, step=0.1, interactive=True, label="Temperature" | |
) | |
max_length_tokens = gr.Slider( | |
minimum=512, maximum=16384, value=2048, step=64, interactive=True, label="Max Length Tokens" | |
) | |
max_context_length_tokens = gr.Slider( | |
minimum=512, maximum=16384, value=4096, step=64, interactive=True, label="Max Context Length Tokens" | |
) | |
video_num_frames = gr.Slider( | |
minimum=1, maximum=64, value=16, step=1, interactive=True, label="Max Number of Frames for Video" | |
) | |
video_long_edge = gr.Slider( | |
minimum=28, maximum=896, value=448, step=28, interactive=True, label="Long Edge of Video" | |
) | |
show_images = gr.HTML(visible=False) | |
gr.Examples( | |
examples=get_examples(ROOT_DIR), | |
inputs=[upload_images, show_images, system_prompt, text_box], | |
) | |
gr.Markdown() | |
input_widgets = [ | |
input_text, | |
input_images, | |
chatbot, | |
history, | |
top_p, | |
temperature, | |
max_length_tokens, | |
max_context_length_tokens, | |
video_num_frames, | |
video_long_edge, | |
system_prompt | |
] | |
output_widgets = [chatbot, history, status_display] | |
transfer_input_args = dict( | |
fn=transfer_input, | |
inputs=[text_box, upload_images], | |
outputs=[input_text, input_images, text_box, upload_images, submit_btn], | |
show_progress=True, | |
) | |
predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True) | |
retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True) | |
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display]) | |
predict_events = [ | |
text_box.submit(**transfer_input_args).then(**predict_args), | |
submit_btn.click(**transfer_input_args).then(**predict_args), | |
] | |
empty_btn.click(reset_state, outputs=output_widgets, show_progress=True) | |
empty_btn.click(**reset_args) | |
retry_btn.click(**retry_args) | |
del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True) | |
cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events) | |
demo.title = "Kimi-VL-A3B-Thinking-2506 Chatbot" | |
return demo | |
def main(args: argparse.Namespace): | |
demo = build_demo(args) | |
reload_javascript() | |
# concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS | |
favicon_path = os.path.join("kimi_vl/serve/assets/favicon.ico") | |
demo.queue().launch( | |
favicon_path=favicon_path, | |
server_name=args.ip, | |
server_port=args.port, | |
) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |