heat icon indicating copy to clipboard operation
heat copied to clipboard

Implement `ht.broadcast_shapes`

Open ClaudiaComito opened this issue 3 years ago • 0 comments

Related #891

The heat function stride_tricks.broadcast_shape differs unnecessarily from its numpy and torch counterparts:

  1. naming (heat broadcast_shape vs. numpy/torch broadcast_shapes)
  2. calling (ht.stride_tricks.broadcast_shape() vs. np.broadcast_shapes() or torch.broadcast_shapes())
  3. arguments: ht.stride_tricks.broadcast_shape only takes 2 shapes as arguments, no limit on the numpy/torch side

Feature functionality

ht.broadcast_shapes should behave exactly like np.broadcast_shapes, i.e. return a shape as tuple - while using torch.broadcast_shapes under the hood for GPU compatibility

ADDENDUM: broadcast_shapes is only available starting torch 1.8.0 and numpy 1.20. Doh!

ClaudiaComito avatar Jan 21 '22 12:01 ClaudiaComito