Spaces:
Build error
Build error
File size: 2,613 Bytes
67c2dd6 8ff16c1 67c2dd6 848eb9d 931dc99 848eb9d 67c2dd6 8ff16c1 dc60d3c 8ff16c1 73288d7 67c2dd6 1427040 5dd41b9 1427040 8ff16c1 67c2dd6 0cde2e7 f5b970b 0cde2e7 f5b970b 0cde2e7 f5b970b 0cde2e7 3e8e7bc 0cde2e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import transformers
import numpy as np
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel
@st.cache(allow_output_mutation=True)
def load_model():
model_ckpt = "bankholdup/rugpt3_song_writer"
# model_ckpt = "bankholdup/mgpt_song_writer"
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
# model = GPT2LMHeadModel.from_pretrained(model_ckpt, low_cpu_mem_usage=True)
return tokenizer, model
def set_seed(rng=100000):
rd = np.random.randint(rng)
np.random.seed(rd)
torch.manual_seed(rd)
title = st.title("Загрузка модели")
tokenizer, model = load_model()
title.title("Генератор текстов русского рэпа на основе ruGPT3 ")
context = st.text_input("Введите начало песни", "Нету милфы сексапильней, чем Екатерина Шульман")
temperature= st.slider("temperature (чем выше, тем модель сильнее импровизирует; чем ниже, тем больше повторяется)", 0.0, 2.5, 0.95)
if st.button("Поехали", help="Может занять какое-то время"):
with st.spinner("Генерируем..."):
generated_sequences = []
set_seed()
prompt_text = f"{context}"
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=200 + len(encoded_prompt[0]),
temperature=temperature,
top_k=50,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
num_return_sequences=1
)
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
splits = total_sequence.splitlines()
for line in range(len(splits)-5):
if "[" in splits[line]:
st.write("\n")
continue
st.write(splits[line])
|