Capture Activation and Attention Maps using PyTorch Hooks¶

Computation Graph¶

PyTorch Hooks¶

Why we use pytorch hooks over other methods¶

How to capture all activations¶

In [1]:
# prompt: create a torch tensor with scalar value 10

import torch


from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

device = "cpu"
vit_processor = ViTImageProcessor.from_pretrained('vit-base-patch16-224')
vit_model = ViTForImageClassification.from_pretrained(
    'vit-base-patch16-224',
    device_map = device
    )
/opt/homebrew/Caskroom/miniconda/base/envs/sasika-test/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
In [2]:
vit_inputs = vit_processor(images=image, return_tensors="pt")
vit_outputs = vit_model(**vit_inputs,
                    ouptut_hidden_states=True,
                    output_attentions=True
                    )
vit_logits = vit_outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = vit_logits.argmax(-1).item()
print("Predicted class:", vit_model.config.id2label[predicted_class_idx])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[2], line 2
      1 vit_inputs = vit_processor(images=image, return_tensors="pt")
----> 2 vit_outputs = vit_model(**vit_inputs,
      3                     ouptut_hidden_states=True,
      4                     output_attentions=True
      5                     )
      6 vit_logits = vit_outputs.logits
      7 # model predicts one of the 1000 ImageNet classes

File /opt/homebrew/Caskroom/miniconda/base/envs/sasika-test/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File /opt/homebrew/Caskroom/miniconda/base/envs/sasika-test/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

TypeError: ViTForImageClassification.forward() got an unexpected keyword argument 'ouptut_hidden_states'
In [ ]:
image
Out[ ]:
No description has been provided for this image
In [ ]:
vit_model
Out[ ]:
ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): ViTOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
      )
    )
    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  )
  (classifier): Linear(in_features=768, out_features=1000, bias=True)
)

I am using Ordered Dictionary to store the activations to keep track of the order of the forward pass

In [ ]:
from collections import OrderedDict
In [ ]:
vit_activations = OrderedDict()
In [ ]:
for name, layer in vit_model._modules.items():
  print(name)
vit
classifier
In [ ]:
type(vit_model._modules['vit'])
Out[ ]:
transformers.models.vit.modeling_vit.ViTModel

Get the execution order¶

In [ ]:
vit_model.eval()
Out[ ]:
ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): ViTOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
      )
    )
    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  )
  (classifier): Linear(in_features=768, out_features=1000, bias=True)
)
In [ ]:
execution_order = []
vit_activations = OrderedDict()
vit_layer_inputs = OrderedDict()

def make_hook(name):
  def hook(module, inp, out):
    execution_order.append(name)
    vit_layer_inputs[name] = inp
    vit_activations[name] = out

  return hook

# register hooks on all leaf modules
hooks = []
for name, module in vit_model.named_modules():
  if len(list(module.children())) == 0: # leaf only; drop this if you want all
    h = module.register_forward_hook(make_hook(name))
    hooks.append(h)

# run a forward pass
x = vit_inputs
_ = vit_model(**x)
In [ ]:
len(execution_order)
Out[ ]:
136
In [ ]:
vit_inputs;
In [ ]:
len(vit_layer_inputs)
Out[ ]:
136
In [ ]:
vit_model
Out[ ]:
ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): ViTOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        )
      )
    )
    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  )
  (classifier): Linear(in_features=768, out_features=1000, bias=True)
)
In [ ]:
vit_model.vit.encoder.layer[11].attention.attention
Out[ ]:
ViTSelfAttention(
  (query): Linear(in_features=768, out_features=768, bias=True)
  (key): Linear(in_features=768, out_features=768, bias=True)
  (value): Linear(in_features=768, out_features=768, bias=True)
)
In [ ]:
 
In [ ]:
from inspect import inspect 

attention_modules = [m for m in vit_model.modules() if isinstance(m, Vitttention)]
In [ ]:
 
In [ ]:
selected_layer = execution_order[6]
vit_layer_inputs[selected_layer][0].shape
Out[ ]:
torch.Size([1, 197, 768])
In [ ]:
for i in range(0, 15):
  selected_layer = execution_order[i]
  shape = vit_layer_inputs[selected_layer][0].shape

  print(i, " : ", selected_layer , " : ", shape )
0  :  vit.embeddings.patch_embeddings.projection  :  torch.Size([1, 3, 224, 224])
1  :  vit.embeddings.dropout  :  torch.Size([1, 197, 768])
2  :  vit.encoder.layer.0.layernorm_before  :  torch.Size([1, 197, 768])
3  :  vit.encoder.layer.0.attention.attention.key  :  torch.Size([1, 197, 768])
4  :  vit.encoder.layer.0.attention.attention.value  :  torch.Size([1, 197, 768])
5  :  vit.encoder.layer.0.attention.attention.query  :  torch.Size([1, 197, 768])
6  :  vit.encoder.layer.0.attention.output.dense  :  torch.Size([1, 197, 768])
7  :  vit.encoder.layer.0.attention.output.dropout  :  torch.Size([1, 197, 768])
8  :  vit.encoder.layer.0.layernorm_after  :  torch.Size([1, 197, 768])
9  :  vit.encoder.layer.0.intermediate.dense  :  torch.Size([1, 197, 768])
10  :  vit.encoder.layer.0.intermediate.intermediate_act_fn  :  torch.Size([1, 197, 3072])
11  :  vit.encoder.layer.0.output.dense  :  torch.Size([1, 197, 3072])
12  :  vit.encoder.layer.0.output.dropout  :  torch.Size([1, 197, 768])
13  :  vit.encoder.layer.1.layernorm_before  :  torch.Size([1, 197, 768])
14  :  vit.encoder.layer.1.attention.attention.key  :  torch.Size([1, 197, 768])
In [ ]:
torch.equal(vit_layer_inputs[execution_order[1]][0],
            vit_layer_inputs[execution_order[2]][0]
            )

# if model.train() dropout acts as an identity function, elif model.eval() dropout acts as a normal
Out[ ]:
True
In [ ]:
len(execution_order)
Out[ ]:
136
In [ ]:
for i, layer in enumerate(execution_order):
  print(f"{i} : {layer}")
