|
import time |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
import glob |
|
import sys |
|
import argparse |
|
import datetime |
|
import json |
|
from pathlib import Path |
|
|
|
|
|
class PRSLogger(object): |
|
def __init__(self, model, device, spatial: bool = True): |
|
self.current_layer = 0 |
|
self.device = device |
|
self.attentions = [] |
|
self.mlps = [] |
|
self.ks = [] |
|
self.qs = [] |
|
self.vs = [] |
|
self.attn_mats = [] |
|
self.spatial = spatial |
|
self.post_ln_std = None |
|
self.post_ln_mean = None |
|
self.model = model |
|
|
|
@torch.no_grad() |
|
def compute_attentions_spatial(self, ret): |
|
assert ( |
|
len(ret.shape) == 5 |
|
), "Verify that you use method=`head` and not method=`head_no_spatial`" |
|
assert ( |
|
self.spatial |
|
), "Verify that you use method=`head` and not method=`head_no_spatial`" |
|
bias_term = self.model.visual.transformer.resblocks[ |
|
self.current_layer |
|
].attn.out_proj.bias |
|
self.current_layer += 1 |
|
return_value = ret[:, 0] |
|
self.attentions.append( |
|
return_value |
|
+ bias_term[np.newaxis, np.newaxis, np.newaxis] |
|
/ (return_value.shape[1] * return_value.shape[2]) |
|
) |
|
return ret |
|
|
|
@torch.no_grad() |
|
def compute_attentions_non_spatial(self, ret): |
|
assert ( |
|
len(ret.shape) == 4 |
|
), "Verify that you use method=`head_no_spatial` and not method=`head`" |
|
assert ( |
|
not self.spatial |
|
), "Verify that you use method=`head_no_spatial` and not method=`head`" |
|
bias_term = self.model.visual.transformer.resblocks[ |
|
self.current_layer |
|
].attn.out_proj.bias |
|
self.current_layer += 1 |
|
|
|
return_value = ret |
|
self.attentions.append( |
|
return_value + bias_term / (return_value.shape[-2]) |
|
) |
|
return ret |
|
|
|
@torch.no_grad() |
|
def compute_k(self, ret): |
|
self.ks.append(ret) |
|
return ret |
|
|
|
@torch.no_grad() |
|
def compute_q(self, ret): |
|
self.qs.append(ret) |
|
return ret |
|
|
|
@torch.no_grad() |
|
def compute_v(self, ret): |
|
self.vs.append(ret) |
|
return ret |
|
|
|
@torch.no_grad() |
|
def compute_attn_mat(self, ret): |
|
self.attn_mats.append(ret) |
|
return ret |
|
|
|
@torch.no_grad() |
|
def compute_mlps(self, ret): |
|
|
|
self.mlps.append(ret) |
|
return ret |
|
|
|
@torch.no_grad() |
|
def log_post_ln_mean(self, ret): |
|
self.post_ln_mean = ret |
|
return ret |
|
|
|
@torch.no_grad() |
|
def log_post_ln_std(self, ret): |
|
self.post_ln_std = ret |
|
return ret |
|
|
|
def _normalize_mlps(self): |
|
len_intermediates = self.attentions.shape[1] + self.mlps.shape[1] |
|
|
|
mean_centered = ( |
|
self.mlps |
|
- self.post_ln_mean[:, :, np.newaxis].to(self.device) / len_intermediates |
|
) |
|
weighted_mean_centered = ( |
|
self.model.visual.ln_post.weight.detach().to(self.device) * mean_centered |
|
) |
|
weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[ |
|
:, :, np.newaxis |
|
].to(self.device) |
|
bias_term = ( |
|
self.model.visual.ln_post.bias.detach().to(self.device) / len_intermediates |
|
) |
|
post_ln = weighted_mean_by_std + bias_term |
|
return post_ln @ self.model.visual.proj.detach().to(self.device) |
|
|
|
def _normalize_attentions_spatial(self): |
|
len_intermediates = self.attentions.shape[1] + self.mlps.shape[1] |
|
normalization_term = ( |
|
self.attentions.shape[2] * self.attentions.shape[3] |
|
) |
|
|
|
mean_centered = self.attentions - self.post_ln_mean[ |
|
:, :, np.newaxis, np.newaxis, np.newaxis |
|
].to(self.device) / (len_intermediates * normalization_term) |
|
weighted_mean_centered = ( |
|
self.model.visual.ln_post.weight.detach().to(self.device) * mean_centered |
|
) |
|
weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[ |
|
:, :, np.newaxis, np.newaxis, np.newaxis |
|
].to(self.device) |
|
bias_term = self.model.visual.ln_post.bias.detach().to(self.device) / ( |
|
len_intermediates * normalization_term |
|
) |
|
post_ln = weighted_mean_by_std + bias_term |
|
return post_ln @ self.model.visual.proj.detach().to(self.device) |
|
|
|
def _normalize_attentions_non_spatial(self): |
|
len_intermediates = self.attentions.shape[1] + self.mlps.shape[1] |
|
normalization_term = self.attentions.shape[2] |
|
|
|
mean_centered = self.attentions - self.post_ln_mean[ |
|
:, :, np.newaxis, np.newaxis |
|
].to(self.device) / (len_intermediates * normalization_term) |
|
weighted_mean_centered = ( |
|
self.model.visual.ln_post.weight.detach().to(self.device) * mean_centered |
|
) |
|
weighted_mean_by_std = weighted_mean_centered / self.post_ln_std[ |
|
:, :, np.newaxis, np.newaxis |
|
].to(self.device) |
|
bias_term = self.model.visual.ln_post.bias.detach().to(self.device) / ( |
|
len_intermediates * normalization_term |
|
) |
|
post_ln = weighted_mean_by_std + bias_term |
|
return post_ln @ self.model.visual.proj.detach().to(self.device) |
|
|
|
@torch.no_grad() |
|
def finalize(self, representation): |
|
"""We calculate the post-ln scaling, project it and normalize by the last norm.""" |
|
self.attentions = torch.stack(self.attentions, axis=1).to( |
|
self.device |
|
) |
|
self.mlps = torch.stack(self.mlps, axis=1).to(self.device) |
|
if self.spatial: |
|
projected_attentions = self._normalize_attentions_spatial() |
|
else: |
|
projected_attentions = self._normalize_attentions_non_spatial() |
|
projected_mlps = self._normalize_mlps() |
|
norm = representation.norm(dim=-1).detach() |
|
if self.spatial: |
|
return ( |
|
projected_attentions |
|
/ norm[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis], |
|
projected_mlps / norm[:, np.newaxis, np.newaxis], |
|
) |
|
return ( |
|
projected_attentions / norm[:, np.newaxis, np.newaxis, np.newaxis], |
|
projected_mlps / norm[:, np.newaxis, np.newaxis], |
|
) |
|
|
|
def reinit(self): |
|
self.current_layer = 0 |
|
self.attentions = [] |
|
self.mlps = [] |
|
self.ks = [] |
|
self.qs = [] |
|
self.vs = [] |
|
self.attn_mats = [] |
|
self.post_ln_mean = None |
|
self.post_ln_std = None |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def hook_prs_logger(model, device, spatial: bool = True): |
|
"""Hooks a projected residual stream logger to the model.""" |
|
prs = PRSLogger(model, device, spatial=spatial) |
|
if spatial: |
|
model.hook_manager.register( |
|
"visual.transformer.resblocks.*.attn.out.post", |
|
prs.compute_attentions_spatial, |
|
) |
|
else: |
|
model.hook_manager.register( |
|
"visual.transformer.resblocks.*.attn.out.post", |
|
prs.compute_attentions_non_spatial, |
|
) |
|
model.hook_manager.register( |
|
"visual.transformer.resblocks.*.mlp.c_proj.post", prs.compute_mlps |
|
) |
|
|
|
|
|
|
|
model.hook_manager.register( |
|
"visual.transformer.resblocks.*.attn.attention.pre_mask", prs.compute_attn_mat |
|
) |
|
model.hook_manager.register("visual.ln_pre_post", prs.compute_mlps) |
|
model.hook_manager.register("visual.ln_post.mean", prs.log_post_ln_mean) |
|
model.hook_manager.register("visual.ln_post.sqrt_var", prs.log_post_ln_std) |
|
return prs |
|
|