bankholdup commited on
Commit
8ff16c1
1 Parent(s): 54bec18

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pystuck
2
+ pystuck.run_server()
3
+
4
+ import os
5
+
6
+ import argparse
7
+ import logging
8
+
9
+ import numpy as np
10
+ import torch
11
+ import datetime
12
+ import gradio as gr
13
+
14
+ from transformers import (
15
+ CTRLLMHeadModel,
16
+ CTRLTokenizer,
17
+ GPT2LMHeadModel,
18
+ GPT2Tokenizer,
19
+ OpenAIGPTLMHeadModel,
20
+ OpenAIGPTTokenizer,
21
+ TransfoXLLMHeadModel,
22
+ TransfoXLTokenizer,
23
+ XLMTokenizer,
24
+ XLMWithLMHeadModel,
25
+ XLNetLMHeadModel,
26
+ XLNetTokenizer,
27
+ )
28
+
29
+
30
+ logging.basicConfig(
31
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
32
+ )
33
+ logger = logging.getLogger(__name__)
34
+
35
+ MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
36
+
37
+ MODEL_CLASSES = {
38
+ "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
39
+ "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
40
+ "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
41
+ "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
42
+ "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
43
+ "xlm": (XLMWithLMHeadModel, XLMTokenizer),
44
+ }
45
+
46
+ def set_seed(args):
47
+ rd = np.random.randint(100000)
48
+ print('seed =', rd)
49
+ np.random.seed(rd)
50
+ torch.manual_seed(rd)
51
+ if args.n_gpu > 0:
52
+ torch.cuda.manual_seed_all(rd)
53
+
54
+ #
55
+ # Functions to prepare models' input
56
+ #
57
+
58
+
59
+ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
60
+ if args.temperature > 0.7:
61
+ logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
62
+
63
+ encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
64
+ if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
65
+ logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
66
+ return prompt_text
67
+
68
+
69
+ def prepare_xlm_input(args, model, tokenizer, prompt_text):
70
+ # kwargs = {"language": None, "mask_token_id": None}
71
+
72
+ # Set the language
73
+ use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
74
+ if hasattr(model.config, "lang2id") and use_lang_emb:
75
+ available_languages = model.config.lang2id.keys()
76
+ if args.xlm_language in available_languages:
77
+ language = args.xlm_language
78
+ else:
79
+ language = None
80
+ while language not in available_languages:
81
+ language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
82
+
83
+ model.config.lang_id = model.config.lang2id[language]
84
+ # kwargs["language"] = tokenizer.lang2id[language]
85
+
86
+ # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
87
+ # XLM masked-language modeling (MLM) models need masked token
88
+ # is_xlm_mlm = "mlm" in args.model_name_or_path
89
+ # if is_xlm_mlm:
90
+ # kwargs["mask_token_id"] = tokenizer.mask_token_id
91
+
92
+ return prompt_text
93
+
94
+
95
+ def prepare_xlnet_input(args, _, tokenizer, prompt_text):
96
+ prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
97
+ return prompt_text
98
+
99
+
100
+ def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
101
+ prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
102
+ return prompt_text
103
+
104
+
105
+ PREPROCESSING_FUNCTIONS = {
106
+ "ctrl": prepare_ctrl_input,
107
+ "xlm": prepare_xlm_input,
108
+ "xlnet": prepare_xlnet_input,
109
+ "transfo-xl": prepare_transfoxl_input,
110
+ }
111
+
112
+
113
+ def adjust_length_to_model(length, max_sequence_length):
114
+ if length < 0 and max_sequence_length > 0:
115
+ length = max_sequence_length
116
+ elif 0 < max_sequence_length < length:
117
+ length = max_sequence_length # No generation bigger than model size
118
+ elif length < 0:
119
+ length = MAX_LENGTH # avoid infinite loop
120
+ return length
121
+
122
+
123
+ def main():
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument(
126
+ "--model_type",
127
+ default=None,
128
+ type=str,
129
+ required=True,
130
+ help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
131
+ )
132
+ parser.add_argument(
133
+ "--model_name_or_path",
134
+ default=None,
135
+ type=str,
136
+ required=True,
137
+ help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
138
+ )
139
+
140
+ parser.add_argument("--prompt", type=str, default="")
141
+ parser.add_argument("--length", type=int, default=20)
142
+ parser.add_argument("--stop_token", type=str, default="</s>", help="Token at which lyrics generation is stopped")
143
+
144
+ parser.add_argument(
145
+ "--temperature",
146
+ type=float,
147
+ default=1.0,
148
+ help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
149
+ )
150
+ parser.add_argument(
151
+ "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
152
+ )
153
+ parser.add_argument("--k", type=int, default=0)
154
+ parser.add_argument("--p", type=float, default=0.9)
155
+
156
+ parser.add_argument("--padding_text", type=str, default="", help="Padding lyrics for Transfo-XL and XLNet.")
157
+ parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
158
+
159
+ parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
160
+ parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
161
+ parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
162
+ args = parser.parse_args()
163
+
164
+ args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
165
+ args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
166
+
167
+ # Initialize the model and tokenizer
168
+ try:
169
+ args.model_type = args.model_type.lower()
170
+ model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
171
+ except KeyError:
172
+ raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
173
+
174
+ tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
175
+ model = model_class.from_pretrained(args.model_name_or_path)
176
+ model.to(args.device)
177
+
178
+ args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
179
+ logger.info(args)
180
+ generated_sequences = []
181
+ prompt_text = ""
182
+ while prompt_text != "stop":
183
+ set_seed(args)
184
+ while not len(prompt_text):
185
+ prompt_text = args.prompt if args.prompt else input("Context >>> ")
186
+
187
+ # Different models need different input formatting and/or extra arguments
188
+ requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
189
+ if requires_preprocessing:
190
+ prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
191
+ preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
192
+ encoded_prompt = tokenizer.encode(
193
+ preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
194
+ )
195
+ else:
196
+ encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
197
+ encoded_prompt = encoded_prompt.to(args.device)
198
+
199
+ output_sequences = model.generate(
200
+ input_ids=encoded_prompt,
201
+ max_length=args.length + len(encoded_prompt[0]),
202
+ temperature=args.temperature,
203
+ top_k=args.k,
204
+ top_p=args.p,
205
+ repetition_penalty=args.repetition_penalty,
206
+ do_sample=True,
207
+ num_return_sequences=args.num_return_sequences,
208
+ )
209
+
210
+ # Remove the batch dimension when returning multiple sequences
211
+ if len(output_sequences.shape) > 2:
212
+ output_sequences.squeeze_()
213
+
214
+ now = datetime.datetime.now()
215
+ date_time = now.strftime('%Y%m%d_%H%M%S%f')
216
+
217
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
218
+ print("ruGPT:".format(generated_sequence_idx + 1))
219
+ generated_sequence = generated_sequence.tolist()
220
+
221
+ # Decode lyrics
222
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
223
+
224
+ # Remove all lyrics after the stop token
225
+ text = text[: text.find(args.stop_token) if args.stop_token else None]
226
+
227
+ # Add the prompt at the beginning of the sequence. Remove the excess lyrics that was used for pre-processing
228
+ total_sequence = (
229
+ prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
230
+ )
231
+
232
+ generated_sequences.append(total_sequence)
233
+ # os.system('clear')
234
+ print(total_sequence)
235
+
236
+ prompt_text = ""
237
+ if args.prompt:
238
+ break
239
+
240
+ return generated_sequences
241
+
242
+ title = "ruGPT3 Song Writer"
243
+ description = "Generate russian songs via fine-tuned ruGPT3"
244
+
245
+ gr.Interface(
246
+ process,
247
+ gr.inputs.Textbox(lines=1, label="Input text", examples="Как дела? Как дела? Это новый кадиллак"),
248
+ gr.outputs.Textbox(lines=20, label="Output text"),
249
+ title=title,
250
+ description=description,
251
+ ).launch(enable_queue=True,cache_examples=True)