bankholdup commited on
Commit
931dc99
1 Parent(s): 67c2dd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -17
app.py CHANGED
@@ -9,8 +9,8 @@ from transformers import pipeline
9
  @st.cache(allow_output_mutation=True)
10
  def load_model():
11
  model_ckpt = "bankholdup/rugpt3_song_writer"
12
- tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt,from_flax=True)
13
- model = GPT2LMHeadModel.from_pretrained(model_ckpt,from_flax=True)
14
  return tokenizer, model
15
 
16
  def set_seed(args):
@@ -23,24 +23,43 @@ def set_seed(args):
23
 
24
  title = st.title("Loading model")
25
  tokenizer, model = load_model()
26
- text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
27
  title.title("ruGPT3 Song Writer")
28
  context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
29
 
30
  if st.button("Поехали", help="Может занять какое-то время"):
31
  st.title(f"{context}")
32
  prefix_text = f"{context}"
33
- generated_song = text_generation(prefix_text, max_length=200, temperature=0.9, k=50, do_sample=True)[0]
34
- for count, line in enumerate(generated_song['generated_text'].split("\n")):
35
- if"<EOS>" in line:
36
- break
37
- if count == 0:
38
- st.markdown(f"**{line[line.find('['):]}**")
39
- continue
40
- if "<BOS>" in line:
41
- st.write(line[5:])
42
- continue
43
- if line.startswith("["):
44
- st.markdown(f"**{line}**")
45
- continue
46
- st.write(line)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  @st.cache(allow_output_mutation=True)
10
  def load_model():
11
  model_ckpt = "bankholdup/rugpt3_song_writer"
12
+ tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
13
+ model = GPT2LMHeadModel.from_pretrained(model_ckpt)
14
  return tokenizer, model
15
 
16
  def set_seed(args):
 
23
 
24
  title = st.title("Loading model")
25
  tokenizer, model = load_model()
 
26
  title.title("ruGPT3 Song Writer")
27
  context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
28
 
29
  if st.button("Поехали", help="Может занять какое-то время"):
30
  st.title(f"{context}")
31
  prefix_text = f"{context}"
32
+ encoded_prompt = tokenizer.encode(prefix_text, add_special_tokens=False, return_tensors="pt")
33
+ output_sequences = model.generate(
34
+ input_ids=encoded_prompt,
35
+ max_length=200 + len(encoded_prompt[0]),
36
+ temperature=0.95,
37
+ top_k=50,
38
+ top_p=0.95,
39
+ repetition_penalty=1.0,
40
+ do_sample=True,
41
+ num_return_sequences=1,
42
+ )
43
+
44
+ # Remove the batch dimension when returning multiple sequences
45
+ if len(output_sequences.shape) > 2:
46
+ output_sequences.squeeze_()
47
+
48
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
49
+ print("ruGPT:".format(generated_sequence_idx + 1))
50
+ generated_sequence = generated_sequence.tolist()
51
+
52
+ # Decode text
53
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
54
+
55
+ # Remove all text after the stop token
56
+ text = text[: text.find(args.stop_token) if args.stop_token else None]
57
+
58
+ # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
59
+ total_sequence = (
60
+ prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
61
+ )
62
+
63
+ generated_sequences.append(total_sequence)
64
+ # os.system('clear')
65
+ st.write(total_sequence)