nested-transformer
nested-transformer copied to clipboard
Is the bulk of the convergence speedup from the ConvPool?!
Reading https://github.com/google-research/nested-transformer/blob/main/models/nest_net.py#L107 and https://github.com/google-research/nested-transformer/blob/main/libml/self_attention.py#L266 there's a 3x3 convolution at each change in resolution (so 2 total).
Did you get a chance to do any ablation studies on this?
E.g. comparing using the current ConvPool (Conv2d 3x3, LayerNorm, MaxPool 2x2) vs e.g. a patch embed style convolution (Conv2d 2x2, stride=2), which wouldn't have as much "power" for modeling the image contents?
Or ablating the early transformer blocks? So it would be just PatchEmbed -> ConvPool -> ConvPool -> Transformers -> head?