Capture Activation and Attention Maps using PyTorch Hooks¶
Computation Graph¶
PyTorch Hooks¶
Why we use pytorch hooks over other methods¶
How to capture all activations¶
In [1]:
# prompt: create a torch tensor with scalar value 10
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
device = "cpu"
vit_processor = ViTImageProcessor.from_pretrained('vit-base-patch16-224')
vit_model = ViTForImageClassification.from_pretrained(
'vit-base-patch16-224',
device_map = device
)
/opt/homebrew/Caskroom/miniconda/base/envs/sasika-test/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
In [2]:
vit_inputs = vit_processor(images=image, return_tensors="pt")
vit_outputs = vit_model(**vit_inputs,
ouptut_hidden_states=True,
output_attentions=True
)
vit_logits = vit_outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = vit_logits.argmax(-1).item()
print("Predicted class:", vit_model.config.id2label[predicted_class_idx])
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[2], line 2 1 vit_inputs = vit_processor(images=image, return_tensors="pt") ----> 2 vit_outputs = vit_model(**vit_inputs, 3 ouptut_hidden_states=True, 4 output_attentions=True 5 ) 6 vit_logits = vit_outputs.logits 7 # model predicts one of the 1000 ImageNet classes File /opt/homebrew/Caskroom/miniconda/base/envs/sasika-test/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs) 1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1750 else: -> 1751 return self._call_impl(*args, **kwargs) File /opt/homebrew/Caskroom/miniconda/base/envs/sasika-test/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs) 1757 # If we don't have any hooks, we want to skip the rest of the logic in 1758 # this function, and just call forward. 1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1760 or _global_backward_pre_hooks or _global_backward_hooks 1761 or _global_forward_hooks or _global_forward_pre_hooks): -> 1762 return forward_call(*args, **kwargs) 1764 result = None 1765 called_always_called_hooks = set() TypeError: ViTForImageClassification.forward() got an unexpected keyword argument 'ouptut_hidden_states'
In [ ]:
image
Out[ ]:
In [ ]:
vit_model
Out[ ]:
ViTForImageClassification(
(vit): ViTModel(
(embeddings): ViTEmbeddings(
(patch_embeddings): ViTPatchEmbeddings(
(projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
)
(dropout): Dropout(p=0.0, inplace=False)
)
(encoder): ViTEncoder(
(layer): ModuleList(
(0-11): 12 x ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
(layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(classifier): Linear(in_features=768, out_features=1000, bias=True)
)
I am using Ordered Dictionary to store the activations to keep track of the order of the forward pass
In [ ]:
from collections import OrderedDict
In [ ]:
vit_activations = OrderedDict()
In [ ]:
for name, layer in vit_model._modules.items():
print(name)
vit classifier
In [ ]:
type(vit_model._modules['vit'])
Out[ ]:
transformers.models.vit.modeling_vit.ViTModel
Get the execution order¶
In [ ]:
vit_model.eval()
Out[ ]:
ViTForImageClassification(
(vit): ViTModel(
(embeddings): ViTEmbeddings(
(patch_embeddings): ViTPatchEmbeddings(
(projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
)
(dropout): Dropout(p=0.0, inplace=False)
)
(encoder): ViTEncoder(
(layer): ModuleList(
(0-11): 12 x ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
(layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(classifier): Linear(in_features=768, out_features=1000, bias=True)
)
In [ ]:
execution_order = []
vit_activations = OrderedDict()
vit_layer_inputs = OrderedDict()
def make_hook(name):
def hook(module, inp, out):
execution_order.append(name)
vit_layer_inputs[name] = inp
vit_activations[name] = out
return hook
# register hooks on all leaf modules
hooks = []
for name, module in vit_model.named_modules():
if len(list(module.children())) == 0: # leaf only; drop this if you want all
h = module.register_forward_hook(make_hook(name))
hooks.append(h)
# run a forward pass
x = vit_inputs
_ = vit_model(**x)
In [ ]:
len(execution_order)
Out[ ]:
136
In [ ]:
vit_inputs;
In [ ]:
len(vit_layer_inputs)
Out[ ]:
136
In [ ]:
vit_model
Out[ ]:
ViTForImageClassification(
(vit): ViTModel(
(embeddings): ViTEmbeddings(
(patch_embeddings): ViTPatchEmbeddings(
(projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
)
(dropout): Dropout(p=0.0, inplace=False)
)
(encoder): ViTEncoder(
(layer): ModuleList(
(0-11): 12 x ViTLayer(
(attention): ViTAttention(
(attention): ViTSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
)
(output): ViTSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
)
(intermediate): ViTIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): ViTOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
)
(layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
)
)
(layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
)
(classifier): Linear(in_features=768, out_features=1000, bias=True)
)
In [ ]:
vit_model.vit.encoder.layer[11].attention.attention
Out[ ]:
ViTSelfAttention( (query): Linear(in_features=768, out_features=768, bias=True) (key): Linear(in_features=768, out_features=768, bias=True) (value): Linear(in_features=768, out_features=768, bias=True) )
In [ ]:
In [ ]:
from inspect import inspect
attention_modules = [m for m in vit_model.modules() if isinstance(m, Vitttention)]
In [ ]:
In [ ]:
selected_layer = execution_order[6]
vit_layer_inputs[selected_layer][0].shape
Out[ ]:
torch.Size([1, 197, 768])
In [ ]:
for i in range(0, 15):
selected_layer = execution_order[i]
shape = vit_layer_inputs[selected_layer][0].shape
print(i, " : ", selected_layer , " : ", shape )
0 : vit.embeddings.patch_embeddings.projection : torch.Size([1, 3, 224, 224]) 1 : vit.embeddings.dropout : torch.Size([1, 197, 768]) 2 : vit.encoder.layer.0.layernorm_before : torch.Size([1, 197, 768]) 3 : vit.encoder.layer.0.attention.attention.key : torch.Size([1, 197, 768]) 4 : vit.encoder.layer.0.attention.attention.value : torch.Size([1, 197, 768]) 5 : vit.encoder.layer.0.attention.attention.query : torch.Size([1, 197, 768]) 6 : vit.encoder.layer.0.attention.output.dense : torch.Size([1, 197, 768]) 7 : vit.encoder.layer.0.attention.output.dropout : torch.Size([1, 197, 768]) 8 : vit.encoder.layer.0.layernorm_after : torch.Size([1, 197, 768]) 9 : vit.encoder.layer.0.intermediate.dense : torch.Size([1, 197, 768]) 10 : vit.encoder.layer.0.intermediate.intermediate_act_fn : torch.Size([1, 197, 3072]) 11 : vit.encoder.layer.0.output.dense : torch.Size([1, 197, 3072]) 12 : vit.encoder.layer.0.output.dropout : torch.Size([1, 197, 768]) 13 : vit.encoder.layer.1.layernorm_before : torch.Size([1, 197, 768]) 14 : vit.encoder.layer.1.attention.attention.key : torch.Size([1, 197, 768])
In [ ]:
torch.equal(vit_layer_inputs[execution_order[1]][0],
vit_layer_inputs[execution_order[2]][0]
)
# if model.train() dropout acts as an identity function, elif model.eval() dropout acts as a normal
Out[ ]:
True
In [ ]:
len(execution_order)
Out[ ]:
136
In [ ]:
for i, layer in enumerate(execution_order):
print(f"{i} : {layer}")
0 : vit.embeddings.patch_embeddings.projection 1 : vit.embeddings.dropout 2 : vit.encoder.layer.0.layernorm_before 3 : vit.encoder.layer.0.attention.attention.key 4 : vit.encoder.layer.0.attention.attention.value 5 : vit.encoder.layer.0.attention.attention.query 6 : vit.encoder.layer.0.attention.output.dense 7 : vit.encoder.layer.0.attention.output.dropout 8 : vit.encoder.layer.0.layernorm_after 9 : vit.encoder.layer.0.intermediate.dense 10 : vit.encoder.layer.0.intermediate.intermediate_act_fn 11 : vit.encoder.layer.0.output.dense 12 : vit.encoder.layer.0.output.dropout 13 : vit.encoder.layer.1.layernorm_before 14 : vit.encoder.layer.1.attention.attention.key 15 : vit.encoder.layer.1.attention.attention.value 16 : vit.encoder.layer.1.attention.attention.query 17 : vit.encoder.layer.1.attention.output.dense 18 : vit.encoder.layer.1.attention.output.dropout 19 : vit.encoder.layer.1.layernorm_after 20 : vit.encoder.layer.1.intermediate.dense 21 : vit.encoder.layer.1.intermediate.intermediate_act_fn 22 : vit.encoder.layer.1.output.dense 23 : vit.encoder.layer.1.output.dropout 24 : vit.encoder.layer.2.layernorm_before 25 : vit.encoder.layer.2.attention.attention.key 26 : vit.encoder.layer.2.attention.attention.value 27 : vit.encoder.layer.2.attention.attention.query 28 : vit.encoder.layer.2.attention.output.dense 29 : vit.encoder.layer.2.attention.output.dropout 30 : vit.encoder.layer.2.layernorm_after 31 : vit.encoder.layer.2.intermediate.dense 32 : vit.encoder.layer.2.intermediate.intermediate_act_fn 33 : vit.encoder.layer.2.output.dense 34 : vit.encoder.layer.2.output.dropout 35 : vit.encoder.layer.3.layernorm_before 36 : vit.encoder.layer.3.attention.attention.key 37 : vit.encoder.layer.3.attention.attention.value 38 : vit.encoder.layer.3.attention.attention.query 39 : vit.encoder.layer.3.attention.output.dense 40 : vit.encoder.layer.3.attention.output.dropout 41 : vit.encoder.layer.3.layernorm_after 42 : vit.encoder.layer.3.intermediate.dense 43 : vit.encoder.layer.3.intermediate.intermediate_act_fn 44 : vit.encoder.layer.3.output.dense 45 : vit.encoder.layer.3.output.dropout 46 : vit.encoder.layer.4.layernorm_before 47 : vit.encoder.layer.4.attention.attention.key 48 : vit.encoder.layer.4.attention.attention.value 49 : vit.encoder.layer.4.attention.attention.query 50 : vit.encoder.layer.4.attention.output.dense 51 : vit.encoder.layer.4.attention.output.dropout 52 : vit.encoder.layer.4.layernorm_after 53 : vit.encoder.layer.4.intermediate.dense 54 : vit.encoder.layer.4.intermediate.intermediate_act_fn 55 : vit.encoder.layer.4.output.dense 56 : vit.encoder.layer.4.output.dropout 57 : vit.encoder.layer.5.layernorm_before 58 : vit.encoder.layer.5.attention.attention.key 59 : vit.encoder.layer.5.attention.attention.value 60 : vit.encoder.layer.5.attention.attention.query 61 : vit.encoder.layer.5.attention.output.dense 62 : vit.encoder.layer.5.attention.output.dropout 63 : vit.encoder.layer.5.layernorm_after 64 : vit.encoder.layer.5.intermediate.dense 65 : vit.encoder.layer.5.intermediate.intermediate_act_fn 66 : vit.encoder.layer.5.output.dense 67 : vit.encoder.layer.5.output.dropout 68 : vit.encoder.layer.6.layernorm_before 69 : vit.encoder.layer.6.attention.attention.key 70 : vit.encoder.layer.6.attention.attention.value 71 : vit.encoder.layer.6.attention.attention.query 72 : vit.encoder.layer.6.attention.output.dense 73 : vit.encoder.layer.6.attention.output.dropout 74 : vit.encoder.layer.6.layernorm_after 75 : vit.encoder.layer.6.intermediate.dense 76 : vit.encoder.layer.6.intermediate.intermediate_act_fn 77 : vit.encoder.layer.6.output.dense 78 : vit.encoder.layer.6.output.dropout 79 : vit.encoder.layer.7.layernorm_before 80 : vit.encoder.layer.7.attention.attention.key 81 : vit.encoder.layer.7.attention.attention.value 82 : vit.encoder.layer.7.attention.attention.query 83 : vit.encoder.layer.7.attention.output.dense 84 : vit.encoder.layer.7.attention.output.dropout 85 : vit.encoder.layer.7.layernorm_after 86 : vit.encoder.layer.7.intermediate.dense 87 : vit.encoder.layer.7.intermediate.intermediate_act_fn 88 : vit.encoder.layer.7.output.dense 89 : vit.encoder.layer.7.output.dropout 90 : vit.encoder.layer.8.layernorm_before 91 : vit.encoder.layer.8.attention.attention.key 92 : vit.encoder.layer.8.attention.attention.value 93 : vit.encoder.layer.8.attention.attention.query 94 : vit.encoder.layer.8.attention.output.dense 95 : vit.encoder.layer.8.attention.output.dropout 96 : vit.encoder.layer.8.layernorm_after 97 : vit.encoder.layer.8.intermediate.dense 98 : vit.encoder.layer.8.intermediate.intermediate_act_fn 99 : vit.encoder.layer.8.output.dense 100 : vit.encoder.layer.8.output.dropout 101 : vit.encoder.layer.9.layernorm_before 102 : vit.encoder.layer.9.attention.attention.key 103 : vit.encoder.layer.9.attention.attention.value 104 : vit.encoder.layer.9.attention.attention.query 105 : vit.encoder.layer.9.attention.output.dense 106 : vit.encoder.layer.9.attention.output.dropout 107 : vit.encoder.layer.9.layernorm_after 108 : vit.encoder.layer.9.intermediate.dense 109 : vit.encoder.layer.9.intermediate.intermediate_act_fn 110 : vit.encoder.layer.9.output.dense 111 : vit.encoder.layer.9.output.dropout 112 : vit.encoder.layer.10.layernorm_before 113 : vit.encoder.layer.10.attention.attention.key 114 : vit.encoder.layer.10.attention.attention.value 115 : vit.encoder.layer.10.attention.attention.query 116 : vit.encoder.layer.10.attention.output.dense 117 : vit.encoder.layer.10.attention.output.dropout 118 : vit.encoder.layer.10.layernorm_after 119 : vit.encoder.layer.10.intermediate.dense 120 : vit.encoder.layer.10.intermediate.intermediate_act_fn 121 : vit.encoder.layer.10.output.dense 122 : vit.encoder.layer.10.output.dropout 123 : vit.encoder.layer.11.layernorm_before 124 : vit.encoder.layer.11.attention.attention.key 125 : vit.encoder.layer.11.attention.attention.value 126 : vit.encoder.layer.11.attention.attention.query 127 : vit.encoder.layer.11.attention.output.dense 128 : vit.encoder.layer.11.attention.output.dropout 129 : vit.encoder.layer.11.layernorm_after 130 : vit.encoder.layer.11.intermediate.dense 131 : vit.encoder.layer.11.intermediate.intermediate_act_fn 132 : vit.encoder.layer.11.output.dense 133 : vit.encoder.layer.11.output.dropout 134 : vit.layernorm 135 : classifier
In [ ]:
vit.encoder.layer.0.output.dropout
In [ ]:
# clean up
for h in hooks:
h.remove()
Visualize activations¶
Different Attention Calculation Types¶
What are the differnt attention calculation types¶
- eager
- sdpa
- flashattention
How to capture attention maps¶
- Here we use 'eager' method to capture attention maps, if we use 'sdpa' attention all the attention maps are converted to CUDA code
Visualize attention maps¶
In [ ]:
execution_order_2 = []
vit_attentions = OrderedDict()
vit_layer_inputs = OrderedDict()
def capture_attention_map_hook(name):
def attention_map_hook(module, inp, out):
execution_order_2.append(name)
vit_layer_inputs[name] = inp
vit_activations[name] = out
# return hook
# register hooks on all leaf modules
hooks = []
for name, module in vit_model.named_modules():
if len(list(module.children())) == 0: # leaf only; drop this if you want all
h = module.register_forward_hook(make_hook(name))
hooks.append(h)
# run a forward pass
x = vit_inputs
_ = vit_model(**x)
In [ ]:
import inspect
In [ ]:
for name, module in vit_model.named_modules():
if len(list(module.children())) == 0: # leaf only; drop this if you want all
# h = module.register_forward_hook(make_hook(name))
# hooks.append(h)
# print(name)
if "attention.output.dense" in name:
print(inspect.getdoc(module.forward))
break
# # run a forward pass
# x = vit_inputs
# _ = vit_model(**x)
Define the computation performed at every call.
Should be overridden by all subclasses.
.. note::
Although the recipe for forward pass needs to be defined within
this function, one should call the :class:`Module` instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
In [ ]:
import inspect
print(inspect.getsource(vit_model.vit.encoder.layer[0].attention.attention.forward))
def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(self.query(hidden_states))
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
context_layer, attention_probs = attention_interface(
self,
query_layer,
key_layer,
value_layer,
head_mask,
is_causal=self.is_causal,
scaling=self.scaling,
dropout=0.0 if not self.training else self.dropout_prob,
)
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
In [ ]:
vit_model.config._attn_implementation
Out[ ]:
'sdpa'