Spaces:
Build error
Build error
import transformers | |
import numpy as np | |
import torch | |
import streamlit as st | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
def load_model(): | |
# model_ckpt = "bankholdup/rugpt3_song_writer" | |
model_ckpt = "bankholdup/mgpt_song_writer" | |
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt) | |
# model = GPT2LMHeadModel.from_pretrained(model_ckpt) | |
model = GPT2LMHeadModel.from_pretrained(model_ckpt, low_cpu_mem_usage=True) | |
return tokenizer, model | |
def set_seed(rng=100000): | |
rd = np.random.randint(rng) | |
np.random.seed(rd) | |
torch.manual_seed(rd) | |
title = st.title("Загрузка модели") | |
tokenizer, model = load_model() | |
title.title("Генератор текстов русского рэпа на основе ruGPT3 ") | |
context = st.text_input("Введите начало песни", "Нету милфы сексапильней, чем Екатерина Шульман") | |
temperature= st.slider("temperature (чем выше, тем модель сильнее импровизирует; чем ниже, тем больше повторяется)", 0.0, 2.5, 0.95) | |
if st.button("Поехали", help="Может занять какое-то время"): | |
with st.spinner("Генерируем..."): | |
generated_sequences = [] | |
set_seed() | |
prompt_text = f"{context}" | |
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") | |
output_sequences = model.generate( | |
input_ids=encoded_prompt, | |
max_length=200 + len(encoded_prompt[0]), | |
temperature=temperature, | |
top_k=50, | |
top_p=0.95, | |
repetition_penalty=1.0, | |
do_sample=True, | |
num_return_sequences=1 | |
) | |
if len(output_sequences.shape) > 2: | |
output_sequences.squeeze_() | |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
generated_sequence = generated_sequence.tolist() | |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) | |
total_sequence = ( | |
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] | |
) | |
splits = total_sequence.splitlines() | |
for line in range(len(splits)-5): | |
if "[" in splits[line]: | |
st.write("\n") | |
continue | |
st.write(splits[line]) | |