bankholdup commited on
Commit
c0ad07a
1 Parent(s): 468bc22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -227
app.py CHANGED
@@ -1,45 +1,7 @@
1
- import os
2
-
3
- import argparse
4
- import logging
5
-
6
  import numpy as np
7
  import torch
8
- import datetime
9
  import gradio as gr
10
 
11
- from transformers import (
12
- CTRLLMHeadModel,
13
- CTRLTokenizer,
14
- GPT2LMHeadModel,
15
- GPT2Tokenizer,
16
- OpenAIGPTLMHeadModel,
17
- OpenAIGPTTokenizer,
18
- TransfoXLLMHeadModel,
19
- TransfoXLTokenizer,
20
- XLMTokenizer,
21
- XLMWithLMHeadModel,
22
- XLNetLMHeadModel,
23
- XLNetTokenizer,
24
- )
25
-
26
-
27
- logging.basicConfig(
28
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
29
- )
30
- logger = logging.getLogger(__name__)
31
-
32
- MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
33
-
34
- MODEL_CLASSES = {
35
- "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
36
- "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
37
- "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
38
- "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
39
- "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
40
- "xlm": (XLMWithLMHeadModel, XLMTokenizer),
41
- }
42
-
43
  def set_seed(args):
44
  rd = np.random.randint(100000)
45
  print('seed =', rd)
@@ -48,201 +10,23 @@ def set_seed(args):
48
  if args.n_gpu > 0:
49
  torch.cuda.manual_seed_all(rd)
50
 
51
- #
52
- # Functions to prepare models' input
53
- #
54
-
55
-
56
- def prepare_ctrl_input(args, _, tokenizer, prompt_text):
57
- if args.temperature > 0.7:
58
- logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
59
-
60
- encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
61
- if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
62
- logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
63
- return prompt_text
64
-
65
-
66
- def prepare_xlm_input(args, model, tokenizer, prompt_text):
67
- # kwargs = {"language": None, "mask_token_id": None}
68
-
69
- # Set the language
70
- use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
71
- if hasattr(model.config, "lang2id") and use_lang_emb:
72
- available_languages = model.config.lang2id.keys()
73
- if args.xlm_language in available_languages:
74
- language = args.xlm_language
75
- else:
76
- language = None
77
- while language not in available_languages:
78
- language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
79
-
80
- model.config.lang_id = model.config.lang2id[language]
81
- # kwargs["language"] = tokenizer.lang2id[language]
82
-
83
- # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
84
- # XLM masked-language modeling (MLM) models need masked token
85
- # is_xlm_mlm = "mlm" in args.model_name_or_path
86
- # if is_xlm_mlm:
87
- # kwargs["mask_token_id"] = tokenizer.mask_token_id
88
-
89
- return prompt_text
90
-
91
-
92
- def prepare_xlnet_input(args, _, tokenizer, prompt_text):
93
- prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
94
- return prompt_text
95
-
96
-
97
- def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
98
- prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
99
- return prompt_text
100
-
101
-
102
- PREPROCESSING_FUNCTIONS = {
103
- "ctrl": prepare_ctrl_input,
104
- "xlm": prepare_xlm_input,
105
- "xlnet": prepare_xlnet_input,
106
- "transfo-xl": prepare_transfoxl_input,
107
- }
108
-
109
-
110
- def adjust_length_to_model(length, max_sequence_length):
111
- if length < 0 and max_sequence_length > 0:
112
- length = max_sequence_length
113
- elif 0 < max_sequence_length < length:
114
- length = max_sequence_length # No generation bigger than model size
115
- elif length < 0:
116
- length = MAX_LENGTH # avoid infinite loop
117
- return length
118
-
119
-
120
- def main():
121
- parser = argparse.ArgumentParser()
122
- parser.add_argument(
123
- "--model_type",
124
- default=None,
125
- type=str,
126
- required=True,
127
- help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
128
- )
129
- parser.add_argument(
130
- "--model_name_or_path",
131
- default=None,
132
- type=str,
133
- required=True,
134
- help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
135
- )
136
-
137
- parser.add_argument("--prompt", type=str, default="")
138
- parser.add_argument("--length", type=int, default=20)
139
- parser.add_argument("--stop_token", type=str, default="</s>", help="Token at which lyrics generation is stopped")
140
-
141
- parser.add_argument(
142
- "--temperature",
143
- type=float,
144
- default=1.0,
145
- help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
146
- )
147
- parser.add_argument(
148
- "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
149
- )
150
- parser.add_argument("--k", type=int, default=0)
151
- parser.add_argument("--p", type=float, default=0.9)
152
-
153
- parser.add_argument("--padding_text", type=str, default="", help="Padding lyrics for Transfo-XL and XLNet.")
154
- parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
155
-
156
- parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
157
- parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
158
- parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
159
- args = parser.parse_args()
160
-
161
- args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
162
- args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
163
-
164
- # Initialize the model and tokenizer
165
- try:
166
- args.model_type = args.model_type.lower()
167
- model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
168
- except KeyError:
169
- raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
170
-
171
- tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
172
- model = model_class.from_pretrained(args.model_name_or_path)
173
- model.to(args.device)
174
-
175
- args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
176
- logger.info(args)
177
- generated_sequences = []
178
- prompt_text = ""
179
- while prompt_text != "stop":
180
- set_seed(args)
181
- while not len(prompt_text):
182
- prompt_text = args.prompt if args.prompt else input("Context >>> ")
183
-
184
- # Different models need different input formatting and/or extra arguments
185
- requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
186
- if requires_preprocessing:
187
- prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
188
- preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
189
- encoded_prompt = tokenizer.encode(
190
- preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
191
- )
192
- else:
193
- encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
194
- encoded_prompt = encoded_prompt.to(args.device)
195
-
196
- output_sequences = model.generate(
197
- input_ids=encoded_prompt,
198
- max_length=args.length + len(encoded_prompt[0]),
199
- temperature=args.temperature,
200
- top_k=args.k,
201
- top_p=args.p,
202
- repetition_penalty=args.repetition_penalty,
203
- do_sample=True,
204
- num_return_sequences=args.num_return_sequences,
205
- )
206
-
207
- # Remove the batch dimension when returning multiple sequences
208
- if len(output_sequences.shape) > 2:
209
- output_sequences.squeeze_()
210
-
211
- now = datetime.datetime.now()
212
- date_time = now.strftime('%Y%m%d_%H%M%S%f')
213
-
214
- for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
215
- print("ruGPT:".format(generated_sequence_idx + 1))
216
- generated_sequence = generated_sequence.tolist()
217
-
218
- # Decode lyrics
219
- text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
220
-
221
- # Remove all lyrics after the stop token
222
- text = text[: text.find(args.stop_token) if args.stop_token else None]
223
-
224
- # Add the prompt at the beginning of the sequence. Remove the excess lyrics that was used for pre-processing
225
- total_sequence = (
226
- prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
227
- )
228
-
229
- generated_sequences.append(total_sequence)
230
- # os.system('clear')
231
- print(total_sequence)
232
 
