vit-explain
vit-explain copied to clipboard
result = torch.eye(attentions[0].size(-1)) problem
/vit-explain/vit_rollout.py", line 10, in rollout result = torch.eye(attentions[0].size(-1)) IndexError: list index out of range
same problem here, did you solve it? do you know how to solve it? thank you in advance
The issue happens because apparently the vision_transformer got changed.
Looking at https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py in lines 285-295:
if self.fused_attn:
x_attn = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_attn = attn @ v
Since this defaults to True (as far as i can tell) you skip the attn_drop() call in each forward pass. Thats why no attentions are collected --> attentions is an empty list.
As a quick and dirty fix you can do something like:
for block in model.blocks:
block.attn.fused_attn = False
for block in model.blocks: block.attn.fused_attn = False
How could you access the model blocks or was it just an example of code?
AttributeError: 'ViTModel' object has no attribute 'blocks'
A quick solution is to rollback the version with pip install timm==0.6.13
I ran into the same problem. You can keep using the current version of timm, but you'll need to set an environment variable to prevent the fused attention. Check here:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/config.py#L31
# use torch.scaled_dot_product_attention where possible
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if 'TIMM_FUSED_ATTN' in os.environ:
_USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN'])
else:
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on (for tested use), 2 == on (for experimental use)
@jacobgil @qiaoyu1002 @rojinakashefi, I think the error "list index out of range" occurs because the code of the vision transformer in the timm repo changed. Therefore, the original code of this repo is trying to register the hook on a dropout layer, and probably pyTorch does not allow that. So, no hook is getting registered, and so, no hook function is being called, so the list is empty, so the error. This error can be mitigated by registering the hook on a previous qkv linear layer. After that we have to manually inject the code and reach up to the attention drop layout layer. Here is my attempt to solve the problem. I have implemented the hook class as the context manager. Therefore, when we leave the hook class, we will remove the hook, as it might leak memory. The vision transformer code is referred from the timm repo.
class VITAttentionRollout:
def __init__(self, model, attention_layer_name='qkv', head_fusion="mean",discard_ratio=0.9):
self.model = model
self.head_fusion = head_fusion
self.discard_ratio = discard_ratio
self.attentions = []
self.hook = []
for name, module in self.model.named_modules():
if attention_layer_name in name:
self.hook.append(module.register_forward_hook(self.get_attention))
def get_attention(self, module, input, output):
self.attentions.append(output.detach().clone())
def __enter__(self, *args):
return self
def __exit__(self, *args):
for handle in self.hook:
handle.remove()
B = x.unsqueeze(0).shape[0] #batch size
C = vit_model.embed_dim #embedding dimension
num_heads = vit_model.blocks[1].attn.num_heads
N = 577 # no of tokens
print(B, N, C, num_heads)
scale= (C//num_heads) ** -0.5
scale
After this code, call the hook class in a context manager.
with VITAttentionRollout(vit_model, discard_ratio=0.9) as hook:
with torch.no_grad():
output = hook.model.eval()(x.unsqueeze(0))
attn_list = []
for attentions in hook.attentions:
qkv = attentions
qkv = qkv.reshape(B, N, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = q * scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn_list.append(attn)
out = attn @ v
print("Attention shape "+str(attn.shape))
count = 0
for name, module in vit_model.named_modules():
if 'qkv' in name:
count+=1
print("\nTotal count is "+str(count)+"\n") #equal to no of layers
print("Attention shape "+str(attn.shape)+"\n")
mask = rollout(attn_list, hook.discard_ratio, hook.head_fusion)
The issue happens because apparently the vision_transformer got changed.
Looking at https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py in lines 285-295:
if self.fused_attn: x_attn = F.scaled_dot_product_attention( q, k, v, dropout_p=self.attn_drop.p, ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x_attn = attn @ v
Since this defaults to True (as far as i can tell) you skip the attn_drop() call in each forward pass. Thats why no attentions are collected --> attentions is an empty list.
As a quick and dirty fix you can do something like:
for block in model.blocks: block.attn.fused_attn = False
How can to do that if using pytorch's vision transformer? cause i got same error
def load_model(weight_path, device):
model = torchvision.models.vit_l_16()
model.heads = torch.nn.Linear(
in_features=model.heads.head.in_features, out_features=2)
model.load_state_dict(torch.load(
weight_path, map_location=torch.device(device)))
model.eval()
return model
model = load_model(weight_path=path, device=device)
img = preprocess_image(image_path)
attention_rollout = VITAttentionRollout(model, head_fusion="max", discard_ratio=0.9)
mask = attention_rollout(img)
error
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[48], [line 2](vscode-notebook-cell:?execution_count=48&line=2)
[1](vscode-notebook-cell:?execution_count=48&line=1) attention_rollout = VITAttentionRollout(model, head_fusion="max", discard_ratio=0.9)
----> [2](vscode-notebook-cell:?execution_count=48&line=2) mask = attention_rollout(img)
Cell In[46], [line 130](vscode-notebook-cell:?execution_count=46&line=130)
[127](vscode-notebook-cell:?execution_count=46&line=127) with torch.no_grad():
[128](vscode-notebook-cell:?execution_count=46&line=128) output = self.model(input_tensor)
--> [130](vscode-notebook-cell:?execution_count=46&line=130) return rollout(self.attentions, self.discard_ratio, self.head_fusion)
Cell In[46], [line 75](vscode-notebook-cell:?execution_count=46&line=75)
[74](vscode-notebook-cell:?execution_count=46&line=74) def rollout(attentions, discard_ratio, head_fusion):
---> [75](vscode-notebook-cell:?execution_count=46&line=75) result = torch.eye(attentions[0].size(-1))
[76](vscode-notebook-cell:?execution_count=46&line=76) with torch.no_grad():
[77](vscode-notebook-cell:?execution_count=46&line=77) for attention in attentions:
IndexError: list index out of range