Spaces:
Runtime error
Runtime error
File size: 2,446 Bytes
7d6f77f 518485c da9ee26 7d6f77f 47d0a4a 7d6f77f 724876e 1bc822d 724876e 7d6f77f 724876e 7d6f77f a16dba0 d2e6254 bb72c45 7d6f77f b86439f 7d6f77f d2e6254 1bc822d 7d6f77f e1458fb 7d6f77f da9ee26 d2e6254 da9ee26 7d6f77f b86439f 7d6f77f ded5ed0 da9ee26 4bd4566 1bc822d d2e6254 4bd4566 ded5ed0 da9ee26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import transformers
import torch
import tokenizers
import streamlit as st
import re
from PIL import Image
@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def get_model(model_name, model_path):
tokenizer = transformers.GPT2Tokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({
'eos_token': '[EOS]'
})
model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model, tokenizer
#@st.cache(hash_funcs={tokenizers.Tokenizer: lambda _: None, tokenizers.AddedToken: lambda _: None, re.Pattern: lambda _: None}, allow_output_mutation=True, suppress_st_warning=True)
def predict(text, model, tokenizer, n_beams=5, temperature=2.5, top_p=0.8, length_of_generated=300):
text += '\n'
input_ids = tokenizer.encode(text, return_tensors="pt")
length_of_prompt = len(input_ids[0])
with torch.no_grad():
out = model.generate(input_ids,
do_sample=True,
num_beams=n_beams,
temperature=temperature,
top_p=top_p,
max_length=length_of_prompt + length_of_generated,
eos_token_id=tokenizer.eos_token_id
)
return list(map(tokenizer.decode, out))[0]
model, tokenizer = get_model('sberbank-ai/rugpt3medium_based_on_gpt2', 'korzh-medium_best_eval_loss.bin')
# st.title("NeuroKorzh")
image = Image.open('korzh.jpg')
st.image(image, caption='NeuroKorzh')
st.markdown("\n")
text = st.text_input(label='Starting point for text generation', value='Что делать, Макс?')
button = st.button('Go')
if button:
#try:
with st.spinner("Generation in progress"):
result = predict(text, model, tokenizer)
#st.subheader('Max Korzh:')
#lines = result.split('\n')
#for line in lines:
# st.write(line)
#lines = result.replace('\n', '\n\n')
#st.write(lines)
st.text_area(label='', value=result, height=1200)
#except Exception:
# st.error("Ooooops, something went wrong. Try again please and report to me, tg: @vladyur") |