joaogante HF staff commited on
Commit
c3cbdc6
·
1 Parent(s): b79fb01

assistant toggle

Browse files
Files changed (1) hide show
  1. app.py +17 -9
app.py CHANGED
@@ -19,7 +19,12 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
19
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(torch_device)
20
 
21
 
22
- def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
 
 
 
 
 
23
  # Get the model and tokenizer, and tokenize the user text.
24
  model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
25
 
@@ -28,9 +33,10 @@ def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
28
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
29
  generate_kwargs = dict(
30
  model_inputs,
 
31
  streamer=streamer,
32
  max_new_tokens=max_new_tokens,
33
- do_sample=True,
34
  top_p=top_p,
35
  temperature=float(temperature),
36
  top_k=top_k
@@ -53,34 +59,36 @@ def reset_textbox():
53
  with gr.Blocks() as demo:
54
  gr.Markdown(
55
  "# 🤗 Assisted Generation Demo\n"
56
- f"Model: {model_id} (using INT8)\n"
57
  f"Assistant Model: {assistant_id}"
58
  )
59
 
60
  with gr.Row():
61
  with gr.Column(scale=4):
62
  user_text = gr.Textbox(
63
- placeholder="Write an email about an alpaca that likes flan",
64
  label="User input"
65
  )
66
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
67
  button_submit = gr.Button(value="Submit")
68
 
69
  with gr.Column(scale=1):
 
70
  max_new_tokens = gr.Slider(
71
- minimum=1, maximum=1000, value=250, step=1, interactive=True, label="Max New Tokens",
72
  )
73
  top_p = gr.Slider(
74
- minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p (nucleus sampling)",
75
  )
76
  top_k = gr.Slider(
77
  minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
78
  )
79
  temperature = gr.Slider(
80
- minimum=0.1, maximum=5.0, value=0.8, step=0.1, interactive=True, label="Temperature",
81
  )
82
 
83
- user_text.submit(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
84
- button_submit.click(run_generation, [user_text, top_p, temperature, top_k, max_new_tokens], model_output)
 
85
 
86
  demo.queue(max_size=32).launch(enable_queue=True)
 
19
  assistant_model = AutoModelForCausalLM.from_pretrained(assistant_id).to(torch_device)
20
 
21
 
22
+ def run_generation(user_text, use_assistant, top_p, temperature, top_k, max_new_tokens):
23
+ if temperature == 0.0:
24
+ do_sample = False
25
+ else:
26
+ do_sample = True
27
+
28
  # Get the model and tokenizer, and tokenize the user text.
29
  model_inputs = tokenizer([user_text], return_tensors="pt").to(torch_device)
30
 
 
33
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
34
  generate_kwargs = dict(
35
  model_inputs,
36
+ assistant_model=assistant_model if use_assistant else None,
37
  streamer=streamer,
38
  max_new_tokens=max_new_tokens,
39
+ do_sample=do_sample,
40
  top_p=top_p,
41
  temperature=float(temperature),
42
  top_k=top_k
 
59
  with gr.Blocks() as demo:
60
  gr.Markdown(
61
  "# 🤗 Assisted Generation Demo\n"
62
+ f"Model: {model_id} (using INT8)\n\n"
63
  f"Assistant Model: {assistant_id}"
64
  )
65
 
66
  with gr.Row():
67
  with gr.Column(scale=4):
68
  user_text = gr.Textbox(
69
+ placeholder="Question: What is the meaning of life? Answer:",
70
  label="User input"
71
  )
72
  model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
73
  button_submit = gr.Button(value="Submit")
74
 
75
  with gr.Column(scale=1):
76
+ use_assistant = gr.Checkbox(label="Use Assistant", default=True)
77
  max_new_tokens = gr.Slider(
78
+ minimum=1, maximum=500, value=250, step=1, interactive=True, label="Max New Tokens",
79
  )
80
  top_p = gr.Slider(
81
+ minimum=0.05, maximum=1.0, value=0.95, step=0.05, interactive=True, label="Top-p",
82
  )
83
  top_k = gr.Slider(
84
  minimum=1, maximum=50, value=50, step=1, interactive=True, label="Top-k",
85
  )
86
  temperature = gr.Slider(
87
+ minimum=0.0, maximum=2.0, value=0.0, step=0.1, interactive=True, label="Temperature (0.0 = Greedy)",
88
  )
89
 
90
+ generate_inputs = [user_text, use_assistant, top_p, temperature, top_k, max_new_tokens]
91
+ user_text.submit(run_generation, generate_inputs, model_output)
92
+ button_submit.click(run_generation, generate_inputs, model_output)
93
 
94
  demo.queue(max_size=32).launch(enable_queue=True)