File size: 2,662 Bytes
a0a6a64
e02e941
a539d3b
a0a6a64
a539d3b
 
 
 
 
a0a6a64
9c5835b
e4974a1
9c5835b
 
e4974a1
e02e941
9c5835b
 
a539d3b
 
 
e02e941
a539d3b
 
 
9c5835b
 
a539d3b
 
 
 
9c5835b
 
a539d3b
 
 
 
 
a0a6a64
 
9cb2953
442299b
9cb2953
 
e02e941
a539d3b
9cb2953
3e0b719
a539d3b
9cb2953
 
 
442299b
9cb2953
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e02e941
a8ec507
 
 
 
16b4096
a539d3b
 
5ff3449
a539d3b
 
 
9cb2953
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import torch
import spaces
import gradio as gr
from PIL import Image
from transformers.utils import move_cache
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer


# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B
MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B"

# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B-int4
# MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B-int4"


### DOWNLOAD ###
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
MODEL_PATH = snapshot_download(MODEL_PATH)
move_cache()

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16


## MODEL ##
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    trust_remote_code=True
)

## TOKENIZER ##
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=TORCH_TYPE,
    trust_remote_code=True,
).to(DEVICE).eval()



text_only_template = """USER: {} ASSISTANT:"""


@spaces.GPU
def generate_caption(image, prompt):
    print(DEVICE)
    
    # Process the image and the prompt
    
    # image = Image.open(image_path).convert('RGB')
    image = image.convert('RGB')
    query = "USER: %s ASSISTANT:" % prompt
    input_by_model = model.build_conversation_input_ids(
        tokenizer,
        query=query,
        history=[],
        images=[image],
        template_version='chat'
    )
    inputs = {
        'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
        'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
        'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
        'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None,
    }
    gen_kwargs = {
            "max_new_tokens": 2048,
            "pad_token_id": 128002,  
    }
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        response = tokenizer.decode(outputs[0])
        response = response.split("<|end_of_text|>")[0]
        print("\nCogVLM2:", response)
    return response


## make predictions via api ##
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app

demo = gr.Interface(
    fn=generate_caption,
    inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")],
    outputs=gr.Textbox(label="Generated Caption")
)

# Launch the interface
demo.launch(share=True)