Spaces:
Build error
Build error
File size: 2,050 Bytes
67c2dd6 8ff16c1 67c2dd6 931dc99 67c2dd6 8ff16c1 dc60d3c 8ff16c1 73288d7 67c2dd6 8ff16c1 67c2dd6 1e0c655 dc60d3c 9e88c07 73288d7 931dc99 9e88c07 54bed4f 931dc99 0c1b939 931dc99 dc60d3c 1e0c655 9e88c07 b4b3720 86ae349 f5d485d b4b3720 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import transformers
import numpy as np
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel
@st.cache(allow_output_mutation=True)
def load_model():
model_ckpt = "bankholdup/rugpt3_song_writer"
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
return tokenizer, model
def set_seed(rng=100000):
rd = np.random.randint(rng)
np.random.seed(rd)
torch.manual_seed(rd)
title = st.title("Загрузка модели")
tokenizer, model = load_model()
title.title("ruGPT3 Song Writer")
context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
if st.button("Поехали", help="Может занять какое-то время"):
generated_sequences = []
set_seed()
st.write("Генерируем...")
prompt_text = f"{context}"
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=200 + len(encoded_prompt[0]),
temperature=1.0,
top_k=50,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
num_return_sequences=1,
)
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
splits = total_sequence.splitlines()
for line in range(len(splits)-1):
if "[" in splits[line]:
st.write("\n")
continue
st.write(splits[line])
|