Spaces:
Runtime error
Runtime error
peter szemraj
commited on
Commit
·
88b1e11
1
Parent(s):
91d1162
format
Browse files
app.py
CHANGED
|
@@ -11,6 +11,7 @@ logging.basicConfig(
|
|
| 11 |
|
| 12 |
use_gpu = torch.cuda.is_available()
|
| 13 |
|
|
|
|
| 14 |
def generate_text(
|
| 15 |
prompt: str,
|
| 16 |
gen_length=64,
|
|
@@ -40,7 +41,7 @@ def generate_text(
|
|
| 40 |
st = time.perf_counter()
|
| 41 |
|
| 42 |
input_tokens = generator.tokenizer(prompt)
|
| 43 |
-
input_len = len(input_tokens[
|
| 44 |
if input_len > abs_max_length:
|
| 45 |
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
|
| 46 |
result = generator(
|
|
@@ -55,9 +56,8 @@ def generate_text(
|
|
| 55 |
early_stopping=True,
|
| 56 |
# tokenizer
|
| 57 |
truncation=True,
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
response = result[0]['generated_text']
|
| 61 |
rt = time.perf_counter() - st
|
| 62 |
if verbose:
|
| 63 |
logging.info(f"Generated text: {response}")
|
|
@@ -74,12 +74,12 @@ def get_parser():
|
|
| 74 |
)
|
| 75 |
|
| 76 |
parser.add_argument(
|
| 77 |
-
|
| 78 |
-
|
| 79 |
required=False,
|
| 80 |
type=str,
|
| 81 |
default="postbot/distilgpt2-emailgen",
|
| 82 |
-
help=
|
| 83 |
)
|
| 84 |
|
| 85 |
parser.add_argument(
|
|
@@ -91,6 +91,7 @@ def get_parser():
|
|
| 91 |
)
|
| 92 |
return parser
|
| 93 |
|
|
|
|
| 94 |
default_prompt = """
|
| 95 |
Hello,
|
| 96 |
|
|
@@ -109,7 +110,6 @@ if __name__ == "__main__":
|
|
| 109 |
device=0 if use_gpu else -1,
|
| 110 |
)
|
| 111 |
|
| 112 |
-
|
| 113 |
demo = gr.Blocks()
|
| 114 |
|
| 115 |
logging.info("launching interface...")
|
|
@@ -119,7 +119,9 @@ if __name__ == "__main__":
|
|
| 119 |
gr.Markdown(
|
| 120 |
"Enter part of an email, and the model will autocomplete it for you!"
|
| 121 |
)
|
| 122 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 123 |
gr.Markdown("---")
|
| 124 |
|
| 125 |
with gr.Column():
|
|
@@ -151,10 +153,11 @@ if __name__ == "__main__":
|
|
| 151 |
value=2,
|
| 152 |
)
|
| 153 |
length_penalty = gr.Slider(
|
| 154 |
-
|
| 155 |
)
|
| 156 |
generated_email = gr.Textbox(
|
| 157 |
-
label="Generated Result",
|
|
|
|
| 158 |
)
|
| 159 |
|
| 160 |
generate_button = gr.Button(
|
|
@@ -168,16 +171,24 @@ if __name__ == "__main__":
|
|
| 168 |
gr.Markdown(
|
| 169 |
"This model is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset."
|
| 170 |
)
|
| 171 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 172 |
gr.Markdown("---")
|
| 173 |
|
| 174 |
generate_button.click(
|
| 175 |
fn=generate_text,
|
| 176 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
outputs=[generated_email],
|
| 178 |
)
|
| 179 |
|
| 180 |
demo.launch(
|
| 181 |
enable_queue=True,
|
| 182 |
-
share=True,
|
| 183 |
)
|
|
|
|
| 11 |
|
| 12 |
use_gpu = torch.cuda.is_available()
|
| 13 |
|
| 14 |
+
|
| 15 |
def generate_text(
|
| 16 |
prompt: str,
|
| 17 |
gen_length=64,
|
|
|
|
| 41 |
st = time.perf_counter()
|
| 42 |
|
| 43 |
input_tokens = generator.tokenizer(prompt)
|
| 44 |
+
input_len = len(input_tokens["input_ids"])
|
| 45 |
if input_len > abs_max_length:
|
| 46 |
logging.info(f"Input too long {input_len} > {abs_max_length}, may cause errors")
|
| 47 |
result = generator(
|
|
|
|
| 56 |
early_stopping=True,
|
| 57 |
# tokenizer
|
| 58 |
truncation=True,
|
| 59 |
+
) # generate
|
| 60 |
+
response = result[0]["generated_text"]
|
|
|
|
| 61 |
rt = time.perf_counter() - st
|
| 62 |
if verbose:
|
| 63 |
logging.info(f"Generated text: {response}")
|
|
|
|
| 74 |
)
|
| 75 |
|
| 76 |
parser.add_argument(
|
| 77 |
+
"-m",
|
| 78 |
+
"--model",
|
| 79 |
required=False,
|
| 80 |
type=str,
|
| 81 |
default="postbot/distilgpt2-emailgen",
|
| 82 |
+
help="Pass an different huggingface model tag to use a custom model",
|
| 83 |
)
|
| 84 |
|
| 85 |
parser.add_argument(
|
|
|
|
| 91 |
)
|
| 92 |
return parser
|
| 93 |
|
| 94 |
+
|
| 95 |
default_prompt = """
|
| 96 |
Hello,
|
| 97 |
|
|
|
|
| 110 |
device=0 if use_gpu else -1,
|
| 111 |
)
|
| 112 |
|
|
|
|
| 113 |
demo = gr.Blocks()
|
| 114 |
|
| 115 |
logging.info("launching interface...")
|
|
|
|
| 119 |
gr.Markdown(
|
| 120 |
"Enter part of an email, and the model will autocomplete it for you!"
|
| 121 |
)
|
| 122 |
+
gr.Markdown(
|
| 123 |
+
"The model used is [postbot/distilgpt2-emailgen](https://huggingface.co/postbot/distilgpt2-emailgen)"
|
| 124 |
+
)
|
| 125 |
gr.Markdown("---")
|
| 126 |
|
| 127 |
with gr.Column():
|
|
|
|
| 153 |
value=2,
|
| 154 |
)
|
| 155 |
length_penalty = gr.Slider(
|
| 156 |
+
minimum=0.5, maximum=1.0, label="length penalty", default=0.8, step=0.05
|
| 157 |
)
|
| 158 |
generated_email = gr.Textbox(
|
| 159 |
+
label="Generated Result",
|
| 160 |
+
placeholder="The completed email will appear here",
|
| 161 |
)
|
| 162 |
|
| 163 |
generate_button = gr.Button(
|
|
|
|
| 171 |
gr.Markdown(
|
| 172 |
"This model is a fine-tuned version of distilgpt2 on a dataset of 50k emails sourced from the internet, including the classic `aeslc` dataset."
|
| 173 |
)
|
| 174 |
+
gr.Markdown(
|
| 175 |
+
"The intended use of this model is to provide suggestions to _auto-complete_ the rest of your email. Said another way, it should serve as a **tool to write predictable emails faster**. It is not intended to write entire emails; at least **some input** is required to guide the direction of the model.\n\nPlease verify any suggestions by the model for A) False claims and B) negation statements before accepting/sending something."
|
| 176 |
+
)
|
| 177 |
gr.Markdown("---")
|
| 178 |
|
| 179 |
generate_button.click(
|
| 180 |
fn=generate_text,
|
| 181 |
+
inputs=[
|
| 182 |
+
prompt_text,
|
| 183 |
+
num_gen_tokens,
|
| 184 |
+
num_beams,
|
| 185 |
+
no_repeat_ngram_size,
|
| 186 |
+
length_penalty,
|
| 187 |
+
],
|
| 188 |
outputs=[generated_email],
|
| 189 |
)
|
| 190 |
|
| 191 |
demo.launch(
|
| 192 |
enable_queue=True,
|
| 193 |
+
share=True, # for local testing
|
| 194 |
)
|