Spaces:
Runtime error
Runtime error
File size: 3,532 Bytes
83f4a0c 25bd6f8 e86e3f8 83f4a0c 25bd6f8 69bcda0 25bd6f8 5f84a74 25bd6f8 3f07cae 25bd6f8 f326df9 25bd6f8 f52b9ad 25bd6f8 f52b9ad e86e3f8 f52b9ad 62786d6 25bd6f8 1d0c020 e86e3f8 1d0c020 16508ee 25bd6f8 e86e3f8 69bcda0 1d0c020 16508ee 258fa6a f52b9ad 258fa6a 16508ee d1bc12b 16508ee 25bd6f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import gradio as gr
import spaces
from scorer import DSGPromptProcessor
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
import io
def draw_colored_graph(dependencies, questions, answers):
# Create a directed graph
G = nx.DiGraph()
# Add nodes with labels and colors based on answers
for node, question in questions.items():
color = 'green' if answers[node] else 'red'
G.add_node(int(node), label=question, color=color)
# Add edges based on dependencies
for node, deps in dependencies.items():
for dep in deps:
G.add_edge(dep, int(node))
# Set node positions using a layout
pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout'
# Draw nodes with custom colors and labels
node_colors = [G.nodes[node]['color'] for node in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='white')
# Draw edges with arrows
nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1')
# Draw labels
labels = nx.get_node_attributes(G, 'label')
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black')
# Save the graph as a Pillow image
buf = io.BytesIO()
plt.axis('off')
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
return img
processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1")
@spaces.GPU()
def process_image(image, prompt):
tuples, _ = processor.generate_tuples(prompt)
dependencies, _ = processor.generate_dependencies(tuples)
questions, _ = processor.generate_questions(
prompt, tuples.tuples, dependencies
)
reward, sorted_questions = processor.get_reward(questions, dependencies, [image])
reward = reward[0]
print(reward)
answers = {str(i): v > 0.5 for i, v in enumerate(reward)}
sorted_questions = {str(i): v for i, v in enumerate(sorted_questions)}
print(answers, sorted_questions)
graph_img = draw_colored_graph(dependencies, sorted_questions, answers)
return graph_img, f"""
Question: {questions}.
Reward per question: {reward}"""
description = """
<p><center>
<a href="https://arxiv.org/abs/2310.18235A" target="_blank">[Original Paper]</a>
<a href="https://github.com/toilaluan" target="_blank">[My Github]</a>
<a href="https://huggingface.co/toilaluan/Florence-2-base-Yes-No-VQA" target="_blank">[Binary VQA Model - Query Answering]</a>
<a href="https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1" target="_blank">[Mixtral 7x8 - Query Generating]</a>
</center></p>
"""
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.5em}
#gallery .grid-wrap{height: 10vh}
'''
# Define the Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Enter your prompt")],
outputs=[gr.Image(type="pil", label="Graph Score Image", format="png"), gr.Textbox(label="Analyzed Result")],
theme=gr.themes.Soft(),
description=description,
examples = [
["examples/input_image.png", "A cat with red eyes in the jungle. All tree in the jungle has blue color."],
],
css=css,
title="T2I Adherence Scorer based on Davidsonian Scene Graph",
cache_examples=True
)
# Launch the Gradio app
interface.launch()
|