File size: 2,446 Bytes
7d6f77f
 
 
 
518485c
da9ee26
7d6f77f
 
47d0a4a
7d6f77f
 
724876e
1bc822d
724876e
7d6f77f
724876e
7d6f77f
 
 
 
 
a16dba0
d2e6254
bb72c45
7d6f77f
b86439f
7d6f77f
 
 
 
 
 
d2e6254
1bc822d
7d6f77f
 
 
 
 
e1458fb
7d6f77f
da9ee26
d2e6254
da9ee26
 
7d6f77f
 
 
b86439f
7d6f77f
 
 
ded5ed0
da9ee26
 
4bd4566
 
 
 
 
 
1bc822d
 
 
d2e6254
4bd4566
ded5ed0
da9ee26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import transformers
import torch
import tokenizers
import streamlit as st
import re
from PIL import Image


@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def get_model(model_name, model_path):
    tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
    tokenizer.add_special_tokens({
        'eos_token': '[EOS]'
    })
    model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
    model.resize_token_embeddings(len(tokenizer))
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model, tokenizer


#@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, length_of_generated=300):
    text += '\n'
    input_ids = tokenizer.encode(text, return_tensors="pt")
    length_of_prompt = len(input_ids[0])
    with torch.no_grad():
        out = model.generate(input_ids,
                             do_sample=True,
                             num_beams=n_beams,
                             temperature=temperature,
                             top_p=top_p,
                             max_length=length_of_prompt + length_of_generated,
                             eos_token_id=tokenizer.eos_token_id
                             )

    return list(map(tokenizer.decode, out))[0]


model, tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')

# st.title("NeuroKorzh")

image = Image.open('korzh.jpg')
st.image(image, caption='NeuroKorzh')

st.markdown("\n")

text = st.text_input(label='Starting point for text generation', value='Что делать, Макс?')
button = st.button('Go')

if button:
    #try:
    with st.spinner("Generation in progress"):
        result = predict(text, model, tokenizer)
    
    #st.subheader('Max Korzh:')
    #lines = result.split('\n')
    #for line in lines:
    #    st.write(line)
    
    #lines = result.replace('\n', '\n\n')
    #st.write(lines)
    
    st.text_area(label='', value=result, height=1200)
    
    #except Exception:
    #    st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur")