vit-explain icon indicating copy to clipboard operation
vit-explain copied to clipboard

result = torch.eye(attentions[0].size(-1)) problem

Open rojinakashefi opened this issue 1 year ago • 7 comments

/vit-explain/vit_rollout.py", line 10, in rollout result = torch.eye(attentions[0].size(-1)) IndexError: list index out of range

rojinakashefi avatar Jul 15 '23 06:07 rojinakashefi

same problem here, did you solve it? do you know how to solve it? thank you in advance

qiaoyu1002 avatar Aug 15 '23 09:08 qiaoyu1002

The issue happens because apparently the vision_transformer got changed.

Looking at https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py in lines 285-295:

        if self.fused_attn:
            x_attn = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x_attn = attn @ v

Since this defaults to True (as far as i can tell) you skip the attn_drop() call in each forward pass. Thats why no attentions are collected --> attentions is an empty list.

As a quick and dirty fix you can do something like:

for block in model.blocks:
            block.attn.fused_attn = False

raphaelspiekermann avatar Aug 17 '23 09:08 raphaelspiekermann

for block in model.blocks:
            block.attn.fused_attn = False

How could you access the model blocks or was it just an example of code?

AttributeError: 'ViTModel' object has no attribute 'blocks'

Song-z-h avatar Oct 01 '23 09:10 Song-z-h

A quick solution is to rollback the version with pip install timm==0.6.13

orientino avatar Nov 24 '23 10:11 orientino

I ran into the same problem. You can keep using the current version of timm, but you'll need to set an environment variable to prevent the fused attention. Check here:

https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/config.py#L31

# use torch.scaled_dot_product_attention where possible
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if 'TIMM_FUSED_ATTN' in os.environ:
    _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])
else:
    _USE_FUSED_ATTN = 1  # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)

ajmeek avatar Apr 16 '24 00:04 ajmeek

@jacobgil @qiaoyu1002 @rojinakashefi, I think the error "list index out of range" occurs because the code of the vision transformer in the timm repo changed. Therefore, the original code of this repo is trying to register the hook on a dropout layer, and probably pyTorch does not allow that. So, no hook is getting registered, and so, no hook function is being called, so the list is empty, so the error. This error can be mitigated by registering the hook on a previous qkv linear layer. After that we have to manually inject the code and reach up to the attention drop layout layer. Here is my attempt to solve the problem. I have implemented the hook class as the context manager. Therefore, when we leave the hook class, we will remove the hook, as it might leak memory. The vision transformer code is referred from the timm repo.

class VITAttentionRollout:
    def __init__(self, model, attention_layer_name='qkv', head_fusion="mean",discard_ratio=0.9):
        self.model = model
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio
        self.attentions = []
        self.hook = []
        for name, module in self.model.named_modules():
            if attention_layer_name in name:
                self.hook.append(module.register_forward_hook(self.get_attention))

    def get_attention(self, module, input, output):
        self.attentions.append(output.detach().clone())
    def __enter__(self, *args):
        return self
    def __exit__(self, *args):
        for handle in self.hook:
            handle.remove()

B = x.unsqueeze(0).shape[0] #batch size
C = vit_model.embed_dim  #embedding dimension
num_heads = vit_model.blocks[1].attn.num_heads  
N = 577   # no of tokens
print(B, N, C, num_heads)
scale= (C//num_heads) ** -0.5
scale

After this code, call the hook class in a context manager.

with VITAttentionRollout(vit_model, discard_ratio=0.9) as hook:
        with torch.no_grad():
            output = hook.model.eval()(x.unsqueeze(0))
        attn_list = []
        for attentions in hook.attentions:
            qkv = attentions
            qkv = qkv.reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv.unbind(0)
            q = q * scale
            attn = (q @ k.transpose(-2, -1))
            attn = attn.softmax(dim=-1)
            attn_list.append(attn)
            out = attn @ v
        print("Attention shape "+str(attn.shape))
        count = 0
        for name, module in vit_model.named_modules():
            if 'qkv' in name:
                count+=1
        print("\nTotal count is "+str(count)+"\n") #equal to no of layers

print("Attention shape "+str(attn.shape)+"\n")

mask = rollout(attn_list, hook.discard_ratio, hook.head_fusion)

vivekh2000 avatar May 16 '24 13:05 vivekh2000

The issue happens because apparently the vision_transformer got changed.

Looking at https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py in lines 285-295:

        if self.fused_attn:
            x_attn = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x_attn = attn @ v

Since this defaults to True (as far as i can tell) you skip the attn_drop() call in each forward pass. Thats why no attentions are collected --> attentions is an empty list.

As a quick and dirty fix you can do something like:

for block in model.blocks:
            block.attn.fused_attn = False

How can to do that if using pytorch's vision transformer? cause i got same error

def load_model(weight_path, device):
    model = torchvision.models.vit_l_16()
    model.heads = torch.nn.Linear(
        in_features=model.heads.head.in_features, out_features=2)
    model.load_state_dict(torch.load(
        weight_path, map_location=torch.device(device)))
    model.eval()

    return model

model = load_model(weight_path=path, device=device)
img = preprocess_image(image_path)
attention_rollout = VITAttentionRollout(model, head_fusion="max", discard_ratio=0.9)
mask = attention_rollout(img)

error

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[48], [line 2](vscode-notebook-cell:?execution_count=48&line=2)
      [1](vscode-notebook-cell:?execution_count=48&line=1) attention_rollout = VITAttentionRollout(model, head_fusion="max", discard_ratio=0.9)
----> [2](vscode-notebook-cell:?execution_count=48&line=2) mask = attention_rollout(img)

Cell In[46], [line 130](vscode-notebook-cell:?execution_count=46&line=130)
    [127](vscode-notebook-cell:?execution_count=46&line=127) with torch.no_grad():
    [128](vscode-notebook-cell:?execution_count=46&line=128)     output = self.model(input_tensor)
--> [130](vscode-notebook-cell:?execution_count=46&line=130) return rollout(self.attentions, self.discard_ratio, self.head_fusion)

Cell In[46], [line 75](vscode-notebook-cell:?execution_count=46&line=75)
     [74](vscode-notebook-cell:?execution_count=46&line=74) def rollout(attentions, discard_ratio, head_fusion):
---> [75](vscode-notebook-cell:?execution_count=46&line=75)     result = torch.eye(attentions[0].size(-1))
     [76](vscode-notebook-cell:?execution_count=46&line=76)     with torch.no_grad():
     [77](vscode-notebook-cell:?execution_count=46&line=77)         for attention in attentions:

IndexError: list index out of range

allail-qadrillah avatar Jul 05 '24 07:07 allail-qadrillah