vladyur commited on
Commit
7d6f77f
·
1 Parent(s): d2b2606

Create new file

Browse files
Files changed (1) hide show
  1. app.py +48 -0
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import tokenizers
4
+ import streamlit as st
5
+
6
+
7
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None}, suppress_st_warning=True)
8
+ def get_model(model_name, model_path):
9
+ tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
10
+ model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
11
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
12
+ model.eval()
13
+ return model, tokenizer
14
+
15
+
16
+ @st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None}, suppress_st_warning=True)
17
+ def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, max_length=300):
18
+ input_ids = tokenizer.encode(text, return_tensors="pt")
19
+ with torch.no_grad():
20
+ out = model.generate(input_ids,
21
+ do_sample=True,
22
+ num_beams=n_beams,
23
+ temperature=temperature,
24
+ top_p=top_p,
25
+ max_length=max_length,
26
+ )
27
+
28
+ return list(map(tokenizer.decode, out))[0]
29
+
30
+
31
+ model, tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'SOME_CHECKPOINT.bin')
32
+
33
+ st.title("NeuroKorzh")
34
+ st.markdown("<img width=200px src='https://avatars.yandex.net/get-music-content/2399641/5d26d7e5.p.975699/m1000x1000'>",
35
+ unsafe_allow_html=True)
36
+
37
+ st.markdown("\n")
38
+
39
+ text = st.text_area(label='Starting point for text generation', height=200)
40
+ button = st.button('Go')
41
+
42
+ if button:
43
+ try:
44
+ result = predict(text, model, tokenizer)
45
+ st.subheader('Max Korzh:')
46
+ st.write(result)
47
+ except Exception:
48
+ st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur")