lightning-thunder
lightning-thunder copied to clipboard
Fix `_broadcast_in_dim_prim_grad` when `reduce_dims` is empty
>>> import jax
>>> import numpy as np
>>> x = np.random.normal(size=(5,))
array([ 0.55060031, -0.03787608, -0.81081705, 0.26335944, -1.4178462 ])
>>> jax.lax.broadcast_in_dim(x, [5], [0])
Array([ 0.5506003 , -0.03787608, -0.81081706, 0.26335943, -1.4178462 ], dtype=float32)
This ^^^ broadcast_in_dim is a no-op. In the gradient function, the reduce_dims tuple is empty, implying we do not want to remove any dimensions. However, when the dim tuple is empty, torch.sum reduces all dimensions. The solution is to skip the reduction.
@IvanYashchuk and/or @nikitaved -- would you review this, please?
@t-vi or @mruberry, needing your approval to merge this. The tests with the updated opinfo samples would fail on main without this fix.
I discovered it when updating https://github.com/Lightning-AI/lightning-thunder/pull/260.
input shape = (5,)
target shape = ()
weight = True
reduction = mean
label_smoothing = 0.5
In this ^^^ single example case, the input shape is (C,) which is the same as the weight tensor (C,). The broadcast_in_dim is a no-op in this case.