LenDigLearn commited on
Commit
d0a99a2
·
1 Parent(s): 031ecb9

added top_k

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -51,6 +51,7 @@ def respond(
51
  max_tokens,
52
  temperature,
53
  top_p,
 
54
  repetition_penalty
55
  ):
56
  messages = [{"role": "system", "content": system_message}]
@@ -66,7 +67,7 @@ def respond(
66
  streamer = CustomIterable()
67
 
68
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True)
69
- thread = threading.Thread(target=model.generate, args=([inputs]), kwargs={"max_new_tokens": max_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "repetition_penalty": repetition_penalty, "streamer": streamer})
70
  thread.start()
71
  response = ""
72
 
@@ -84,7 +85,7 @@ demo = gr.ChatInterface(
84
  respond,
85
  additional_inputs=[
86
  gr.Textbox(value="Du bist ein hilfreicher Assistent.", label="System message"),
87
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
88
  gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature"),
89
  gr.Slider(
90
  minimum=0.1,
@@ -93,6 +94,13 @@ demo = gr.ChatInterface(
93
  step=0.05,
94
  label="Top-p (nucleus sampling)",
95
  ),
 
 
 
 
 
 
 
96
  gr.Slider(
97
  minimum=0.1,
98
  maximum=2.0,
 
51
  max_tokens,
52
  temperature,
53
  top_p,
54
+ top_k,
55
  repetition_penalty
56
  ):
57
  messages = [{"role": "system", "content": system_message}]
 
67
  streamer = CustomIterable()
68
 
69
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True)
70
+ thread = threading.Thread(target=model.generate, args=([inputs]), kwargs={"max_new_tokens": max_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "streamer": streamer})
71
  thread.start()
72
  response = ""
73
 
 
85
  respond,
86
  additional_inputs=[
87
  gr.Textbox(value="Du bist ein hilfreicher Assistent.", label="System message"),
88
+ gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max new tokens"),
89
  gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature"),
90
  gr.Slider(
91
  minimum=0.1,
 
94
  step=0.05,
95
  label="Top-p (nucleus sampling)",
96
  ),
97
+ gr.Slider(
98
+ minimum=16,
99
+ maximum=1024,
100
+ value=512,
101
+ step=1,
102
+ label="Top-k",
103
+ ),
104
  gr.Slider(
105
  minimum=0.1,
106
  maximum=2.0,