Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ model.eval()
|
|
13 |
model.to('cpu')
|
14 |
|
15 |
# Define the function that generates text from a prompt
|
16 |
-
def generate_text(prompt, temperature
|
17 |
|
18 |
print(prompt)
|
19 |
|
@@ -29,13 +29,6 @@ def generate_text(prompt, temperature, top_p):
|
|
29 |
with torch.no_grad():
|
30 |
outputs = model(input_tokens)
|
31 |
predictions = outputs.logits[:, -1, :] / temperature
|
32 |
-
sorted_logits, sorted_indices = torch.sort(predictions, descending=True)
|
33 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
34 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
35 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
36 |
-
sorted_indices_to_remove[..., 0] = 0
|
37 |
-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
38 |
-
predictions[:, indices_to_remove] = -float('Inf')
|
39 |
next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
|
40 |
|
41 |
input_tokens = torch.cat((input_tokens, next_token), dim=1)
|
@@ -46,13 +39,12 @@ def generate_text(prompt, temperature, top_p):
|
|
46 |
break
|
47 |
yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
|
48 |
|
49 |
-
# Create a Gradio interface with a text input
|
50 |
interface = gr.Interface(
|
51 |
fn=generate_text,
|
52 |
inputs=[
|
53 |
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
|
54 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.1, label="Temperature"),
|
55 |
-
gr.Slider(minimum=0.1, maximum=1.0, value=1.0, label="Top_P"),
|
56 |
],
|
57 |
outputs=gr.Textbox(),
|
58 |
live=False,
|
|
|
13 |
model.to('cpu')
|
14 |
|
15 |
# Define the function that generates text from a prompt
|
16 |
+
def generate_text(prompt, temperature):
|
17 |
|
18 |
print(prompt)
|
19 |
|
|
|
29 |
with torch.no_grad():
|
30 |
outputs = model(input_tokens)
|
31 |
predictions = outputs.logits[:, -1, :] / temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
|
33 |
|
34 |
input_tokens = torch.cat((input_tokens, next_token), dim=1)
|
|
|
39 |
break
|
40 |
yield generated_text[prompt_length:] # Yield the generated text excluding the initial prompt plus "EOS"
|
41 |
|
42 |
+
# Create a Gradio interface with a text input and a slider for temperature
|
43 |
interface = gr.Interface(
|
44 |
fn=generate_text,
|
45 |
inputs=[
|
46 |
gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
|
47 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.1, label="Temperature"),
|
|
|
48 |
],
|
49 |
outputs=gr.Textbox(),
|
50 |
live=False,
|