skylersterling commited on
Commit
a819a8c
·
verified ·
1 Parent(s): 5325787

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -10
app.py CHANGED
@@ -13,7 +13,7 @@ model.eval()
13
  model.to('cpu')
14
 
15
  # Define the function that generates text from a prompt
16
- def generate_text(prompt, temperature, top_p):
17
 
18
  print(prompt)
19
 
@@ -29,13 +29,6 @@ def generate_text(prompt, temperature, top_p):
29
  with torch.no_grad():
30
  outputs = model(input_tokens)
31
  predictions = outputs.logits[:, -1, :] / temperature
32
- sorted_logits, sorted_indices = torch.sort(predictions, descending=True)
33
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
34
- sorted_indices_to_remove = cumulative_probs > top_p
35
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
36
- sorted_indices_to_remove[..., 0] = 0
37
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
38
- predictions[:, indices_to_remove] = -float('Inf')
39
  next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
40
 
41
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
@@ -46,13 +39,12 @@ def generate_text(prompt, temperature, top_p):
46
  break
47
  yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
48
 
49
- # Create a Gradio interface with a text input, sliders for temperature and top_p, and a text output
50
  interface = gr.Interface(
51
  fn=generate_text,
52
  inputs=[
53
  gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
54
  gr.Slider(minimum=0.1, maximum=1.0, value=0.1, label="Temperature"),
55
- gr.Slider(minimum=0.1, maximum=1.0, value=1.0, label="Top_P"),
56
  ],
57
  outputs=gr.Textbox(),
58
  live=False,
 
13
  model.to('cpu')
14
 
15
  # Define the function that generates text from a prompt
16
+ def generate_text(prompt, temperature):
17
 
18
  print(prompt)
19
 
 
29
  with torch.no_grad():
30
  outputs = model(input_tokens)
31
  predictions = outputs.logits[:, -1, :] / temperature
 
 
 
 
 
 
 
32
  next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
33
 
34
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
 
39
  break
40
  yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
41
 
42
+ # Create a Gradio interface with a text input and a slider for temperature
43
  interface = gr.Interface(
44
  fn=generate_text,
45
  inputs=[
46
  gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
47
  gr.Slider(minimum=0.1, maximum=1.0, value=0.1, label="Temperature"),
 
48
  ],
49
  outputs=gr.Textbox(),
50
  live=False,