|
import torch |
|
import spaces |
|
print(f"Is CUDA available: {torch.cuda.is_available()}") |
|
|
|
if torch.cuda.is_available(): |
|
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") |
|
|
|
STYLE = """ |
|
.container { |
|
width: 100%; |
|
display: grid; |
|
align-items: center; |
|
margin: 0!important; |
|
} |
|
.prose ul ul { |
|
margin: 0!important; |
|
} |
|
.tree { |
|
padding: 0px; |
|
margin: 0!important; |
|
box-sizing: border-box; |
|
font-size: 16px; |
|
width: 100%; |
|
height: auto; |
|
text-align: center; |
|
} |
|
.tree ul { |
|
padding-top: 20px; |
|
position: relative; |
|
transition: .5s; |
|
margin: 0!important; |
|
} |
|
.tree li { |
|
display: inline-table; |
|
text-align: center; |
|
list-style-type: none; |
|
position: relative; |
|
padding: 10px; |
|
transition: .5s; |
|
} |
|
.tree li::before, .tree li::after { |
|
content: ''; |
|
position: absolute; |
|
top: 0; |
|
right: 50%; |
|
border-top: 1px solid #ccc; |
|
width: 51%; |
|
height: 10px; |
|
} |
|
.tree li::after { |
|
right: auto; |
|
left: 50%; |
|
border-left: 1px solid #ccc; |
|
} |
|
.tree li:only-child::after, .tree li:only-child::before { |
|
display: none; |
|
} |
|
.tree li:only-child { |
|
padding-top: 0; |
|
} |
|
.tree li:first-child::before, .tree li:last-child::after { |
|
border: 0 none; |
|
} |
|
.tree li:last-child::before { |
|
border-right: 1px solid #ccc; |
|
border-radius: 0 5px 0 0; |
|
-webkit-border-radius: 0 5px 0 0; |
|
-moz-border-radius: 0 5px 0 0; |
|
} |
|
.tree li:first-child::after { |
|
border-radius: 5px 0 0 0; |
|
-webkit-border-radius: 5px 0 0 0; |
|
-moz-border-radius: 5px 0 0 0; |
|
} |
|
.tree ul ul::before { |
|
content: ''; |
|
position: absolute; |
|
top: 0; |
|
left: 50%; |
|
border-left: 1px solid #ccc; |
|
width: 0; |
|
height: 20px; |
|
} |
|
.tree li a { |
|
border: 1px solid #ccc; |
|
padding: 10px; |
|
display: inline-grid; |
|
border-radius: 5px; |
|
text-decoration-line: none; |
|
border-radius: 5px; |
|
transition: .5s; |
|
} |
|
.tree li a span { |
|
border: 1px solid #ccc; |
|
border-radius: 5px; |
|
color: #666; |
|
padding: 8px; |
|
font-size: 12px; |
|
text-transform: uppercase; |
|
letter-spacing: 1px; |
|
font-weight: 500; |
|
} |
|
/*Hover-Section*/ |
|
.tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a { |
|
background: #c8e4f8; |
|
color: #000; |
|
border: 1px solid #94a0b4; |
|
} |
|
.tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before { |
|
border-color: #94a0b4; |
|
} |
|
""" |
|
|
|
from transformers import GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer |
|
import numpy as np |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
model = AutoModelForCausalLM.from_pretrained("gpt2") |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
print("Loading finished.") |
|
def generate_html(token, node): |
|
"""Recursively generate HTML for the tree.""" |
|
|
|
html_content = f" <li> <a href='#'> <span> <b>{token}</b> </span> " |
|
html_content += node["table"] if node["table"] is not None else "" |
|
html_content += "</a>" |
|
if len(node["children"].keys()) > 0: |
|
html_content += "<ul> " |
|
for token, subnode in node["children"].items(): |
|
html_content += generate_html(token, subnode) |
|
html_content += "</ul>" |
|
|
|
html_content += "</li>" |
|
|
|
return html_content |
|
|
|
|
|
def generate_markdown_table(scores, top_k=4, chosen_tokens=None): |
|
markdown_table = """ |
|
<table> |
|
<tr> |
|
<th><b>Token</b></th> |
|
<th><b>Probability</b></th> |
|
</tr>""" |
|
for token_idx in np.argsort(scores)[-top_k:]: |
|
token = tokenizer.decode([token_idx]) |
|
style = "" |
|
if chosen_tokens and token in chosen_tokens: |
|
style = "background-color:red" |
|
markdown_table += f""" |
|
<tr style={style}> |
|
<td>{token}</td> |
|
<td>{scores[token_idx]}</td> |
|
</tr>""" |
|
markdown_table += """ |
|
</table>""" |
|
return markdown_table |
|
|
|
|
|
def display_tree(start_sentence, scores, sequences, beam_indices): |
|
display = """<div class="container"> |
|
<div class="tree"> |
|
<ul>""" |
|
sequences = sequences.cpu().numpy() |
|
print(tokenizer.batch_decode(sequences)) |
|
original_tree = {"table": None, "children": {}} |
|
for sequence_ix in range(len(sequences)): |
|
current_tree = original_tree |
|
for step, step_scores in enumerate(scores): |
|
current_token_choice = tokenizer.decode([sequences[sequence_ix, step]]) |
|
current_beam = beam_indices[sequence_ix, step] |
|
|
|
if current_token_choice not in current_tree["children"]: |
|
current_tree["children"][current_token_choice] = { |
|
"table": None, |
|
"children": {}, |
|
} |
|
|
|
|
|
markdown_table = generate_markdown_table( |
|
step_scores[current_beam, :], |
|
chosen_tokens=current_tree["children"].keys(), |
|
) |
|
current_tree["table"] = markdown_table |
|
|
|
current_tree = current_tree["children"][current_token_choice] |
|
|
|
display += generate_html(start_sentence, original_tree) |
|
|
|
display += """ |
|
</ul> |
|
</div> |
|
</body> |
|
""" |
|
return display |
|
|
|
@spaces.GPU |
|
def get_tables(input_text, number_steps, number_beams): |
|
inputs = tokenizer([input_text], return_tensors="pt") |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=number_steps, |
|
num_beams=number_beams, |
|
num_return_sequences=number_beams, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
top_k=5, |
|
temperature=1.0, |
|
do_sample=True, |
|
) |
|
|
|
tables = display_tree( |
|
input_text, |
|
outputs.scores, |
|
outputs.sequences[:, len(inputs) :], |
|
outputs.beam_indices[:, : -len(inputs)], |
|
) |
|
return tables |
|
|
|
import gradio as gr |
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Soft( |
|
text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green |
|
), |
|
css=STYLE, |
|
) as demo: |
|
text = gr.Textbox(label="Sentence to decode from🪶", value="Today is") |
|
steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4) |
|
beams = gr.Slider(label="Number of beams", minimum=1, maximum=3, step=1, value=3) |
|
button = gr.Button() |
|
out = gr.Markdown(label="Output") |
|
button.click(get_tables, inputs=[text, steps, beams], outputs=out) |
|
|
|
demo.launch() |