Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
import spaces | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# トークナイザーのロード(GPUは不要なのでグローバルにロード) | |
tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-3-8x1.8b-instruct3") | |
# グローバル変数としてモデルを定義(最初はNone) | |
model = None | |
# ZeroGPUで実行される関数(モデルロードに時間がかかる可能性があるため120秒に設定) | |
def generate_text(system_prompt, user_input, max_length=512, temperature=0.7, top_p=0.95): | |
# モデルのロード(初回実行時のみ) | |
global model | |
if model is None: | |
model = AutoModelForCausalLM.from_pretrained( | |
"llm-jp/llm-jp-3-8x1.8b-instruct3", | |
device_map="auto", | |
torch_dtype=torch.bfloat16 | |
) | |
# チャット形式の入力を作成 | |
chat = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_input}, | |
] | |
# トークン化 | |
tokenized_input = tokenizer.apply_chat_template( | |
chat, | |
add_generation_prompt=True, | |
tokenize=True, | |
return_tensors="pt" | |
).to(model.device) | |
# 生成 | |
with torch.no_grad(): | |
output = model.generate( | |
tokenized_input, | |
max_new_tokens=max_length, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=1.05, | |
)[0] | |
# デコード | |
generated_text = tokenizer.decode(output, skip_special_tokens=True) | |
return generated_text | |
# Gradioインターフェースの作成 | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLM-JP-3-8x1.8b-instruct3 デモ") | |
gr.Markdown("国立情報学研究所が開発した日本語大規模言語モデル「LLM-JP-3」のデモです。ZeroGPUを使用しているため、初回実行時はGPUの割り当てに少し時間がかかることがあります。") | |
with gr.Row(): | |
with gr.Column(): | |
system_prompt = gr.Textbox( | |
label="システムプロンプト", | |
value="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。", | |
lines=2 | |
) | |
user_input = gr.Textbox(label="ユーザー入力", lines=5, placeholder="質問や指示を入力してください...") | |
with gr.Row(): | |
max_length = gr.Slider(label="最大生成トークン数", minimum=10, maximum=1024, value=512, step=1) | |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1) | |
top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.95, step=0.05) | |
submit_btn = gr.Button("生成") | |
with gr.Column(): | |
output = gr.Textbox(label="生成結果", lines=20) | |
submit_btn.click( | |
fn=generate_text, | |
inputs=[system_prompt, user_input, max_length, temperature, top_p], | |
outputs=output | |
) | |
gr.Examples( | |
examples=[ | |
["以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。", "自然言語処理とは何か"], | |
["以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。", "日本の四季について教えてください"], | |
["あなたは優秀な物語作家です。", "猫と犬が友達になる短い物語を書いてください。"] | |
], | |
inputs=[system_prompt, user_input] | |
) | |
# アプリの起動 | |
demo.launch() | |