Fraser commited on
Commit
7bbddfb
β€’
1 Parent(s): 5f81dcb

add wiki model

Browse files
Files changed (4) hide show
  1. app.py +78 -2
  2. assets/autoencoder.png +0 -0
  3. assets/t5-vae.png +0 -0
  4. info.py +5 -0
app.py CHANGED
@@ -3,9 +3,27 @@ import jax.numpy as jnp
3
  from transformers import AutoTokenizer
4
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
5
  from t5_vae_flax_alt.src.t5_vae import FlaxT5VaeForAutoencoding
 
6
 
7
 
8
- st.title('T5-VAE')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  st.text('''
10
  Try interpolating between lines of Python code using this T5-VAE.
11
  ''')
@@ -79,11 +97,13 @@ def slerp(ratio, t1, t2):
79
  return res
80
 
81
 
82
- def decode(ratio, txt_1, txt_2):
83
  if not txt_1 or not txt_2:
84
  return ''
 
85
  lt_1, lt_2 = get_latent(txt_1), get_latent(txt_2)
86
  lt_new = slerp(ratio, lt_1, lt_2)
 
87
  tkns = tokens_from_latent(lt_new)
88
  return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
89
 
@@ -93,6 +113,62 @@ in_2 = st.text_input("Another line of Python code.", "x = a + 10 * 2")
93
  r = st.slider('Interpolation Ratio', min_value=0.0, max_value=1.0, value=0.5)
94
  container = st.empty()
95
  container.write('Loading...')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  out = decode(r, in_1, in_2)
97
  container.empty()
98
  st.write(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import AutoTokenizer
4
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
5
  from t5_vae_flax_alt.src.t5_vae import FlaxT5VaeForAutoencoding
6
+ import info
7
 
8
 
9
+ st.set_page_config(
10
+ page_title="T5-VAE",
11
+ page_icon="πŸ™‚πŸ˜πŸ™",
12
+ layout="wide",
13
+ initial_sidebar_state="expanded"
14
+ )
15
+
16
+
17
+ st.title('T5-VAE πŸ™‚πŸ˜πŸ™')
18
+
19
+ st.text('''
20
+ This is a variational autoencoder trained on text.
21
+
22
+ It allows interpolating on text at a high level, try it out!
23
+
24
+ See how it works [here](http://fras.uk/ml/large%20prior-free%20models/transformer-vae/2020/08/13/Transformers-as-Variational-Autoencoders.html).
25
+ ''')
26
+
27
  st.text('''
28
  Try interpolating between lines of Python code using this T5-VAE.
29
  ''')
 
97
  return res
98
 
99
 
100
+ def decode(cnt, ratio, txt_1, txt_2):
101
  if not txt_1 or not txt_2:
102
  return ''
103
+ cnt.write('Getting latents...')
104
  lt_1, lt_2 = get_latent(txt_1), get_latent(txt_2)
105
  lt_new = slerp(ratio, lt_1, lt_2)
106
+ cnt.write('Decoding latent...')
107
  tkns = tokens_from_latent(lt_new)
108
  return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
109
 
 
113
  r = st.slider('Interpolation Ratio', min_value=0.0, max_value=1.0, value=0.5)
114
  container = st.empty()
115
  container.write('Loading...')
116
+ out = decode(container, r, in_1, in_2)
117
+ container.empty()
118
+ st.write(out)
119
+
120
+
121
+ st.text('''
122
+ Try interpolating between sentences from wikipedia using this T5-VAE.
123
+ ''')
124
+
125
+
126
+ @st.cache(allow_output_mutation=True)
127
+ def get_wiki_model():
128
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
129
+ model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-wiki")
130
+ assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
131
+ return model, tokenizer
132
+
133
+
134
+ model, tokenizer = get_wiki_model()
135
+
136
+
137
+ in_1 = st.text_input("A sentence.", "Children are looking for the water to be clear.")
138
+ in_2 = st.text_input("Another sentence.", "There are two people playing soccer.")
139
+ r = st.slider('Interpolation Ratio', min_value=0.0, max_value=1.0, value=0.5)
140
+ container = st.empty()
141
+ container.write('Loading...')
142
  out = decode(r, in_1, in_2)
143
  container.empty()
144
  st.write(out)
145
+
146
+
147
+ st.text('''
148
+ Try arithmetic in latent space.
149
+ ''')
150
+
151
+
152
+ def arithmetic(cnt, txt_a, txt_b, txt_c):
153
+ if not txt_a or not txt_b or not txt_c:
154
+ return ''
155
+ cnt.write('getting latents...')
156
+ lt_a, lt_b, lt_c = get_latent(txt_a), get_latent(txt_b), get_latent(txt_c)
157
+ lt_d = lt_c + (lt_b - lt_a)
158
+ cnt.write('decoding C + (B - A)...')
159
+ tkns = tokens_from_latent(lt_d)
160
+ return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
161
+
162
+
163
+ in_a = st.text_input("A", "Children are looking for the water to be clear.")
164
+ in_b = st.text_input("B", "There are two people playing soccer.")
165
+ in_c = st.text_input("C", "Children are looking for the water to be clear.")
166
+
167
+ st.text('''
168
+ A is to B as C is to...
169
+ ''')
170
+ container = st.empty()
171
+ container.write('Loading...')
172
+ out = arithmetic(container, in_a, in_b, in_c)
173
+ container.empty()
174
+ st.write(out)
assets/autoencoder.png ADDED
assets/t5-vae.png ADDED
info.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ BACKGROUND = """
3
+
4
+
5
+ """