Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| from scorer import DSGPromptProcessor | |
| import matplotlib.pyplot as plt | |
| import networkx as nx | |
| from PIL import Image | |
| def draw_colored_graph(dependencies, questions, answers): | |
| # Create a directed graph | |
| G = nx.DiGraph() | |
| # Add nodes with labels and colors based on answers | |
| for node, question in questions.items(): | |
| color = 'green' if answers[node] else 'red' | |
| G.add_node(int(node), label=question, color=color) | |
| # Add edges based on dependencies | |
| for node, deps in dependencies.items(): | |
| for dep in deps: | |
| G.add_edge(dep, int(node)) | |
| # Set node positions using a layout | |
| pos = nx.spring_layout(G) # You can use other layouts like 'shell_layout' or 'circular_layout' | |
| # Draw nodes with custom colors and labels | |
| node_colors = [G.nodes[node]['color'] for node in G.nodes()] | |
| nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='black') | |
| # Draw edges with arrows | |
| nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1') | |
| # Draw labels | |
| labels = nx.get_node_attributes(G, 'label') | |
| nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black') | |
| # Save the graph as a Pillow image | |
| buf = io.BytesIO() | |
| plt.axis('off') | |
| plt.savefig(buf, format='png') | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
| def process_image(image, prompt): | |
| tuples, _ = processor.generate_tuples(prompt) | |
| dependencies, _ = processor.generate_dependencies(tuples) | |
| questions, _ = processor.generate_questions( | |
| prompt, tuples.tuples, dependencies | |
| ) | |
| reward = processor.get_reward(questions, dependencies, [image]) | |
| reward = reward[0] | |
| answers = {i: v > 0.5 for i, v in enumerate(reward)} | |
| graph_img = draw_colored_graph(dependencies, questions, answers) | |
| return reward, f""" | |
| Question: {questions}. | |
| Reward per question: {reward}""" | |
| # Define the Gradio interface | |
| interface = gr.Interface( | |
| fn=process_image, | |
| inputs=[gr.Image(type="pil"), gr.Textbox(label="Enter your prompt")], | |
| outputs=[gr.Image(type="pil"), gr.Textbox(label="Output text")], | |
| title="Image and Prompt Interface", | |
| description="Upload an image and enter a prompt. The output is an image and text below it." | |
| ) | |
| # Launch the Gradio app | |
| interface.launch() | |