Spaces:
Build error
Build error
bankholdup
commited on
Commit
•
0cde2e7
1
Parent(s):
5dd41b9
Update app.py
Browse files
app.py
CHANGED
@@ -22,40 +22,42 @@ tokenizer, model = load_model()
|
|
22 |
title.title("ruGPT3 Song Writer")
|
23 |
context = st.text_input("Введите начало песни", "Нету милфы сексапильней, чем Екатерина Шульман")
|
24 |
temperature= st.slider("temperature (чем выше, тем текст безумнее, чем ниже, тем ближе к исходным данным)", 0.0, 2.5, 0.95)
|
|
|
25 |
|
26 |
if st.button("Поехали", help="Может занять какое-то время"):
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
output_sequences.
|
46 |
-
|
47 |
-
|
48 |
-
generated_sequence
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
22 |
title.title("ruGPT3 Song Writer")
|
23 |
context = st.text_input("Введите начало песни", "Нету милфы сексапильней, чем Екатерина Шульман")
|
24 |
temperature= st.slider("temperature (чем выше, тем текст безумнее, чем ниже, тем ближе к исходным данным)", 0.0, 2.5, 0.95)
|
25 |
+
st.write("_____________")
|
26 |
|
27 |
if st.button("Поехали", help="Может занять какое-то время"):
|
28 |
+
with st.spinner("Генерируем..."):
|
29 |
+
generated_sequences = []
|
30 |
+
set_seed()
|
31 |
+
#st.write("Генерируем...")
|
32 |
+
#st.write("temperature = {}".format(temperature))
|
33 |
+
#st.write("_____________")
|
34 |
+
prompt_text = f"{context}"
|
35 |
+
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
|
36 |
+
output_sequences = model.generate(
|
37 |
+
input_ids=encoded_prompt,
|
38 |
+
max_length=200 + len(encoded_prompt[0]),
|
39 |
+
temperature=temperature,
|
40 |
+
top_k=50,
|
41 |
+
top_p=0.95,
|
42 |
+
repetition_penalty=1.0,
|
43 |
+
do_sample=True,
|
44 |
+
num_return_sequences=1,
|
45 |
+
)
|
46 |
+
if len(output_sequences.shape) > 2:
|
47 |
+
output_sequences.squeeze_()
|
48 |
+
|
49 |
+
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
50 |
+
generated_sequence = generated_sequence.tolist()
|
51 |
+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
52 |
+
|
53 |
+
total_sequence = (
|
54 |
+
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
|
55 |
+
)
|
56 |
+
|
57 |
+
splits = total_sequence.splitlines()
|
58 |
+
for line in range(len(splits)-5):
|
59 |
+
if "[" in splits[line]:
|
60 |
+
st.write("\n")
|
61 |
+
continue
|
62 |
+
|
63 |
+
st.write(splits[line])
|