Spaces:
Build error
Build error
File size: 2,056 Bytes
67c2dd6 8ff16c1 67c2dd6 931dc99 67c2dd6 8ff16c1 dc60d3c 8ff16c1 73288d7 67c2dd6 8ff16c1 67c2dd6 1e0c655 dc60d3c ca9f478 73288d7 931dc99 dc60d3c 54bed4f 931dc99 1e0c655 54bed4f 931dc99 6d2bb2a dc60d3c 1e0c655 6d2bb2a 931dc99 03808e1 1e0c655 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import transformers
import numpy as np
import torch
import streamlit as st
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import pipeline
@st.cache(allow_output_mutation=True)
def load_model():
model_ckpt = "bankholdup/rugpt3_song_writer"
tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
model = GPT2LMHeadModel.from_pretrained(model_ckpt)
return tokenizer, model
def set_seed(rng=100000):
rd = np.random.randint(rng)
np.random.seed(rd)
torch.manual_seed(rd)
title = st.title("Загрузка модели")
tokenizer, model = load_model()
title.title("ruGPT3 Song Writer")
context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
if st.button("Поехали", help="Может занять какое-то время"):
generated_sequences = []
set_seed()
st.write("Генерируем")
prompt_text = f"{context}"
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=250 + len(encoded_prompt[0]),
temperature=1.0,
top_k=50,
top_p=0.95,
repetition_penalty=1.0,
do_sample=True,
num_return_sequences=1,
)
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
st.title("Топовая песня")
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
generated_sequence = generated_sequence.tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
#text = text[: text.find("/n") if "/n" else None]
total_sequence = (
prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
)
#generated_sequences.append(total_sequence)
st.write(total_sequence)
|