Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from .. import losses | |
| import ignite.distributed as idist | |
| import torch_optimizer | |
| from tqdm import tqdm | |
| import matplotlib.pyplot as plt | |
| from torch.nn import functional as F | |
| import os | |
| import shutil | |
| from modelguidedattacks.cls_models.registry import MMPretrainVisualTransformerWrapper | |
| from modelguidedattacks.data.imagenet_metadata import imgnet_idx_to_name | |
| class Unguided(nn.Module): | |
| def __init__(self, model: nn.Module, config, optimizer=torch.optim.AdamW, seed=0, iterations=1000, | |
| loss_fn=losses.CVXProjLoss, lr=1e-3, | |
| binary_search_steps=1, topk_loss_coef_upper=10., | |
| topk_loss_coef_lower=0.) -> None: | |
| super().__init__() | |
| self.guided = False | |
| self.model = model | |
| self.seed = seed | |
| self.iterations = iterations | |
| self.loss = loss_fn() | |
| self.optimizer = optimizer | |
| self.lr = lr | |
| self.binary_search_steps = binary_search_steps | |
| self.topk_loss_coef_upper = topk_loss_coef_upper | |
| self.topk_loss_coef_lower = topk_loss_coef_lower | |
| self.config = config | |
| def surject_perturbation(self, x, max_norm=5.): | |
| x_shape = x.shape | |
| x = x.flatten(1) | |
| x_norm = x.norm(dim=-1) | |
| x_unit = x / x_norm[:, None] | |
| x_norm_outside = x_norm > max_norm | |
| x_norm_outside = x_norm_outside.expand_as(x) | |
| x = torch.where(x_norm_outside, x_unit*max_norm, x) | |
| return x.view(x_shape) | |
| def attack(self, x, attack_targets, gt_labels, topk_coefs): | |
| """ | |
| For a given set of topk coefficients, this function computes | |
| best energy attack in the given number of iterations and configuration | |
| x: [B, C, H, W] [0-1 for colors] | |
| attack_targets: [B, K] (long) | |
| gt_labels: [B] (long) | |
| topk_coefs: [B] (floats) | |
| """ | |
| topk_coefs = topk_coefs.clone() | |
| K = attack_targets.shape[-1] | |
| x_perturbation = nn.Parameter(torch.randn(x.shape, | |
| device=x.device)*2e-3) | |
| optimizer = self.optimizer([x_perturbation], lr=self.lr) | |
| precomputed_state = self.loss.precompute(attack_targets, gt_labels, self.config) | |
| with torch.no_grad(): | |
| prediction_logits_0, prediction_feats_0 \ | |
| = self.model(x, return_features=True) | |
| best_perturbations = torch.zeros_like(x) # [B, 3, H, W] | |
| has_successful_attack = torch.zeros(x.shape[0], dtype=torch.long, device=x.device) # [B] | |
| best_energy = torch.full((x.shape[0],), float('inf'), device=x.device) # [B] | |
| pbar = tqdm(range(self.iterations)) | |
| for i in pbar: | |
| if i == self.config.opt_warmup_its: | |
| # Reset optimizer state | |
| optimizer = self.optimizer([x_perturbation], lr=self.lr) | |
| x_perturbed = x + x_perturbation#self.surject_perturbation(x_perturbation) | |
| prediction_logits, prediction_feats = self.model(x_perturbed, return_features=True) | |
| pred_classes = prediction_logits.argsort(dim=-1, descending=True) # [B, C] | |
| attack_successful = (pred_classes[:, :K] == attack_targets).all(dim=-1) # [B] | |
| attack_energy = x_perturbation.flatten(1).norm(dim=-1) # [B] | |
| attack_improved = attack_successful & (attack_energy <= best_energy) | |
| best_perturbations[attack_improved] = x_perturbation[attack_improved] | |
| has_successful_attack[attack_improved] = True | |
| best_energy[attack_improved] = attack_energy[attack_improved] | |
| loss = self.loss(logits_pred=prediction_logits, | |
| feats_pred=prediction_feats, | |
| feats_pred_0=prediction_feats_0, | |
| attack_targets=attack_targets, | |
| model=self.model, **precomputed_state) | |
| loss = loss * topk_coefs | |
| loss = loss.sum() | |
| pbar.set_description(f"Loss: {loss.item():.3f}") | |
| loss = loss + x_perturbation.flatten(1).square().sum() | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| # If we were successfull let's start taking the norm down | |
| topk_coefs[attack_improved] *= 0.75 | |
| # Project perturbation to be within image limits | |
| with torch.no_grad(): | |
| x_perturbed = x + x_perturbation | |
| x_perturbed = x_perturbed.clamp_(min=0., max=1.) | |
| x_perturbation.data = x_perturbed - x | |
| x_perturbed_best = x + best_perturbations | |
| prediction_logits, prediction_feats = self.model(x_perturbed_best, return_features=True) | |
| if self.config.dump_plots: | |
| if os.path.isdir(self.config.plot_out): | |
| shutil.rmtree(self.config.plot_out) | |
| if has_successful_attack.any(): | |
| def dump_random_map(): | |
| os.makedirs(self.config.plot_out, exist_ok=True) | |
| # selected_idx = best_energy.argmin() | |
| successful_idxs = has_successful_attack.nonzero()[:, 0] | |
| if self.config.plot_idx == "find": | |
| selected_idx = successful_idxs[torch.randperm(len(successful_idxs))[0]] | |
| # selected_idx = best_energy.argmin() | |
| else: | |
| selected_idx = int(self.config.plot_idx) | |
| print ("Selected idx", selected_idx) | |
| top_classes = prediction_logits_0[selected_idx].argsort(dim=-1, descending=True) | |
| attack_targets_selected = attack_targets[selected_idx] | |
| def imgnet_names(idxs): | |
| return [imgnet_idx_to_name[int(idx)].split(",")[0] for idx in idxs] | |
| top_class_names = imgnet_names(top_classes)[:K] | |
| attack_targets_selected_names = imgnet_names(attack_targets_selected) | |
| def plot_attn_map(attn_map): | |
| attn_map = attn_map[0].mean(dim=0)[1:] # [196] get class tokens | |
| attn_map = attn_map.view(14, 14) | |
| attn_map = F.interpolate( | |
| attn_map[None, None], | |
| x.shape[-2:], | |
| mode="bilinear" | |
| ).view(x.shape[-2:]) | |
| plt.imshow(attn_map.detach().cpu(), alpha=0.5) | |
| plt.figure() | |
| plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu()) | |
| plt.axis("off") | |
| plt.savefig(f"{self.config.plot_out}/clean_image.png", bbox_inches="tight", pad_inches=0) | |
| plt.figure() | |
| plt.imshow(x_perturbed_best[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu()) | |
| plt.axis("off") | |
| plt.savefig(f"{self.config.plot_out}/perturbed_image.png", bbox_inches="tight", pad_inches=0) | |
| plt.figure() | |
| plt.imshow(best_perturbations[selected_idx].mean(dim=0).abs().detach().cpu(), cmap="hot") | |
| plt.colorbar() | |
| plt.savefig(f"{self.config.plot_out}/perturbation.png", bbox_inches="tight") | |
| if isinstance(self.model, MMPretrainVisualTransformerWrapper): | |
| attn_maps_clean = self.model.get_attention_maps(x)[-1][selected_idx] | |
| attn_maps_attacked = self.model.get_attention_maps(x_perturbed_best)[-1][selected_idx] | |
| plt.figure() | |
| plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu()) | |
| plot_attn_map(attn_maps_clean) | |
| plt.axis("off") | |
| plt.savefig(f"{self.config.plot_out}/clean_map.png", bbox_inches="tight", pad_inches=0) | |
| plt.figure() | |
| plt.imshow(x[selected_idx].permute(1,2,0).flip(dims=(-1,)).detach().cpu()) | |
| plot_attn_map(attn_maps_attacked) | |
| plt.axis("off") | |
| plt.savefig(f"{self.config.plot_out}/attacked_map.png", bbox_inches="tight", pad_inches=0) | |
| with open(f'{self.config.plot_out}/clean_classes_names.txt', 'w') as f: | |
| f.write(", ".join(top_class_names)) | |
| with open(f'{self.config.plot_out}/attack_targets_names.txt', 'w') as f: | |
| f.write(", ".join(attack_targets_selected_names)) | |
| with open(f'{self.config.plot_out}/clean_classes_names.txt', 'w') as f: | |
| f.write(", ".join(top_class_names)) | |
| with open(f'{self.config.plot_out}/selected_idx.txt', 'w') as f: | |
| if isinstance(selected_idx, torch.Tensor): | |
| selected_idx = selected_idx.item() | |
| f.write(str(selected_idx)) | |
| with open(f'{self.config.plot_out}/energy.txt', 'w') as f: | |
| f.write(str(best_energy[selected_idx].item())) | |
| C = prediction_logits_0.shape[-1] | |
| class_idxs = torch.arange(C) + 1 | |
| clean_probs = prediction_logits_0[selected_idx].detach().cpu().softmax(dim=-1) | |
| attacked_probs = prediction_logits[selected_idx].detach().cpu().softmax(dim=-1) | |
| def label_classes(bars): | |
| adjusted_heights = {} | |
| for i, cls_idx in enumerate(attack_targets_selected.tolist()): | |
| bar = bars[cls_idx] | |
| height = bar.get_height() | |
| ann_x = bar.get_x() + bar.get_width() | |
| rotation = 90 | |
| font_size = 10 | |
| max_neighboring_height = -1 | |
| for other_cls_idx in attack_targets_selected.tolist(): | |
| if abs(cls_idx - other_cls_idx) <= 40 and cls_idx != other_cls_idx: | |
| if other_cls_idx in adjusted_heights and adjusted_heights[other_cls_idx] > max_neighboring_height: | |
| max_neighboring_height = adjusted_heights[other_cls_idx] | |
| if max_neighboring_height > 0: | |
| height = max_neighboring_height + 0.05 | |
| adjusted_heights[cls_idx] = height | |
| plt.text(ann_x, height, f"[{i}]", rotation=rotation, | |
| ha='center', va='bottom', fontsize=font_size, color='red')#.get_bbox_patch().get_height() | |
| plt.figure() | |
| bars_clean = plt.bar(class_idxs, clean_probs, width=4) | |
| plt.ylim(0,1) | |
| label_classes(bars_clean) | |
| plt.savefig(f"{self.config.plot_out}/clean_probs.png", bbox_inches="tight", pad_inches=0) | |
| plt.figure() | |
| bars_attacked = plt.bar(class_idxs, attacked_probs, width=4) | |
| plt.ylim(0,1) | |
| label_classes(bars_attacked) | |
| plt.savefig(f"{self.config.plot_out}/attacked_probs.png", bbox_inches="tight", pad_inches=0) | |
| print ("Idx", selected_idx) | |
| print (best_energy[selected_idx]) | |
| print ("Finished plotting") | |
| dump_random_map() | |
| import sys | |
| sys.exit(1) | |
| print ("Dumped attention map") | |
| return prediction_logits, best_perturbations, best_energy | |
| def forward(self, x, attack_targets, gt_labels): | |
| """ | |
| This function is in charge of performing a binary search through | |
| topk loss coefficients and running attacks on each. | |
| """ | |
| B = x.shape[0] | |
| device = x.device | |
| topk_coefs_lower = torch.full((B,), fill_value=self.topk_loss_coef_lower, | |
| device=device, dtype=torch.float) | |
| topk_coefs_upper = torch.full((B,), fill_value=self.topk_loss_coef_upper, | |
| device=device, dtype=torch.float) | |
| best_perturbations = torch.zeros_like(x) # [B, 3, H, W] | |
| best_energy = torch.full((B,), float('inf'), device=device) # [B] | |
| best_prediction_logits = None | |
| for search_step_i in range(self.binary_search_steps): | |
| if x.device.index is None or x.device.index == 0: | |
| print ("Running binary search step", search_step_i + 1) | |
| current_topk_coefs = (topk_coefs_lower + topk_coefs_upper) / 2 | |
| current_logits, current_perturbations, current_energy = \ | |
| self.attack(x, attack_targets, gt_labels, current_topk_coefs) | |
| current_attack_suceeded = ~torch.isinf(current_energy) | |
| update_mask = current_energy < best_energy | |
| best_perturbations[update_mask] = current_perturbations[update_mask] | |
| best_energy[update_mask] = current_energy[update_mask] | |
| if best_prediction_logits is None: | |
| best_prediction_logits = current_logits.clone() | |
| else: | |
| best_prediction_logits[update_mask] = current_logits[update_mask] | |
| # If we fail to attack, we must increase our topk coef | |
| topk_coefs_lower[~current_attack_suceeded] = current_topk_coefs[~current_attack_suceeded] | |
| # If we succeed, we must lower to seek a more frugal attack | |
| topk_coefs_upper[current_attack_suceeded] = current_topk_coefs[current_attack_suceeded] | |
| idist.barrier() | |
| return best_prediction_logits, best_perturbations |