username
initial commit
b782015
raw
history blame
3.76 kB
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
# トークナイザーのロード(GPUは不要なのでグローバルにロード)
tokenizer = AutoTokenizer.from_pretrained("llm-jp/llm-jp-3-8x1.8b-instruct3")
# グローバル変数としてモデルを定義(最初はNone)
model = None
# ZeroGPUで実行される関数(モデルロードに時間がかかる可能性があるため120秒に設定)
@spaces.GPU(duration=120)
def generate_text(system_prompt, user_input, max_length=512, temperature=0.7, top_p=0.95):
# モデルのロード(初回実行時のみ)
global model
if model is None:
model = AutoModelForCausalLM.from_pretrained(
"llm-jp/llm-jp-3-8x1.8b-instruct3",
device_map="auto",
torch_dtype=torch.bfloat16
)
# チャット形式の入力を作成
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input},
]
# トークン化
tokenized_input = tokenizer.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(model.device)
# 生成
with torch.no_grad():
output = model.generate(
tokenized_input,
max_new_tokens=max_length,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=1.05,
)[0]
# デコード
generated_text = tokenizer.decode(output, skip_special_tokens=True)
return generated_text
# Gradioインターフェースの作成
with gr.Blocks() as demo:
gr.Markdown("# LLM-JP-3-8x1.8b-instruct3 デモ")
gr.Markdown("国立情報学研究所が開発した日本語大規模言語モデル「LLM-JP-3」のデモです。ZeroGPUを使用しているため、初回実行時はGPUの割り当てに少し時間がかかることがあります。")
with gr.Row():
with gr.Column():
system_prompt = gr.Textbox(
label="システムプロンプト",
value="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。",
lines=2
)
user_input = gr.Textbox(label="ユーザー入力", lines=5, placeholder="質問や指示を入力してください...")
with gr.Row():
max_length = gr.Slider(label="最大生成トークン数", minimum=10, maximum=1024, value=512, step=1)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=0.7, step=0.1)
top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.95, step=0.05)
submit_btn = gr.Button("生成")
with gr.Column():
output = gr.Textbox(label="生成結果", lines=20)
submit_btn.click(
fn=generate_text,
inputs=[system_prompt, user_input, max_length, temperature, top_p],
outputs=output
)
gr.Examples(
examples=[
["以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。", "自然言語処理とは何か"],
["以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。", "日本の四季について教えてください"],
["あなたは優秀な物語作家です。", "猫と犬が友達になる短い物語を書いてください。"]
],
inputs=[system_prompt, user_input]
)
# アプリの起動
demo.launch()