m-ric HF staff commited on
Commit
50809fa
·
verified ·
1 Parent(s): 5e72e33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -32
app.py CHANGED
@@ -30,14 +30,12 @@ STYLE = """
30
  .prose table {
31
  margin-bottom: 0!important;
32
  }
33
-
34
  .prose td, th {
35
  padding-left: 2px;
36
  padding-right: 2px;
37
  padding-top: 0;
38
  padding-bottom: 0;
39
  }
40
-
41
  .tree {
42
  padding: 0px;
43
  margin: 0!important;
@@ -48,13 +46,11 @@ STYLE = """
48
  text-align: center;
49
  display:inline-block;
50
  }
51
-
52
  #root {
53
  display: inline-grid!important;
54
  width:auto!important;
55
  min-width: 220px;
56
  }
57
-
58
  .tree ul {
59
  padding-left: 20px;
60
  position: relative;
@@ -75,7 +71,6 @@ STYLE = """
75
  justify-content: start;
76
  align-items: center;
77
  }
78
-
79
  .tree li::before, .tree li::after {
80
  content: "";
81
  position: absolute;
@@ -96,7 +91,6 @@ STYLE = """
96
  .tree li:only-child::after, li:only-child::before {
97
  display: none;
98
  }
99
-
100
  .tree li:first-child::before, .tree li:last-child::after {
101
  border: 0 none;
102
  }
@@ -111,7 +105,6 @@ STYLE = """
111
  -webkit-border-radius: 5px 0 0 0;
112
  -moz-border-radius: 5px 0 0 0;
113
  }
114
-
115
  .tree ul ul::before {
116
  content: "";
117
  position: absolute;
@@ -124,7 +117,6 @@ STYLE = """
124
  .tree ul:has(> li:only-child)::before {
125
  width:40px;
126
  }
127
-
128
  a:before {
129
  border-right: 1px solid var(--body-text-color);
130
  border-bottom: 1px solid var(--body-text-color);
@@ -138,8 +130,6 @@ a:before {
138
  margin-left: 6px;
139
  transform: rotate(315deg);
140
  }
141
-
142
-
143
  .tree li a {
144
  border: 1px solid var(--body-text-color);
145
  padding: 5px;
@@ -155,7 +145,6 @@ a:before {
155
  .tree li a span {
156
  padding: 5px;
157
  font-size: 12px;
158
- text-transform: uppercase;
159
  letter-spacing: 1px;
160
  font-weight: 500;
161
  }
@@ -166,7 +155,7 @@ a:before {
166
  .tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
167
  border-color: #7c2d12;
168
  }
169
- .chosen {
170
  background-color: #ea580c;
171
  width:auto!important;
172
  }
@@ -206,33 +195,37 @@ def generate_markdown_table(
206
  def generate_nodes(token_ix, node, step):
207
  """Recursively generate HTML for the tree nodes."""
208
  token = tokenizer.decode([token_ix])
209
- html_content = f" <li> <a href='#' class='{('chosen' if node.table is None else '')}'> <span> <b>{token_ix}:<br>{clean(token)}</b> </span> "
 
 
 
 
 
 
210
  if node.table is not None:
211
  html_content += node.table
212
  html_content += "</a>"
 
213
  if len(node.children.keys()) > 0:
214
  html_content += "<ul> "
215
  for token_ix, subnode in node.children.items():
216
  html_content += generate_nodes(token_ix, subnode, step=step + 1)
217
  html_content += "</ul>"
218
  html_content += "</li>"
 
219
  return html_content
220
 
221
 
222
  def generate_html(start_sentence, original_tree):
223
-
224
  html_output = f"""<div class="custom-container">
225
  <div class="tree">
226
- <ul>
227
- <li> <a href='#' id='root'> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>"""
228
- if len(original_tree.children.keys()) > 0:
229
- html_output += "<ul> "
230
- for token_ix, subnode in original_tree.children.items():
231
- html_output += generate_nodes(token_ix, subnode, step=1)
232
- html_output += "</ul>"
233
-
234
  html_output += """
235
- </ul>
236
  </div>
237
  </body>
238
  """
@@ -246,11 +239,14 @@ from dataclasses import dataclass
246
 
247
  @dataclass
248
  class BeamNode:
 
249
  cumulative_score: float
250
  children_score_divider: float
251
  table: str
252
  current_sentence: str
253
  children: Dict[int, "BeamNode"]
 
 
254
 
255
 
256
  def generate_beams(start_sentence, scores, sequences, length_penalty):
@@ -258,13 +254,19 @@ def generate_beams(start_sentence, scores, sequences, length_penalty):
258
  input_length = len(tokenizer([start_sentence], return_tensors="pt"))
259
  original_tree = BeamNode(
260
  cumulative_score=0,
 
261
  table=None,
262
  current_sentence=start_sentence,
263
  children={},
264
  children_score_divider=((input_length + 1) ** length_penalty),
 
 
265
  )
266
  n_beams = len(scores[0])
267
  beam_trees = [original_tree] * n_beams
 
 
 
268
  for step, step_scores in enumerate(scores):
269
  (
270
  top_token_indexes,
@@ -273,8 +275,13 @@ def generate_beams(start_sentence, scores, sequences, length_penalty):
273
  current_completions,
274
  top_tokens,
275
  ) = ([], [], [], [], [])
276
- for beam_ix in range(n_beams):
277
  current_beam = beam_trees[beam_ix]
 
 
 
 
 
278
  # Get top cumulative scores for the current beam
279
  current_top_token_indexes = list(
280
  np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
@@ -337,14 +344,31 @@ def generate_beams(start_sentence, scores, sequences, length_penalty):
337
  + scores[step][source_beam_ix][current_token_choice_ix].numpy()
338
  )
339
  beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
 
340
  table=None,
341
  children={},
342
  current_sentence=beam_trees[source_beam_ix].current_sentence
343
  + current_token_choice,
344
  cumulative_score=cumulative_score,
 
 
345
  children_score_divider=((input_length + step + 1) ** length_penalty),
 
 
 
 
346
  )
347
 
 
 
 
 
 
 
 
 
 
 
348
  # Reassign all beams at once
349
  beam_trees = [
350
  beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
@@ -355,6 +379,7 @@ def generate_beams(start_sentence, scores, sequences, length_penalty):
355
  for beam_ix in range(n_beams):
356
  current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
357
  beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
 
358
 
359
  return original_tree
360
 
@@ -373,9 +398,10 @@ def get_beam_search_html(input_text, number_steps, number_beams, length_penalty)
373
  do_sample=False,
374
  )
375
  markdown = "Output sequences:"
 
376
  decoded_sequences = tokenizer.batch_decode(outputs.sequences)
377
  for i, sequence in enumerate(decoded_sequences):
378
- markdown += f"\n- {clean(sequence.replace('<s> ', ''))} (score {outputs.sequences_scores[i]:.2f})"
379
 
380
  original_tree = generate_beams(
381
  input_text,
@@ -393,7 +419,8 @@ with gr.Blocks(
393
  ),
394
  css=STYLE,
395
  ) as demo:
396
- gr.Markdown("""# Beam search visualizer
 
397
 
398
  Play with the parameters below to understand how beam search decoding works!
399
 
@@ -402,15 +429,29 @@ Play with the parameters below to understand how beam search decoding works!
402
  - **Number of steps**: the number of tokens to generate
403
  - **Number of beams**: the number of beams to use
404
  - **Length penalty**: the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
405
- """)
406
- text = gr.Textbox(label="Sentence to decode from", value="Conclusion: thanks a lot. This article was originally published on")
 
 
 
 
407
  with gr.Row():
408
- steps = gr.Slider(label="Number of steps", minimum=1, maximum=8, step=1, value=4)
409
- beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3)
410
- length_penalty = gr.Slider(label="Length penalty", minimum=-4, maximum=4, step=0.5, value=1)
 
 
 
 
 
 
411
  button = gr.Button()
412
  out_html = gr.Markdown()
413
  out_markdown = gr.Markdown()
414
- button.click(get_beam_search_html, inputs=[text, steps, beams, length_penalty], outputs=[out_html, out_markdown])
 
 
 
 
415
 
416
  demo.launch()
 
30
  .prose table {
31
  margin-bottom: 0!important;
32
  }
 
33
  .prose td, th {
34
  padding-left: 2px;
35
  padding-right: 2px;
36
  padding-top: 0;
37
  padding-bottom: 0;
38
  }
 
39
  .tree {
40
  padding: 0px;
41
  margin: 0!important;
 
46
  text-align: center;
47
  display:inline-block;
48
  }
 
49
  #root {
50
  display: inline-grid!important;
51
  width:auto!important;
52
  min-width: 220px;
53
  }
 
54
  .tree ul {
55
  padding-left: 20px;
56
  position: relative;
 
71
  justify-content: start;
72
  align-items: center;
73
  }
 
74
  .tree li::before, .tree li::after {
75
  content: "";
76
  position: absolute;
 
91
  .tree li:only-child::after, li:only-child::before {
92
  display: none;
93
  }
 
94
  .tree li:first-child::before, .tree li:last-child::after {
95
  border: 0 none;
96
  }
 
105
  -webkit-border-radius: 5px 0 0 0;
106
  -moz-border-radius: 5px 0 0 0;
107
  }
 
108
  .tree ul ul::before {
109
  content: "";
110
  position: absolute;
 
117
  .tree ul:has(> li:only-child)::before {
118
  width:40px;
119
  }
 
120
  a:before {
121
  border-right: 1px solid var(--body-text-color);
122
  border-bottom: 1px solid var(--body-text-color);
 
130
  margin-left: 6px;
131
  transform: rotate(315deg);
132
  }
 
 
133
  .tree li a {
134
  border: 1px solid var(--body-text-color);
135
  padding: 5px;
 
145
  .tree li a span {
146
  padding: 5px;
147
  font-size: 12px;
 
148
  letter-spacing: 1px;
149
  font-weight: 500;
150
  }
 
155
  .tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before {
156
  border-color: #7c2d12;
157
  }
158
+ .end-of-text, .chosen {
159
  background-color: #ea580c;
160
  width:auto!important;
161
  }
 
195
  def generate_nodes(token_ix, node, step):
196
  """Recursively generate HTML for the tree nodes."""
197
  token = tokenizer.decode([token_ix])
198
+
199
+ if node.is_final:
200
+ return f"<li> <a href='#' class='end-of-text'> <span> <b>{token_ix}:<br>{clean(token)}</b> <br> Total score: {node.total_score:.2f} </span> </a> </li>"
201
+
202
+ html_content = (
203
+ f"<li> <a href='#'> <span> <b>{token_ix}:<br>{clean(token)}</b> </span>"
204
+ )
205
  if node.table is not None:
206
  html_content += node.table
207
  html_content += "</a>"
208
+
209
  if len(node.children.keys()) > 0:
210
  html_content += "<ul> "
211
  for token_ix, subnode in node.children.items():
212
  html_content += generate_nodes(token_ix, subnode, step=step + 1)
213
  html_content += "</ul>"
214
  html_content += "</li>"
215
+
216
  return html_content
217
 
218
 
219
  def generate_html(start_sentence, original_tree):
 
220
  html_output = f"""<div class="custom-container">
221
  <div class="tree">
222
+ <ul> <li> <a href='#' id='root'> <span> <b>{start_sentence}</b> </span> {original_tree.table} </a>"""
223
+ html_output += "<ul> "
224
+ for token_ix, subnode in original_tree.children.items():
225
+ html_output += generate_nodes(token_ix, subnode, step=1)
226
+ html_output += "</ul>"
 
 
 
227
  html_output += """
228
+ </li> </ul>
229
  </div>
230
  </body>
231
  """
 
239
 
240
  @dataclass
241
  class BeamNode:
242
+ current_token_ix: int
243
  cumulative_score: float
244
  children_score_divider: float
245
  table: str
246
  current_sentence: str
247
  children: Dict[int, "BeamNode"]
248
+ total_score: float
249
+ is_final: bool
250
 
251
 
252
  def generate_beams(start_sentence, scores, sequences, length_penalty):
 
254
  input_length = len(tokenizer([start_sentence], return_tensors="pt"))
255
  original_tree = BeamNode(
256
  cumulative_score=0,
257
+ current_token_ix=None,
258
  table=None,
259
  current_sentence=start_sentence,
260
  children={},
261
  children_score_divider=((input_length + 1) ** length_penalty),
262
+ total_score=None,
263
+ is_final=False,
264
  )
265
  n_beams = len(scores[0])
266
  beam_trees = [original_tree] * n_beams
267
+
268
+ candidate_nodes = []
269
+
270
  for step, step_scores in enumerate(scores):
271
  (
272
  top_token_indexes,
 
275
  current_completions,
276
  top_tokens,
277
  ) = ([], [], [], [], [])
278
+ for beam_ix in range(n_beams): # Get possible descendants for each beam
279
  current_beam = beam_trees[beam_ix]
280
+
281
+ # skip if the beam is already final
282
+ if current_beam.is_final:
283
+ continue
284
+
285
  # Get top cumulative scores for the current beam
286
  current_top_token_indexes = list(
287
  np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1]
 
344
  + scores[step][source_beam_ix][current_token_choice_ix].numpy()
345
  )
346
  beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode(
347
+ current_token_ix=current_token_choice_ix,
348
  table=None,
349
  children={},
350
  current_sentence=beam_trees[source_beam_ix].current_sentence
351
  + current_token_choice,
352
  cumulative_score=cumulative_score,
353
+ total_score=cumulative_score
354
+ / ((input_length + step - 1) ** length_penalty),
355
  children_score_divider=((input_length + step + 1) ** length_penalty),
356
+ is_final=(
357
+ step == len(scores) - 1
358
+ or current_token_choice_ix == tokenizer.eos_token_id
359
+ ),
360
  )
361
 
362
+ # Check this child should be selected as a top beam.
363
+ # Is it a final step or an EOS token?
364
+ if (
365
+ step == len(scores) - 1
366
+ or current_token_choice_ix == tokenizer.eos_token_id
367
+ ):
368
+ candidate_nodes.append(
369
+ beam_trees[source_beam_ix].children[current_token_choice_ix]
370
+ )
371
+
372
  # Reassign all beams at once
373
  beam_trees = [
374
  beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])]
 
379
  for beam_ix in range(n_beams):
380
  current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"]
381
  beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix]
