File size: 5,508 Bytes
59696ca
e128669
eb0afed
 
3b81b26
f8b495d
0fc5d6b
 
e128669
eb0afed
 
8615850
 
 
 
 
 
 
3b81b26
 
eb0afed
59696ca
f8b495d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb0afed
 
 
 
f8b495d
 
 
eb0afed
 
 
 
 
 
f8b495d
eb0afed
 
f8b495d
eb0afed
 
 
 
 
 
f8b495d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8615850
 
 
f8b495d
 
 
863f234
 
 
 
f8b495d
 
8615850
f8b495d
 
 
 
 
0fc5d6b
 
 
 
 
 
 
 
 
 
863f234
0fc5d6b
 
 
863f234
0fc5d6b
 
863f234
0fc5d6b
 
f8b495d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

import gradio as gr
#import peft
import transformers
import os
import re
import json


device = "cpu"
is_peft = False
model_id = os.environ.get("MODEL_ID") or "treadon/prompt-fungineer-355M"
auth_token = os.environ.get("HUB_TOKEN") or True

print(f"Using model {model_id}.")

if auth_token != True:
    print("Using auth token.")

model = transformers.AutoModelForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True,use_auth_token=auth_token)
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")


def format_prompt(prompt, enhancers=True, inspiration=False, negative_prompt=False):
    try:
        pattern = r"(BRF:|POS:|ENH:|INS:|NEG:) (.*?)(?= (BRF:|POS:|ENH:|INS:|NEG:)|$)"
        matches = re.findall(pattern, prompt)
        vals = {key: value.strip() for key, value,ex in matches}
        result = vals["POS:"]
        if enhancers:
            result += " " + vals["ENH:"]
        if inspiration:
            result += " " + vals["INS:"]
        if negative_prompt:
            result += "\n\n--no " + vals["NEG:"]

        return result
    except Exception as e:
        return "Failed to generate prompt."

    
def generate_text(prompt, extra=False, top_k=100, top_p=0.95, temperature=0.85, enhancers = True, inpspiration = False , negative_prompt = False):
    
    if not prompt.startswith("BRF:"):
        prompt = "BRF: " + prompt

    if not extra:
        prompt = prompt + " POS:"

    model.eval()
    # SOFT SAMPLE
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    samples = []
    try:
        for i in range(1):
            outputs = model.generate(**inputs, max_length=256, do_sample=True, top_k=top_k, top_p=top_p, temperature=temperature, num_return_sequences=4, pad_token_id=tokenizer.eos_token_id)
            for output in outputs:
                sample = tokenizer.decode(output, skip_special_tokens=True)
                sample = format_prompt(sample, enhancers, inpspiration, negative_prompt)
                samples.append(sample)
    except Exception as e:
        print(e)

    return samples


with gr.Blocks() as fungineer:
    with gr.Row():
        gr.Markdown("""# Midjourney / Dalle 2 / Stable Diffusion Prompt Generator

This is the 355M parameter model.  There is also a 7B parameter model that is much better but far slower (access coming soon).

Just enter a basic prompt and the fungineering model will use its wildest imagination to expand the prompt in detail.""")
    with gr.Row():
        with gr.Column():
            base_prompt = gr.Textbox(lines=5, label="Base Prompt", placeholder="An astronaut in space", info="Enter a very simple prompt that will be fungineered into something exciting!")
            extra = gr.Checkbox(value=True, label="Extra Fungineer Imagination", info="If checked, the model will be allowed to go wild with its imagination.")
            with gr.Accordion("Advanced Generation Settings", open=False):
                top_k = gr.Slider( minimum=10, maximum=1000, value=100, label="Top K", info="Top K sampling")
                top_p = gr.Slider( minimum=0.1, maximum=1, value=0.95, step=0.01, label="Top P", info="Top P sampling")
                temperature = gr.Slider( minimum=0.1, maximum=1.2, value=0.85, step=0.01, label="Temperature", info="Temperature sampling.  Higher values will make the model more creative")

            with gr.Accordion("Advanced Output Settings", open=False):
                enh = gr.Checkbox(value=True, label="Enhancers", info="Add image meta information such as lens type, shuffter speed, camera model, etc.")
                insp = gr.Checkbox(value=False, label="Inpsiration", info="Include inspirational photographers that are known for this type of photography.  Sometimes random people will appear here, needs more training.")
                neg = gr.Checkbox(value=False, label="Negative Prompt", info="Include a negative prompt, more often used in Stable Diffusion.  If you're a Stable Diffusion user, chances are you already have a better negative prompt you like to use.")

        with gr.Column():
            outputs = [
                gr.Textbox(lines=2, label="Fungineered Text 1"),
                gr.Textbox(lines=2, label="Fungineered Text 2"),
                gr.Textbox(lines=2, label="Fungineered Text 3"),
                gr.Textbox(lines=2, label="Fungineered Text 4"),
            ]

    inputs = [base_prompt, extra, top_k, top_p, temperature, enh, insp, neg]


    submit = gr.Button(label="Fungineer",variant="primary")
    submit.click(generate_text, inputs=inputs, outputs=outputs)

    examples = []
    with open("examples.json") as f:
        examples = json.load(f)

    for i, example in enumerate(examples):
        with gr.Tab(f"Example {i+1}"):
            with gr.Row():
                with gr.Column():
                    gr.Markdown(f"### Base Prompt")
                    gr.Image(value=f"{example['base']['src']}")
                    gr.Markdown(f"{example['base']['prompt']}")
                with gr.Column():
                    gr.Markdown(f"### 355M Prompt Fungineered")
                    gr.Image(value=f"{example['355M']['src']}")
                    gr.Markdown(f"{example['355M']['prompt']}")
                with gr.Column():
                    gr.Markdown(f"### 7B Prompt Fungineered")
                    gr.Markdown(f"Coming Soon!")


fungineer.launch(enable_queue=True)