File size: 2,050 Bytes
67c2dd6
8ff16c1
 
67c2dd6
 
 
 
 
 
 
931dc99
 
67c2dd6
8ff16c1
dc60d3c
 
8ff16c1
 
 
73288d7
67c2dd6
 
 
8ff16c1
67c2dd6
1e0c655
dc60d3c
9e88c07
73288d7
 
931dc99
 
9e88c07
54bed4f
931dc99
 
 
 
 
 
 
0c1b939
 
931dc99
 
 
dc60d3c
1e0c655
 
 
9e88c07
b4b3720
86ae349
f5d485d
 
 
 
b4b3720
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import transformers
import numpy as np
import torch
import streamlit as st

from transformers import GPT2Tokenizer, GPT2LMHeadModel

@st.cache(allow_output_mutation=True)
def load_model():
  model_ckpt = "bankholdup/rugpt3_song_writer"
  tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt)
  model = GPT2LMHeadModel.from_pretrained(model_ckpt)
  return tokenizer, model

def set_seed(rng=100000):
    rd = np.random.randint(rng)
    np.random.seed(rd)
    torch.manual_seed(rd)

title = st.title("Загрузка модели")
tokenizer, model = load_model()
title.title("ruGPT3 Song Writer")
context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")

if st.button("Поехали", help="Может занять какое-то время"):
    generated_sequences = []
    set_seed()
    st.write("Генерируем...")
    prompt_text = f"{context}"
    encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
    output_sequences = model.generate(
            input_ids=encoded_prompt,
            max_length=200 + len(encoded_prompt[0]),
            temperature=1.0,
            top_k=50,
            top_p=0.95,
            repetition_penalty=1.0,
            do_sample=True,
            num_return_sequences=1,
        )
    if len(output_sequences.shape) > 2:
        output_sequences.squeeze_()   
        
    for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
        generated_sequence = generated_sequence.tolist()
        text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)

        total_sequence = (
            prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
        )

        splits = total_sequence.splitlines()
        for line in range(len(splits)-1):
            if "[" in splits[line]:
                st.write("\n")
                continue
                
            st.write(splits[line])