gemma-3-270m / app.py
TakiTakiTa's picture
Update app.py
9694f44 verified
# app.py
import gradio as gr
import torch
from threading import Thread
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
# Choose any chat model with a chat template; Zephyr works well:
MODEL_NAME = "google/gemma-3-270m-it"
# Load model + tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
)
def build_chat(system_message: str, history: list[tuple[str, str]], user_message: str):
"""Convert Gradio history into a list of chat messages for apply_chat_template."""
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
for u, a in history:
if u:
messages.append({"role": "user", "content": u})
if a:
messages.append({"role": "assistant", "content": a})
messages.append({"role": "user", "content": user_message})
return messages
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# 1) Build chat messages and tokenize using the model's chat template
messages = build_chat(system_message, history, message)
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
# 2) Stream generation
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
gen_kwargs = dict(
input_ids=inputs,
max_new_tokens=int(max_tokens),
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer,
)
# Run generate() in a background thread while we yield chunks
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
thread.join()
# Gradio UI (same controls as your example)
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
),
],
)
if __name__ == "__main__":
demo.launch()