Spaces:
Runtime error
Runtime error
""" | |
This script creates a Gradio demo with a Transformers backend for the glm-4v-9b model, allowing users to interact with the model through a Gradio web UI. | |
Usage: | |
- Run the script to start the Gradio server. | |
- Interact with the model via the web UI. | |
Requirements: | |
- Gradio package | |
- Type `pip install gradio` to install Gradio. | |
""" | |
import os | |
import torch | |
import gradio as gr | |
from threading import Thread | |
from transformers import ( | |
AutoTokenizer, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
TextIteratorStreamer, AutoModel, BitsAndBytesConfig | |
) | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4v-9b') | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATH, | |
trust_remote_code=True, | |
encode_special_tokens=True | |
) | |
model = AutoModel.from_pretrained( | |
MODEL_PATH, | |
trust_remote_code=True, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
).eval() | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = model.config.eos_token_id | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def get_image(image_path=None, image_url=None): | |
if image_path: | |
return Image.open(image_path).convert("RGB") | |
elif image_url: | |
response = requests.get(image_url) | |
return Image.open(BytesIO(response.content)).convert("RGB") | |
return None | |
def chatbot(image_path=None, image_url=None, assistant_prompt=""): | |
image = get_image(image_path, image_url) | |
messages = [ | |
{"role": "assistant", "content": assistant_prompt}, | |
{"role": "user", "content": "", "image": image} | |
] | |
model_inputs = tokenizer.apply_chat_template( | |
messages, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_tensors="pt", | |
return_dict=True | |
).to(next(model.parameters()).device) | |
streamer = TextIteratorStreamer( | |
tokenizer=tokenizer, | |
timeout=60, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
generate_kwargs = { | |
**model_inputs, | |
"streamer": streamer, | |
"max_new_tokens": 1024, | |
"do_sample": True, | |
"top_p": 0.8, | |
"temperature": 0.6, | |
"stopping_criteria": StoppingCriteriaList([StopOnTokens()]), | |
"repetition_penalty": 1.2, | |
"eos_token_id": [151329, 151336, 151338], | |
} | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
response = "" | |
for new_token in streamer: | |
if new_token: | |
response += new_token | |
return image, response.strip() | |
with gr.Blocks() as demo: | |
demo.title = "GLM-4V-9B Image Recognition Demo" | |
demo.description = """ | |
This demo uses the GLM-4V-9B model to got image infomation. | |
""" | |
with gr.Row(): | |
with gr.Column(): | |
image_path_input = gr.File(label="Upload Image (High-Priority)", type="filepath") | |
image_url_input = gr.Textbox(label="Image URL (Low-Priority)") | |
assistant_prompt_input = gr.Textbox(label="Assistant Prompt (You Can Change It)", value="这是什么?") | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
chatbot_output = gr.Textbox(label="GLM-4V-9B Model Response") | |
image_output = gr.Image(label="Image Preview") | |
submit_button.click(chatbot, | |
inputs=[image_path_input, image_url_input, assistant_prompt_input], | |
outputs=[image_output, chatbot_output]) | |
demo.launch(server_name="127.0.0.1", server_port=8911, inbrowser=True, share=False) |