Spaces:
Configuration error
Configuration error
File size: 2,662 Bytes
a0a6a64 e02e941 a539d3b a0a6a64 a539d3b a0a6a64 9c5835b e4974a1 9c5835b e4974a1 e02e941 9c5835b a539d3b e02e941 a539d3b 9c5835b a539d3b 9c5835b a539d3b a0a6a64 9cb2953 442299b 9cb2953 e02e941 a539d3b 9cb2953 3e0b719 a539d3b 9cb2953 442299b 9cb2953 e02e941 a8ec507 16b4096 a539d3b 5ff3449 a539d3b 9cb2953 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import torch
import spaces
import gradio as gr
from PIL import Image
from transformers.utils import move_cache
from huggingface_hub import snapshot_download
from transformers import AutoModelForCausalLM, AutoTokenizer
# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B
MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B"
# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B-int4
# MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B-int4"
### DOWNLOAD ###
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
MODEL_PATH = snapshot_download(MODEL_PATH)
move_cache()
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
## MODEL ##
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True
)
## TOKENIZER ##
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
).to(DEVICE).eval()
text_only_template = """USER: {} ASSISTANT:"""
@spaces.GPU
def generate_caption(image, prompt):
print(DEVICE)
# Process the image and the prompt
# image = Image.open(image_path).convert('RGB')
image = image.convert('RGB')
query = "USER: %s ASSISTANT:" % prompt
input_by_model = model.build_conversation_input_ids(
tokenizer,
query=query,
history=[],
images=[image],
template_version='chat'
)
inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None,
}
gen_kwargs = {
"max_new_tokens": 2048,
"pad_token_id": 128002,
}
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = tokenizer.decode(outputs[0])
response = response.split("<|end_of_text|>")[0]
print("\nCogVLM2:", response)
return response
## make predictions via api ##
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app
demo = gr.Interface(
fn=generate_caption,
inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")],
outputs=gr.Textbox(label="Generated Caption")
)
# Launch the interface
demo.launch(share=True) |