dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Mean, Sum across arbitrary dimensions (especially batch dimension)

Open vikigenius opened this issue 2 years ago • 4 comments

This might potentially depend on Reshape, but it would be good to compute mean/sum across arbitrary dimensions by letting mean/sum take dim or axis as an argument just like how Pytorch does it.

This should also be useful for the batch_norm issue https://github.com/coreylowman/dfdx/issues/78 because currently there is no api/easy way to compute the mean across the batch dimension.

vikigenius avatar Jul 21 '22 18:07 vikigenius

In fact, torch goes one step further and allows reduction across multiple dimensions simultaneously.

This is the PR where they first added this feature https://github.com/pytorch/pytorch/pull/6152 (only for sum), others were added in later PRs, just linking this for reference if I decide to work on this. But we need to flesh out the API for handling this properly.

vikigenius avatar Jul 21 '22 18:07 vikigenius

Now that we have reshape implemented. What is the right way of tackling this?

I don't know if it is possible to write a generic trait like ReduceOverDims similar to ReduceLastDim considering that depending on the shape of the tensor there could be different allowed dimensions for reduction, so an associated constant like LAST_DIM is not going to work.

vikigenius avatar Aug 07 '22 22:08 vikigenius

Yeah I'm actually working on this now! Will be able to support reducing 1 dimension at first. e.g.:

let t: Tensor4D<A, B, C, D> = ...
let _: Tensor3D<B, C, D> = t.sum_axis::<0>;
let _: Tensor3D<A, C, D> = t.sum_axis::<1>;
let _: Tensor3D<A, B, D> = t.sum_axis::<2>;
let _: Tensor3D<A, B, C> = t.sum_axis::<3>;

coreylowman avatar Aug 08 '22 22:08 coreylowman

Nice looking forward to it. Reducing across multiple dimensions is not that big of a deal and is just a nice to have. You could always apply reduce_over_dim multiple times. It just comes with the cognitive load of recalculating the axis position when you have applied reduce already.

vikigenius avatar Aug 08 '22 22:08 vikigenius