import transformers import numpy as np import torch import streamlit as st from transformers import GPT2Tokenizer, GPT2LMHeadModel @st.cache(allow_output_mutation=True) def load_model(): model_ckpt = "bankholdup/rugpt3_song_writer" tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt) model = GPT2LMHeadModel.from_pretrained(model_ckpt) return tokenizer, model def set_seed(rng=100000): rd = np.random.randint(rng) np.random.seed(rd) torch.manual_seed(rd) title = st.title("Загрузка модели") tokenizer, model = load_model() title.title("ruGPT3 Song Writer") context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак") if st.button("Поехали", help="Может занять какое-то время"): generated_sequences = [] set_seed() st.write("Генерируем") prompt_text = f"{context}" encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") output_sequences = model.generate( input_ids=encoded_prompt, max_length=250 + len(encoded_prompt[0]), temperature=1.0, top_k=50, top_p=0.95, repetition_penalty=1.0, do_sample=True, num_return_sequences=1, ) if len(output_sequences.shape) > 2: output_sequences.squeeze_() for generated_sequence_idx, generated_sequence in enumerate(output_sequences): generated_sequence = generated_sequence.tolist() text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) #text = text[: text.find("/n") if "/n" else None] total_sequence = ( prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] ) #generated_sequences.append(total_sequence) splits = total_sequence.splitlines() for line in range(len(splits)-1): st.write(splits[line])