""" File: vlm.py Description: Vision language model utility functions. Author: Didier Guillevic Date: 2025-03-16 """ import spaces from transformers import AutoProcessor, Gemma3ForConditionalGeneration from transformers import TextIteratorStreamer from threading import Thread import torch # # Load the model: google/gemma-3-4b-it # device = 'cuda' if torch.cuda.is_available() else 'cpu' model_id = "google/gemma-3-4b-it" processor = AutoProcessor.from_pretrained(model_id, use_fast=True, padding_side="left") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.bfloat16 ).to(device).eval() # # Build messages # def build_messages(message: dict, history: list[tuple]): """Build messages given message & history from a **multimodal** chat interface. Args: message: dictionary with keys: 'text', 'files' history: list of tuples with (message, response) Returns: list of messages (to be sent to the model) """ # Get the user's text and list of images user_text = message.get("text", "") user_images = message.get("files", []) # List of images # Build the message list including history messages = [] combined_user_input = [] #Combine images and text if found in same turn. for user_turn, bot_turn in history: if isinstance(user_turn, tuple): # Image input image_content = [{"type": "image", "url": image_url} for image_url in user_turn] combined_user_input.extend(image_content) elif isinstance(user_turn, str): #Text input combined_user_input.append({"type":"text", "text": user_turn}) if combined_user_input and bot_turn: messages.append({'role': 'user', 'content': combined_user_input}) messages.append({'role': 'assistant', 'content': [{"type": "text", "text": bot_turn}]}) combined_user_input = [] #reset the combined user input. # Build the user message's content from the provided message user_content = [] if user_text: user_content.append({"type": "text", "text": user_text}) for image in user_images: user_content.append({"type": "image", "url": image}) messages.append({'role': 'user', 'content': user_content}) return messages # # Streaming response # @spaces.GPU @torch.inference_mode() def stream_response(messages: list[dict]): """Stream the model's response to the chat interface. Args: messages: list of messages to send to the model """ # Generate model's response inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer( processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=2_048, do_sample=False ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() partial_message = "" for new_text in streamer: partial_message += new_text yield partial_message # # Response (non-streaming) # @spaces.GPU @torch.inference_mode() def get_response(messages: list[dict]): """Get the model's response. Args: messages: list of messages to send to the model """ # Generate model's response inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.inference_mode(): generation = model.generate(**inputs, max_new_tokens=2_048, do_sample=False) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True) return decoded