File size: 8,616 Bytes
1dd7eb5
ec0c384
 
00b05e0
d0d6669
 
e76a04b
d0d6669
 
 
e76a04b
d163769
f53adeb
 
 
 
d0d6669
 
 
 
 
d163769
 
81d0953
d163769
e76a04b
 
00b05e0
e76a04b
 
 
a62b0d6
 
 
e76a04b
00b05e0
 
e76a04b
 
 
00b05e0
 
 
4948600
ec0c384
4948600
e76a04b
ec0c384
f53adeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec0c384
f53adeb
 
 
 
 
 
 
 
 
 
 
d0d6669
f53adeb
 
 
d0d6669
f53adeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0d6669
f53adeb
 
d0d6669
 
 
1dd7eb5
f53adeb
d0d6669
 
 
 
1dd7eb5
d0d6669
f53adeb
 
d0d6669
 
 
 
 
 
 
 
 
 
f53adeb
d0d6669
 
f53adeb
 
d0d6669
 
 
 
 
 
 
f53adeb
d0d6669
 
f53adeb
d0d6669
f53adeb
a62b0d6
 
d0d6669
 
 
f53adeb
d0d6669
f53adeb
 
d0d6669
f53adeb
d0d6669
a62b0d6
 
 
d0d6669
a62b0d6
e76a04b
a62b0d6
d0d6669
 
 
62028bb
d163769
62028bb
bedfdc1
51aab6e
bedfdc1
51aab6e
bedfdc1
 
 
d163769
8558a87
 
 
 
d163769
8558a87
 
d163769
8558a87
 
 
 
 
 
d163769
8558a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc83f63
62028bb
8558a87
 
 
 
 
f53adeb
8558a87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0d6669
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import spaces
import gradio as gr
from huggingface_hub import list_models
from typing import List
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import json
import re
import logging
from datasets import load_dataset
import os
import numpy as np
from datetime import datetime
# Importar utils y save_img si no están ya importados
import utils

# Logging configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Paths to the static image and GIF
README_IMAGE_PATH = os.path.join("figs", "saliencies-merit-dataset.png")
GIF_PATH = os.path.join("figs", "demo-samples.gif")

# Global variables for Donut model, processor, and dataset
dataset = None

def load_merit_dataset():
    global dataset
    if dataset is None:
        dataset = load_dataset(
            "de-Rodrigo/merit", name="en-digital-seq", split="test", num_proc=8
        )
    return dataset

def get_image_from_dataset(index):
    global dataset
    if dataset is None:
        dataset = load_merit_dataset()
    image_data = dataset[int(index)]["image"]
    return image_data

def get_collection_models(tag: str) -> List[str]:
    """Get a list of models from a specific Hugging Face collection."""
    models = list_models(author="de-Rodrigo")
    return [model.modelId for model in models if tag in model.tags]

def initialize_donut():
    try:
        donut_model = VisionEncoderDecoderModel.from_pretrained(
            "de-Rodrigo/donut-merit"
        )
        donut_processor = DonutProcessor.from_pretrained("de-Rodrigo/donut-merit")
        donut_model = donut_model.to("cuda")
        logger.info("Donut model loaded successfully on GPU")
        return donut_model, donut_processor
    except Exception as e:
        logger.error(f"Error loading Donut model: {str(e)}")
        raise

def compute_saliency(outputs, pixels, donut_p, image):
    token_logits = torch.stack(outputs.scores, dim=1)
    token_probs = torch.softmax(token_logits, dim=-1)
    token_texts = []
    saliency_images = []

    for token_index in range(len(token_probs[0])):
        target_token_prob = token_probs[
            0, token_index, outputs.sequences[0, token_index]
        ]

        if pixels.grad is not None:
            pixels.grad.zero_()

        target_token_prob.backward(retain_graph=True)

        saliency = pixels.grad.data.abs().squeeze().mean(dim=0)

        token_id = outputs.sequences[0][token_index].item()
        token_text = donut_p.tokenizer.decode([token_id])
        logger.info(f"Considered sequence token: {token_text}")

        safe_token_text = re.sub(r'[<>:"/\\|?*]', "_", token_text)
        current_datetime = datetime.now().strftime("%Y%m%d%H%M%S")

        unique_safe_token_text = f"{safe_token_text}_{current_datetime}"
        file_name = f"saliency_{unique_safe_token_text}.png"

        saliency = utils.convert_tensor_to_rgba_image(saliency)

        # Merge saliency image twice
        saliency = utils.add_transparent_image(np.array(image), saliency)
        saliency = utils.convert_rgb_to_rgba_image(saliency)
        saliency = utils.add_transparent_image(np.array(image), saliency, 0.7)

        saliency = utils.label_frame(saliency, token_text)

        saliency_images.append(saliency)
        token_texts.append(token_text)

    return saliency_images, token_texts

