Add HungarianMatcher
#2
by
emanuelevivoli
- opened
- conditional_detr_utils.py +179 -0
- modelling_magi.py +4 -3
conditional_detr_utils.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2022 Microsoft Research Asia and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch Conditional DETR model."""
|
16 |
+
|
17 |
+
from transformers.utils import (
|
18 |
+
is_scipy_available,
|
19 |
+
is_vision_available,
|
20 |
+
logging
|
21 |
+
)
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch import Tensor, nn
|
25 |
+
|
26 |
+
if is_scipy_available():
|
27 |
+
from scipy.optimize import linear_sum_assignment
|
28 |
+
|
29 |
+
if is_vision_available():
|
30 |
+
from transformers.image_transforms import center_to_corners_format
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__)
|
33 |
+
|
34 |
+
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr
|
35 |
+
class ConditionalDetrHungarianMatcher(nn.Module):
|
36 |
+
"""
|
37 |
+
This class computes an assignment between the targets and the predictions of the network.
|
38 |
+
|
39 |
+
For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more
|
40 |
+
predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are
|
41 |
+
un-matched (and thus treated as non-objects).
|
42 |
+
|
43 |
+
Args:
|
44 |
+
class_cost:
|
45 |
+
The relative weight of the classification error in the matching cost.
|
46 |
+
bbox_cost:
|
47 |
+
The relative weight of the L1 error of the bounding box coordinates in the matching cost.
|
48 |
+
giou_cost:
|
49 |
+
The relative weight of the giou loss of the bounding box in the matching cost.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1):
|
53 |
+
super().__init__()
|
54 |
+
|
55 |
+
self.class_cost = class_cost
|
56 |
+
self.bbox_cost = bbox_cost
|
57 |
+
self.giou_cost = giou_cost
|
58 |
+
if class_cost == 0 and bbox_cost == 0 and giou_cost == 0:
|
59 |
+
raise ValueError("All costs of the Matcher can't be 0")
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def forward(self, outputs, targets):
|
63 |
+
"""
|
64 |
+
Args:
|
65 |
+
outputs (`dict`):
|
66 |
+
A dictionary that contains at least these entries:
|
67 |
+
* "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
|
68 |
+
* "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates.
|
69 |
+
targets (`List[dict]`):
|
70 |
+
A list of targets (len(targets) = batch_size), where each target is a dict containing:
|
71 |
+
* "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of
|
72 |
+
ground-truth
|
73 |
+
objects in the target) containing the class labels
|
74 |
+
* "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
`List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where:
|
78 |
+
- index_i is the indices of the selected predictions (in order)
|
79 |
+
- index_j is the indices of the corresponding selected targets (in order)
|
80 |
+
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
|
81 |
+
"""
|
82 |
+
batch_size, num_queries = outputs["logits"].shape[:2]
|
83 |
+
|
84 |
+
# We flatten to compute the cost matrices in a batch
|
85 |
+
out_prob = outputs["logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes]
|
86 |
+
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
|
87 |
+
|
88 |
+
# Also concat the target labels and boxes
|
89 |
+
target_ids = torch.cat([v["class_labels"] for v in targets])
|
90 |
+
target_bbox = torch.cat([v["boxes"] for v in targets])
|
91 |
+
|
92 |
+
# Compute the classification cost.
|
93 |
+
alpha = 0.25
|
94 |
+
gamma = 2.0
|
95 |
+
neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log())
|
96 |
+
pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
|
97 |
+
class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids]
|
98 |
+
|
99 |
+
# Compute the L1 cost between boxes
|
100 |
+
bbox_cost = torch.cdist(out_bbox, target_bbox, p=1)
|
101 |
+
|
102 |
+
# Compute the giou cost between boxes
|
103 |
+
giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox))
|
104 |
+
|
105 |
+
# Final cost matrix
|
106 |
+
cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost
|
107 |
+
cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu()
|
108 |
+
|
109 |
+
sizes = [len(v["boxes"]) for v in targets]
|
110 |
+
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))]
|
111 |
+
return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
|
112 |
+
|
113 |
+
|
114 |
+
# Copied from transformers.models.detr.modeling_detr._upcast
|
115 |
+
def _upcast(t: Tensor) -> Tensor:
|
116 |
+
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
|
117 |
+
if t.is_floating_point():
|
118 |
+
return t if t.dtype in (torch.float32, torch.float64) else t.float()
|
119 |
+
else:
|
120 |
+
return t if t.dtype in (torch.int32, torch.int64) else t.int()
|
121 |
+
|
122 |
+
|
123 |
+
# Copied from transformers.models.detr.modeling_detr.box_area
|
124 |
+
def box_area(boxes: Tensor) -> Tensor:
|
125 |
+
"""
|
126 |
+
Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
|
130 |
+
Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
|
131 |
+
< x2` and `0 <= y1 < y2`.
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
`torch.FloatTensor`: a tensor containing the area for each box.
|
135 |
+
"""
|
136 |
+
boxes = _upcast(boxes)
|
137 |
+
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
138 |
+
|
139 |
+
|
140 |
+
# Copied from transformers.models.detr.modeling_detr.box_iou
|
141 |
+
def box_iou(boxes1, boxes2):
|
142 |
+
area1 = box_area(boxes1)
|
143 |
+
area2 = box_area(boxes2)
|
144 |
+
|
145 |
+
left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
146 |
+
right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
147 |
+
|
148 |
+
width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2]
|
149 |
+
inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M]
|
150 |
+
|
151 |
+
union = area1[:, None] + area2 - inter
|
152 |
+
|
153 |
+
iou = inter / union
|
154 |
+
return iou, union
|
155 |
+
|
156 |
+
|
157 |
+
# Copied from transformers.models.detr.modeling_detr.generalized_box_iou
|
158 |
+
def generalized_box_iou(boxes1, boxes2):
|
159 |
+
"""
|
160 |
+
Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
`torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
|
164 |
+
"""
|
165 |
+
# degenerate boxes gives inf / nan results
|
166 |
+
# so do an early check
|
167 |
+
if not (boxes1[:, 2:] >= boxes1[:, :2]).all():
|
168 |
+
raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
|
169 |
+
if not (boxes2[:, 2:] >= boxes2[:, :2]).all():
|
170 |
+
raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
|
171 |
+
iou, union = box_iou(boxes1, boxes2)
|
172 |
+
|
173 |
+
top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
174 |
+
bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
175 |
+
|
176 |
+
width_height = (bottom_right - top_left).clamp(min=0) # [N,M,2]
|
177 |
+
area = width_height[:, :, 0] * width_height[:, :, 1]
|
178 |
+
|
179 |
+
return iou - (area - union) / area
|
modelling_magi.py
CHANGED
@@ -2,15 +2,15 @@ from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel
|
|
2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
3 |
ConditionalDetrMLPPredictionHead,
|
4 |
ConditionalDetrModelOutput,
|
5 |
-
ConditionalDetrHungarianMatcher,
|
6 |
inverse_sigmoid,
|
7 |
)
|
|
|
8 |
from .configuration_magi import MagiConfig
|
9 |
from .processing_magi import MagiProcessor
|
10 |
from torch import nn
|
11 |
from typing import Optional, List
|
12 |
import torch
|
13 |
-
from einops import rearrange, repeat
|
14 |
from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
|
15 |
|
16 |
class MagiModel(PreTrainedModel):
|
@@ -498,4 +498,5 @@ class MagiModel(PreTrainedModel):
|
|
498 |
if apply_sigmoid:
|
499 |
text_character_affinities = text_character_affinities.sigmoid()
|
500 |
affinity_matrices.append(text_character_affinities)
|
501 |
-
return affinity_matrices
|
|
|
|
2 |
from transformers.models.conditional_detr.modeling_conditional_detr import (
|
3 |
ConditionalDetrMLPPredictionHead,
|
4 |
ConditionalDetrModelOutput,
|
|
|
5 |
inverse_sigmoid,
|
6 |
)
|
7 |
+
from .conditional_detr_utils import ConditionalDetrHungarianMatcher
|
8 |
from .configuration_magi import MagiConfig
|
9 |
from .processing_magi import MagiProcessor
|
10 |
from torch import nn
|
11 |
from typing import Optional, List
|
12 |
import torch
|
13 |
+
from einops import rearrange, repeat
|
14 |
from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order
|
15 |
|
16 |
class MagiModel(PreTrainedModel):
|
|
|
498 |
if apply_sigmoid:
|
499 |
text_character_affinities = text_character_affinities.sigmoid()
|
500 |
affinity_matrices.append(text_character_affinities)
|
501 |
+
return affinity_matrices
|
502 |
+
|