Natalia Gimelshein
Natalia Gimelshein
The existing implementation (`F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)`) accidentally puts the rest of the model in channels last format, which is bad for fp32/older...
cc @csarofeen for regressions in backward, my understanding was that (at least for not-channels-last) aot is a win. @rwightman I agree with your points, and btw if TorchScript is important...
Let me open a new issue to review all ts/aot problems in 1.12 Edit: https://github.com/rwightman/pytorch-image-models/discussions/1350
Thanks for the data, Ross! I'm really surprised that channels-last LN is so slow - after all, it's just calling existing layer norm kernels directly, and it shouldn't be using...
>Would the Apex LN provide any benfits here if combined with the needed permute to/from NHWC? Just to be clear, for channels last inputs permutation is a very fast op...
Yes, layer_norm is force-cast to fp32 by amp (tbh, I don't know if it's strictly necessary or is it out of abundance of caution, I've heard some stories where it...
Sorry, wouldn't `_cast_if_autocast_enabled` cast all inputs to fp32? I couldn't find this function.
It is kind of a problem for triton, because currently it's set up to load compiled code to the current device on the first run, and this logic will need...
@voznesenskym is there a test forthcoming or should we land this?
Why does it have to be ugly? It could be as simple as ``` def fn(x): return x/3 opt_fn = torch.compile(fn) x=torch.randn(4, device="cuda") opt_fn(x) x=torch.randn(4, device="cuda:1") opt_fn(x) torch.cuda.synchronize() ```