# Imports import gradio as gr import spaces import torch import os import math import gc import librosa from PIL import Image, ImageSequence from decord import VideoReader, cpu from transformers import AutoModel, AutoTokenizer, AutoProcessor # Variables DEVICE = "auto" if DEVICE == "auto": DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[SYSTEM] | Using {DEVICE} type compute device.") DEFAULT_INPUT = "Describe in one short sentence." MAX_FRAMES = 64 AUDIO_SR = 16000 model_name = "openbmb/MiniCPM-o-2_6" repo = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="sdpa", torch_dtype=torch.bfloat16, init_vision=True, init_audio=True, init_tts=False).to(DEVICE) tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) css = ''' .gradio-container{max-width: 560px !important} h1{text-align:center} footer { visibility: hidden } ''' instruction = "You will analyze image, GIF, video, and audio input, then use as much keywords to describe the given content and take as much guesses of what it could be." filetypes = { "Image": { "extensions": [".jpg",".jpeg",".png",".bmp"], "instruction": "Analyze the '█' image.", "function": "build_image" }, "GIF":{ "extensions": [".gif"], "instruction": "Analyze the '█' GIF.", "function": "build_gif" }, "Video": { "extensions": [".mp4",".mov",".avi",".mkv"], "instruction": "Analyze the '█' video including the audio associated with the video.", "function": "build_video" }, "Audio": { "extensions": [".wav",".mp3",".flac",".aac"], "instruction": "Analyze the '█' audio.", "function": "build_audio" }, } # Functions def uniform_sample(sequence, n): return sequence[::max(len(sequence) // n,1)][:n] def build_image(path): return [Image.open(path).convert("RGB")] def build_gif(path): frames = [f.copy().convert("RGB") for f in ImageSequence.Iterator(Image.open(path))] return uniform_sample(frames, MAX_FRAMES) def build_video(path): vr = VideoReader(path, ctx=cpu(0)) idx = uniform_sample(range(len(vr)), MAX_FRAMES) frames = [Image.fromarray(f.astype("uint8")) for f in vr.get_batch(idx).asnumpy()] audio = build_audio(path)[0] units = [] for i, frame in enumerate(frames): chunk = audio[i*AUDIO_SR:(i+1)*AUDIO_SR] if not chunk.size: break units.extend(["", frame, chunk]) return units def build_audio(path): audio, _ = librosa.load(path, sr=AUDIO_SR, mono=True) return [audio] @spaces.GPU(duration=30) def generate(filepath, input=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512): if not input: return "No input provided." extension = os.path.splitext(filepath)[1].lower() filetype = next((k for k, v in filetypes.items() if extension in v["extensions"]), None) if not filetype: return "Unsupported file type." filetype_data = filetypes[filetype] input_prefix = filetype_data["instruction"].replace("█", os.path.basename(filepath)) content = globals()[filetype_data["function"]](filepath) + [f"{instruction}\n{input_prefix}\n{input}"] messages = [{ "role": "user", "content": content }] print(messages) output = repo.chat( msgs=messages, tokenizer=tokenizer, sampling=sampling, temperature= temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, max_new_tokens=max_tokens, omni_input=True, use_image_id=False, max_slice_nums=9 ) torch.cuda.empty_cache() gc.collect() return output def cloud(): print("[CLOUD] | Space maintained.") # Initialize with gr.Blocks(css=css) as main: with gr.Column(): file = gr.File(label="File", file_types=["image", "video", "audio"], type="filepath") input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input") sampling = gr.Checkbox(value=True, label="Sampling") temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=0.7, label="Temperature") top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.8, label="Top P") top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=50, label="Top K") repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=1.05, label="Repetition Penalty") max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens") submit = gr.Button("▶") maintain = gr.Button("☁️") with gr.Column(): output = gr.Textbox(lines=1, value="", label="Output") submit.click(fn=generate, inputs=[file, input, sampling, temperature, top_p, top_k, repetition_penalty, max_tokens], outputs=[output], queue=False) maintain.click(cloud, inputs=[], outputs=[], queue=False) main.launch(show_api=True)