Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
import functools as ft
|
| 2 |
-
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
import random
|
|
@@ -17,7 +15,6 @@ tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
|
|
| 17 |
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
|
| 18 |
model.to(device)
|
| 19 |
|
| 20 |
-
@ft.lru_cache(maxsize=1024)
|
| 21 |
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
|
| 22 |
if seed == 0:
|
| 23 |
seed = random.randint(1, 2**32-1)
|
|
@@ -30,22 +27,21 @@ def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model
|
|
| 30 |
|
| 31 |
model.to(dtype)
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
|
| 50 |
|
| 51 |
your_prompt = gr.Textbox(label="Your Prompt", interactive=True)
|
|
@@ -62,7 +58,7 @@ top_p = gr.Slider(value=1, minimum=0, maximum=2, step=0.05, interactive=True, la
|
|
| 62 |
|
| 63 |
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
|
| 64 |
|
| 65 |
-
seed = gr.
|
| 66 |
|
| 67 |
examples = [
|
| 68 |
["A storefront with 'Text to Image' written on it.", 512, 1.2, 0.5, "fp16", 1, 50, 42]
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import random
|
|
|
|
| 15 |
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", torch_dtype=torch.float16)
|
| 16 |
model.to(device)
|
| 17 |
|
|
|
|
| 18 |
def generate(your_prompt, max_new_tokens, repetition_penalty, temperature, model_precision_type, top_p, top_k, seed):
|
| 19 |
if seed == 0:
|
| 20 |
seed = random.randint(1, 2**32-1)
|
|
|
|
| 27 |
|
| 28 |
model.to(dtype)
|
| 29 |
|
| 30 |
+
input_text = f"Expand the following prompt to add more detail: {your_prompt}"
|
| 31 |
+
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
|
|
|
| 32 |
|
| 33 |
+
outputs = model.generate(
|
| 34 |
+
input_ids,
|
| 35 |
+
max_new_tokens=max_new_tokens,
|
| 36 |
+
repetition_penalty=repetition_penalty,
|
| 37 |
+
do_sample=True,
|
| 38 |
+
temperature=temperature,
|
| 39 |
+
top_p=top_p,
|
| 40 |
+
top_k=top_k,
|
| 41 |
+
)
|
| 42 |
|
| 43 |
+
better_prompt = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 44 |
+
return better_prompt
|
| 45 |
|
| 46 |
|
| 47 |
your_prompt = gr.Textbox(label="Your Prompt", interactive=True)
|
|
|
|
| 58 |
|
| 59 |
top_k = gr.Slider(value=1, minimum=1, maximum=100, step=1, interactive=True, label="Top K", info="Higher k means more diverse outputs by considering a range of tokens")
|
| 60 |
|
| 61 |
+
seed = gr.Slider(value=42, minimum=0, maximum=2**32-1, interactive=True, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
|
| 62 |
|
| 63 |
examples = [
|
| 64 |
["A storefront with 'Text to Image' written on it.", 512, 1.2, 0.5, "fp16", 1, 50, 42]
|