Natalia Gimelshein

Results 90 comments of Natalia Gimelshein

The existing implementation (`F.layer_norm(x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)`) accidentally puts the rest of the model in channels last format, which is bad for fp32/older...

cc @csarofeen for regressions in backward, my understanding was that (at least for not-channels-last) aot is a win. @rwightman I agree with your points, and btw if TorchScript is important...

Let me open a new issue to review all ts/aot problems in 1.12 Edit: https://github.com/rwightman/pytorch-image-models/discussions/1350

Thanks for the data, Ross! I'm really surprised that channels-last LN is so slow - after all, it's just calling existing layer norm kernels directly, and it shouldn't be using...

>Would the Apex LN provide any benfits here if combined with the needed permute to/from NHWC? Just to be clear, for channels last inputs permutation is a very fast op...

Yes, layer_norm is force-cast to fp32 by amp (tbh, I don't know if it's strictly necessary or is it out of abundance of caution, I've heard some stories where it...

Sorry, wouldn't `_cast_if_autocast_enabled` cast all inputs to fp32? I couldn't find this function.

It is kind of a problem for triton, because currently it's set up to load compiled code to the current device on the first run, and this logic will need...

@voznesenskym is there a test forthcoming or should we land this?

Why does it have to be ugly? It could be as simple as ``` def fn(x): return x/3 opt_fn = torch.compile(fn) x=torch.randn(4, device="cuda") opt_fn(x) x=torch.randn(4, device="cuda:1") opt_fn(x) torch.cuda.synchronize() ```