382
+ print("Final nodes", candidate_nodes)
383
 
384
  return original_tree
385
 
 
398
  do_sample=False,
399
  )
400
  markdown = "Output sequences:"
401
+ # Sequences are padded anyway so you can batch decode them
402
  decoded_sequences = tokenizer.batch_decode(outputs.sequences)
403
  for i, sequence in enumerate(decoded_sequences):
404
+ markdown += f"\n- '{clean(sequence.replace('<s> ', ''))}' (score {outputs.sequences_scores[i]:.2f})"
405
 
406
  original_tree = generate_beams(
407
  input_text,
 
419
  ),
420
  css=STYLE,
421
  ) as demo:
422
+ gr.Markdown(
423
+ """# Beam search visualizer
424
 
425
  Play with the parameters below to understand how beam search decoding works!
426
 
 
429
  - **Number of steps**: the number of tokens to generate
430
  - **Number of beams**: the number of beams to use
431
  - **Length penalty**: the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences.
432
+ """
433
+ )
434
+ text = gr.Textbox(
435
+ label="Sentence to decode from",
436
+ value="Conclusion: thanks a lot. This article was originally published on",
437
+ )
438
  with gr.Row():
439
+ steps = gr.Slider(
440
+ label="Number of steps", minimum=1, maximum=8, step=1, value=4
441
+ )
442
+ beams = gr.Slider(
443
+ label="Number of beams", minimum=2, maximum=4, step=1, value=3
444
+ )
445
+ length_penalty = gr.Slider(
446
+ label="Length penalty", minimum=-4, maximum=4, step=0.5, value=1
447
+ )
448
  button = gr.Button()
449
  out_html = gr.Markdown()
450
  out_markdown = gr.Markdown()
451
+ button.click(
452
+ get_beam_search_html,
453
+ inputs=[text, steps, beams, length_penalty],
454
+ outputs=[out_html, out_markdown],
455
+ )
456
 
457
  demo.launch()