lycaoduong's picture
update new model
4cdd788
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria
import gradio as gr
import torch
import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import os
from threading import Thread
import re
import time
from PIL import Image
import torch
import spaces
import subprocess
from time import sleep
import base64
from io import BytesIO
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# torch.set_default_device('cuda')
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=12):
if isinstance(image_file, str): # Check if it's a file path
image = Image.open(image_file).convert('RGB')
else: # Assume it's a base64 string
image_data = base64.b64decode(image_file)
image = Image.open(BytesIO(image_data)).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
model_name = "lycaoduong/KLPintern-v2-1B"
access_token = os.getenv("LPInternVL21B")
model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
token=access_token
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False, token=access_token)
@spaces.GPU
def chat(message, history):
# print(history)
# print(message)
model.to('cuda')
test_image = message["files"][0]
pixel_values = load_image(test_image, max_num=12).to(torch.bfloat16).cuda()
generation_config = dict(max_new_tokens= 1024, do_sample=True, num_beams = 3, repetition_penalty=2.5)
question = '<image>\n'+message["text"]
response, conv_history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
# print(f'User: {question}\nAssistant: {response}')
txt_stream = ''
for c in response:
sleep(0.01)
txt_stream = txt_stream + c
yield txt_stream
# return response
CSS ="""
# @media only screen and (max-width: 600px){
# #component-3 {
# height: 90dvh !important;
# transform-origin: top; /* Ensure that the element expands from top to bottom. */
# border-style: solid;
# overflow: hidden;
# flex-grow: 1;
# min-width: min(160px, 100%);
# border-width: var(--block-border-width);
# }
# }
#component-3 {
height: 50dvh !important;
transform-origin: top; /* Ensure that the element expands from top to bottom. */
border-style: solid;
overflow: hidden;
flex-grow: 1;
min-width: min(160px, 100%);
border-width: var(--block-border-width);
}
/* Ensure that the image inside the button is displayed correctly for buttons with a specified aria-label. */
button.svelte-1lcyrx4[aria-label="user's message: a file of type image/jpeg, "] img.svelte-1pijsyv {
width: 100%;
object-fit: contain;
height: 100%;
border-radius: 13px; /* Add rounded corners to the image. */
max-width: 50vw; /* Limit the image width. */
}
/* Set the height for the button and allow text selection only for buttons with a specified aria-label. */
button.svelte-1lcyrx4[aria-label="user's message: a file of type image/jpeg, "] {
user-select: text;
text-align: left;
height: 300px;
}
/* Add border-radius and limit the width for images that do not belong to the avatar container. */
.message-wrap.svelte-1lcyrx4 > div.svelte-1lcyrx4 .svelte-1lcyrx4:not(.avatar-container) img {
border-radius: 13px;
max-width: 50vw;
}
.message-wrap.svelte-1lcyrx4 .message.svelte-1lcyrx4 img {
margin: var(--size-2);
max-height: 500px;
}
"""
demo = gr.ChatInterface(
fn=chat,
description="Test Korean License Plate OCR VLMs in this demo.",
examples=[{"text": "Extract LPR information.", "files":["./demo/P0.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P1.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P2.jpg"]},
{"text": "Extract LPR information, return in JSON format", "files":["./demo/P3.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P4.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P5.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P6.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P7.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P8.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P9.jpg"]},
{"text": "Extract LPR information.", "files":["./demo/P10.jpg"]},
],
title="❄️ InternVL2-1B fine-tuned for Korean License Plate recognition.❄️",
multimodal=True,
css=CSS
)
demo.queue().launch()