Spaces:
Configuration error
Configuration error
import os | |
import torch | |
import spaces | |
import gradio as gr | |
from PIL import Image | |
from transformers.utils import move_cache | |
from huggingface_hub import snapshot_download | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B | |
MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B" | |
# https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B-int4 | |
# MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B-int4" | |
### DOWNLOAD ### | |
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' | |
MODEL_PATH = snapshot_download(MODEL_PATH) | |
move_cache() | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 | |
## MODEL ## | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATH, | |
trust_remote_code=True | |
) | |
## TOKENIZER ## | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_PATH, | |
torch_dtype=TORCH_TYPE, | |
trust_remote_code=True, | |
).to(DEVICE).eval() | |
text_only_template = """USER: {} ASSISTANT:""" | |
def generate_caption(image, prompt): | |
print(DEVICE) | |
# Process the image and the prompt | |
# image = Image.open(image_path).convert('RGB') | |
image = image.convert('RGB') | |
query = "USER: %s ASSISTANT:" % prompt | |
input_by_model = model.build_conversation_input_ids( | |
tokenizer, | |
query=query, | |
history=[], | |
images=[image], | |
template_version='chat' | |
) | |
inputs = { | |
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), | |
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), | |
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), | |
'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None, | |
} | |
gen_kwargs = { | |
"max_new_tokens": 2048, | |
"pad_token_id": 128002, | |
} | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, **gen_kwargs) | |
outputs = outputs[:, inputs['input_ids'].shape[1]:] | |
response = tokenizer.decode(outputs[0]) | |
response = response.split("<|end_of_text|>")[0] | |
print("\nCogVLM2:", response) | |
return response | |
## make predictions via api ## | |
# https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app | |
demo = gr.Interface( | |
fn=generate_caption, | |
inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")], | |
outputs=gr.Textbox(label="Generated Caption") | |
) | |
# Launch the interface | |
demo.launch(share=True) |