equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Support for data format NHWC in convolutional layers

Open MathiesW opened this issue 1 year ago • 5 comments

Hello, First of all, thank you for your great work! I noticed that Equinoxs' convolutional layers use NCHW data format, which, according to Nvidia Tensor Layout, is not optimal for Tensor cores on modern day GPUs. Changing the (default) data format from NCHW to NHWC hence may benefit performance of convolutional operations.

MathiesW avatar Jul 28 '23 14:07 MathiesW

I think this kind of optimisation should happen automatically at the XLA level. Outside of the jit'd region, store your input in NHWC. Then pass into the jit'd region, transpose to NCHW, and call Equinox.

Under-the-hood, XLA will remove the transpose + NCHW kernel and just replace it with a call to a NHWC implementation.

(I believe.)

patrick-kidger avatar Jul 28 '23 15:07 patrick-kidger

Hi @patrick-kidger, thanks for the quick response.

You may be right here, I have no idea what happens under the hood. I just noticed that the convolution operation in Haiku and FLAX support this option and figured, there may be a difference :-)

MathiesW avatar Aug 09 '23 06:08 MathiesW

I think this kind of optimisation should happen automatically at the XLA level. Outside of the jit'd region, store your input in NHWC. Then pass into the jit'd region, transpose to NCHW, and call Equinox.

Under-the-hood, XLA will remove the transpose + NCHW kernel and just replace it with a call to a NHWC implementation.

(I believe.)

~~Do you mind explaining this a bit further? I've implemented the VDVAE in both Flax and Equinox, and fp32 training is a bit faster in Flax while float16 training is 40-50% faster. As far as I can tell, the only implementation difference is Equinox with NCHW and Flax with NHWC. I enjoy Equinox much more so I'd like to regain as much of that gap as I can.~~

~~Edit: I suppose I can implement my own conv class, but the performance difference seems somewhat nontrivial so changing default layout might be worth considering.~~

~~Edit: I'd be happy to take a stab at making the layout an argument of the conv class.~~

Well, I went ahead and rewrote everything to be in the "better" order and it didn't make a difference in throughput. The performance gap I saw seems to have been from the flattening you mention here https://docs.kidger.site/equinox/tricks/#low-overhead-training-loops. With that change the gap between my Flax and Equinox implementation has closed.

haydn-jones avatar Dec 02 '23 02:12 haydn-jones

Interesting! The overhead from Equinox flattening usually isn't very much -- is your model quite small? Anyway, I'm glad that you were able to resolve this.

patrick-kidger avatar Dec 02 '23 22:12 patrick-kidger

I was testing one with ~14 million parameters and was running train batches at ~20it/s.

haydn-jones avatar Dec 03 '23 04:12 haydn-jones