dfdx
dfdx copied to clipboard
Mean, Sum across arbitrary dimensions (especially batch dimension)
This might potentially depend on Reshape, but it would be good to compute mean/sum across arbitrary dimensions by letting mean/sum take dim
or axis
as an argument just like how Pytorch does it.
This should also be useful for the batch_norm
issue https://github.com/coreylowman/dfdx/issues/78 because currently there is no api/easy way to compute the mean across the batch dimension.
In fact, torch goes one step further and allows reduction across multiple dimensions simultaneously.
This is the PR where they first added this feature https://github.com/pytorch/pytorch/pull/6152 (only for sum), others were added in later PRs, just linking this for reference if I decide to work on this. But we need to flesh out the API for handling this properly.
Now that we have reshape implemented. What is the right way of tackling this?
I don't know if it is possible to write a generic trait like ReduceOverDims
similar to ReduceLastDim considering that depending on the shape of the tensor there could be different allowed dimensions for reduction, so an associated constant like LAST_DIM
is not going to work.
Yeah I'm actually working on this now! Will be able to support reducing 1 dimension at first. e.g.:
let t: Tensor4D<A, B, C, D> = ...
let _: Tensor3D<B, C, D> = t.sum_axis::<0>;
let _: Tensor3D<A, C, D> = t.sum_axis::<1>;
let _: Tensor3D<A, B, D> = t.sum_axis::<2>;
let _: Tensor3D<A, B, C> = t.sum_axis::<3>;
Nice looking forward to it. Reducing across multiple dimensions is not that big of a deal and is just a nice to have. You could always apply reduce_over_dim multiple times. It just comes with the cognitive load of recalculating the axis position when you have applied reduce already.