Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
from scorer import DSGPromptProcessor | |
import matplotlib.pyplot as plt | |
import networkx as nx | |
from PIL import Image | |
import io | |
def draw_colored_graph(dependencies, questions, answers): | |
# Create a directed graph | |
G = nx.DiGraph() | |
# Add nodes with labels and colors based on answers | |
for node, question in questions.items(): | |
color = 'green' if answers[node] else 'red' | |
G.add_node(int(node), label=question, color=color) | |
# Add edges based on dependencies | |
for node, deps in dependencies.items(): | |
for dep in deps: | |
G.add_edge(dep, int(node)) | |
# Set node positions using a layout | |
pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout' | |
# Draw nodes with custom colors and labels | |
node_colors = [G.nodes[node]['color'] for node in G.nodes()] | |
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='white') | |
# Draw edges with arrows | |
nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1') | |
# Draw labels | |
labels = nx.get_node_attributes(G, 'label') | |
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black') | |
# Save the graph as a Pillow image | |
buf = io.BytesIO() | |
plt.axis('off') | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
img = Image.open(buf) | |
return img | |
processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
def process_image(image, prompt): | |
tuples, _ = processor.generate_tuples(prompt) | |
dependencies, _ = processor.generate_dependencies(tuples) | |
questions, _ = processor.generate_questions( | |
prompt, tuples.tuples, dependencies | |
) | |
reward, sorted_questions = processor.get_reward(questions, dependencies, [image]) | |
reward = reward[0] | |
print(reward) | |
answers = {str(i): v > 0.5 for i, v in enumerate(reward)} | |
sorted_questions = {str(i): v for i, v in enumerate(sorted_questions)} | |
print(answers, sorted_questions) | |
graph_img = draw_colored_graph(dependencies, sorted_questions, answers) | |
return graph_img, f""" | |
Question: {questions}. | |
Reward per question: {reward}""" | |
description = """ | |
<p><center> | |
<a href="https://arxiv.org/abs/2310.18235A" target="_blank">[Original Paper]</a> | |
<a href="https://github.com/toilaluan" target="_blank">[My Github]</a> | |
<a href="https://huggingface.co/toilaluan/Florence-2-base-Yes-No-VQA" target="_blank">[Binary VQA Model - Query Answering]</a> | |
<a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1" target="_blank">[Mixtral 7x8 - Query Generating]</a> | |
</center></p> | |
""" | |
css = ''' | |
#gen_btn{height: 100%} | |
#title{text-align: center} | |
#title h1{font-size: 3em; display:inline-flex; align-items:center} | |
#title img{width: 100px; margin-right: 0.5em} | |
#gallery .grid-wrap{height: 10vh} | |
''' | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=process_image, | |
inputs=[gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Enter your prompt")], | |
outputs=[gr.Image(type="pil", label="Graph Score Image", format="png"), gr.Textbox(label="Analyzed Result")], | |
theme=gr.themes.Soft(), | |
description=description, | |
examples = [ | |
["examples/input_image.png", "A cat with red eyes in the jungle. All tree in the jungle has blue color."], | |
], | |
css=css, | |
title="T2I Adherence Scorer based on Davidsonian Scene Graph", | |
cache_examples=True | |
) | |
# Launch the Gradio app | |
interface.launch() | |