rten icon indicating copy to clipboard operation
rten copied to clipboard

Replace all usage of `TensorBase::broadcast_iter`

Open robertknight opened this issue 1 year ago • 0 comments

A common operation in various ONNX operators is to broadcast multiple inputs, of dynamically determined rank, to a common shape, then apply some operation over matching elements from the broadcasted views. Up until now this was implemented with TensorBase::broadcast_iter, which is a regular Rust iterator that supports inputs of varying ranks.

I've since learned that a different pattern generates much more efficient code:

  1. Broadcast input views to output shape (of dynamic rank)
  2. Pad views to at least N dims (currently always 4) by inserting 1-sized axes on the left. N is fixed statically.
  3. Loop over matching static-rank views of the innermost N dims, and apply operation to paired elements from each input

This issue is for applying a pattern like this to all the places where broadcast_iter is currently used, if potentially performance sensitive. For places that are not performance-sensitive, tensor.broadcast(shape).iter() can be used instead.

robertknight avatar May 19 '24 17:05 robertknight