Spaces:
Runtime error
Runtime error
| import argparse | |
| import logging | |
| import time | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline | |
| from utils import make_mailto_form, postprocess, clear, make_email_link | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| use_gpu = torch.cuda.is_available() | |
| def generate_text( | |
| prompt: str, | |
| gen_length=64, | |
| num_beams=4, | |
| no_repeat_ngram_size=2, | |
| length_penalty=1.0, | |
| num_beam_groups=2, | |
| # perma params (not set by user) | |
| repetition_penalty=3.5, | |
| abs_max_length=512, | |
| verbose=False, | |
| ): | |
| """ | |
| generate_text - generate text from a prompt using a text generation pipeline | |
| Args: | |
| prompt (str): the prompt to generate text from | |
| model_input (_type_): the text generation pipeline | |
| max_length (int, optional): the maximum length of the generated text. Defaults to 128. | |
| method (str, optional): the generation method. Defaults to "Sampling". | |
| verbose (bool, optional): the verbosity of the output. Defaults to False. | |
| Returns: | |
| str: the generated text | |
| """ | |
| global generator | |
| if verbose: | |
| logging.info(f"Generating text from prompt:\n\n{prompt}") | |
| logging.info( | |
| f"params:\tmax_length={gen_length}, num_beams={num_beams}, no_repeat_ngram_size={no_repeat_ngram_size}, length_penalty={length_penalty}, repetition_penalty={repetition_penalty}, abs_max_length={abs_max_length}" | |
| ) | |
| st = time.perf_counter() | |
| input_tokens = generator.tokenizer(prompt) | |
| input_len = len(input_tokens["input_ids"]) | |
| if input_len > abs_max_length: | |
| logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors") | |
| result = generator( | |
| prompt, | |
| max_length=gen_length + input_len, | |
| min_length=input_len + 4, | |
| num_beams=num_beams, | |
| num_beam_groups=num_beam_groups, | |
| repetition_penalty=repetition_penalty, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| length_penalty=length_penalty, | |
| do_sample=False, | |
| early_stopping=True, | |
| ) # generate | |
| response = result[0]["generated_text"] | |
| rt = time.perf_counter() - st | |
| if verbose: | |
| logging.info(f"Generated text: {response}") | |
| rt_string = f"Generation time: {rt:.2f}s" | |
| logging.info(rt_string) | |
| formatted_email = postprocess(response) | |
| return make_mailto_form(body=formatted_email), formatted_email | |
| def load_emailgen_model(model_tag: str): | |
| """ | |
| load_emailgen_model - load a text generation pipeline for email generation | |
| Args: | |
| model_tag (str): the huggingface model tag to load | |
| Returns: | |
| transformers.pipelines.TextGenerationPipeline: the text generation pipeline | |
| """ | |
| global generator | |
| generator = pipeline( | |
| "text-generation", | |
| model_tag, | |
| device=0 if use_gpu else -1, | |
| ) | |
| def get_parser(): | |
| """ | |
| get_parser - a helper function for the argparse module | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="Text Generation demo for postbot", | |
| ) | |
| parser.add_argument( | |
| "-m", | |
| "--model", | |
| required=False, | |
| type=str, | |
| default="postbot/distilgpt2-emailgen-V2", | |
| help="Pass an different huggingface model tag to use a custom model", | |
| ) | |
| parser.add_argument( | |
| "-v", | |
| "--verbose", | |
| required=False, | |
| action="store_true", | |
| help="Verbose output", | |
| ) | |
| parser.add_argument( | |
| "-nb", | |
| "--num_beams", | |
| type=int, | |
| default=4, | |
| help="Number of beams for beam search. 1 means no beam search.", | |
| ) | |
| parser.add_argument( | |
| "--num_beam_groups", | |
| type=int, | |
| default=2, | |
| help="Number of groups to divide nbest candidates into in order to ensure diversity among different groups of beams that yield the best n results. 1 means no group beam search.", | |
| ) | |
| return parser | |
| default_prompt = """ | |
| Hello, | |
| Following up on last week's bubblegum shipment, I""" | |
| available_models = [ | |
| "postbot/distilgpt2-emailgen-V2", | |
| "postbot/distilgpt2-emailgen", | |
| "postbot/gpt2-medium-emailgen", | |
| ] | |
| if __name__ == "__main__": | |
| logging.info("\n\n\nStarting new instance of app.py") | |
| args = get_parser().parse_args() | |
| logging.info(f"received args:\t{args}") | |
| model_tag = args.model | |
| verbose = args.verbose | |
| logging.info(f"Loading model: {model_tag}, use GPU = {use_gpu}") | |
| generator = pipeline( | |
| "text-generation", | |
| model_tag, | |
| device=0 if use_gpu else -1, | |
| ) | |
| demo = gr.Blocks() | |
| logging.info("launching interface...") | |
| with demo: | |
| gr.Markdown("# Auto-Complete Emails - Demo") | |
| gr.Markdown( | |
| "Enter part of an email, and a text-gen model will complete it! See details below. " | |
| ) | |
| gr.Markdown("---") | |
| with gr.Column(): | |
| gr.Markdown("## Generate Text") | |
| gr.Markdown("Edit the prompt and parameters and press **Generate**!") | |
| prompt_text = gr.Textbox( | |
| lines=4, | |
| label="Email Prompt", | |
| value=default_prompt, | |
| ) | |
| with gr.Row(): | |
| clear_button = gr.Button( | |
| value="Clear Prompt", | |
| ) | |
| num_gen_tokens = gr.Slider( | |
| label="Generation Tokens", | |
| value=48, | |
| maximum=96, | |
| minimum=32, | |
| step=8, | |
| ) | |
| generate_button = gr.Button( | |
| value="Generate!", | |
| variant="primary", | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Results") | |
| # put a large HTML placeholder here | |
| generated_email = gr.Textbox( | |
| label="Generated Text", | |
| placeholder="This is where the generated text will appear", | |
| interactive=False, | |
| ) | |
| email_mailto_button = gr.HTML("<i>a clickable email button will appear here</i>") | |
| gr.Markdown("---") | |
| gr.Markdown("## Advanced Options") | |
| gr.Markdown( | |
| "This demo generates text via beam search. See details about these parameters [here](https://huggingface.co/blog/how-to-generate), otherwise they should be fine as-is." | |
| ) | |
| with gr.Row(): | |
| model_name = gr.Dropdown( | |
| choices=available_models, | |
| label="Choose a model", | |
| value=model_tag, | |
| ) | |
| load_model_button = gr.Button( | |
| "Load Model", | |
| variant="secondary", | |
| ) | |
| no_repeat_ngram_size = gr.Radio( | |
| choices=[1, 2, 3, 4], | |
| label="no repeat ngram size", | |
| value=2, | |
| ) | |
| with gr.Row(): | |
| num_beams = gr.Radio( | |
| choices=[2, 3, 4, 8], | |
| label="Number of Beams", | |
| value=3, | |
| ) | |
| num_beam_groups = gr.Radio( | |
| choices=[1, 2], | |
| label="Number of Beam Groups", | |
| value=1, | |
| ) | |
| length_penalty = gr.Slider( | |
| minimum=0.5, | |
| maximum=1.0, | |
| label="length penalty", | |
| value=0.8, | |
| step=0.1, | |
| ) | |
| gr.Markdown("---") | |
| with gr.Column(): | |
| gr.Markdown("## About") | |
| gr.Markdown( | |
| "[This model](https://huggingface.co/postbot/distilgpt2-emailgen) is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset.\n\nCheck out the model card for details on notebook & command line usage." | |
| ) | |
| gr.Markdown( | |
| "The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails from scratch; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements **before** accepting/sending something." | |
| ) | |
| gr.Markdown("---") | |
| clear_button.click( | |
| fn=clear, | |
| inputs=[prompt_text], | |
| outputs=[prompt_text], | |
| ) | |
| generate_button.click( | |
| fn=generate_text, | |
| inputs=[ | |
| prompt_text, | |
| num_gen_tokens, | |
| num_beams, | |
| no_repeat_ngram_size, | |
| length_penalty, | |
| num_beam_groups, | |
| ], | |
| outputs=[email_mailto_button, generated_email], | |
| ) | |
| load_model_button.click( | |
| fn=load_emailgen_model, | |
| inputs=[model_name], | |
| outputs=[], | |
| ) | |
| demo.launch( | |
| enable_queue=True, | |
| share=True, # for local testing | |
| ) | |