skylersterling commited on
Commit
c4dece9
·
verified ·
1 Parent(s): 2c4becd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -5
app.py CHANGED
@@ -14,7 +14,7 @@ model.eval()
14
  model.to('cpu')
15
 
16
  # Define the function that generates text from a prompt
17
- def generate_text(prompt):
18
  input_tokens = tokenizer.encode(prompt, return_tensors='pt')
19
  input_tokens = input_tokens.to('cpu')
20
 
@@ -23,15 +23,34 @@ def generate_text(prompt):
23
  for _ in range(80): # Adjust the range to control the number of tokens generated
24
  with torch.no_grad():
25
  outputs = model(input_tokens)
26
- predictions = outputs.logits
 
 
 
 
 
 
 
27
  next_token = torch.multinomial(torch.softmax(predictions[:, -1, :], dim=-1), 1)
28
 
29
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
30
 
31
  decoded_token = tokenizer.decode(next_token.item())
32
  generated_text += decoded_token # Append the new token to the generated text
 
 
33
  yield generated_text # Yield the entire generated text so far
34
 
35
- # Create a Gradio interface with a text input and a text output
36
- interface = gr.Interface(fn=generate_text, inputs='text', outputs='text', live=False)
37
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
14
  model.to('cpu')
15
 
16
  # Define the function that generates text from a prompt
17
+ def generate_text(prompt, temperature, top_p):
18
  input_tokens = tokenizer.encode(prompt, return_tensors='pt')
19
  input_tokens = input_tokens.to('cpu')
20
 
 
23
  for _ in range(80): # Adjust the range to control the number of tokens generated
24
  with torch.no_grad():
25
  outputs = model(input_tokens)
26
+ predictions = outputs.logits / temperature
27
+ sorted_logits, sorted_indices = torch.sort(predictions[:, -1, :], descending=True)
28
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
29
+ sorted_indices_to_remove = cumulative_probs > top_p
30
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
31
+ sorted_indices_to_remove[..., 0] = 0
32
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
33
+ predictions[:, -1, indices_to_remove] = -float('Inf')
34
  next_token = torch.multinomial(torch.softmax(predictions[:, -1, :], dim=-1), 1)
35
 
36
  input_tokens = torch.cat((input_tokens, next_token), dim=1)
37
 
38
  decoded_token = tokenizer.decode(next_token.item())
39
  generated_text += decoded_token # Append the new token to the generated text
40
+ if decoded_token == "#": # Stop if the end of sequence token is generated
41
+ break
42
  yield generated_text # Yield the entire generated text so far
43
 
44
+ # Create a Gradio interface with a text input, sliders for temperature and top_p, and a text output
45
+ interface = gr.Interface(
46
+ fn=generate_text,
47
+ inputs=[
48
+ gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here..."),
49
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, default=1.0, label="Temperature"),
50
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.9, label="Top-p")
51
+ ],
52
+ outputs='text',
53
+ live=False
54
+ )
55
+
56
+ interface.launch()