skylersterling commited on
Commit
3aeba0a
·
verified ·
1 Parent(s): 34df75b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -25,15 +25,15 @@ def generate_text(prompt, temperature, top_p):
25
  for _ in range(80): # Adjust the range to control the number of tokens generated
26
  with torch.no_grad():
27
  outputs = model(input_tokens)
28
- predictions = outputs.logits / temperature
29
- sorted_logits, sorted_indices = torch.sort(predictions[:, -1, :], descending=True)
30
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
31
  sorted_indices_to_remove = cumulative_probs > top_p
32
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
33
  sorted_indices_to_remove[..., 0] = 0
34
  indices_to_remove = sorted_indices[sorted_indices_to_remove]
35
- predictions[:, -1, indices_to_remove] = -float('Inf')
36
- next_token = torch.multinomial(torch.softmax(predictions[:, -1, :], dim=-1), 1)
37
 
38
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
39
 
 
25
  for _ in range(80): # Adjust the range to control the number of tokens generated
26
  with torch.no_grad():
27
  outputs = model(input_tokens)
28
+ predictions = outputs.logits[:, -1, :] / temperature
29
+ sorted_logits, sorted_indices = torch.sort(predictions, descending=True)
30
  cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
31
  sorted_indices_to_remove = cumulative_probs > top_p
32
  sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
33
  sorted_indices_to_remove[..., 0] = 0
34
  indices_to_remove = sorted_indices[sorted_indices_to_remove]
35
+ predictions[:, indices_to_remove] = -float('Inf')
36
+ next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
37
 
38
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
39