|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import numpy as np |
|
import gradio as gr |
|
import spaces |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
|
|
print("Loading finished.") |
|
|
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
|
|
if torch.cuda.is_available(): |
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
|
|
STYLE = """ |
|
.custom-container { |
|
display: grid; |
|
align-items: center; |
|
margin: 0!important; |
|
overflow-y: hidden; |
|
} |
|
.prose ul ul { |
|
font-size: 10px!important; |
|
} |
|
.prose li { |
|
margin-bottom: 0!important; |
|
} |
|
.prose table { |
|
margin-bottom: 0!important; |
|
} |
|
.prose td, th { |
|
padding-left: 2px; |
|
padding-right: 2px; |
|
padding-top: 0; |
|
padding-bottom: 0; |
|
text-wrap:nowrap; |
|
} |
|
.tree { |
|
padding: 0px; |
|
margin: 0!important; |
|
box-sizing: border-box; |
|
font-size: 10px; |
|
width: 100%; |
|
height: auto; |
|
text-align: center; |
|
display:inline-block; |
|
} |
|
#root { |
|
display: inline-grid!important; |
|
width:auto!important; |
|
min-width: 220px; |
|
} |
|
.tree ul { |
|
padding-left: 20px; |
|
position: relative; |
|
transition: all 0.5s ease 0s; |
|
display: flex; |
|
flex-direction: column; |
|
gap: 10px; |
|
margin: 0px !important; |
|
} |
|
.tree li { |
|
display: flex; |
|
text-align: center; |
|
list-style-type: none; |
|
position: relative; |
|
padding-left: 20px; |
|
transition: all 0.5s ease 0s; |
|
flex-direction: row; |
|
justify-content: start; |
|
align-items: center; |
|
} |
|
.tree li::before, .tree li::after { |
|
content: ""; |
|
position: absolute; |
|
left: 0px; |
|
border-left: 1px solid var(--body-text-color); |
|
width: 20px; |
|
} |
|
.tree li::before { |
|
top: 0; |
|
height:50%; |
|
} |
|
.tree li::after { |
|
top: 50%; |
|
height: 55%; |
|
bottom: auto; |
|
border-top: 1px solid var(--body-text-color); |
|
} |
|
.tree li:only-child::after, li:only-child::before { |
|
display: none; |
|
} |
|
.tree li:first-child::before, .tree li:last-child::after { |
|
border: 0 none; |
|
} |
|
.tree li:last-child::before { |
|
border-bottom: 1px solid var(--body-text-color); |
|
border-radius: 0px 0px 0px 5px; |
|
-webkit-border-radius: 0px 0px 0px 5px; |
|
-moz-border-radius: 0px 0px 0px 5px; |
|
} |
|
.tree li:first-child::after { |
|
border-radius: 5px 0 0 0; |
|
-webkit-border-radius: 5px 0 0 0; |
|
-moz-border-radius: 5px 0 0 0; |
|
} |
|
.tree ul ul::before { |
|
content: ""; |
|
position: absolute; |
|
left: 0; |
|
top: 50%; |
|
border-top: 1px solid var(--body-text-color); |
|
width: 20px; |
|
height: 0; |
|
} |
|
.tree ul:has(> li:only-child)::before { |
|
width:40px; |
|
} |
|
.child:before { |
|
border-right: 2px solid var(--body-text-color); |
|
border-bottom: 2px solid var(--body-text-color); |
|
content: ""; |
|
position: absolute; |
|
width: 10px; |
|
left: 8px; |
|
height: 10px; |
|
top: 50%; |
|
margin-top: -5px; |
|
transform: rotate(315deg); |
|
} |
|
.box { |
|
border: 1px solid var(--body-text-color); |
|
padding: 5px; |
|
border-radius: 5px; |
|
text-decoration-line: none; |
|
border-radius: 5px; |
|
transition: .5s; |
|
display: flex; |
|
align-items: center; |
|
justify-content: space-between; |
|
overflow: hidden; |
|
cursor: pointer; |
|
} |
|
.box span { |
|
padding: 5px; |
|
font-size: 12px; |
|
letter-spacing: 1px; |
|
font-weight: 500; |
|
} |
|
/*Hover-Section*/ |
|
.box:hover, .box:hover+ul li .box { |
|
background: var(--primary-500); |
|
} |
|
.box:hover+ul li::after, .box:hover+ul li::before, .box:hover+ul::before, .box:hover+ul ul::before, .box:hover+ul .box::before { |
|
border-color: var(--primary-500); |
|
} |
|
.chosen-token { |
|
background-color: var(--primary-400); |
|
} |
|
.chosen-token td, .chosen-token tr { |
|
color: black!important; |
|
} |
|
.end-of-text { |
|
width:auto!important; |
|
} |
|
.nonfinal { |
|
width:280px; |
|
min-width: 280px; |
|
} |
|
.selected-sequence { |
|
background-color: var(--secondary-500); |
|
} |
|
.nonselected-sequence { |
|
background-color: var(--primary-500); |
|
} |
|
.nomargin { |
|
padding-left: 0!important; |
|
} |
|
""" |
|
|
|
|
|
def clean(s): |
|
return s.replace("\n", r"\n").replace("\t", r"\t").strip() |
|
|
|
|
|
def generate_markdown_table( |
|
scores, previous_cumul_score, score_divider, top_k=4, chosen_tokens=None |
|
): |
|
markdown_table = """ |
|
<table> |
|
<tr> |
|
<th><b>Token</b></th> |
|
<th><b>Step score</b></th> |
|
<th><b>Total score</b></th> |
|
</tr>""" |
|
for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]: |
|
token = tokenizer.decode([token_idx]) |
|
item_class = "" |
|
if chosen_tokens and token in chosen_tokens: |
|
item_class = "chosen-token" |
|
markdown_table += f""" |
|
<tr class={item_class}> |
|
<td>{clean(token)}</td> |
|
<td>{scores[token_idx]:.4f}</td> |
|
<td>{(scores[token_idx] + previous_cumul_score)/score_divider:.4f}</td> |
|
</tr>""" |
|
markdown_table += """ |
|
</table>""" |
|
return markdown_table |
|
|
|
|
|
def generate_nodes(node, step): |
|
"""Recursively generate HTML for the tree nodes.""" |
|
token = tokenizer.decode([node.current_token_ix]) |
|
|
|
if node.is_final: |
|
if node.is_selected_sequence: |
|
selected_class = "selected-sequence" |
|
else: |
|
selected_class = "nonselected-sequence" |
|
return f"<li> <div class='box end-of-text child {selected_class}'> <span> <b>{clean(token)}</b> <br>Total score: {node.total_score:.2f}</span> </div> </li>" |
|
|
|
html_content = ( |
|
f"<li> <div class='box nonfinal child'> <span> <b>{clean(token)}</b> </span>" |
|
) |
|
if node.table is not None: |
|
html_content += node.table |
|
html_content += "</div>" |
|
|
|
if len(node.children.keys()) > 0: |
|
html_content += "<ul> " |
|
for token_ix, subnode in node.children.items(): |
|
html_content += generate_nodes(subnode, step=step + 1) |
|
html_content += "</ul>" |
|
html_content += "</li>" |
|
|
|
return html_content |
|
|
|
|
|
def generate_html(start_sentence, original_tree): |
|
html_output = f"""<div class="custom-container"> |
|
<div class="tree"> <ul class="nomargin"><li class="nomargin"> |
|
<div class="box" id='root'> <span> <b>{start_sentence}</b> </span> {original_tree.table} </div>""" |
|
html_output += "<ul> " |
|
for subnode in original_tree.children.values(): |
|
html_output += generate_nodes(subnode, step=1) |
|
html_output += "</ul>" |
|
html_output += """ |
|
</li></ul></div> |
|
</div> |
|
""" |
|
return html_output |
|
|
|
|
|
import pandas as pd |
|
from typing import Dict |
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class BeamNode: |
|
current_token_ix: int |
|
cumulative_score: float |
|
children_score_divider: float |
|
table: str |
|
current_sequence: str |
|
children: Dict[int, "BeamNode"] |
|
total_score: float |
|
is_final: bool |
|
is_selected_sequence: bool |
|
|
|
|
|
def generate_beams(start_sentence, scores, length_penalty, decoded_sequences): |
|
input_length = len(tokenizer([start_sentence], return_tensors="pt")) |
|
original_tree = BeamNode( |
|
cumulative_score=0, |
|
current_token_ix=None, |
|
table=None, |
|
current_sequence=start_sentence, |
|
children={}, |
|
children_score_divider=((input_length + 1) ** length_penalty), |
|
total_score=None, |
|
is_final=False, |
|
is_selected_sequence=False, |
|
) |
|
n_beams = len(scores[0]) |
|
beam_trees = [original_tree] * n_beams |
|
|
|
for step, step_scores in enumerate(scores): |
|
( |
|
top_token_indexes, |
|
top_cumulative_scores, |
|
beam_indexes, |
|
current_sequence, |
|
top_tokens, |
|
) = ([], [], [], [], []) |
|
for beam_ix in range(n_beams): |
|
current_beam = beam_trees[beam_ix] |
|
|
|
|
|
if current_beam.is_final: |
|
continue |
|
|
|
|
|
current_top_token_indexes = list( |
|
np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1] |
|
) |
|
top_token_indexes += current_top_token_indexes |
|
top_cumulative_scores += list( |
|
np.array(scores[step][beam_ix][current_top_token_indexes]) |
|
+ current_beam.cumulative_score |
|
) |
|
beam_indexes += [beam_ix] * n_beams |
|
current_sequence += [beam_trees[beam_ix].current_sequence] * n_beams |
|
top_tokens += [tokenizer.decode([el]) for el in current_top_token_indexes] |
|
|
|
|
|
top_df = pd.DataFrame.from_dict( |
|
{ |
|
"token_index": top_token_indexes, |
|
"cumulative_score": top_cumulative_scores, |
|
"beam_index": beam_indexes, |
|
"current_sequence": current_sequence, |
|
"token": top_tokens, |
|
} |
|
) |
|
maxes = top_df.groupby(["token_index", "current_sequence"])[ |
|
"cumulative_score" |
|
].idxmax() |
|
|
|
top_df = top_df.loc[maxes] |
|
|
|
|
|
top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[ |
|
:n_beams |
|
] |
|
|
|
|
|
for beam_ix in reversed(list(range(n_beams))): |
|
current_beam = beam_trees[beam_ix] |
|
if current_beam.table is None: |
|
selected_tokens = top_df_selected.loc[ |
|
top_df_selected["current_sequence"] == current_beam.current_sequence |
|
] |
|
markdown_table = generate_markdown_table( |
|
step_scores[beam_ix, :], |
|
current_beam.cumulative_score, |
|
current_beam.children_score_divider, |
|
chosen_tokens=list(selected_tokens["token"].values), |
|
) |
|
beam_trees[beam_ix].table = markdown_table |
|
|
|
|
|
cumulative_scores = [beam.cumulative_score for beam in beam_trees] |
|
for beam_ix in range(n_beams): |
|
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] |
|
current_token_choice = tokenizer.decode([current_token_choice_ix]) |
|
|
|
|
|
source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"]) |
|
|
|
cumulative_score = ( |
|
cumulative_scores[source_beam_ix] |
|
+ scores[step][source_beam_ix][current_token_choice_ix].numpy() |
|
) |
|
current_sequence = ( |
|
beam_trees[source_beam_ix].current_sequence + current_token_choice |
|
) |
|
beam_trees[source_beam_ix].children[current_token_choice_ix] = BeamNode( |
|
current_token_ix=current_token_choice_ix, |
|
table=None, |
|
children={}, |
|
current_sequence=current_sequence, |
|
cumulative_score=cumulative_score, |
|
total_score=cumulative_score |
|
/ ((input_length + step - 1) ** length_penalty), |
|
children_score_divider=((input_length + step) ** length_penalty), |
|
is_final=( |
|
step == len(scores) - 1 |
|
or current_token_choice_ix == tokenizer.eos_token_id |
|
), |
|
is_selected_sequence=( |
|
current_sequence.replace("<|endoftext|>", "") |
|
in [el.replace("<|endoftext|>", "") for el in decoded_sequences] |
|
), |
|
) |
|
|
|
|
|
beam_trees = [ |
|
beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])] |
|
for beam_ix in range(n_beams) |
|
] |
|
|
|
|
|
for beam_ix in range(n_beams): |
|
current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] |
|
beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice_ix] |
|
|
|
return original_tree |
|
|
|
|
|
@spaces.GPU |
|
def get_beam_search_html( |
|
input_text, number_steps, number_beams, length_penalty, num_return_sequences |
|
): |
|
inputs = tokenizer([input_text], return_tensors="pt") |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=number_steps, |
|
num_beams=number_beams, |
|
num_return_sequences=num_return_sequences, |
|
return_dict_in_generate=True, |
|
length_penalty=length_penalty, |
|
output_scores=True, |
|
do_sample=False, |
|
) |
|
markdown = "The conclusive sequences are the ones that end in an `<|endoftext|>` token or at the end of generation." |
|
markdown += "\n\nThey are ranked by their scores, as given by the formula `score = cumulative_score / (output_length ** length_penalty)`.\n\n" |
|
markdown += "Only the top `num_beams` scoring sequences are returned: in the tree they are highlighted in **<span style='color:var(--secondary-500)!important'>blue</span>**." |
|
markdown += " The non-selected sequences are also shown in the tree, highlighted in **<span style='color:var(--primary-500)!important'>yellow</span>**." |
|
markdown += "\n#### <span style='color:var(--secondary-500)!important'>Output sequences:</span>" |
|
|
|
decoded_sequences = tokenizer.batch_decode(outputs.sequences) |
|
for i, sequence in enumerate(decoded_sequences): |
|
markdown += f"\n- Score `{outputs.sequences_scores[i]:.2f}`: `{clean(sequence.replace('<s> ', ''))}`" |
|
|
|
original_tree = generate_beams( |
|
input_text, |
|
outputs.scores[:], |
|
length_penalty, |
|
decoded_sequences, |
|
) |
|
html = generate_html(input_text, original_tree) |
|
return html, markdown |
|
|
|
|
|
def change_num_return_sequences(n_beams): |
|
return gr.Slider( |
|
label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=n_beams |
|
) |
|
|
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft( |
|
primary_hue=gr.themes.colors.yellow, |
|
secondary_hue=gr.themes.colors.blue, |
|
), |
|
css=STYLE, |
|
) as demo: |
|
gr.Markdown( |
|
"""# <span style='color:var(--primary-500)!important'>Beam Search Visualizer</span> |
|
|
|
Play with the parameters below to understand how beam search decoding works! |
|
|
|
#### <span style='color:var(--primary-500)!important'>Parameters:</span> |
|
- **Sentence to decode from** (`inputs`): the input sequence to your decoder. |
|
- **Number of steps** (`max_new_tokens`): the number of tokens to generate. |
|
- **Number of beams** (`num_beams`): the number of beams to use. |
|
- **Length penalty** (`length_penalty`): the length penalty to apply to outputs. `length_penalty` > 0.0 promotes longer sequences, while `length_penalty` < 0.0 encourages shorter sequences. |
|
This parameter will not impact the beam search paths, but only influence the choice of sequences in the end towards longer or shorter sequences. |
|
- **Number of return sequences** (`num_return_sequences`): the number of sequences to be returned at the end of generation. Should be `<= num_beams`. |
|
""" |
|
) |
|
text = gr.Textbox( |
|
label="Sentence to decode from", |
|
value="Conclusion: thanks a lot. This article was originally published on", |
|
) |
|
with gr.Row(): |
|
n_steps = gr.Slider( |
|
label="Number of steps", minimum=1, maximum=10, step=1, value=4 |
|
) |
|
n_beams = gr.Slider( |
|
label="Number of beams", minimum=2, maximum=4, step=1, value=3 |
|
) |
|
length_penalty = gr.Slider( |
|
label="Length penalty", minimum=-3, maximum=3, step=0.5, value=1 |
|
) |
|
num_return_sequences = gr.Slider( |
|
label="Number of return sequences", minimum=1, maximum=3, step=1, value=2 |
|
) |
|
|
|
n_beams.change( |
|
fn=change_num_return_sequences, inputs=n_beams, outputs=num_return_sequences |
|
) |
|
button = gr.Button() |
|
out_html = gr.Markdown() |
|
out_markdown = gr.Markdown() |
|
button.click( |
|
get_beam_search_html, |
|
inputs=[text, n_steps, n_beams, length_penalty, num_return_sequences], |
|
outputs=[out_html, out_markdown], |
|
) |
|
|
|
demo.launch() |