Spaces:
Build error
Build error
File size: 2,326 Bytes
67c2dd6 8ff16c1 67c2dd6 931dc99 67c2dd6 8ff16c1 dc60d3c 8ff16c1 73288d7 67c2dd6 91c7bac 8ff16c1 67c2dd6 1e0c655 dc60d3c 9e88c07 8ffa1f8 9d70aae 73288d7 931dc99 9e88c07 91c7bac 9104177 931dc99 0c1b939 931dc99 dc60d3c 1e0c655 9e88c07 b4b3720 4ad1735 f5d485d b4b3720 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import transformers
import numpy as np
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel
@st.cache(allow_output_mutation=True)
def load_model():
model_ckpt = "bankholdup/rugpt3_song_writer"
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
return tokenizer, model
def set_seed(rng=100000):
rd = np.random.randint(rng)
np.random.seed(rd)
torch.manual_seed(rd)
title = st.title("Загрузка модели")
tokenizer, model = load_model()
title.title("ruGPT3 Song Writer")
context = st.text_input("Введите начало песни", "Батя возвращается трезвым, в руке буханка")
temperature= st.slider("temperature (чем выше, тем текст безумнее, чем ниже, тем ближе к исходным данным)", 0.0, 2.5, 0.95)
if st.button("Поехали", help="Может занять какое-то время"):
generated_sequences = []
set_seed()
st.write("Генерируем...")
st.write("temperature = {}".format(temperature))
st.write("_____________")
prompt_text = f"{context}"
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=200 + len(encoded_prompt[0]),
temperature=temperature,
top_k=50,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
num_return_sequences=1,
)
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
splits = total_sequence.splitlines()
for line in range(len(splits)-5):
if "[" in splits[line]:
st.write("\n")
continue
st.write(splits[line])
|