File size: 2,452 Bytes
83f4a0c
25bd6f8
 
 
 
 
83f4a0c
 
25bd6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f84a74
25bd6f8
 
 
 
 
f326df9
25bd6f8
f326df9
25bd6f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
import spaces
from scorer import DSGPromptProcessor
import matplotlib.pyplot as plt
import networkx as nx
from PIL import Image


def draw_colored_graph(dependencies, questions, answers):
    # Create a directed graph
    G = nx.DiGraph()

    # Add nodes with labels and colors based on answers
    for node, question in questions.items():
        color = 'green' if answers[node] else 'red'
        G.add_node(int(node), label=question, color=color)

    # Add edges based on dependencies
    for node, deps in dependencies.items():
        for dep in deps:
            G.add_edge(dep, int(node))

    # Set node positions using a layout
    pos = nx.spring_layout(G)  # You can use other layouts like 'shell_layout' or 'circular_layout'

    # Draw nodes with custom colors and labels
    node_colors = [G.nodes[node]['color'] for node in G.nodes()]
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000, edgecolors='black')

    # Draw edges with arrows
    nx.draw_networkx_edges(G, pos, arrowstyle='-|>', arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.1')

    # Draw labels
    labels = nx.get_node_attributes(G, 'label')
    nx.draw_networkx_labels(G, pos, labels, font_size=10, font_color='black')

    # Save the graph as a Pillow image
    buf = io.BytesIO()
    plt.axis('off')
    plt.savefig(buf, format='png')
    buf.seek(0)
    img = Image.open(buf)
    return img

processor = DSGPromptProcessor("mistralai/Mixtral-8x7B-Instruct-v0.1")

def process_image(image, prompt):
    tuples, _ = processor.generate_tuples(prompt)
    dependencies, _ = processor.generate_dependencies(tuples)
    questions, _ = processor.generate_questions(
        prompt, tuples.tuples, dependencies
    )
    reward = processor.get_reward(questions, dependencies, [image])
    reward = reward[0]
    answers = {i: v > 0.5 for i, v in enumerate(reward)}
    graph_img = draw_colored_graph(dependencies, questions, answers)
    return reward, f"""
Question: {questions}.
Reward per question: {reward}"""

# Define the Gradio interface
interface = gr.Interface(
    fn=process_image, 
    inputs=[gr.Image(type="pil"), gr.Textbox(label="Enter your prompt")], 
    outputs=[gr.Image(type="pil"), gr.Textbox(label="Output text")],
    title="Image and Prompt Interface",
    description="Upload an image and enter a prompt. The output is an image and text below it."
)

# Launch the Gradio app
interface.launch()