bankholdup's picture
Update app.py
dc60d3c
raw
history blame
No virus
2.05 kB
import transformers
import numpy as np
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import pipeline
@st.cache(allow_output_mutation=True)
def load_model():
model_ckpt = "bankholdup/rugpt3_song_writer"
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
return tokenizer, model
def set_seed(rng=100000):
rd = np.random.randint(rng)
np.random.seed(rd)
torch.manual_seed(rd)
title = st.title("Загрузка модели")
tokenizer, model = load_model()
title.title("ruGPT3 Song Writer")
context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
generated_sequences = []
if st.button("Поехали", help="Может занять какое-то время"):
set_seed()
prompt_text = f"{context}"
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=250 + len(encoded_prompt[0]),
temperature=1.95,
top_k=50,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
num_return_sequences=1,
)
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print("ruGPT:".format(generated_sequence_idx + 1))
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
text = text[: text.find("</s>") if "</s>" else None]
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
# generated_sequences.append(total_sequence)
# os.system('clear')
st.write(total_sequence)