File size: 3,532 Bytes
83f4a0c
25bd6f8
 
 
 
 
e86e3f8
83f4a0c
25bd6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69bcda0
25bd6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f84a74
25bd6f8
3f07cae
25bd6f8
 
 
 
f326df9
25bd6f8
f52b9ad
25bd6f8
f52b9ad
e86e3f8
f52b9ad
 
 
62786d6
25bd6f8
 
 
1d0c020
 
e86e3f8
 
1d0c020
 
 
 
 
16508ee
 
 
 
 
 
 
 
25bd6f8
 
 
e86e3f8
69bcda0
1d0c020
16508ee
258fa6a
f52b9ad
258fa6a
16508ee
d1bc12b
16508ee
25bd6f8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import spaces
from scorer import DSGPromptProcessor
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
import io

def draw_colored_graph(dependencies, questions, answers):
    # Create a directed graph
    G = nx.DiGraph()

    # Add nodes with labels and colors based on answers
    for node, question in questions.items():
        color = 'green' if answers[node] else 'red'
        G.add_node(int(node), label=question, color=color)

    # Add edges based on dependencies
    for node, deps in dependencies.items():
        for dep in deps:
            G.add_edge(dep, int(node))

    # Set node positions using a layout
    pos = nx.spring_layout(G)  # You can use other layouts like 'shell_layout' or 'circular_layout'

    # Draw nodes with custom colors and labels
    node_colors = [G.nodes[node]['color'] for node in G.nodes()]
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='white')

    # Draw edges with arrows
    nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1')

    # Draw labels
    labels = nx.get_node_attributes(G, 'label')
    nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black')

    # Save the graph as a Pillow image
    buf = io.BytesIO()
    plt.axis('off')
    plt.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    return img

processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1")

@spaces.GPU()
def process_image(image, prompt):
    tuples, _ = processor.generate_tuples(prompt)
    dependencies, _ = processor.generate_dependencies(tuples)
    questions, _ = processor.generate_questions(
        prompt, tuples.tuples, dependencies
    )
    reward, sorted_questions = processor.get_reward(questions, dependencies, [image])
    reward = reward[0]
    print(reward)
    answers = {str(i): v > 0.5 for i, v in enumerate(reward)}
    sorted_questions = {str(i): v for i, v in enumerate(sorted_questions)}
    print(answers, sorted_questions)
    graph_img = draw_colored_graph(dependencies, sorted_questions, answers)
    return graph_img, f"""
Question: {questions}.
Reward per question: {reward}"""

description = """
<p><center>
<a href="https://arxiv.org/abs/2310.18235A" target="_blank">[Original Paper]</a>
<a href="https://github.com/toilaluan" target="_blank">[My Github]</a>
<a href="https://huggingface.co/toilaluan/Florence-2-base-Yes-No-VQA" target="_blank">[Binary VQA Model - Query Answering]</a>
<a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1" target="_blank">[Mixtral 7x8 - Query Generating]</a>
</center></p>
"""

css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
'''

# Define the Gradio interface
interface = gr.Interface(
    fn=process_image, 
    inputs=[gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Enter your prompt")], 
    outputs=[gr.Image(type="pil", label="Graph Score Image", format="png"), gr.Textbox(label="Analyzed Result")],
    theme=gr.themes.Soft(),
    description=description,
    examples = [
        ["examples/input_image.png", "A cat with red eyes in the jungle. All tree in the jungle has blue color."], 
    ],
    css=css,
    title="T2I Adherence Scorer based on Davidsonian Scene Graph",
    cache_examples=True
)

# Launch the Gradio app
interface.launch()