File size: 3,117 Bytes
76d42da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
import torch
import json
from io import BytesIO
from PIL import Image, ImageOps
from IPython.display import display, Markdown
from transformers import AutoModelForCausalLM, LlamaTokenizer
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch

# Initialize tokenizer and model
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
# tokenizer = LlamaTokenizer.from_pretrained('vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
        'THUDM/cogvlm-chat-hf',
        load_in_4bit=True,
        trust_remote_code=True,
        device_map="auto"
    ).eval()

def generate_description(image, query, top_p, top_k, output_length, temperature):
    # Use the uploaded image (PIL format)
    display_size = (224, 224)
    image = image.resize(display_size, Image.LANCZOS)

    # Build the conversation input
    inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image])

    # Prepare the inputs dictionary for model.generate()
    inputs = {
        'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
        'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
        'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
        'images': [[inputs['images'][0].to('cuda').to(torch.float16)]],
    }

    # Set the generation kwargs with user-defined values
    gen_kwargs = {
        "max_length": output_length,
        "do_sample": True,  # Enable sampling to use top_p, top_k, and temperature
        "top_p": top_p,
        "top_k": top_k,
        "temperature": temperature
    }

    # Generate the description
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        description = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    return description

with gr.Blocks() as app:
    gr.Markdown("# Visual Product DNA - Image to Attribute Extractor")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Image", type="pil", height=500) 
            gr.skip 
            query_input = gr.Textbox(label="Enter your prompt", value="Capture all attributes as JSON", lines=4)  
            
        with gr.Column():
            top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.1, label="Creativity (top_p)")
            top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=100, label="Coherence (top_k)")
            output_length_slider = gr.Slider(minimum=1, maximum=4096, step=1, value=2048, label="Output Length")
            temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.01, value=0.1, label="Temperature")
            submit_button = gr.Button("Extract Attributes")
            description_output = gr.Textbox(label="Generated JSON", lines=12)  
            
    submit_button.click(
        fn=generate_description, 
        inputs=[image_input, query_input, top_p_slider, top_k_slider, output_length_slider, temperature_slider], 
        outputs=description_output
    )

app.launch(share=True, input = False)