import gradio as gr import cv2 import torch from PIL import Image from pathlib import Path from threading import Thread from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer import spaces import time TITLE = " מודל מבוסס גמה 3 ליצירת שירים מטופשים בעברית " DESCRIPTION= """ ניתן לבקש שיר על בסיס טקסט, תמונה ווידאו בכל פעם, יווצר שיר שונה, אז אם לא אהבתם, אפשר לנסות שוב עם אותו הפרומפט [המודל זמין להורדה](https://huggingface.co/Norod78/gemma-3_4b_hebrew-lyrics-finetune) המודל כּוּיַּיל ע״י [דורון אדלר](https://linktr.ee/Norod78) """ # model config model_4b_name = "Norod78/gemma-3_4b_hebrew-lyrics-finetune" model_4b = Gemma3ForConditionalGeneration.from_pretrained( model_4b_name, device_map="auto", torch_dtype=torch.bfloat16 ).eval() processor_4b = AutoProcessor.from_pretrained(model_4b_name) # I will add timestamp later def extract_video_frames(video_path, num_frames=8): cap = cv2.VideoCapture(video_path) frames = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) step = max(total_frames // num_frames, 1) for i in range(num_frames): cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) ret, frame = cap.read() if ret: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame)) cap.release() return frames def format_message(content, files): message_content = [] if content: parts = content.split('') for i, part in enumerate(parts): if part.strip(): message_content.append({"type": "text", "text": part.strip()}) if i < len(parts) - 1 and files: img = Image.open(files.pop(0)) message_content.append({"type": "image", "image": img}) for file in files: file_path = file if isinstance(file, str) else file.name if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: img = Image.open(file_path) message_content.append({"type": "image", "image": img}) elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: frames = extract_video_frames(file_path) for frame in frames: message_content.append({"type": "image", "image": frame}) return message_content def format_conversation_history(chat_history): messages = [] current_user_content = [] for item in chat_history: role = item["role"] content = item["content"] if role == "user": if isinstance(content, str): current_user_content.append({"type": "text", "text": content}) elif isinstance(content, list): current_user_content.extend(content) else: current_user_content.append({"type": "text", "text": str(content)}) elif role == "assistant": if current_user_content: messages.append({"role": "user", "content": current_user_content}) current_user_content = [] messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) if current_user_content: messages.append({"role": "user", "content": current_user_content}) return messages @spaces.GPU(duration=120) def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): if isinstance(input_data, dict) and "text" in input_data: text = input_data["text"] files = input_data.get("files", []) else: text = str(input_data) files = [] new_message_content = format_message(text, files) new_message = {"role": "user", "content": new_message_content} system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] processed_history = format_conversation_history(chat_history) messages = system_message + processed_history if messages and messages[-1]["role"] == "user": messages[-1]["content"].extend(new_message["content"]) else: messages.append(new_message) model = model_4b processor = processor_4b inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True ).to(model.device) streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate_response, chatbot=gr.Chatbot(rtl=True, show_copy_button=True,type="messages"), additional_inputs=[ gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), gr.Textbox( label="System Prompt", value="אתה משורר ישראלי, כותב שירים בעברית", lines=4, placeholder="שנה את ההגדרות של המודל", text_align = 'right', rtl = True ), gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.2), gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.4), gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=30), gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1), ], examples=[ [{"text": "כתוב לי בבקשה שיר המתאר את התמונה", "files": ["examples/image1.jpg"]}], [{"text": "תפוח אדמה עם חרדה חברתית"}] ], textbox=gr.MultimodalTextbox( rtl=True, label="קלט", file_types=["image", "video"], file_count="multiple", placeholder="בקשו שיר ו/או העלו תמונה", ), cache_examples=False, type="messages", fill_height=True, stop_btn="הפסק", css_paths=["style.css"], multimodal=True, title=TITLE, description=DESCRIPTION, theme=gr.themes.Soft(), ) if __name__ == "__main__": chat_interface.queue(max_size=20).launch(mcp_server=True)