bankholdup commited on
Commit
3e8e7bc
1 Parent(s): 69f6ca7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -7,8 +7,8 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
7
 
8
  @st.cache(allow_output_mutation=True)
9
  def load_model():
10
- model_ckpt = "bankholdup/rugpt3_song_writer"
11
- # model_ckpt = "bankholdup/mgpt_song_write"
12
  tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
13
  model = GPT2LMHeadModel.from_pretrained(model_ckpt)
14
  return tokenizer, model
@@ -32,7 +32,8 @@ if st.button("Поехали", help="Может занять какое-то в
32
  #st.write("temperature = {}".format(temperature))
33
  #st.write("_____________")
34
  prompt_text = f"{context}"
35
- encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
 
36
  output_sequences = model.generate(
37
  input_ids=encoded_prompt,
38
  max_length=200 + len(encoded_prompt[0]),
@@ -42,7 +43,10 @@ if st.button("Поехали", help="Может занять какое-то в
42
  repetition_penalty=1.0,
43
  do_sample=True,
44
  num_return_sequences=1,
 
 
45
  )
 
46
  if len(output_sequences.shape) > 2:
47
  output_sequences.squeeze_()
48
 
 
7
 
8
  @st.cache(allow_output_mutation=True)
9
  def load_model():
10
+ # model_ckpt = "bankholdup/rugpt3_song_writer"
11
+ model_ckpt = "bankholdup/mgpt_song_write"
12
  tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
13
  model = GPT2LMHeadModel.from_pretrained(model_ckpt)
14
  return tokenizer, model
 
32
  #st.write("temperature = {}".format(temperature))
33
  #st.write("_____________")
34
  prompt_text = f"{context}"
35
+ #encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
36
+ encoded_prompt = tokenizer.encode(text, return_tensors="pt").cuda(device)
37
  output_sequences = model.generate(
38
  input_ids=encoded_prompt,
39
  max_length=200 + len(encoded_prompt[0]),
 
43
  repetition_penalty=1.0,
44
  do_sample=True,
45
  num_return_sequences=1,
46
+ eos_token_id=5,
47
+ pad_token=1,
48
  )
49
+
50
  if len(output_sequences.shape) > 2:
51
  output_sequences.squeeze_()
52