lecture-jax icon indicating copy to clipboard operation
lecture-jax copied to clipboard

JAX NN row operations

Open jstac opened this issue 2 months ago • 0 comments

In jax_nn we state the following:

We work with row vectors because Python numerical operations are row-major rather than column-major, so that row-based operations tend to be more efficient.

I'm not sure this makes sense. If we calculate x @ W where x is a row vector, then we are summing along the columns of W.

This needs to be investigated and possibly corrected.

jstac avatar Sep 19 '25 01:09 jstac