233
- prompt_text = ""
234
- if args.prompt:
235
- break
236
 
237
- return generated_sequences
 
 
238
 
239
- title = "ruGPT3 Song Writer"
240
- description = "Generate russian songs via fine-tuned ruGPT3"
241
 
242
  gr.Interface(
243
- process,
244
- gr.inputs.Textbox(lines=1, label="Input text", examples="Как дела? Как дела? Это новый кадиллак"),
245
  gr.outputs.Textbox(lines=20, label="Output text"),
 
246
  title=title,
247
  description=description,
248
  ).launch(enable_queue=True,cache_examples=True)
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
 
3
  import gradio as gr
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  def set_seed(args):
6
  rd = np.random.randint(100000)
7
  print('seed =', rd)
 
10
  if args.n_gpu > 0:
11
  torch.cuda.manual_seed_all(rd)
12
 
13
+ title = "ruGPT3 Song Writer"
14
+ description = "Generate russian songs via fine-tuned ruGPT3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ io = gr.Interface.load("models/bankholdup/rugpt3_song_writer")
 
 
17
 
18
+ examples = [
19
+ ['Как дела? Как дела? Это новый кадиллак']
20
+ ]
21
 
22
+ def inference(text):
23
+ return io(text)
24
 
25
  gr.Interface(
26
+ inference,
27
+ [gr.inputs.Textbox(lines=1, label="Input text")],
28
  gr.outputs.Textbox(lines=20, label="Output text"),
29
+ examples=examples,
30
  title=title,
31
  description=description,
32
  ).launch(enable_queue=True,cache_examples=True)