File size: 2,045 Bytes
3e248d4
 
6d5cbde
 
 
 
 
 
3e248d4
 
9297e16
 
 
 
 
 
 
 
 
e47c7c5
 
 
 
 
 
 
 
 
 
 
 
cfaf86b
e47c7c5
 
 
d253128
d5bacd1
e47c7c5
d5bacd1
e47c7c5
 
e5723fa
e47c7c5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import subprocess
def install_spacy_model(model_name):
    try:
        subprocess.check_call(["python", "-m", "spacy", "download", model_name])
    except subprocess.CalledProcessError as e:
        print(f"Error occurred while installing the model: {model_name}")
        print(f"Error details: {str(e)}")

install_spacy_model("en_core_web_trf")

import gradio as gr
import torch

from syngen_diffusion_pipeline import SynGenDiffusionPipeline





model_path = 'CompVis/stable-diffusion-v1-4'
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
pipe = SynGenDiffusionPipeline.from_pretrained(model_path).to(device)


def generate_fn(prompt, seed):
    generator = torch.Generator(device.type).manual_seed(int(seed))
    result = pipe(prompt=prompt, generator=generator, num_inference_steps=50)
    return result['images'][0]

title = "SynGen"
description = """
This is the demo for [SynGen](https://github.com/RoyiRa/Syntax-Guided-Generation), an image synthesis approach which first syntactically analyses the prompt to identify entities and their modifiers, and then uses a novel loss function that encourages the cross-attention maps to agree with the linguistic binding reflected by the syntax. Preprint: \"Linguistic Binding in Diffusion Models: Enhancing Attribute Correspondence through Attention Map Alignment\"(https://arxiv.org/abs/2306.08877).
"""

examples = [
    ["the apple is blue and the carrot is purple", "20"],
    ["a yellow flamingo and a pink sunflower", "16"],
    ["a checkered bowl in a cluttered room", "77"],
    ["a horned lion and a spotted monkey", "1269"]
]

prompt_textbox = gr.Textbox(label="Prompt", placeholder="a pink sunflower and a yellow flamingo", lines=1)
seed_textbox = gr.Textbox(label="Seed", placeholder="42", lines=1)

output = gr.Image(label="generation")
demo = gr.Interface(fn=generate_fn, inputs=[prompt_textbox, seed_textbox], outputs=output, examples=examples,
                    title=title, description=description, allow_flagging=False)

demo.launch()