from typing import Dict, Any, List from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor from qwen_vl_utils import process_vision_info class EndpointHandler: def __init__(self, path: str = "") -> None: # Load the Qwen2-VL-7B-Instruct model on available devices. # The torch_dtype is set to "auto" and device_map="auto" for optimal device usage. self.model = Qwen2VLForConditionalGeneration.from_pretrained( path, torch_dtype="auto", device_map="auto" ) # Load the default processor which handles text formatting, image resizing, # and optionally video preprocessing for Qwen2-VL. self.processor = AutoProcessor.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # Extract the conversation messages from the input data. messages = data.get("messages") if messages is None: raise ValueError("Input data must contain a 'messages' key with conversation data.") # Create the text prompt using the processor’s chat template function. # This will add necessary system and generation prompts. text_prompt = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process any visual inputs (images and/or videos) from the messages. # The helper function from qwen_vl_utils handles various formats (URLs, base64, local files). image_inputs, video_inputs = process_vision_info(messages) # Prepare a dictionary of model inputs. inputs = self.processor( text=[text_prompt], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ) # Move all tensors to the device where the model is loaded. inputs = inputs.to(self.model.device) # Use the model's generate() method to produce output. # You can pass an optional "max_new_tokens" parameter from the input data. max_new_tokens = data.get("max_new_tokens", 128) generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) # Remove the input prompt tokens from the generated sequence. generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] # Decode the token ids to obtain the final text output. output_text = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False ) return {"output": output_text}