@spaces.GPU(duration=300)
def process_image_donut(image):
    try:
        model, processor = initialize_donut()

        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)

        pixel_values = processor(image, return_tensors="pt").pixel_values.to("cuda")
        pixel_values.requires_grad = True

        task_prompt = "<s_cord-v2>"
        decoder_input_ids = processor.tokenizer(
            task_prompt, add_special_tokens=False, return_tensors="pt"
        )["input_ids"].to("cuda")

        outputs = model.generate.__wrapped__(
            model,
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=model.decoder.config.max_position_embeddings,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=1,
            bad_words_ids=[[processor.tokenizer.unk_token_id]],
            return_dict_in_generate=True,
            output_scores=True,
        )

        saliency_images, token_texts = compute_saliency(outputs, pixel_values, processor, image)

        sequence = processor.batch_decode(outputs.sequences)[0]
        sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(
            processor.tokenizer.pad_token, ""
        )
        sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()

        result = processor.token2json(sequence)
        return saliency_images, json.dumps(result, indent=2)
    except Exception as e:
        logger.error(f"Error processing image with Donut: {str(e)}")
        return None, f"Error: {str(e)}"

@spaces.GPU(duration=300)
def process_image(model_name, image=None, dataset_image_index=None):
    if dataset_image_index is not None:
        image = get_image_from_dataset(dataset_image_index)

    if model_name == "de-Rodrigo/donut-merit":
        saliency_images, result = process_image_donut(image)
    else:
        # Aquí deberías implementar el procesamiento para otros modelos
        saliency_images, result = None, f"Processing for model {model_name} not implemented"

    return saliency_images, result

def update_image(dataset_image_index):
    return get_image_from_dataset(dataset_image_index)

if __name__ == "__main__":
    # Load the dataset
    load_merit_dataset()

    models = get_collection_models("saliency")
    models.append("de-Rodrigo/donut-merit")

    with gr.Blocks() as demo:
        gr.Markdown("# Saliency Maps with the MERIT Dataset 🎒📃🏆")

        with gr.Row():
            with gr.Column(scale=1):
                gr.Image(value=README_IMAGE_PATH, height=400)
            with gr.Column(scale=1):
                gr.Image(
                    value=GIF_PATH, label="Dataset samples you can process", height=400
                )

        with gr.Tab("Introduction"):
            gr.Markdown(
                """
            ## Welcome to Saliency Maps with the [MERIT Dataset](https://huggingface.co/datasets/de-Rodrigo/merit) 🎒📃🏆

            This space demonstrates the capabilities of different Vision Language models 
            for document understanding tasks.

            ### Key Features:
            - Process images from the [MERIT Dataset](https://huggingface.co/datasets/de-Rodrigo/merit) or upload your own image.
            - Use a fine-tuned version of the models availabe to extract grades from documents.
            - Visualize saliency maps to understand where the model is looking (WIP 🛠️).
            """
            )

        with gr.Tab("Try It Yourself"):
            gr.Markdown(
                "Select a model and an image from the dataset, or upload your own image."
            )

            with gr.Row():
                with gr.Column():
                    model_dropdown = gr.Dropdown(choices=models, label="Select Model")
                    dataset_slider = gr.Slider(
                        minimum=0,
                        maximum=len(dataset) - 1,
                        step=1,
                        label="Dataset Image Index",
                    )
                    upload_image = gr.Image(
                        type="pil", label="Or Upload Your Own Image"
                    )

                preview_image = gr.Image(label="Selected/Uploaded Image")

            process_button = gr.Button("Process Image")

            with gr.Row():
                output_image = gr.Gallery(label="Processed Saliency Images")
                output_text = gr.Textbox(label="Result")

            # Update preview image when slider changes
            dataset_slider.change(
                fn=update_image, inputs=[dataset_slider], outputs=[preview_image]
            )

            # Update preview image when an image is uploaded
            upload_image.change(
                fn=lambda x: x, inputs=[upload_image], outputs=[preview_image]
            )

            # Process image when button is clicked
            process_button.click(
                fn=process_image,
                inputs=[model_dropdown, upload_image, dataset_slider],
                outputs=[output_image, output_text],
            )

    demo.launch()