0 : vit.embeddings.patch_embeddings.projection
1 : vit.embeddings.dropout
2 : vit.encoder.layer.0.layernorm_before
3 : vit.encoder.layer.0.attention.attention.key
4 : vit.encoder.layer.0.attention.attention.value
5 : vit.encoder.layer.0.attention.attention.query
6 : vit.encoder.layer.0.attention.output.dense
7 : vit.encoder.layer.0.attention.output.dropout
8 : vit.encoder.layer.0.layernorm_after
9 : vit.encoder.layer.0.intermediate.dense
10 : vit.encoder.layer.0.intermediate.intermediate_act_fn
11 : vit.encoder.layer.0.output.dense
12 : vit.encoder.layer.0.output.dropout
13 : vit.encoder.layer.1.layernorm_before
14 : vit.encoder.layer.1.attention.attention.key
15 : vit.encoder.layer.1.attention.attention.value
16 : vit.encoder.layer.1.attention.attention.query
17 : vit.encoder.layer.1.attention.output.dense
18 : vit.encoder.layer.1.attention.output.dropout
19 : vit.encoder.layer.1.layernorm_after
20 : vit.encoder.layer.1.intermediate.dense
21 : vit.encoder.layer.1.intermediate.intermediate_act_fn
22 : vit.encoder.layer.1.output.dense
23 : vit.encoder.layer.1.output.dropout
24 : vit.encoder.layer.2.layernorm_before
25 : vit.encoder.layer.2.attention.attention.key
26 : vit.encoder.layer.2.attention.attention.value
27 : vit.encoder.layer.2.attention.attention.query
28 : vit.encoder.layer.2.attention.output.dense
29 : vit.encoder.layer.2.attention.output.dropout
30 : vit.encoder.layer.2.layernorm_after
31 : vit.encoder.layer.2.intermediate.dense
32 : vit.encoder.layer.2.intermediate.intermediate_act_fn
33 : vit.encoder.layer.2.output.dense
34 : vit.encoder.layer.2.output.dropout
35 : vit.encoder.layer.3.layernorm_before
36 : vit.encoder.layer.3.attention.attention.key
37 : vit.encoder.layer.3.attention.attention.value
38 : vit.encoder.layer.3.attention.attention.query
39 : vit.encoder.layer.3.attention.output.dense
40 : vit.encoder.layer.3.attention.output.dropout
41 : vit.encoder.layer.3.layernorm_after
42 : vit.encoder.layer.3.intermediate.dense
43 : vit.encoder.layer.3.intermediate.intermediate_act_fn
44 : vit.encoder.layer.3.output.dense
45 : vit.encoder.layer.3.output.dropout
46 : vit.encoder.layer.4.layernorm_before
47 : vit.encoder.layer.4.attention.attention.key
48 : vit.encoder.layer.4.attention.attention.value
49 : vit.encoder.layer.4.attention.attention.query
50 : vit.encoder.layer.4.attention.output.dense
51 : vit.encoder.layer.4.attention.output.dropout
52 : vit.encoder.layer.4.layernorm_after
53 : vit.encoder.layer.4.intermediate.dense
54 : vit.encoder.layer.4.intermediate.intermediate_act_fn
55 : vit.encoder.layer.4.output.dense
56 : vit.encoder.layer.4.output.dropout
57 : vit.encoder.layer.5.layernorm_before
58 : vit.encoder.layer.5.attention.attention.key
59 : vit.encoder.layer.5.attention.attention.value
60 : vit.encoder.layer.5.attention.attention.query
61 : vit.encoder.layer.5.attention.output.dense
62 : vit.encoder.layer.5.attention.output.dropout
63 : vit.encoder.layer.5.layernorm_after
64 : vit.encoder.layer.5.intermediate.dense
65 : vit.encoder.layer.5.intermediate.intermediate_act_fn
66 : vit.encoder.layer.5.output.dense
67 : vit.encoder.layer.5.output.dropout
68 : vit.encoder.layer.6.layernorm_before
69 : vit.encoder.layer.6.attention.attention.key
70 : vit.encoder.layer.6.attention.attention.value
71 : vit.encoder.layer.6.attention.attention.query
72 : vit.encoder.layer.6.attention.output.dense
73 : vit.encoder.layer.6.attention.output.dropout
74 : vit.encoder.layer.6.layernorm_after
75 : vit.encoder.layer.6.intermediate.dense
76 : vit.encoder.layer.6.intermediate.intermediate_act_fn
77 : vit.encoder.layer.6.output.dense
78 : vit.encoder.layer.6.output.dropout
79 : vit.encoder.layer.7.layernorm_before
80 : vit.encoder.layer.7.attention.attention.key
81 : vit.encoder.layer.7.attention.attention.value
82 : vit.encoder.layer.7.attention.attention.query
83 : vit.encoder.layer.7.attention.output.dense
84 : vit.encoder.layer.7.attention.output.dropout
85 : vit.encoder.layer.7.layernorm_after
86 : vit.encoder.layer.7.intermediate.dense
87 : vit.encoder.layer.7.intermediate.intermediate_act_fn
88 : vit.encoder.layer.7.output.dense
89 : vit.encoder.layer.7.output.dropout
90 : vit.encoder.layer.8.layernorm_before
91 : vit.encoder.layer.8.attention.attention.key
92 : vit.encoder.layer.8.attention.attention.value
93 : vit.encoder.layer.8.attention.attention.query
94 : vit.encoder.layer.8.attention.output.dense
95 : vit.encoder.layer.8.attention.output.dropout
96 : vit.encoder.layer.8.layernorm_after
97 : vit.encoder.layer.8.intermediate.dense
98 : vit.encoder.layer.8.intermediate.intermediate_act_fn
99 : vit.encoder.layer.8.output.dense
100 : vit.encoder.layer.8.output.dropout
101 : vit.encoder.layer.9.layernorm_before
102 : vit.encoder.layer.9.attention.attention.key
103 : vit.encoder.layer.9.attention.attention.value
104 : vit.encoder.layer.9.attention.attention.query
105 : vit.encoder.layer.9.attention.output.dense
106 : vit.encoder.layer.9.attention.output.dropout
107 : vit.encoder.layer.9.layernorm_after
108 : vit.encoder.layer.9.intermediate.dense
109 : vit.encoder.layer.9.intermediate.intermediate_act_fn
110 : vit.encoder.layer.9.output.dense
111 : vit.encoder.layer.9.output.dropout
112 : vit.encoder.layer.10.layernorm_before
113 : vit.encoder.layer.10.attention.attention.key
114 : vit.encoder.layer.10.attention.attention.value
115 : vit.encoder.layer.10.attention.attention.query
116 : vit.encoder.layer.10.attention.output.dense
117 : vit.encoder.layer.10.attention.output.dropout
118 : vit.encoder.layer.10.layernorm_after
119 : vit.encoder.layer.10.intermediate.dense
120 : vit.encoder.layer.10.intermediate.intermediate_act_fn
121 : vit.encoder.layer.10.output.dense
122 : vit.encoder.layer.10.output.dropout
123 : vit.encoder.layer.11.layernorm_before
124 : vit.encoder.layer.11.attention.attention.key
125 : vit.encoder.layer.11.attention.attention.value
126 : vit.encoder.layer.11.attention.attention.query
127 : vit.encoder.layer.11.attention.output.dense
128 : vit.encoder.layer.11.attention.output.dropout
129 : vit.encoder.layer.11.layernorm_after
130 : vit.encoder.layer.11.intermediate.dense
131 : vit.encoder.layer.11.intermediate.intermediate_act_fn
132 : vit.encoder.layer.11.output.dense
133 : vit.encoder.layer.11.output.dropout
134 : vit.layernorm
135 : classifier
In [ ]:
vit.encoder.layer.0.output.dropout
In [ ]:
# clean up
for h in hooks:
  h.remove()

