import transformers import numpy as np import torch import streamlit as st from transformers import GPT2Tokenizer, GPT2LMHeadModel @st.cache(allow_output_mutation=True) def load_model(): model_ckpt = "bankholdup/rugpt3_song_writer" tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt) model = GPT2LMHeadModel.from_pretrained(model_ckpt) return tokenizer, model def set_seed(rng=100000): rd = np.random.randint(rng) np.random.seed(rd) torch.manual_seed(rd) title = st.title("Загрузка модели") tokenizer, model = load_model() title.title("ruGPT3 Song Writer") context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак") temperature= st.slider("temperature", 0.0, 2.5, 0.95) if st.button("Поехали", help="Может занять какое-то время"): generated_sequences = [] set_seed() st.write("Генерируем...") st.write("temperature = {}".format(temperature)) st.write("_____________") prompt_text = f"{context}" encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") output_sequences = model.generate( input_ids=encoded_prompt, max_length=200 + len(encoded_prompt[0]), temperature=temperature, top_k=50, top_p=0.95, repetition_penalty=1.0, do_sample=True, num_return_sequences=1, ) if len(output_sequences.shape) > 2: output_sequences.squeeze_() for generated_sequence_idx, generated_sequence in enumerate(output_sequences): generated_sequence = generated_sequence.tolist() text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) total_sequence = ( prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] ) splits = total_sequence.splitlines() for line in range(len(splits)-5): if "[" in splits[line]: st.write("\n") continue st.write(splits[line])