|
import argparse |
|
import logging |
|
import time |
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline |
|
|
|
from utils import make_mailto_form, postprocess, clear, make_email_link |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
use_gpu = torch.cuda.is_available() |
|
|
|
|
|
def generate_text( |
|
prompt: str, |
|
gen_length=64, |
|
num_beams=4, |
|
no_repeat_ngram_size=2, |
|
length_penalty=1.0, |
|
num_beam_groups=2, |
|
|
|
repetition_penalty=3.5, |
|
abs_max_length=512, |
|
verbose=False, |
|
): |
|
""" |
|
generate_text - generate text from a prompt using a text generation pipeline |
|
|
|
Args: |
|
prompt (str): the prompt to generate text from |
|
model_input (_type_): the text generation pipeline |
|
max_length (int, optional): the maximum length of the generated text. Defaults to 128. |
|
method (str, optional): the generation method. Defaults to "Sampling". |
|
verbose (bool, optional): the verbosity of the output. Defaults to False. |
|
|
|
Returns: |
|
str: the generated text |
|
""" |
|
global generator |
|
if verbose: |
|
logging.info(f"Generating text from prompt:\n\n{prompt}") |
|
logging.info( |
|
f"params:\tmax_length={gen_length}, num_beams={num_beams}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}, repetition_penalty={repetition_penalty}, abs_max_length={abs_max_length}" |
|
) |
|
st = time.perf_counter() |
|
|
|
input_tokens = generator.tokenizer(prompt) |
|
input_len = len(input_tokens["input_ids"]) |
|
if input_len > abs_max_length: |
|
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors") |
|
result = generator( |
|
prompt, |
|
max_length=gen_length + input_len, |
|
min_length=input_len + 4, |
|
num_beams=num_beams, |
|
num_beam_groups=num_beam_groups, |
|
repetition_penalty=repetition_penalty, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
length_penalty=length_penalty, |
|
do_sample=False, |
|
early_stopping=True, |
|
) |
|
response = result[0]["generated_text"] |
|
rt = time.perf_counter() - st |
|
if verbose: |
|
logging.info(f"Generated text: {response}") |
|
rt_string = f"Generation time: {rt:.2f}s" |
|
logging.info(rt_string) |
|
|
|
formatted_email = postprocess(response) |
|
return make_mailto_form(body=formatted_email), formatted_email |
|
|
|
|
|
def load_emailgen_model(model_tag: str): |
|
""" |
|
load_emailgen_model - load a text generation pipeline for email generation |
|
|
|
Args: |
|
model_tag (str): the huggingface model tag to load |
|
|
|
Returns: |
|
transformers.pipelines.TextGenerationPipeline: the text generation pipeline |
|
""" |
|
global generator |
|
generator = pipeline( |
|
"text-generation", |
|
model_tag, |
|
device=0 if use_gpu else -1, |
|
) |
|
|
|
|
|
def get_parser(): |
|
""" |
|
get_parser - a helper function for the argparse module |
|
""" |
|
parser = argparse.ArgumentParser( |
|
description="Text Generation demo for postbot", |
|
) |
|
|
|
parser.add_argument( |
|
"-m", |
|
"--model", |
|
required=False, |
|
type=str, |
|
default="postbot/distilgpt2-emailgen-V2", |
|
help="Pass an different huggingface model tag to use a custom model", |
|
) |
|
|
|
parser.add_argument( |
|
"-v", |
|
"--verbose", |
|
required=False, |
|
action="store_true", |
|
help="Verbose output", |
|
) |
|
|
|
parser.add_argument( |
|
"-nb", |
|
"--num_beams", |
|
type=int, |
|
default=4, |
|
help="Number of beams for beam search. 1 means no beam search.", |
|
) |
|
|
|
parser.add_argument( |
|
"--num_beam_groups", |
|
type=int, |
|
default=2, |
|
help="Number of groups to divide nbest candidates into in order to ensure diversity among different groups of beams that yield the best n results. 1 means no group beam search.", |
|
) |
|
return parser |
|
|
|
|
|
default_prompt = """ |
|
Hello, |
|
|
|
Following up on last week's bubblegum shipment, I""" |
|
|
|
available_models = [ |
|
"postbot/distilgpt2-emailgen-V2", |
|
"postbot/distilgpt2-emailgen", |
|
"postbot/gpt2-medium-emailgen", |
|
] |
|
|
|
if __name__ == "__main__": |
|
logging.info("\n\n\nStarting new instance of app.py") |
|
args = get_parser().parse_args() |
|
logging.info(f"received args:\t{args}") |
|
model_tag = args.model |
|
verbose = args.verbose |
|
logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}") |
|
generator = pipeline( |
|
"text-generation", |
|
model_tag, |
|
device=0 if use_gpu else -1, |
|
) |
|
|
|
demo = gr.Blocks() |
|
|
|
logging.info("launching interface...") |
|
|
|
with demo: |
|
gr.Markdown("# Auto-Complete Emails - Demo") |
|
gr.Markdown( |
|
"Enter part of an email, and a text-gen model will complete it! See details below. " |
|
) |
|
gr.Markdown("---") |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown("## Generate Text") |
|
gr.Markdown("Edit the prompt and parameters and press **Generate**!") |
|
prompt_text = gr.Textbox( |
|
lines=4, |
|
label="Email Prompt", |
|
value=default_prompt, |
|
) |
|
|
|
with gr.Row(): |
|
clear_button = gr.Button( |
|
value="Clear Prompt", |
|
) |
|
num_gen_tokens = gr.Slider( |
|
label="Generation Tokens", |
|
value=48, |
|
maximum=96, |
|
minimum=32, |
|
step=8, |
|
) |
|
|
|
generate_button = gr.Button( |
|
value="Generate!", |
|
variant="primary", |
|
) |
|
gr.Markdown("---") |
|
gr.Markdown("### Results") |
|
|
|
generated_email = gr.Textbox( |
|
label="Generated Text", |
|
placeholder="This is where the generated text will appear", |
|
interactive=False, |
|
) |
|
email_mailto_button = gr.HTML( |
|
"<i>a clickable email button will appear here</i>" |
|
) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown("## Advanced Options") |
|
gr.Markdown( |
|
"This demo generates text via beam search. See details about these parameters [here](https://huggingface.co/blog/how-to-generate), otherwise they should be fine as-is." |
|
) |
|
with gr.Row(): |
|
model_name = gr.Dropdown( |
|
choices=available_models, |
|
label="Choose a model", |
|
value=model_tag, |
|
) |
|
load_model_button = gr.Button( |
|
"Load Model", |
|
variant="secondary", |
|
) |
|
no_repeat_ngram_size = gr.Radio( |
|
choices=[1, 2, 3, 4], |
|
label="no repeat ngram size", |
|
value=2, |
|
) |
|
with gr.Row(): |
|
num_beams = gr.Radio( |
|
choices=[2, 4, 8], |
|
label="Number of Beams", |
|
value=4, |
|
) |
|
|
|
num_beam_groups = gr.Radio( |
|
choices=[1, 2], |
|
label="Number of Beam Groups", |
|
value=1, |
|
) |
|
length_penalty = gr.Slider( |
|
minimum=0.5, |
|
maximum=1.0, |
|
label="length penalty", |
|
value=0.8, |
|
step=0.1, |
|
) |
|
gr.Markdown("---") |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown("## About") |
|
gr.Markdown( |
|
"[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage." |
|
) |
|
gr.Markdown( |
|
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something." |
|
) |
|
gr.Markdown("---") |
|
|
|
clear_button.click( |
|
fn=clear, |
|
inputs=[prompt_text], |
|
outputs=[prompt_text], |
|
) |
|
generate_button.click( |
|
fn=generate_text, |
|
inputs=[ |
|
prompt_text, |
|
num_gen_tokens, |
|
num_beams, |
|
no_repeat_ngram_size, |
|
length_penalty, |
|
num_beam_groups, |
|
], |
|
outputs=[email_mailto_button, generated_email], |
|
) |
|
|
|
load_model_button.click( |
|
fn=load_emailgen_model, |
|
inputs=[model_name], |
|
outputs=[], |
|
) |
|
demo.launch( |
|
enable_queue=True, |
|
share=True, |
|
) |
|
|