#!/usr/bin/env python import os import re import tempfile from collections.abc import Iterator from threading import Thread import cv2 import gradio as gr import spaces import torch from loguru import logger from PIL import Image from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer # [PDF] PyPDF2 추가 import PyPDF2 model_id = os.getenv("MODEL_ID", "google/gemma-3-27b-it") processor = AutoProcessor.from_pretrained(model_id, padding_side="left") model = Gemma3ForConditionalGeneration.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="eager" ) MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5")) def count_files_in_new_message(paths: list[str]) -> tuple[int, int]: image_count = 0 video_count = 0 for path in paths: if path.endswith(".mp4"): video_count += 1 else: image_count += 1 return image_count, video_count def count_files_in_history(history: list[dict]) -> tuple[int, int]: image_count = 0 video_count = 0 for item in history: if item["role"] != "user" or isinstance(item["content"], str): continue if item["content"][0].endswith(".mp4"): video_count += 1 else: image_count += 1 return image_count, video_count def validate_media_constraints(message: dict, history: list[dict]) -> bool: """ 이미지/비디오 개수와 혼합 여부 등을 검사하는 함수. PDF는 검사 로직에서 제외하여 업로드만 허용. """ # [PDF] PDF 파일 제외 처리 pdf_files = [f for f in message["files"] if f.endswith(".pdf")] non_pdf_files = [f for f in message["files"] if not f.endswith(".pdf")] # 기존 로직은 non_pdf_files(= 이미지/비디오)에 대해서만 체크 new_image_count, new_video_count = count_files_in_new_message(non_pdf_files) history_image_count, history_video_count = count_files_in_history(history) image_count = history_image_count + new_image_count video_count = history_video_count + new_video_count if video_count > 1: gr.Warning("Only one video is supported.") return False if video_count == 1: if image_count > 0: gr.Warning("Mixing images and videos is not allowed.") return False if "" in message["text"]: gr.Warning("Using tags with video files is not supported.") return False # TODO: Add frame count validation for videos similar to image count limits # noqa: FIX002, TD002, TD003 if video_count == 0 and image_count > MAX_NUM_IMAGES: gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.") return False # [PDF] PDF 갯수 제한(필요하다면)도 추가 가능 # 일단 제한은 두지 않고 바로 True 반환 # 태그가 있을 경우, 이미지 개수와 매칭 검사 if "" in message["text"]: # new_image_count는 pdf 제외된 이미지 수 if message["text"].count("") != new_image_count: gr.Warning("The number of tags in the text does not match the number of images.") return False return True def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]: vidcap = cv2.VideoCapture(video_path) fps = vidcap.get(cv2.CAP_PROP_FPS) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_interval = int(fps / 3) frames = [] for i in range(0, total_frames, frame_interval): vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) timestamp = round(i / fps, 2) frames.append((pil_image, timestamp)) vidcap.release() return frames def process_video(video_path: str) -> list[dict]: content = [] frames = downsample_video(video_path) for frame in frames: pil_image, timestamp = frame with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: pil_image.save(temp_file.name) content.append({"type": "text", "text": f"Frame {timestamp}:"}) content.append({"type": "image", "url": temp_file.name}) logger.debug(f"{content=}") return content def process_interleaved_images(message: dict) -> list[dict]: logger.debug(f"{message['files']=}") parts = re.split(r"()", message["text"]) logger.debug(f"{parts=}") content = [] image_index = 0 for part in parts: logger.debug(f"{part=}") if part == "": content.append({"type": "image", "url": message["files"][image_index]}) logger.debug(f"file: {message['files'][image_index]}") image_index += 1 elif part.strip(): content.append({"type": "text", "text": part.strip()}) elif isinstance(part, str) and part != "": content.append({"type": "text", "text": part}) logger.debug(f"{content=}") return content # [PDF] PDF -> Markdown 변환 함수 추가 def pdf_to_markdown(pdf_path: str) -> str: """ PDF 파일을 텍스트로 추출 후, 간단한 Markdown 형태로 반환. """ text_chunks = [] with open(pdf_path, "rb") as f: reader = PyPDF2.PdfReader(f) for page_num, page in enumerate(reader.pages, start=1): page_text = page.extract_text() page_text = page_text.strip() if page_text else "" if page_text: # 페이지별로 간단한 헤더와 본문을 Markdown으로 합침 text_chunks.append(f"## Page {page_num}\n\n{page_text}\n") return "\n".join(text_chunks) def process_new_user_message(message: dict) -> list[dict]: """ 새 user message에서 text, 파일(이미지/비디오/PDF)을 처리. """ if not message["files"]: return [{"type": "text", "text": message["text"]}] # [PDF] PDF 파일 목록 pdf_files = [f for f in message["files"] if f.endswith(".pdf")] # 이미지·비디오 목록 other_files = [f for f in message["files"] if not f.endswith(".pdf")] # 일단 사용자의 text를 가장 먼저 넣는다 content_list = [{"type": "text", "text": message["text"]}] # PDF 변환 후 추가 for pdf_path in pdf_files: pdf_markdown = pdf_to_markdown(pdf_path) if pdf_markdown.strip(): content_list.append({"type": "text", "text": pdf_markdown}) else: content_list.append({"type": "text", "text": "(PDF에서 텍스트 추출 실패)"}) # 영상이 있는지 확인 video_files = [f for f in other_files if f.endswith(".mp4")] if video_files: # 비디오는 한 개만 처리한다는 전제 (validate_media_constraints에서 이미 검사) # 여러 개일 경우 첫 번째 것만 처리하거나, 경고 처리 content_list += process_video(video_files[0]) return content_list # interleaved 이미지 if "" in message["text"]: return process_interleaved_images(message) # 일반 이미지(여러 장) image_files = [f for f in other_files if not f.endswith(".mp4")] if image_files: content_list += [{"type": "image", "url": path} for path in image_files] return content_list def process_history(history: list[dict]) -> list[dict]: messages = [] current_user_content: list[dict] = [] for item in history: if item["role"] == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]}) else: content = item["content"] if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) else: current_user_content.append({"type": "image", "url": content[0]}) return messages @spaces.GPU(duration=120) def run(message: dict, history: list[dict], system_prompt: str = "", max_new_tokens: int = 512) -> Iterator[str]: if not validate_media_constraints(message, history): yield "" return messages = [] if system_prompt: messages.append({"role": "system", "content": [{"type": "text", "text": system_prompt}]}) messages.extend(process_history(history)) messages.append({"role": "user", "content": process_new_user_message(message)}) inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(device=model.device, dtype=torch.bfloat16) streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() output = "" for delta in streamer: output += delta yield output examples = [ [ { "text": "I need to be in Japan for 10 days, going to Tokyo, Kyoto and Osaka. Think about number of attractions in each of them and allocate number of days to each city. Make public transport recommendations.", "files": [], } ], [ { "text": "Write the matplotlib code to generate the same bar chart.", "files": ["assets/additional-examples/barchart.png"], } ], [ { "text": "What is odd about this video?", "files": ["assets/additional-examples/tmp.mp4"], } ], [ { "text": "I already have this supplement and I want to buy this one . Any warnings I should know about?", "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"], } ], [ { "text": "Write a poem inspired by the visual elements of the images.", "files": ["assets/sample-images/06-1.png", "assets/sample-images/06-2.png"], } ], [ { "text": "Compose a short musical piece inspired by the visual elements of the images.", "files": [ "assets/sample-images/07-1.png", "assets/sample-images/07-2.png", "assets/sample-images/07-3.png", "assets/sample-images/07-4.png", ], } ], [ { "text": "Write a short story about what might have happened in this house.", "files": ["assets/sample-images/08.png"], } ], [ { "text": "Create a short story based on the sequence of images.", "files": [ "assets/sample-images/09-1.png", "assets/sample-images/09-2.png", "assets/sample-images/09-3.png", "assets/sample-images/09-4.png", "assets/sample-images/09-5.png", ], } ], [ { "text": "Describe the creatures that would live in this world.", "files": ["assets/sample-images/10.png"], } ], [ { "text": "Read text in the image.", "files": ["assets/additional-examples/1.png"], } ], [ { "text": "When is this ticket dated and how much did it cost?", "files": ["assets/additional-examples/2.png"], } ], [ { "text": "Read the text in the image into markdown.", "files": ["assets/additional-examples/3.png"], } ], [ { "text": "Evaluate this integral.", "files": ["assets/additional-examples/4.png"], } ], [ { "text": "caption this image", "files": ["assets/sample-images/01.png"], } ], [ { "text": "What's the sign says?", "files": ["assets/sample-images/02.png"], } ], [ { "text": "Compare and contrast the two images.", "files": ["assets/sample-images/03.png"], } ], [ { "text": "List all the objects in the image and their colors.", "files": ["assets/sample-images/04.png"], } ], [ { "text": "Describe the atmosphere of the scene.", "files": ["assets/sample-images/05.png"], } ], ] DESCRIPTION = """\ This is a demo of Gemma 3 27B it, a vision language model with outstanding performance on a wide range of tasks. You can upload images, interleaved images and videos. Note that video input only supports single-turn conversation and mp4 input. Also, PDF files are now supported: any uploaded PDF will be converted to Markdown text and passed into the conversation. """ # [PDF] .pdf 허용 demo = gr.ChatInterface( fn=run, type="messages", chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]), textbox=gr.MultimodalTextbox( file_types=["image", ".mp4", ".pdf"], # [PDF] 허용 file_count="multiple", autofocus=True ), multimodal=True, additional_inputs=[ gr.Textbox(label="System Prompt", value="You are a helpful assistant."), gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700), ], stop_btn=False, title="Gemma 3 27B IT", description=DESCRIPTION, examples=examples, run_examples_on_click=False, cache_examples=False, css_paths="style.css", delete_cache=(1800, 1800), ) if __name__ == "__main__": demo.launch()