heat
heat copied to clipboard
Implement `ht.broadcast_shapes`
Related #891
The heat
function stride_tricks.broadcast_shape
differs unnecessarily from its numpy and torch counterparts:
- naming (heat
broadcast_shape
vs. numpy/torchbroadcast_shapes
) - calling (
ht.stride_tricks.broadcast_shape()
vs.np.broadcast_shapes()
ortorch.broadcast_shapes()
) - arguments:
ht.stride_tricks.broadcast_shape
only takes 2 shapes as arguments, no limit on the numpy/torch side
Feature functionality
ht.broadcast_shapes
should behave exactly like np.broadcast_shapes
, i.e. return a shape as tuple - while using torch.broadcast_shapes
under the hood for GPU compatibility
ADDENDUM: broadcast_shapes
is only available starting torch 1.8.0 and numpy 1.20. Doh!