Spaces:
Runtime error
Runtime error
File size: 2,452 Bytes
83f4a0c 25bd6f8 83f4a0c 25bd6f8 5f84a74 25bd6f8 f326df9 25bd6f8 f326df9 25bd6f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import gradio as gr
import spaces
from scorer import DSGPromptProcessor
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image
def draw_colored_graph(dependencies, questions, answers):
# Create a directed graph
G = nx.DiGraph()
# Add nodes with labels and colors based on answers
for node, question in questions.items():
color = 'green' if answers[node] else 'red'
G.add_node(int(node), label=question, color=color)
# Add edges based on dependencies
for node, deps in dependencies.items():
for dep in deps:
G.add_edge(dep, int(node))
# Set node positions using a layout
pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout'
# Draw nodes with custom colors and labels
node_colors = [G.nodes[node]['color'] for node in G.nodes()]
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='black')
# Draw edges with arrows
nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1')
# Draw labels
labels = nx.get_node_attributes(G, 'label')
nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black')
# Save the graph as a Pillow image
buf = io.BytesIO()
plt.axis('off')
plt.savefig(buf, format='png')
buf.seek(0)
img = Image.open(buf)
return img
processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1")
def process_image(image, prompt):
tuples, _ = processor.generate_tuples(prompt)
dependencies, _ = processor.generate_dependencies(tuples)
questions, _ = processor.generate_questions(
prompt, tuples.tuples, dependencies
)
reward = processor.get_reward(questions, dependencies, [image])
reward = reward[0]
answers = {i: v > 0.5 for i, v in enumerate(reward)}
graph_img = draw_colored_graph(dependencies, questions, answers)
return reward, f"""
Question: {questions}.
Reward per question: {reward}"""
# Define the Gradio interface
interface = gr.Interface(
fn=process_image,
inputs=[gr.Image(type="pil"), gr.Textbox(label="Enter your prompt")],
outputs=[gr.Image(type="pil"), gr.Textbox(label="Output text")],
title="Image and Prompt Interface",
description="Upload an image and enter a prompt. The output is an image and text below it."
)
# Launch the Gradio app
interface.launch()
|