bankholdup commited on
Commit
dc60d3c
1 Parent(s): 41e6718

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -13,13 +13,10 @@ def load_model():
13
  model = GPT2LMHeadModel.from_pretrained(model_ckpt)
14
  return tokenizer, model
15
 
16
- def set_seed(args):
17
- rd = np.random.randint(100000)
18
- print('seed =', rd)
19
  np.random.seed(rd)
20
  torch.manual_seed(rd)
21
- if args.n_gpu > 0:
22
- torch.cuda.manual_seed_all(rd)
23
 
24
  title = st.title("Загрузка модели")
25
  tokenizer, model = load_model()
@@ -28,38 +25,34 @@ context = st.text_input("Введите начало песни", "Как дел
28
  generated_sequences = []
29
 
30
  if st.button("Поехали", help="Может занять какое-то время"):
 
31
  prompt_text = f"{context}"
32
  encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
33
  output_sequences = model.generate(
34
  input_ids=encoded_prompt,
35
- max_length=200 + len(encoded_prompt[0]),
36
- temperature=0.95,
37
  top_k=50,
38
  top_p=0.95,
39
  repetition_penalty=1.0,
40
  do_sample=True,
41
  num_return_sequences=1,
42
  )
43
-
44
- # Remove the batch dimension when returning multiple sequences
45
  if len(output_sequences.shape) > 2:
46
  output_sequences.squeeze_()
47
 
48
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
49
  print("ruGPT:".format(generated_sequence_idx + 1))
50
  generated_sequence = generated_sequence.tolist()
51
-
52
- # Decode text
53
  text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
54
 
55
- # Remove all text after the stop token
56
  text = text[: text.find("</s>") if "</s>" else None]
57
-
58
- # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
59
  total_sequence = (
60
- prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] + "/n"
61
  )
62
 
63
- generated_sequences.append(total_sequence)
64
  # os.system('clear')
65
  st.write(total_sequence)
 
13
  model = GPT2LMHeadModel.from_pretrained(model_ckpt)
14
  return tokenizer, model
15
 
16
+ def set_seed(rng=100000):
17
+ rd = np.random.randint(rng)
 
18
  np.random.seed(rd)
19
  torch.manual_seed(rd)
 
 
20
 
21
  title = st.title("Загрузка модели")
22
  tokenizer, model = load_model()
 
25
  generated_sequences = []
26
 
27
  if st.button("Поехали", help="Может занять какое-то время"):
28
+ set_seed()
29
  prompt_text = f"{context}"
30
  encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
31
  output_sequences = model.generate(
32
  input_ids=encoded_prompt,
33
+ max_length=250 + len(encoded_prompt[0]),
34
+ temperature=1.95,
35
  top_k=50,
36
  top_p=0.95,
37
  repetition_penalty=1.0,
38
  do_sample=True,
39
  num_return_sequences=1,
40
  )
 
 
41
  if len(output_sequences.shape) > 2:
42
  output_sequences.squeeze_()
43
 
44
  for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
45
  print("ruGPT:".format(generated_sequence_idx + 1))
46
  generated_sequence = generated_sequence.tolist()
47
+
 
48
  text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
49
 
 
50
  text = text[: text.find("</s>") if "</s>" else None]
51
+
 
52
  total_sequence = (
53
+ prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
54
  )
55
 
56
+ # generated_sequences.append(total_sequence)
57
  # os.system('clear')
58
  st.write(total_sequence)