Spaces:
Runtime error
Runtime error
File size: 5,018 Bytes
ea7e714 61938d3 ea7e714 61938d3 ea7e714 93770f8 ea7e714 61938d3 ea7e714 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
from transformers import AutoModelForCausalLM, T5Tokenizer
import csv, re, mojimoji
class Zmaker:
#GPT2のモデル名
gpt_model_name = "rinna/japanese-gpt2-medium"
#文章の最大長
min_len, max_len = 1, 128
#予測時のパラメータ
top_k, top_p = 40, 0.95 #top-k検索の閾値
num_text = 1 #出力する文の数
temp = 0.1
repeat_ngram_size = 1
#推論にCPU利用を強制するか
use_cpu = True
def __init__(self, ft_path = None):
"""コンストラクタ
コンストラクタ。モデルをファイルから読み込む場合と,
新規作成する場合で動作を分ける.
Args:
ft_path : ファインチューニングされたモデルのパス.
Returns:
なし
"""
#モデルの設定
self.__SetModel(ft_path)
#モデルの状態をCPUかGPUかで切り替える
if self.use_cpu: #CPUの利用を強制する場合の処理
device = torch.device('cpu')
else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.model.to(device)
def __SetModel(self, ft_path = None):
"""GPT2の設定
GPT2のTokenizerおよびモデルを設定する.
ユーザー定義後と顔文字も語彙として認識されるように設定する.
Args:
ft_path : ファインチューニング済みのモデルを読み込む
何も指定しないとself.gpt_model_nameの事前学習モデルを
ネットからダウンロードする.
Returns:
なし
"""
#GPT2のTokenizerのインスタンスを生成
self.tokenizer = T5Tokenizer.from_pretrained(self.gpt_model_name)
self.tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
#モデルの読み込み
if ft_path is not None:
self.model = AutoModelForCausalLM.from_pretrained(
ft_path, #torch_dtype = torch.bfloat16
)
else:
print("fine-tuned model was not found")
#モデルをevalモードに
self.model.eval()
def __TextCleaning(self, texts):
"""テキストの前処理をする
テキストの前処理を行う.具体的に行うこととしては...
・全角/半角スペースの除去
・半角数字/アルファベットの全角化
"""
#半角スペース,タブ,改行改ページを削除
texts = [re.sub("[\u3000 \t \s \n]", "", t) for t in texts]
#半角/全角を変換
texts = [mojimoji.han_to_zen(t) for t in texts]
return texts
def GenLetter(self, prompt):
"""怪文書の生成
GPT2で怪文書を生成する.
promptに続く文章を生成して出力する
Args:
prompt : 文章の先頭
Retunrs:
生成された文章のリスト
"""
#テキストをクリーニング
prompt_clean = [prompt]
#文章をtokenizerでエンコード
x = self.tokenizer.encode(
prompt_clean[0], return_tensors="pt",
add_special_tokens=False
)
#デバイスの選択
if self.use_cpu: #CPUの利用を強制する場合の処理
device = torch.device('cpu')
else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
x = x.to(device)
#gpt2による推論
with torch.no_grad():
y = self.model.generate(
x, #入力
min_length=self.min_len, # 文章の最小長
max_length=self.max_len, # 文章の最大長
do_sample=True, # 次の単語を確率で選ぶ
top_k=self.top_k, # Top-Kサンプリング
top_p=self.top_p, # Top-pサンプリング
temperature=self.temp, # 確率分布の調整
no_repeat_ngram_size = self.repeat_ngram_size, #同じ単語を何回繰り返していいか
num_return_sequences=self.num_text, # 生成する文章の数
pad_token_id=self.tokenizer.pad_token_id, # パディングのトークンID
bos_token_id=self.tokenizer.bos_token_id, # テキスト先頭のトークンID
eos_token_id=self.tokenizer.eos_token_id, # テキスト終端のトークンID
early_stopping=True
)
# 特殊トークンをスキップして推論結果を文章にデコード
res = self.tokenizer.batch_decode(y, skip_special_tokens=True)
return res |