File size: 2,116 Bytes
b233041
 
8b2cc38
b233041
 
4d7a2e1
 
 
 
 
 
 
 
 
 
 
 
b233041
 
05bafe4
 
b233041
 
 
99ac71e
 
b233041
e743472
 
 
b233041
99ac71e
 
 
 
05bafe4
99ac71e
5bc68e8
99ac71e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05bafe4
99ac71e
05bafe4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# りんなGPT-2-medium ファインチューニングやってみた

# %%time

# ファインチューニングの実行
# python ./transformers/examples/pytorch/language-modeling/run_clm.py \
#    --model_name_or_path=rinna/japanese-gpt2-medium \
#    --train_file=natsumesouseki.txt \
#   --validation_file=natsumesouseki.txt \
#    --do_train \
#    --do_eval \
#    --num_train_epochs=3 \
#    --save_steps=5000 \
#    --save_total_limit=3 \
#    --per_device_train_batch_size=1 \
#    --per_device_eval_batch_size=1 \
#    --output_dir=output/

from transformers import T5Tokenizer, AutoModelForCausalLM
import gradio as gr
import torch

# トークナイザーとモデルの準備
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-gpt2-medium")
# model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")
model = AutoModelForCausalLM.from_pretrained("output/")

# 平均/分散の値を正規化
model.eval()

# 推論の実行
#def Chat(prompt):
#    input = tokenizer.encode(prompt, return_tensors="pt")
#    output = model.generate(input, do_sample=True, max_length=100, num_return_sequences=5)
#    return tokenizer.batch_decode(output)
def Chat(prompt):
  num = 3
  input_ids = tokenizer.encode(prompt, return_tensors="pt",add_special_tokens=False)
  #with torch.no_grad():
  output = model.generate(
        input_ids,
        max_length=300, # 最長の文章長
        min_length=100, # 最短の文章長
        do_sample=True,
        top_k=500, # 上位{top_k}個の文章を保持
        top_p=0.95, # 上位{top_p}%の単語から選択する。例)上位95%の単語から選んでくる
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        #bad_word_ids=[[tokenizer.unk_token_id]],
        num_return_sequences=num # 生成する文章の数
    )
  decoded = tokenizer.decode(output.tolist()[0])
  return decoded

app = gr.Interface(fn=Chat, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs="text" , title="夏目漱石GPT")
app.launch()