lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Fix `_broadcast_in_dim_prim_grad` when `reduce_dims` is empty

Open rdspring1 opened this issue 1 year ago • 1 comments

>>> import jax
>>> import numpy as np
>>> x = np.random.normal(size=(5,))
array([ 0.55060031, -0.03787608, -0.81081705,  0.26335944, -1.4178462 ])
>>> jax.lax.broadcast_in_dim(x, [5], [0])
Array([ 0.5506003 , -0.03787608, -0.81081706,  0.26335943, -1.4178462 ],      dtype=float32)

This ^^^ broadcast_in_dim is a no-op. In the gradient function, the reduce_dims tuple is empty, implying we do not want to remove any dimensions. However, when the dim tuple is empty, torch.sum reduces all dimensions. The solution is to skip the reduction.

rdspring1 avatar Apr 27 '24 16:04 rdspring1

@IvanYashchuk and/or @nikitaved -- would you review this, please?

mruberry avatar May 01 '24 14:05 mruberry

@t-vi or @mruberry, needing your approval to merge this. The tests with the updated opinfo samples would fail on main without this fix.

IvanYashchuk avatar May 28 '24 16:05 IvanYashchuk

I discovered it when updating https://github.com/Lightning-AI/lightning-thunder/pull/260.

input shape = (5,)
target shape = () 
weight = True
reduction = mean
label_smoothing = 0.5

In this ^^^ single example case, the input shape is (C,) which is the same as the weight tensor (C,). The broadcast_in_dim is a no-op in this case.

rdspring1 avatar May 28 '24 17:05 rdspring1