bankholdup commited on
Commit
67c2dd6
1 Parent(s): 2d36036

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -21
app.py CHANGED
@@ -1,7 +1,17 @@
 
1
  import numpy as np
2
  import torch
3
- import gradio as gr
4
- from gradio import mix
 
 
 
 
 
 
 
 
 
5
 
6
  def set_seed(args):
7
  rd = np.random.randint(100000)
@@ -11,23 +21,26 @@ def set_seed(args):
11
  if args.n_gpu > 0:
12
  torch.cuda.manual_seed_all(rd)
13
 
14
- title = "ruGPT3 Song Writer"
15
- description = "Generate russian songs via fine-tuned ruGPT3"
16
-
17
- io = gr.Interface.load("models/bankholdup/rugpt3_song_writer")
18
-
19
- examples = [
20
- ['Как дела? Как дела? Это новый кадиллак']
21
- ]
22
-
23
- def inference(text):
24
- return io(text)
25
 
26
- gr.Interface(
27
- inference,
28
- [gr.inputs.Textbox(label="Input text")],
29
- gr.outputs.Textbox(label="Output text"),
30
- examples=examples,
31
- title=title,
32
- description=description,
33
- ).launch(enable_queue=True,cache_examples=True)
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
  import numpy as np
3
  import torch
4
+ import streamlit as st
5
+
6
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
7
+ from transformers import pipeline
8
+
9
+ @st.cache(allow_output_mutation=True)
10
+ def load_model():
11
+ model_ckpt = "bankholdup/rugpt3_song_writer"
12
+ tokenizer = GPT2Tokenizer.from_pretrained(model_ckpt,from_flax=True)
13
+ model = GPT2LMHeadModel.from_pretrained(model_ckpt,from_flax=True)
14
+ return tokenizer, model
15
 
16
  def set_seed(args):
17
  rd = np.random.randint(100000)
 
21
  if args.n_gpu > 0:
22
  torch.cuda.manual_seed_all(rd)
23
 
24
+ title = st.title("Loading model")
25
+ tokenizer, model = load_model()
26
+ text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
27
+ title.title("ruGPT3 Song Writer")
28
+ context = st.text_input("Введите начало песни", "Как дела? Как дела? Это новый кадиллак")
 
 
 
 
 
 
29
 
30
+ if st.button("Поехали", help="Может занять какое-то время"):
31
+ st.title(f"{context}")
32
+ prefix_text = f"{context}"
33
+ generated_song = text_generation(prefix_text, max_length=200, temperature=0.9, k=50, do_sample=True)[0]
34
+ for count, line in enumerate(generated_song['generated_text'].split("\n")):
35
+ if"<EOS>" in line:
36
+ break
37
+ if count == 0:
38
+ st.markdown(f"**{line[line.find('['):]}**")
39
+ continue
40
+ if "<BOS>" in line:
41
+ st.write(line[5:])
42
+ continue
43
+ if line.startswith("["):
44
+ st.markdown(f"**{line}**")
45
+ continue
46
+ st.write(line)