File size: 2,326 Bytes
67c2dd6
8ff16c1
 
67c2dd6
 
 
 
 
 
 
931dc99
 
67c2dd6
8ff16c1
dc60d3c
 
8ff16c1
 
 
73288d7
67c2dd6
 
91c7bac
 
8ff16c1
67c2dd6
1e0c655
dc60d3c
9e88c07
8ffa1f8
9d70aae
73288d7
 
931dc99
 
9e88c07
91c7bac
9104177
 
931dc99
 
 
 
 
0c1b939
 
931dc99
 
 
dc60d3c
1e0c655
 
 
9e88c07
b4b3720
4ad1735
f5d485d
 
 
 
b4b3720
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import transformers
import numpy as np
import torch
import streamlit as st

from transformers import GPT2Tokenizer, GPT2LMHeadModel

@st.cache(allow_output_mutation=True)
def load_model():
  model_ckpt = "bankholdup/rugpt3_song_writer"
  tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
  model = GPT2LMHeadModel.from_pretrained(model_ckpt)
  return tokenizer, model

def set_seed(rng=100000):
    rd = np.random.randint(rng)
    np.random.seed(rd)
    torch.manual_seed(rd)

title = st.title("Загрузка модели")
tokenizer, model = load_model()
title.title("ruGPT3 Song Writer")
context = st.text_input("Введите начало песни", "Батя возвращается трезвым, в руке буханка")
temperature= st.slider("temperature (чем выше, тем текст безумнее, чем ниже, тем ближе к исходным данным)", 0.0, 2.5, 0.95)

if st.button("Поехали", help="Может занять какое-то время"):
    generated_sequences = []
    set_seed()
    st.write("Генерируем...")
    st.write("temperature = {}".format(temperature))
    st.write("_____________")
    prompt_text = f"{context}"
    encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
    output_sequences = model.generate(
            input_ids=encoded_prompt,
            max_length=200 + len(encoded_prompt[0]),
            temperature=temperature,
            top_k=50,
            top_p=0.95,
            repetition_penalty=1.0,
            do_sample=True,
            num_return_sequences=1,
        )
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()   
        
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        generated_sequence = generated_sequence.tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        total_sequence = (
            prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )

        splits = total_sequence.splitlines()
        for line in range(len(splits)-5):
            if "[" in splits[line]:
                st.write("\n")
                continue
                
            st.write(splits[line])