import gradio as gr import torch import json from io import BytesIO from PIL import Image, ImageOps from IPython.display import display, Markdown from transformers import AutoModelForCausalLM, LlamaTokenizer from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch # Initialize tokenizer and model tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') # tokenizer = LlamaTokenizer.from_pretrained('vicuna-7b-v1.5') model = AutoModelForCausalLM.from_pretrained( 'THUDM/cogvlm-chat-hf', load_in_4bit=True, trust_remote_code=True, device_map="auto" ).eval() def generate_description(image, query, top_p, top_k, output_length, temperature): # Use the uploaded image (PIL format) display_size = (224, 224) image = image.resize(display_size, Image.LANCZOS) # Build the conversation input inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image]) # Prepare the inputs dictionary for model.generate() inputs = { 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'), 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'), 'images': [[inputs['images'][0].to('cuda').to(torch.float16)]], } # Set the generation kwargs with user-defined values gen_kwargs = { "max_length": output_length, "do_sample": True, # Enable sampling to use top_p, top_k, and temperature "top_p": top_p, "top_k": top_k, "temperature": temperature } # Generate the description with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) description = tokenizer.decode(outputs[0], skip_special_tokens=True) return description with gr.Blocks() as app: gr.Markdown("# Visual Product DNA - Image to Attribute Extractor") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Image", type="pil", height=500) gr.skip query_input = gr.Textbox(label="Enter your prompt", value="Capture all attributes as JSON", lines=4) with gr.Column(): top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.1, label="Creativity (top_p)") top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=100, label="Coherence (top_k)") output_length_slider = gr.Slider(minimum=1, maximum=4096, step=1, value=2048, label="Output Length") temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.01, value=0.1, label="Temperature") submit_button = gr.Button("Extract Attributes") description_output = gr.Textbox(label="Generated JSON", lines=12) submit_button.click( fn=generate_description, inputs=[image_input, query_input, top_p_slider, top_k_slider, output_length_slider, temperature_slider], outputs=description_output ) app.launch(share=True, input = False)