1day_1paper
1day_1paper copied to clipboard
[74] Not All Patches are What You Need: Expediting Vision Transformers via Token Reorganizations (EVIT)
attention 좀 더 효율적으로 수행하자. (필요한 것만 쓰자!)
Token Reorganization
image token 들을 identify (background or object) 하고, fusing 하는 방법.
Attentive Token Identification
n 을 ViT 의 encoder 의 input token 개수라 하자. [CLS] token 과 나머지 token 간의 관계는 일반적으로 사용하는 attention 에서 값들을 가져올 수 있다. 관계가 많이 연결되는 애들이 중요한 애들 아닐까? 하는 motivation!
일반적으로 [CLS] token 구하는 식 한 번만 더 보고 가자.
x_class == [CLS] token
a == attention vector
attentive 를 구하기 위해서, attn = mean(attn)
을 수행해 준다. (attention head 는 12 개니까 평균을 내준다.)
이 값을 갖고, top-k 개를 attentive 로 둔다.
이것 만으로는 부족하다.
DeiT-S 에서 (4, 7, 10) layer 에서 inattentive token 들을 지워나가니, acc 가 확확 떨어지더라.
그래서 혼합하는 방법을 생각해 냈다.
InAttentive Token Fusion
그냥, inattentive 한 애들은 weighted average 를 해서 다음 layer 로 넘겨준다.
즉, block 지날 때마다 patch 가 줄어드는 거다.
code
from https://github.com/youweiliang/evit/blob/0999f090edbcb6dea095546b5faeb2750beaf88b/vision_transformer.py#L307-L314
cls_attn = attn[:, :, 0, 1:] # [B, H, N-1]
cls_attn = cls_attn.mean(dim=1) # [B, N-1]
_, idx = torch.topk(cls_attn, left_tokens, dim=1, largest=True, sorted=True) # [B, left_tokens]
# cls_idx = torch.zeros(B, 1, dtype=idx.dtype, device=idx.device)
# index = torch.cat([cls_idx, idx + 1], dim=1)
index = idx.unsqueeze(-1).expand(-1, -1, C) # [B, left_tokens, C]
return x, index, idx, cls_attn, left_tokens
from https://github.com/youweiliang/evit/blob/0999f090edbcb6dea095546b5faeb2750beaf88b/vision_transformer.py#L350-L358
if self.fuse_token:
compl = complement_idx(idx, N - 1) # [B, N-1-left_tokens]
non_topk = torch.gather(non_cls, dim=1, index=compl.unsqueeze(-1).expand(-1, -1, C)) # [B, N-1-left_tokens, C]
non_topk_attn = torch.gather(cls_attn, dim=1, index=compl) # [B, N-1-left_tokens]
extra_token = torch.sum(non_topk * non_topk_attn.unsqueeze(-1), dim=1, keepdim=True) # [B, 1, C]
x = torch.cat([x[:, 0:1], x_others, extra_token], dim=1)
else:
x = torch.cat([x[:, 0:1], x_others], dim=1)
필자 의견
- global 하게 보고, top-k 를 뽑아내는 게 성능이 더 좋지 않을까 생각해 본다.
- pruning 도 global 하게 pruning 하는게 잘 되지 않았는가. (structured 든, unstructured 든)
- hierarchical transformer 에 대한 성능이 어떨 지 궁금하다. (swin 등)
Result
visualize
inattentive token 들을 visualize 하면 다음과 같다.
ImageNet
모델 별 성능.
DeIT-S 에 inattentive fusion 하냐, 안하냐 에 따른 차이
pretrained DeiT-S 를 oracle 로 두어서 실험해 봄. 일종의 distillation 처럼 생각할 수 있음.
DeiT-S 가 무슨 token 이 중요한 지만 뽑아서 알려주는 것임.
Dynamic ViT 와의 성능 비교.
pretrained ==> model initialize 를 pretrained 로 했다는 뜻