m-ric HF staff commited on
Commit
ad36776
·
verified ·
1 Parent(s): bc46ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -71
app.py CHANGED
@@ -1,79 +1,278 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer
3
 
4
- bert_tokenizer = AutoTokenizer.from_pretrained('openai-community/gpt2')
5
 
6
- def display_next_step_tokens(sentence, step):
7
- return (
8
- gr.Textbox.update(visible=(split_selection==LABEL_RECURSIVE)),
9
- gr.Radio.update(visible=(split_selection==LABEL_RECURSIVE)),
10
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
-
13
- with gr.Blocks(theme=gr.themes.Soft(text_size='lg', font=["monospace"], primary_hue=gr.themes.colors.green)) as demo:
14
- text = gr.Textbox(label="Your prompt to start decoding", value="Ok, I")
15
-
16
- with gr.Row():
17
- split_selection = gr.Dropdown(
18
- choices=[
19
- LABEL_TEXTSPLITTER,
20
- LABEL_RECURSIVE,
21
- ],
22
- value=LABEL_RECURSIVE,
23
- label="Method to split chunks 🍞",
24
- )
25
- separators_selection = gr.Textbox(
26
- elem_id="textbox_id",
27
- value=["\n\n", "\n", " ", ""],
28
- info="Separators used in RecursiveCharacterTextSplitter",
29
- show_label=False, # or set label to an empty string if you want to keep its space
30
- visible=True,
31
- )
32
- separator_preset_selection = gr.Radio(
33
- ['Default', 'Python', 'Markdown'],
34
- label="Choose a preset",
35
- info="This will apply a specific set of separators to RecursiveCharacterTextSplitter.",
36
- visible=True,
37
- )
38
- with gr.Row():
39
- length_unit_selection = gr.Dropdown(
40
- choices=[
41
- "Character count",
42
- "Token count (BERT tokens)",
43
- ],
44
- value="Character count",
45
- label="Length function",
46
- info="How should we measure our chunk lengths?",
47
- )
48
- slider_count = gr.Slider(
49
- 50, 500, value=200, step=1, label="Chunk length 📏", info="In the chosen unit."
50
- )
51
- chunk_overlap = gr.Slider(
52
- 0, 50, value=10, step=1, label="Overlap between chunks", info="In the chosen unit."
53
- )
54
- out = gr.HighlightedText(
55
- label="Output",
56
- show_legend=True,
57
- show_label=False,
58
- color_map={'Overlap': '#DADADA'}
59
- )
60
 
61
- split_selection.change(
62
- fn=change_split_selection,
63
- inputs=split_selection,
64
- outputs=[separators_selection, separator_preset_selection],
65
- )
66
- separator_preset_selection.change(
67
- fn=change_preset_separators,
68
- inputs=separator_preset_selection,
69
- outputs=separators_selection,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
- gr.on(
72
- [text.change, length_unit_selection.change, separators_selection.change, split_selection.change, slider_count.change, chunk_overlap.change],
73
- chunk,
74
- [text, slider_count, split_selection, separators_selection, length_unit_selection, chunk_overlap],
75
- outputs=out
76
  )
77
- demo.load(chunk, inputs=[text, slider_count, split_selection, separators_selection, length_unit_selection, chunk_overlap], outputs=out)
78
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  demo.launch()
 
1
  import gradio as gr
 
2
 
 
3
 
4
+ STYLE = """
5
+ @import url('https://fonts.googleapis.com/css2?family=Poppins:ital,wght@0,100;0,200;0,300;0,400;0,500;0,600;0,700;0,800;0,900;1,100;1,200;1,300;1,400;1,500;1,600;1,700;1,800;1,900&display=swap');
6
+ * {
7
+ padding: 0px;
8
+ margin: 0px;
9
+ box-sizing: border-box;
10
+ font-size: 16px;
11
+ }
12
+ body {
13
+ height: 100vh;
14
+ width: 100vw;
15
+ display: grid;
16
+ align-items: center;
17
+ font-family: 'Poppins', sans-serif;
18
+ }
19
+ .tree {
20
+ width: 100%;
21
+ height: auto;
22
+ text-align: center;
23
+ }
24
+ .tree ul {
25
+ padding-top: 20px;
26
+ position: relative;
27
+ transition: .5s;
28
+ }
29
+ .tree li {
30
+ display: flex;
31
+ flex-direction:row;
32
+ text-align: center;
33
+ list-style-type: none;
34
+ position: relative;
35
+ padding: 10px;
36
+ transition: .5s;
37
+ }
38
+ .tree li::before, .tree li::after {
39
+ content: '';
40
+ position: absolute;
41
+ top: 0;
42
+ right: 50%;
43
+ border-top: 1px solid #ccc;
44
+ width: 51%;
45
+ height: 10px;
46
+ }
47
+ .tree li::after {
48
+ right: auto;
49
+ left: 50%;
50
+ border-left: 1px solid #ccc;
51
+ }
52
+ .tree li:only-child::after, .tree li:only-child::before {
53
+ display: none;
54
+ }
55
+ .tree li:only-child {
56
+ padding-top: 0;
57
+ }
58
+ .tree li:first-child::before, .tree li:last-child::after {
59
+ border: 0 none;
60
+ }
61
+ .tree li:last-child::before {
62
+ border-right: 1px solid #ccc;
63
+ border-radius: 0 5px 0 0;
64
+ -webkit-border-radius: 0 5px 0 0;
65
+ -moz-border-radius: 0 5px 0 0;
66
+ }
67
+ .tree li:first-child::after {
68
+ border-radius: 5px 0 0 0;
69
+ -webkit-border-radius: 5px 0 0 0;
70
+ -moz-border-radius: 5px 0 0 0;
71
+ }
72
+ .tree ul ul::before {
73
+ content: '';
74
+ position: absolute;
75
+ top: 0;
76
+ left: 50%;
77
+ border-left: 1px solid #ccc;
78
+ width: 0;
79
+ height: 20px;
80
+ }
81
+ .tree li a {
82
+ border: 1px solid #ccc;
83
+ padding: 10px;
84
+ display: inline-grid;
85
+ border-radius: 5px;
86
+ text-decoration-line: none;
87
+ border-radius: 5px;
88
+ transition: .5s;
89
+ }
90
+ .tree li a img {
91
+ width: 50px;
92
+ height: 50px;
93
+ margin-bottom: 10px !important;
94
+ border-radius: 100px;
95
+ margin: auto;
96
+ }
97
+ .tree li a span {
98
+ border: 1px solid #ccc;
99
+ border-radius: 5px;
100
+ color: #666;
101
+ padding: 8px;
102
+ font-size: 12px;
103
+ text-transform: uppercase;
104
+ letter-spacing: 1px;
105
+ font-weight: 500;
106
+ }
107
+ /*Hover-Section*/
108
+ .tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a {
109
+ background: #c8e4f8;
110
+ color: #000;
111
+ border: 1px solid #94a0b4;
112
+ }
113
+ .tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
114
+ border-color: #94a0b4;
115
+ }
116
+ """
117
 
118
+ from transformers import GPT2Tokenizer, AutoModelForCausalLM
119
+ import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
122
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
123
+ tokenizer.pad_token_id = tokenizer.eos_token_id
124
+
125
+ def display_top_k_tokens(scores, sequences, beam_indices):
126
+ display = "<div style='display: flex; flex-direction:row;'>"
127
+ for i, sequence in enumerate(sequences):
128
+ markdown_table = f"""<p>Sequence {i}: {tokenizer.batch_decode(sequence)}<p><br>
129
+ <table>
130
+ <tr>
131
+ <th><b>Token</b></th>
132
+ <th><b>Probability</b></th>
133
+ </tr>"""
134
+ for step, step_scores in enumerate(scores):
135
+ markdown_table += f"""
136
+ <tr>
137
+ <td><b>Step {step}</b></td>
138
+ <td>=====</td>
139
+ </tr>"""
140
+ current_beam = beam_indices[i, step]
141
+ chosen_token = sequences[i, step]
142
+ for token_idx in np.argsort(step_scores[current_beam, :])[-5:]:
143
+ if token_idx == chosen_token:
144
+ markdown_table += f"""
145
+ <tr style="background-color:red">
146
+ <td>{tokenizer.decode([token_idx])}</td>
147
+ <td>{step_scores[current_beam, token_idx]}</td>
148
+ </tr>"""
149
+ else:
150
+ markdown_table += f"""
151
+ <tr>
152
+ <td>{tokenizer.decode([token_idx])}</td>
153
+ <td>{step_scores[current_beam, token_idx]}</td>
154
+ </tr>"""
155
+ markdown_table += "</table>"
156
+ display += markdown_table
157
+ display += "</div>"
158
+ print(display)
159
+ return display
160
+
161
+
162
+ def generate_html(token, node):
163
+ """Recursively generate HTML for the tree."""
164
+
165
+ html_content = f" <ul> <a href='#'> <p> <b>{token}</b> </p> "
166
+ html_content += node["table"] if node["table"] is not None else ""
167
+ html_content += "</a>"
168
+ if len(node["children"].keys()) > 0:
169
+ html_content += "<li> "
170
+ for token, subnode in node["children"].items():
171
+ html_content += generate_html(token, subnode)
172
+ html_content += "</li>"
173
+
174
+ html_content += "</ul>"
175
+
176
+ return html_content
177
+
178
+
179
+ def generate_markdown_table(scores, top_k=4, chosen_tokens=None):
180
+ markdown_table = """
181
+ <table>
182
+ <tr>
183
+ <th><b>Token</b></th>
184
+ <th><b>Probability</b></th>
185
+ </tr>"""
186
+ for token_idx in np.argsort(scores)[-top_k:]:
187
+ token = tokenizer.decode([token_idx])
188
+ style = ""
189
+ if chosen_tokens and token in chosen_tokens:
190
+ style = "background-color:red"
191
+ markdown_table += f"""
192
+ <tr style={style}>
193
+ <td>{token}</td>
194
+ <td>{scores[token_idx]}</td>
195
+ </tr>"""
196
+ markdown_table += """
197
+ </table>"""
198
+ return markdown_table
199
+
200
+
201
+ def display_tree(scores, sequences, beam_indices):
202
+ display = """<body>
203
+ <div class="container">
204
+ <div class="row">
205
+ <div class="tree">"""
206
+ sequences = sequences.cpu().numpy()
207
+ print(tokenizer.batch_decode(sequences))
208
+ original_tree = {"table": None, "children": {}}
209
+ for sequence_ix in range(len(sequences)):
210
+ current_tree = original_tree
211
+ for step, step_scores in enumerate(scores):
212
+ current_token_choice = tokenizer.decode([sequences[sequence_ix, step]])
213
+ current_beam = beam_indices[sequence_ix, step]
214
+
215
+ if current_token_choice not in current_tree["children"]:
216
+ current_tree["children"][current_token_choice] = {
217
+ "table": None,
218
+ "children": {},
219
+ }
220
+
221
+ # Rewrite the probs table even if it was there before, since new chosen nodes have appeared in the children of current tree
222
+ markdown_table = generate_markdown_table(
223
+ step_scores[current_beam, :],
224
+ chosen_tokens=current_tree["children"].keys(),
225
+ )
226
+ current_tree["table"] = markdown_table
227
+
228
+ current_tree = current_tree["children"][current_token_choice]
229
+
230
+ display += generate_html("Today is", original_tree)
231
+
232
+ display += """
233
+ </div>
234
+ </div>
235
+ </div>
236
+ </body>
237
+ """
238
+ print(display)
239
+ return display
240
+
241
+
242
+ def get_tables(input_text, number_steps, number_beams):
243
+ inputs = tokenizer([input_text], return_tensors="pt")
244
+
245
+ outputs = model.generate(
246
+ **inputs,
247
+ max_new_tokens=number_steps,
248
+ num_beams=number_beams,
249
+ num_return_sequences=number_beams,
250
+ return_dict_in_generate=True,
251
+ output_scores=True,
252
+ top_k=5,
253
+ temperature=1.0,
254
+ do_sample=True,
255
  )
256
+
257
+ tables = display_tree(
258
+ outputs.scores,
259
+ outputs.sequences[:, len(inputs) :],
260
+ outputs.beam_indices[:, : -len(inputs)],
261
  )
262
+ return tables
263
+
264
+
265
+ with gr.Blocks(
266
+ theme=gr.themes.Soft(
267
+ text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green
268
+ ),
269
+ css=STYLE,
270
+ ) as demo:
271
+ text = gr.Textbox(label="Sentence to decode from🪶", value="Today is")
272
+ steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4)
273
+ beams = gr.Slider(label="Number of beams", minimum=1, maximum=3, step=1, value=3)
274
+ button = gr.Button()
275
+ out = gr.Markdown(label="Output")
276
+ button.click(get_tables, inputs=[text, steps, beams], outputs=out)
277
+
278
  demo.launch()