Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,086 Bytes
efcf35e 47f64a9 efcf35e 698d75d 4d3de5e 76bc95f 4d3de5e 47f64a9 efcf35e 76bc95f efcf35e 76bc95f efcf35e 76bc95f 47f64a9 efcf35e 47f64a9 434184f 8bfc45f efcf35e 76bc95f 8bfc45f efcf35e 76bc95f 47f64a9 efcf35e 1219a4a efcf35e 4d3de5e efcf35e 4d3de5e efcf35e 1219a4a efcf35e 47f64a9 efcf35e 47f64a9 efcf35e 76bc95f efcf35e 47f64a9 efcf35e 47f64a9 efcf35e 47f64a9 1219a4a 47f64a9 efcf35e 4d3de5e efcf35e 76bc95f efcf35e 47f64a9 76bc95f 47f64a9 efcf35e 76bc95f efcf35e 76bc95f efcf35e 76bc95f 47f64a9 efcf35e 47f64a9 1219a4a efcf35e 47f64a9 efcf35e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import subprocess
subprocess.run(
"pip install flash-attn --no-build-isolation",
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
shell=True,
)
DESCRIPTION = """\
# Gemma 3 270m IT 💎💬
Try this mini model by Google.
[🪪 **Model card**](https://huggingface.co/google/gemma-3-270m-it)
"""
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
model_id = "google/gemma-3-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
attn_impl = "flash_attention_2" if torch.cuda.is_available() else "eager"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation=attn_impl,
)
model.config.sliding_window = 4096
model.eval()
@spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_message: str = "",
max_new_tokens: int = 1024,
temperature: float = 0.001,
top_p: float = 1.0,
top_k: int = 50,
repetition_penalty: float = 1.0,
) -> Iterator[str]:
conversation = [{"role": "system", "content": system_message}]
for user, assistant in chat_history:
conversation.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
disable_compile=True, # https://ai.google.dev/gemma/docs/core/huggingface_text_full_finetune#test_model_inference
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(
value="",
label="System message",
render=False,
),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0,
maximum=4.0,
step=0.1,
value=1.0, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.95, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=64, # from https://huggingface.co/google/gemma-3-270m-it/blob/main/generation_config.json
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.0, # default from https://huggingface.co/docs/transformers/en/main_classes/text_generation
),
],
stop_btn=None,
examples = [
["Hi! How are you?"],
["Pros and cons of a long-term relationship. Bullet list with max 3 pros and 3 cons, concise."],
["How many hours does it take a man to eat a helicopter?"],
["How do you open a JSON file in Python?"],
["Make a bullet list of pros and cons of living in San Francisco. Maximum 2 pros and 2 cons."],
["Invent a short story with animals about the value of friendship."],
["Can you briefly explain what the Python programming language is?"],
["Write a 100-word article on 'Benefits of Open-Source in AI Research'."],
],
cache_examples=False,
)
with gr.Blocks(css="style.css", fill_height=True, theme="soft") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|