toilaluan's picture
update
f52b9ad
import gradio as gr
import spaces
from scorer import DSGPromptProcessor
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
import io
def draw_colored_graph(dependencies, questions, answers):
# Create a directed graph
G = nx.DiGraph()
# Add nodes with labels and colors based on answers
for node, question in questions.items():
color = 'green' if answers[node] else 'red'
G.add_node(int(node), label=question, color=color)
# Add edges based on dependencies
for node, deps in dependencies.items():
for dep in deps:
G.add_edge(dep, int(node))
# Set node positions using a layout
pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout'
# Draw nodes with custom colors and labels
node_colors = [G.nodes[node]['color'] for node in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='white')
# Draw edges with arrows
nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1')
# Draw labels
labels = nx.get_node_attributes(G, 'label')
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black')
# Save the graph as a Pillow image
buf = io.BytesIO()
plt.axis('off')
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
return img
processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1")
@spaces.GPU()
def process_image(image, prompt):
tuples, _ = processor.generate_tuples(prompt)
dependencies, _ = processor.generate_dependencies(tuples)
questions, _ = processor.generate_questions(
prompt, tuples.tuples, dependencies
)
reward, sorted_questions = processor.get_reward(questions, dependencies, [image])
reward = reward[0]
print(reward)
answers = {str(i): v > 0.5 for i, v in enumerate(reward)}
sorted_questions = {str(i): v for i, v in enumerate(sorted_questions)}
print(answers, sorted_questions)
graph_img = draw_colored_graph(dependencies, sorted_questions, answers)
return graph_img, f"""
Question: {questions}.
Reward per question: {reward}"""
description = """
<p><center>
<a href="https://arxiv.org/abs/2310.18235A" target="_blank">[Original Paper]</a>
<a href="https://github.com/toilaluan" target="_blank">[My Github]</a>
<a href="https://huggingface.co/toilaluan/Florence-2-base-Yes-No-VQA" target="_blank">[Binary VQA Model - Query Answering]</a>
<a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1" target="_blank">[Mixtral 7x8 - Query Generating]</a>
</center></p>
"""
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
'''
# Define the Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Enter your prompt")],
outputs=[gr.Image(type="pil", label="Graph Score Image", format="png"), gr.Textbox(label="Analyzed Result")],
theme=gr.themes.Soft(),
description=description,
examples = [
["examples/input_image.png", "A cat with red eyes in the jungle. All tree in the jungle has blue color."],
],
css=css,
title="T2I Adherence Scorer based on Davidsonian Scene Graph",
cache_examples=True
)
# Launch the Gradio app
interface.launch()