import transformers import numpy as np import torch import streamlit as st from transformers import GPT2Tokenizer, GPT2LMHeadModel @st.cache(allow_output_mutation=True) def load_model(): # model_ckpt = "bankholdup/rugpt3_song_writer" model_ckpt = "bankholdup/mgpt_song_writer" tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt) model = GPT2LMHeadModel.from_pretrained(model_ckpt) return tokenizer, model def set_seed(rng=100000): rd = np.random.randint(rng) np.random.seed(rd) torch.manual_seed(rd) title = st.title("Загрузка модели") tokenizer, model = load_model() title.title("Генератор текстов русского рэпа на основе ruGPT3 ") context = st.text_input("Введите начало песни", "Нету милфы сексапильней, чем Екатерина Шульман") temperature= st.slider("temperature (чем выше, тем модель сильнее импровизирует; чем ниже, тем больше повторяется)", 0.0, 2.5, 0.95) if st.button("Поехали", help="Может занять какое-то время"): with st.spinner("Генерируем..."): generated_sequences = [] set_seed() #st.write("Генерируем...") #st.write("temperature = {}".format(temperature)) #st.write("_____________") prompt_text = f"{context}" #encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = tokenizer.encode(text, return_tensors="pt").cuda(device) output_sequences = model.generate( input_ids=encoded_prompt, max_length=200 + len(encoded_prompt[0]), temperature=temperature, top_k=50, top_p=0.95, repetition_penalty=1.0, do_sample=True, num_return_sequences=1, eos_token_id=5, pad_token=1, ) if len(output_sequences.shape) > 2: output_sequences.squeeze_() for generated_sequence_idx, generated_sequence in enumerate(output_sequences): generated_sequence = generated_sequence.tolist() text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) total_sequence = ( prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] ) splits = total_sequence.splitlines() for line in range(len(splits)-5): if "[" in splits[line]: st.write("\n") continue st.write(splits[line])