Ross Wightman
Ross Wightman
@collinmccarthy you are correct on all counts, I didn't explicitly support this when I added foward_intermediates() as I was focused on getting it working / integrated and then didn't revisit....
@dzhulgakov Can we bump this in priority? More and more networks would benefit from this. The lack of a native impl appears to be governing some design decisions in cnn-transformer...
> I am not sure why you tagged me here, but why won't you just use [InstanceNorm](https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html)? Sorry, I was a bit quick on the @ autocomplete there :/ InstanceNorm...
@vadimkantorov yeah, I'd really want to have channels_last support, it would be much less useful without.
@ngimel as in this ? `F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)`? I've spent some time running through iterations of this, and while the above is...
So, back to this thread with the original ask, LayerNorm w/ arbitrary axis. @ngimel demo'd some hacks that can be used with current PyTorch codegen to get some better performance...
@brianhou0208 thanks for the work, and looks like a good job getting it in shape. I took a closer look using your code but I have some doubts about this...
```python def single_attn(self, x: torch.Tensor) -> torch.Tensor: k, q, v = torch.split(self.kqv(x), self.emb, dim=-1) if not torch.jit.is_scripting(): with torch.autocast(device_type=v.device.type, enabled=False): y = self._attn_impl(k, q, v) else: y = self._attn_impl(k, q,...
@brianhou0208 I don't know if not having the input conv is a 'feature', my very first vit impl here, before the official JAX code was released that used the Conv2D...
@brianhou0208 known issue, along with #1597 ... don't really have a good way to handle the channel count changes in these pruned models, I feel it'd be quite a bit...