Visualize activations¶

Different Attention Calculation Types¶

What are the differnt attention calculation types¶

  • eager
  • sdpa
  • flashattention

How to capture attention maps¶

  • Here we use 'eager' method to capture attention maps, if we use 'sdpa' attention all the attention maps are converted to CUDA code

Visualize attention maps¶

In [ ]:
execution_order_2 = []
vit_attentions = OrderedDict()
vit_layer_inputs = OrderedDict()

def capture_attention_map_hook(name):
  def attention_map_hook(module, inp, out):
    execution_order_2.append(name)
    vit_layer_inputs[name] = inp
    vit_activations[name] = out

  # return hook

# register hooks on all leaf modules
hooks = []
for name, module in vit_model.named_modules():
  if len(list(module.children())) == 0: # leaf only; drop this if you want all
    h = module.register_forward_hook(make_hook(name))
    hooks.append(h)

# run a forward pass
x = vit_inputs
_ = vit_model(**x)
In [ ]:
import inspect
In [ ]:
for name, module in vit_model.named_modules():
    

    if len(list(module.children())) == 0: # leaf only; drop this if you want all
        # h = module.register_forward_hook(make_hook(name))
        # hooks.append(h)
        # print(name)
        if "attention.output.dense" in name:
            print(inspect.getdoc(module.forward))
            break

# # run a forward pass
# x = vit_inputs
# _ = vit_model(**x)
Define the computation performed at every call.

Should be overridden by all subclasses.

.. note::
    Although the recipe for forward pass needs to be defined within
    this function, one should call the :class:`Module` instance afterwards
    instead of this since the former takes care of running the
    registered hooks while the latter silently ignores them.
In [ ]:
import inspect

print(inspect.getsource(vit_model.vit.encoder.layer[0].attention.attention.forward))
    def forward(
        self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
    ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))
        query_layer = self.transpose_for_scores(self.query(hidden_states))

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and output_attentions:
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        context_layer, attention_probs = attention_interface(
            self,
            query_layer,
            key_layer,
            value_layer,
            head_mask,
            is_causal=self.is_causal,
            scaling=self.scaling,
            dropout=0.0 if not self.training else self.dropout_prob,
        )

        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.reshape(new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs

In [ ]:
vit_model.config._attn_implementation
Out[ ]:
'sdpa'