gabrielmotablima commited on
Commit
c16cb4b
·
1 Parent(s): d1c139b

:tada: first commit

Browse files
Files changed (2) hide show
  1. app.py +169 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+ import matplotlib.pyplot as plt
4
+ from torchvision.transforms import ToPILImage
5
+ from PIL import Image
6
+ import os
7
+ import math
8
+
9
+ from transformers import AutoTokenizer, AutoImageProcessor, VisionEncoderDecoderModel
10
+ import gradio as gr
11
+ from concurrent.futures import ThreadPoolExecutor
12
+
13
+ ############################## RATIONAL BEHIND ###############################
14
+
15
+ # Load the model, tokenizer, and image processor with error handling
16
+ def load_model_and_components(model_name):
17
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ image_processor = AutoImageProcessor.from_pretrained(model_name)
20
+ return model, tokenizer, image_processor
21
+
22
+ # Preload both models in parallel
23
+ def preload_models():
24
+ models = {}
25
+ model_names = ["laicsiifes/swin-distilbertimbau"] #, "laicsiifes/swin-gportuguese-2"]
26
+ with ThreadPoolExecutor() as executor:
27
+ results = executor.map(load_model_and_components, model_names)
28
+ for name, result in zip(model_names, results):
29
+ models[name] = result
30
+ return models
31
+
32
+ models = preload_models()
33
+
34
+ # Predefined images for selection
35
+ image_folder = "images"
36
+ predefined_images = [
37
+ Image.open(os.path.join(image_folder, fname)).convert("RGB")
38
+ for fname in os.listdir(image_folder)
39
+ if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.ppm'))
40
+ ]
41
+
42
+ # Function to preprocess the image to RGB format
43
+ def preprocess_image(image):
44
+ if image is None:
45
+ return None, None
46
+ pil_image = image.convert("RGB")
47
+ return pil_image, None
48
+
49
+ # Function to process the image in tokens with its attention maps
50
+ def get_attn_map(model, image, processor, tokenizer):
51
+ pixel_values = processor(image, return_tensors="pt").pixel_values
52
+ model.eval()
53
+ with torch.no_grad():
54
+ output = model.generate(
55
+ pixel_values=pixel_values,
56
+ return_dict_in_generate=True,
57
+ output_hidden_states=True,
58
+ output_attentions=True,
59
+ max_length=25,
60
+ num_beams=5
61
+ )
62
+
63
+ last_layers = [tensor_tuple[-1] for tensor_tuple in output.cross_attentions]
64
+ attention_maps = torch.stack(last_layers, dim=0)
65
+ attention_maps = einops.reduce(
66
+ attention_maps,
67
+ 'token batch head sequence (height width) -> token sequence (height width)',
68
+ height=7, width=7,
69
+ reduction='mean'
70
+ )
71
+
72
+ tokens = output.sequences[0]
73
+ token_texts = tokenizer.convert_ids_to_tokens(tokens)
74
+ valid_token_texts = token_texts[1:]
75
+
76
+ return valid_token_texts, attention_maps, output
77
+
78
+ # Function to preprocess the captions tokens and attention maps
79
+ # e.g. tokens `sent` and `##ada` yield the word `sentada`
80
+ def join_tokens(text_tokens, attention_maps, connect_symbol='##'):
81
+ tokens = text_tokens.copy()
82
+ attn_map = attention_maps.detach().clone()
83
+
84
+ i = 0
85
+ while i < len(tokens) and tokens[i] != '[SEP]':
86
+ if tokens[i].startswith(connect_symbol):
87
+ tokens[i] = tokens[i - 1] + tokens[i].replace(connect_symbol, '')
88
+ tokens.pop(i - 1)
89
+ attn_map[i][0] = attn_map[i - 1][0] + attn_map[i][0]
90
+ attn_map = torch.cat((attn_map[:i - 1], attn_map[i:]), dim=0)
91
+ i -= 1
92
+ i += 1
93
+
94
+ tokens = tokens[1:i - 1]
95
+ attn_map = attn_map[1:i - 1]
96
+
97
+ return tokens, attn_map
98
+
99
+ # Make the attention maps visually organized and presentable
100
+ def generate_attention_gallery(image, selected_model):
101
+ if image is None:
102
+ return []
103
+
104
+ model, tokenizer, processor = models[selected_model]
105
+ tokens, attention_maps, _ = get_attn_map(model, image, processor, tokenizer)
106
+ joined_tokens, joined_attn_maps = join_tokens(tokens, attention_maps)
107
+
108
+ grid_size = int(joined_attn_maps.size(-1) ** 0.5)
109
+ gallery_output = []
110
+
111
+ for i, token in enumerate(joined_tokens):
112
+ att_map = joined_attn_maps[i].view(grid_size, grid_size)
113
+ att_map = (att_map - att_map.min()) / (att_map.max() - att_map.min())
114
+
115
+ att_map = att_map.repeat_interleave(32, dim=0).repeat_interleave(32, dim=1)
116
+
117
+ att_map_resized = ToPILImage()(
118
+ att_map.unsqueeze(0).repeat(3, 1, 1)
119
+ ).resize(image.size[::])
120
+
121
+ blended = Image.blend(image, att_map_resized, alpha=0.75)
122
+ gallery_output.append((blended, token))
123
+
124
+ return gallery_output
125
+
126
+ ################################### PAGE ####################################
127
+
128
+ # Define UI
129
+ with gr.Blocks(theme=gr.themes.Citrus(primary_hue="blue", secondary_hue="orange")) as interface:
130
+ gr.Markdown("""
131
+ # Welcome to the LAICSI-IFES Vision Encoder-Decoder Demo
132
+ ---
133
+ ### Select a pretrained model and upload an image to visualize attention maps.
134
+ """)
135
+
136
+ with gr.Row(variant='panel'):
137
+ model_selector = gr.Dropdown(
138
+ choices=list(models.keys()),
139
+ value="laicsiifes/swin-distilbertimbau",
140
+ label="Select Model"
141
+ )
142
+
143
+ gr.Markdown("""---\n### Upload or select an image and click 'Generate' to view attention maps.""")
144
+
145
+ with gr.Row(variant='panel'):
146
+ with gr.Column():
147
+ image_display = gr.Image(type="pil", label="Image Preview", image_mode="RGB", height=400)
148
+ with gr.Column():
149
+ output_gallery = gr.Gallery(label="Attention Maps", columns=4, rows=3, height=600)
150
+ generate_button = gr.Button("Generate")
151
+
152
+ gr.Markdown("""---""")
153
+
154
+ with gr.Row(variant='panel'):
155
+ examples = gr.Examples(
156
+ examples=predefined_images,
157
+ fn=preprocess_image,
158
+ inputs=[image_display],
159
+ outputs=[image_display, output_gallery],
160
+ label="Examples"
161
+ )
162
+
163
+ # Actions
164
+ model_selector.change(fn=lambda: (None, []), outputs=[image_display, output_gallery])
165
+ image_display.upload(fn=preprocess_image, inputs=[image_display], outputs=[image_display, output_gallery])
166
+ image_display.clear(fn=lambda: None, outputs=[output_gallery])
167
+ generate_button.click(fn=generate_attention_gallery, inputs=[image_display, model_selector], outputs=output_gallery)
168
+
169
+ interface.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.33.0
2
+ Pillow==9.5.0
3
+ requests==2.31.0
4
+ gradio==3.29.0
5
+ torch==2.0.1
6
+ numpy==1.26.4