Spaces:
Runtime error
Runtime error
File size: 2,116 Bytes
b233041 8b2cc38 b233041 4d7a2e1 b233041 05bafe4 b233041 99ac71e b233041 e743472 b233041 99ac71e 05bafe4 99ac71e 5bc68e8 99ac71e 05bafe4 99ac71e 05bafe4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# りんなGPT-2-medium ファインチューニングやってみた
# %%time
# ファインチューニングの実行
# python ./transformers/examples/pytorch/language-modeling/run_clm.py \
# --model_name_or_path=rinna/japanese-gpt2-medium \
# --train_file=natsumesouseki.txt \
# --validation_file=natsumesouseki.txt \
# --do_train \
# --do_eval \
# --num_train_epochs=3 \
# --save_steps=5000 \
# --save_total_limit=3 \
# --per_device_train_batch_size=1 \
# --per_device_eval_batch_size=1 \
# --output_dir=output/
from transformers import T5Tokenizer, AutoModelForCausalLM
import gradio as gr
import torch
# トークナイザーとモデルの準備
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
# model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
model = AutoModelForCausalLM.from_pretrained("output/")
# 平均/分散の値を正規化
model.eval()
# 推論の実行
#def Chat(prompt):
# input = tokenizer.encode(prompt, return_tensors="pt")
# output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=5)
# return tokenizer.batch_decode(output)
def Chat(prompt):
num = 3
input_ids = tokenizer.encode(prompt, return_tensors="pt",add_special_tokens=False)
#with torch.no_grad():
output = model.generate(
input_ids,
max_length=300, # 最長の文章長
min_length=100, # 最短の文章長
do_sample=True,
top_k=500, # 上位{top_k}個の文章を保持
top_p=0.95, # 上位{top_p}%の単語から選択する。例)上位95%の単語から選んでくる
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
#bad_word_ids=[[tokenizer.unk_token_id]],
num_return_sequences=num # 生成する文章の数
)
decoded = tokenizer.decode(output.tolist()[0])
return decoded
app = gr.Interface(fn=Chat, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs="text" , title="夏目漱石GPT")
app.launch() |