bankholdup's picture
Update app.py
eb38bb6
raw
history blame
2.33 kB
import transformers
import numpy as np
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import pipeline
@st.cache(allow_output_mutation=True)
def load_model():
model_ckpt = "bankholdup/rugpt3_song_writer"
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
return tokenizer, model
def set_seed(args):
rd = np.random.randint(100000)
print('seed =', rd)
np.random.seed(rd)
torch.manual_seed(rd)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(rd)
title = st.title("Loading model")
tokenizer, model = load_model()
title.title("ruGPT3 Song Writer")
context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
if st.button("Поехали", help="Может занять какое-то время"):
prefix_text = f"{context}"
encoded_prompt = tokenizer.encode(prefix_text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=200 + len(encoded_prompt[0]),
temperature=0.95,
top_k=50,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
num_return_sequences=1,
)
# Remove the batch dimension when returning multiple sequences
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print("ruGPT:".format(generated_sequence_idx + 1))
generated_sequence = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
# Remove all text after the stop token
text = text[: text.find("</s>") if "</s>" else None]
# Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
generated_sequences.append(total_sequence)
# os.system('clear')
st.write(total_sequence)