Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from ola_vlm.constants import DEFAULT_IMAGE_TOKEN | |
| from ola_vlm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | |
| from ola_vlm.conversation import conv_templates, SeparatorStyle | |
| from ola_vlm.model.builder import load_pretrained_model | |
| from ola_vlm.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images | |
| from diffusers import StableUnCLIPImg2ImgPipeline, DPMSolverMultistepScheduler | |
| from transformers import OneFormerProcessor | |
| from ola_vlm.model.aux_heads.oneformer_head import OneFormerHead | |
| from ola_vlm.ola_utils import visualize_oneformer_masks_on_image, oneformer_prepare_panoptic_instance_prediction | |
| import matplotlib | |
| from PIL import Image, ImageDraw, ImageFont | |
| import argparse | |
| import math | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| import subprocess | |
| # Install flash attention, skipping CUDA build if necessary | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| def make_grid(pil_images, layer_indices=None): | |
| new_images = [] | |
| new_captions = [] | |
| # Resize images and prepare captions | |
| for i, pil_image in enumerate(pil_images): | |
| pil_image = pil_image.resize((256, 256)) | |
| new_images.append(pil_image) | |
| if layer_indices is not None: | |
| new_captions.append(f"Layer: {layer_indices[i]}") | |
| else: | |
| new_captions.append(f"Layer: {i+1}") | |
| images = new_images | |
| captions = new_captions | |
| width, height = images[0].size | |
| font_size = 18 | |
| # Calculate the number of rows and columns for the grid | |
| images_per_row = min(len(images), 4) # Max 4 images per row | |
| row_count = math.ceil(len(images) / images_per_row) | |
| total_width = width * images_per_row | |
| total_height = height * row_count | |
| # Create a new blank image | |
| new_image = Image.new("RGB", (total_width, total_height), "white") | |
| draw = ImageDraw.Draw(new_image) | |
| # Load a default font | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", font_size) | |
| except: | |
| font = ImageFont.load_default() | |
| # Place images and captions in the grid | |
| for i, (image, caption) in enumerate(zip(images, captions)): | |
| row = i // images_per_row | |
| col = i % images_per_row | |
| x_offset = col * width | |
| y_offset = row * height | |
| # Paste the image | |
| new_image.paste(image, (x_offset, y_offset)) | |
| # Calculate text and background positions | |
| text_width, text_height = draw.textsize(caption, font=font) | |
| text_position = (x_offset + 10, y_offset + height - text_height - 10) | |
| background_position = ( | |
| text_position[0] - 5, | |
| text_position[1] - 5, | |
| text_position[0] + text_width + 5, | |
| text_position[1] + text_height + 5, | |
| ) | |
| # Draw background rectangle and text | |
| draw.rectangle(background_position, fill="white", outline="black") | |
| draw.text(text_position, caption, fill="black", font=font) | |
| return new_image | |
| def reload_from_ckpt(model_path, model, cache_dir=None): | |
| import os | |
| from safetensors import safe_open | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| state_dict = {} | |
| # Check if the path is a local directory or HF Hub model | |
| if os.path.isdir(model_path): | |
| # Local directory: Load safetensors files | |
| safetensors_paths = [os.path.join(model_path, f) for f in os.listdir(model_path) if f.endswith('.safetensors')] | |
| else: | |
| # HF Hub: Get list of safetensors files and download them | |
| repo_files = list_repo_files(model_path) | |
| safetensors_paths = [ | |
| hf_hub_download(model_path, file_name, cache_dir=cache_dir) | |
| for file_name in repo_files if file_name.endswith('.safetensors') | |
| ] | |
| # Load safetensors files into the state_dict | |
| for path in safetensors_paths: | |
| with safe_open(path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| # Load the state dict into the model | |
| model.load_state_dict(state_dict, strict=False) | |
| return model | |
| # os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp' | |
| no_change_btn = gr.Button() | |
| enable_btn = gr.Button(interactive=True) | |
| disable_btn = gr.Button(interactive=False) | |
| argparser = argparse.ArgumentParser() | |
| argparser.add_argument("--server_name", default="0.0.0.0", type=str) | |
| argparser.add_argument("--port", default="6324", type=str) | |
| argparser.add_argument("--model-path", default="shi-labs/pretrain_dsg_OLA-VLM-CLIP-ViT-Llama3-8b", type=str) | |
| argparser.add_argument("--model-base", type=str, default=None) | |
| argparser.add_argument("--num-gpus", type=int, default=1) | |
| argparser.add_argument("--conv-mode", type=str, default="llava_llama_3") | |
| argparser.add_argument("--temperature", type=float, default=0.2) | |
| argparser.add_argument("--max-new-tokens", type=int, default=512) | |
| argparser.add_argument("--num_frames", type=int, default=16) | |
| argparser.add_argument("--load-8bit", action="store_true") | |
| argparser.add_argument("--load-4bit", action="store_true") | |
| argparser.add_argument("--debug", action="store_true") | |
| args = argparser.parse_args() | |
| model_path = args.model_path | |
| conv_mode = args.conv_mode | |
| filt_invalid="cut" | |
| model_name = get_model_name_from_path(args.model_path) | |
| tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) | |
| model = reload_from_ckpt("shi-labs/OLA-VLM-CLIP-ViT-Llama3-8b", model) | |
| our_chatbot = None | |
| pipe = StableUnCLIPImg2ImgPipeline.from_pretrained(f"stabilityai/stable-diffusion-2-1-unclip", torch_dtype=torch.float16, variant="fp16") | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| oneformer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large") | |
| oneformer = OneFormerHead.from_pretrained("shi-labs/oneformer_coco_swin_large") | |
| gen_layer_indices = model.config.image_gen["img_layer_indices"].split("-") | |
| seg_layer_indices = model.config.image_seg["seg_layer_indices"].split("-") | |
| depth_layer_indices = model.config.image_depth["depth_layer_indices"].split("-") | |
| def clear_history(): | |
| state =conv_templates[conv_mode].copy() | |
| return (state, state.to_gradio_chatbot(), "", None, None, None, None) + (disable_btn,) * 5 | |
| def add_text(state, imagebox, textbox, image_process_mode): | |
| if state is None: | |
| state = conv_templates[conv_mode].copy() | |
| if imagebox is not None: | |
| textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox | |
| image = Image.open(imagebox).convert('RGB') | |
| if imagebox is not None: | |
| textbox = (textbox, image, image_process_mode) | |
| state.append_message(state.roles[0], textbox) | |
| state.append_message(state.roles[1], None) | |
| yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
| def get_gen_images(out, pipe): | |
| pipe = pipe.to("cuda") | |
| img_embeds = out.image_embs | |
| if len(img_embeds) == 0: | |
| return None | |
| images = [] | |
| for img_embed in img_embeds: | |
| gen_image = pipe(image_embeds=img_embed.squeeze(1), | |
| num_inference_steps=25, | |
| ).images[0] | |
| images.append(gen_image) | |
| grid_image = make_grid(images, gen_layer_indices) | |
| return grid_image | |
| def get_depth_images(out, org_size): | |
| depth_preds = out.depth_preds | |
| if len(depth_preds) == 0: | |
| return None | |
| depths = [] | |
| for i, depth_pred in enumerate(depth_preds): | |
| depth = (depth_pred - depth_pred.min()) / (depth_pred.max() - depth_pred.min()) * 255.0 | |
| depth = depth.squeeze(0).cpu().numpy() | |
| depth = depth.astype(np.uint8) | |
| cmap = matplotlib.colormaps.get_cmap('Spectral_r') | |
| depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) | |
| depth = Image.fromarray(depth) | |
| depth = depth.resize(org_size) | |
| depths.append(depth) | |
| grid_image = make_grid(depths, depth_layer_indices) | |
| return grid_image | |
| def get_seg_images(out, image, oneformer): | |
| oneformer = oneformer.to("cuda") | |
| seg_embs = out.seg_embs | |
| if len(seg_embs) == 0: | |
| return None | |
| seg_preds = [] | |
| inputs = oneformer_processor(image, ["semantic"], return_tensors="pt") | |
| inputs["pixel_values"] = inputs["pixel_values"].to(out.logits.device, out.logits.dtype) | |
| inputs["task_inputs"] = inputs["task_inputs"].to(out.logits.device, out.logits.dtype) | |
| backbone_features = oneformer.get_backbone_feats(**inputs) | |
| for i, seg_emb in enumerate(seg_embs): | |
| pred = oneformer.get_masks(**inputs, backbone_last_feature=seg_emb.float(), all_backbone_features=backbone_features) | |
| pred = oneformer_processor.post_process_panoptic_segmentation( | |
| pred, target_sizes=[image.size[::-1]] | |
| )[0] | |
| pred_msk, pred_cls = oneformer_prepare_panoptic_instance_prediction(**pred, oneformer=oneformer) | |
| pred = visualize_oneformer_masks_on_image(image, pred_msk, pred_cls) | |
| seg_preds.append(pred) | |
| grid_image = make_grid(seg_preds, seg_layer_indices) | |
| return grid_image | |
| def delete_text(state, image_process_mode): | |
| state.messages[-1][-1] = None | |
| prev_human_msg = state.messages[-2] | |
| if type(prev_human_msg[1]) in (tuple, list): | |
| prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
| yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
| def regenerate(state, image_process_mode): | |
| state.messages[-1][-1] = None | |
| prev_human_msg = state.messages[-2] | |
| if type(prev_human_msg[1]) in (tuple, list): | |
| prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
| state.skip_next = False | |
| return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
| def get_interm_outs(state): | |
| prompt = state.get_prompt() | |
| images = state.get_images(return_pil=True) | |
| #prompt, image_args = process_image(prompt, images) | |
| if images is not None and len(images) > 0: | |
| if len(images) > 0: | |
| if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): | |
| raise ValueError("Number of images does not match number of <image> tokens in prompt") | |
| #images = [load_image_from_base64(image) for image in images] | |
| image_sizes = [image.size for image in images] | |
| inp_images = process_images(images, image_processor, model.config) | |
| if type(inp_images) is list: | |
| inp_images = [image.to(model.device, dtype=torch.float16) for image in images] | |
| else: | |
| inp_images = inp_images.to(model.device, dtype=torch.float16) | |
| else: | |
| inp_images = None | |
| image_sizes = None | |
| image_args = {"images": inp_images, "image_sizes": image_sizes} | |
| else: | |
| inp_images = None | |
| image_args = {} | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
| interm_outs = model.get_visual_interpretations( | |
| input_ids, | |
| **image_args | |
| ) | |
| depth_outs = get_depth_images(interm_outs, image_sizes[0]) | |
| seg_outs = get_seg_images(interm_outs, images[0], oneformer) | |
| gen_outs = get_gen_images(interm_outs, pipe) | |
| return depth_outs, seg_outs, gen_outs | |
| def generate(state, temperature, top_p, max_output_tokens): | |
| prompt = state.get_prompt() | |
| images = state.get_images(return_pil=True) | |
| #prompt, image_args = process_image(prompt, images) | |
| ori_prompt = prompt | |
| num_image_tokens = 0 | |
| if images is not None and len(images) > 0: | |
| if len(images) > 0: | |
| if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): | |
| raise ValueError("Number of images does not match number of <image> tokens in prompt") | |
| #images = [load_image_from_base64(image) for image in images] | |
| image_sizes = [image.size for image in images] | |
| images = process_images(images, image_processor, model.config) | |
| if type(images) is list: | |
| images = [image.to(model.device, dtype=torch.float16) for image in images] | |
| else: | |
| images = images.to(model.device, dtype=torch.float16) | |
| else: | |
| images = None | |
| image_sizes = None | |
| image_args = {"images": images, "image_sizes": image_sizes} | |
| else: | |
| images = None | |
| image_args = {} | |
| max_context_length = getattr(model.config, 'max_position_embeddings', 2048) | |
| max_new_tokens = max_output_tokens | |
| do_sample = True if temperature > 0.001 else False | |
| stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2 | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) | |
| max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) | |
| if max_new_tokens < 1: | |
| return | |
| thread = Thread(target=model.generate, kwargs=dict( | |
| inputs=input_ids, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_new_tokens=max_new_tokens, | |
| streamer=streamer, | |
| use_cache=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| **image_args | |
| )) | |
| thread.start() | |
| generated_text = '' | |
| for new_text in streamer: | |
| generated_text += new_text | |
| if generated_text.endswith(stop_str): | |
| generated_text = generated_text[:-len(stop_str)] | |
| state.messages[-1][-1] = generated_text | |
| yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
| yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5 | |
| torch.cuda.empty_cache() | |
| txt = gr.Textbox( | |
| scale=4, | |
| show_label=False, | |
| placeholder="Enter text and press enter.", | |
| container=False, | |
| ) | |
| title = "<h1 style='margin-bottom: -10px; text-align: center'>Elevating Visual Perception in Multimodal LLMs with Auxiliary Embedding Distillation</h1>" | |
| description = "<p style='font-size: 16px; margin: 5px; font-weight: w300; text-align: center'> <a href='https://praeclarumjj3.github.io/' style='text-decoration:none' target='_blank'>Jitesh Jain</a>   <a href='https://zyang-ur.github.io/' style='text-decoration:none' target='_blank'>Zhengyuan Yang</a>   <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Humphrey Shi<sup>*</sup></a>   <a href='https://www.humphreyshi.com/home' style='text-decoration:none' target='_blank'>Jianfeng Gao<sup>*</sup></a>   <a href='https://jwyang.github.io/' style='text-decoration:none' target='_blank'>Jianwei Yang<sup>*</sup></a></p>" \ | |
| + "<p style='font-size: 12px; margin: 5px; font-weight: w300; text-align: center'><sup>*</sup>Equal Advising</p>" \ | |
| + "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://praeclarumjj3.github.io/visper_lm/' target='_blank'>Project Page</a> | <a href='https://youtu.be/' target='_blank'>Video</a> | <a href='https://arxiv.org/abs/2412.09585' target='_blank'>ArXiv</a> | <a href='https://github.com/SHI-Labs/VisPer-LM' target='_blank'>Github</a></p>" \ | |
| + "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'>VisPer-LM introduces a new approach to distilling vision knowledge into the hidden representations of LLMs, utilizing target representations to advance visual perception in MLLMs.</p>" \ | |
| + "<p style='text-align: left; font-size: 14px; margin: 5px; font-weight: w300;'>In the demo, along with the chatting with VisPer-LM, you can also visualize the intermediate representations from selected layers of the LLM by clicking on the <code style='font-size: 14px;'>Visualize Intermediate Representations</code> button! Note that our demo only supports single image input currently.</p>" \ | |
| + "<ul style='text-align: left; font-size: 14px; margin: 5px; font-weight: w300; padding: 0;'> \ | |
| <li><b>depth</b>: Visualizes the depth information in the representations using the decoder from the <a href='https://github.com/DepthAnything/Depth-Anything-V2' target='_blank'>Depth-Anything-v2 model</a>.</li> \ | |
| <li><b>seg</b>: Visualizes the segmentation information in the representations using the decoder from the <a href='https://github.com/SHI-Labs/OneFormer' target='_blank'>OneFormer model</a>.</li> \ | |
| <li><b>gen</b>: Visualizes the general information of the representations using the <a href='https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip' target='_blank'>SD-2.1-unCLIP</a>. Note that the output is a variation of the input image due to the nature of unCLIP.</li> \ | |
| </ul>" | |
| tos_markdown = (""" | |
| ### Terms of use | |
| By using this service, users are required to agree to the following terms: | |
| The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. | |
| """) | |
| learn_more_markdown = (""" | |
| ### License | |
| The service is a research preview intended for non-commercial use only, subject to the [License](https://huggingface.co/lmsys/vicuna-7b-v1.5) of Vicuna-v1.5, [License](https://github.com/haotian-liu/LLaVA/blob/main/LICENSE) of LLaVA, [Terms of Use](https://cocodataset.org/#termsofuse) of the COCO dataset, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. | |
| """) | |
| block_css = """ | |
| #buttons button { | |
| min-width: min(120px,100%); | |
| } | |
| """ | |
| textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) | |
| with gr.Blocks(title="VisPer-LM", theme=gr.themes.Default(), css=block_css) as demo: | |
| state = gr.State() | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| imagebox = gr.Image(label="Input Image", type="filepath") | |
| image_process_mode = gr.Radio( | |
| ["Crop", "Resize", "Pad", "Default"], | |
| value="Default", | |
| label="Preprocess for non-square image", visible=False) | |
| # with gr.Accordion("Parameters", open=False) as parameter_row: | |
| with gr.Row(): | |
| temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) | |
| top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) | |
| max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) | |
| with gr.Column(scale=8): | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| label="VisPer-LM", | |
| height=300, | |
| layout="panel", | |
| ) | |
| textbox.render() | |
| with gr.Row(elem_id="buttons") as button_row: | |
| upvote_btn = gr.Button(value="π Upvote", interactive=False, visible=False) | |
| downvote_btn = gr.Button(value="π Downvote", interactive=False, visible=False) | |
| flag_btn = gr.Button(value="β οΈ Flag", interactive=False, visible=False) | |
| #stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False) | |
| regenerate_btn = gr.Button(value="π Regenerate", interactive=False) | |
| clear_btn = gr.Button(value="ποΈ Clear", interactive=False) | |
| submit_btn = gr.Button(value="Send", variant="primary") | |
| # with gr.Accordion("Representations from selected layers of the LLM (expects only a single image input)", open=False) as interm_out: | |
| inter_vis_btn = gr.Button(value="β¨ Visualize Intermediate Representations") | |
| with gr.Row(): | |
| depth_box = gr.Image(label="depth", type="pil", visible=True) | |
| seg_box = gr.Image(label="seg", type="pil", visible=True) | |
| gen_box = gr.Image(label="gen", type="pil", visible=True) | |
| gr.Examples(examples=[ | |
| [f"assets/cars.jpg", "Which car is in front: the blue or the brown one?"], | |
| [f"assets/pb.jpg", "Where is the bulding located with respect to the man?"], | |
| ], inputs=[imagebox, textbox], cache_examples=False) | |
| # gr.Markdown(tos_markdown) | |
| # gr.Markdown(learn_more_markdown) | |
| # url_params = gr.JSON(visible=False) | |
| # Register listeners | |
| btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
| inter_vis_btn.click( | |
| get_interm_outs, | |
| [state], | |
| [depth_box, seg_box, gen_box], | |
| ) | |
| clear_btn.click( | |
| clear_history, | |
| None, | |
| [state, chatbot, textbox, imagebox, depth_box, gen_box, seg_box] + btn_list, | |
| queue=False | |
| ) | |
| regenerate_btn.click( | |
| delete_text, | |
| [state, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| ).then( | |
| generate, | |
| [state, temperature, top_p, max_output_tokens], | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| ) | |
| textbox.submit( | |
| add_text, | |
| [state, imagebox, textbox, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| ).then( | |
| generate, | |
| [state, temperature, top_p, max_output_tokens], | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| ) | |
| submit_btn.click( | |
| add_text, | |
| [state, imagebox, textbox, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| ).then( | |
| generate, | |
| [state, temperature, top_p, max_output_tokens], | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| ) | |
| demo.queue( | |
| status_update_rate=10, | |
| api_open=False | |
| ).launch(share=False) | |
| demo.queue() |