jaxga icon indicating copy to clipboard operation
jaxga copied to clipboard

Memory layout question

Open EelcoHoogendoorn opened this issue 2 years ago • 18 comments

Thanks for making this library; I wanted to do something similar a few months ago but other things got into the way; awesome to see that the JAX ecosystem is maturing this fast!

One design question I had was wether to go with a struct-of-arrays or array-of-structs memory layout. Unless ive misread, your design is the latter; that is, if I vmap over a multivector, the components of the multivector will form the last axis; which in JAX is the contiguous stride==1 axis. If youd think about how this would get vectorized on a GPU, that wouldnt be ideal; if every thread in a block gets to work on a single element of the vmapped axis, which is the most straightforward parrelelization, now the threads in this warp are not performing contiguous memory accesses. Hence the structs-of-arrays are generally preferred on the GPU. Also if you dig into the deepmind/alphafold repo, you will see that they also use a struct-of-array layout for their vector types and the like.

Now this is all terrible premature optimization as far as the actual goals im trying to achieve; but I guess im trying to form a bit of a deeper understanding of JAX and TPUs on a low level. So with that in mind; was this a deliberate choice, or something that you have given any thought to?

EelcoHoogendoorn avatar Dec 27 '21 18:12 EelcoHoogendoorn