John Doe commited on
Commit
ea7e714
·
1 Parent(s): 04bdba9

Zmaker.pyをアップロード

Browse files

前回のcommitでZmaker.pyを入れ忘れたので追加

Files changed (1) hide show
  1. Zmaker.py +140 -0
Zmaker.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, T5Tokenizer
3
+ import csv, re, mojimoji
4
+
5
+ class Zmaker:
6
+
7
+ #GPT2のモデル名
8
+ gpt_model_name = "rinna/japanese-gpt2-medium"
9
+
10
+ #文章の最大長
11
+ min_len, max_len = 1, 128
12
+
13
+ #予測時のパラメータ
14
+ top_k, top_p = 50, 0.95 #top-k検索の閾値
15
+ num_text = 1 #出力する文の数
16
+ temp = 1.0
17
+ repeat_ngram_size = 1
18
+
19
+ #推論にCPU利用を強制するか
20
+ use_cpu = True
21
+
22
+ def __init__(self, ft_path = None):
23
+ """コンストラクタ
24
+
25
+ コンストラクタ。モデルをファイルから読み込む場合と,
26
+ 新規作成する場合で動作を分ける.
27
+
28
+ Args:
29
+ ft_path : ファインチューニングされたモデルのパス.Noneを指定すると
30
+ train : 学習を行うか
31
+ Returns:
32
+ なし
33
+ """
34
+
35
+ #モデルの設定
36
+ self.__SetModel(ft_path)
37
+
38
+ #モデルの状態をCPUかGPUかで切り替える
39
+ if self.use_cpu: #CPUの利用を強制する場合の処理
40
+ device = torch.device('cpu')
41
+ else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
42
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
43
+ self.model.to(device)
44
+
45
+
46
+
47
+ def __SetModel(self, ft_path = None):
48
+ """GPT2の設定
49
+
50
+ GPT2のTokenizerおよびモデルを設定する.
51
+ ユーザー定義後と顔文字も語彙として認識されるように設定する.
52
+
53
+ Args:
54
+ ft_path : ファインチューニング済みのモデルを読み込む
55
+ 何も指定しないとself.gpt_model_nameの事前学習モデルを
56
+ ネットからダウンロードする.
57
+ Returns:
58
+ なし
59
+ """
60
+ #GPT2のTokenizerのインスタンスを生成
61
+ self.tokenizer = T5Tokenizer.from_pretrained(self.gpt_model_name)
62
+ self.tokenizer.do_lower_case = True # due to some bug of tokenizer config loading
63
+
64
+ #モデルの読み込み
65
+ if ft_path is not None:
66
+ self.model = AutoModelForCausalLM.from_pretrained(
67
+ ft_path,
68
+ #torch_dtype = torch.bfloat16,
69
+ low_cpu_mem_usage = True
70
+ )
71
+ else:
72
+ print("fine-tuned model was not found")
73
+
74
+ #モデルをevalモードに
75
+ self.model.eval()
76
+
77
+ def __TextCleaning(self, texts):
78
+ """テキストの前処理をする
79
+
80
+ テキストの前処理を行う.具体的に行うこととしては...
81
+ ・全角/半角スペースの除去
82
+ ・半角数字/アルファベットの全角化
83
+ """
84
+ #半角スペース,タブ,改行改ページを削除
85
+ texts = [re.sub("[\u3000 \t \s \n]", "", t) for t in texts]
86
+
87
+ #半角/全角を変換
88
+ texts = [mojimoji.han_to_zen(t) for t in texts]
89
+ return texts
90
+
91
+
92
+ def GenLetter(self, prompt):
93
+ """怪文書の生成
94
+
95
+ GPT2で怪文書を生成する.
96
+ promptに続く文章を生成して出力する
97
+
98
+ Args:
99
+ prompt : 文章の先頭
100
+ Retunrs:
101
+ 生成された文章のリスト
102
+ """
103
+
104
+ #テキストをクリーニング
105
+ prompt_clean = [prompt]
106
+
107
+ #文章をtokenizerでエンコード
108
+ x = self.tokenizer.encode(
109
+ prompt_clean[0], return_tensors="pt",
110
+ add_special_tokens=False
111
+ )
112
+
113
+ #デバイスの選択
114
+ if self.use_cpu: #CPUの利用を強制する場合の処理
115
+ device = torch.device('cpu')
116
+ else: #特に指定が無いなら,GPUがあるときはGPUを使い,CPUのみの場合はCPUを使う
117
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
118
+ x = x.to(device)
119
+
120
+ #gpt2による推論
121
+ with torch.no_grad():
122
+ y = self.model.generate(
123
+ x, #入力
124
+ min_length=self.min_len, # 文章の最小長
125
+ max_length=self.max_len, # 文章の最大長
126
+ do_sample=True, # 次の単語を確率で選ぶ
127
+ top_k=self.top_k, # Top-Kサンプリング
128
+ top_p=self.top_p, # Top-pサンプリング
129
+ temperature=self.temp, # 確率分布の調整
130
+ no_repeat_ngram_size = self.repeat_ngram_size, #同じ単語を何回繰り返していいか
131
+ num_return_sequences=self.num_text, # 生成する文章の数
132
+ pad_token_id=self.tokenizer.pad_token_id, # パディングのトークンID
133
+ bos_token_id=self.tokenizer.bos_token_id, # テキスト先頭のトークンID
134
+ eos_token_id=self.tokenizer.eos_token_id, # テキスト終端のトークンID
135
+ early_stopping=True
136
+ )
137
+
138
+ # 特殊トークンをスキップして推論結果を文章にデコード
139
+ res = self.tokenizer.batch_decode(y, skip_special_tokens=True)
140
+ return res