bankholdup commited on
Commit
0cde2e7
1 Parent(s): 5dd41b9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -35
app.py CHANGED
@@ -22,40 +22,42 @@ tokenizer, model = load_model()
22
  title.title("ruGPT3 Song Writer")
23
  context = st.text_input("Введите начало песни", "Нету милфы сексапильней, чем Екатерина Шульман")
24
  temperature= st.slider("temperature (чем выше, тем текст безумнее, чем ниже, тем ближе к исходным данным)", 0.0, 2.5, 0.95)
 
25
 
26
  if st.button("Поехали", help="Может занять какое-то время"):
27
- generated_sequences = []
28
- set_seed()
29
- st.write("Генерируем...")
30
- st.write("temperature = {}".format(temperature))
31
- st.write("_____________")
32
- prompt_text = f"{context}"
33
- encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
34
- output_sequences = model.generate(
35
- input_ids=encoded_prompt,
36
- max_length=200 + len(encoded_prompt[0]),
37
- temperature=temperature,
38
- top_k=50,
39
- top_p=0.95,
40
- repetition_penalty=1.0,
41
- do_sample=True,
42
- num_return_sequences=1,
43
- )
44
- if len(output_sequences.shape) > 2:
45
- output_sequences.squeeze_()
46
-
47
- for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
48
- generated_sequence = generated_sequence.tolist()
49
- text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
50
-
51
- total_sequence = (
52
- prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
53
- )
54
-
55
- splits = total_sequence.splitlines()
56
- for line in range(len(splits)-5):
57
- if "[" in splits[line]:
58
- st.write("\n")
59
- continue
60
-
61
- st.write(splits[line])
 
 
22
  title.title("ruGPT3 Song Writer")
23
  context = st.text_input("Введите начало песни", "Нету милфы сексапильней, чем Екатерина Шульман")
24
  temperature= st.slider("temperature (чем выше, тем текст безумнее, чем ниже, тем ближе к исходным данным)", 0.0, 2.5, 0.95)
25
+ st.write("_____________")
26
 
27
  if st.button("Поехали", help="Может занять какое-то время"):
28
+ with st.spinner("Генерируем..."):
29
+ generated_sequences = []
30
+ set_seed()
31
+ #st.write("Генерируем...")
32
+ #st.write("temperature = {}".format(temperature))
33
+ #st.write("_____________")
34
+ prompt_text = f"{context}"
35
+ encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
36
+ output_sequences = model.generate(
37
+ input_ids=encoded_prompt,
38
+ max_length=200 + len(encoded_prompt[0]),
39
+ temperature=temperature,
40
+ top_k=50,
41
+ top_p=0.95,
42
+ repetition_penalty=1.0,
43
+ do_sample=True,
44
+ num_return_sequences=1,
45
+ )
46
+ if len(output_sequences.shape) > 2:
47
+ output_sequences.squeeze_()
48
+
49
+ for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
50
+ generated_sequence = generated_sequence.tolist()
51
+ text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
52
+
53
+ total_sequence = (
54
+ prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
55
+ )
56
+
57
+ splits = total_sequence.splitlines()
58
+ for line in range(len(splits)-5):
59
+ if "[" in splits[line]:
60
+ st.write("\n")
61
+ continue
62
+
63
+ st.write(splits[line])