burn icon indicating copy to clipboard operation
burn copied to clipboard

Document limitations of broadcasting for matmul in Burn

Open antimora opened this issue 1 year ago • 9 comments

The current documentation for the matmul operation in Burn does not clearly explain the limitations of broadcasting support. This has led to confusion among users who expect broadcasting behavior similar to that of PyTorch or NumPy.

The main limitation is that Burn requires the rank (number of dimensions) of the input tensors to be the same for broadcasting to work. This is due to constraints imposed by Rust's type system. In contrast, frameworks like PyTorch and NumPy allow broadcasting even when the input tensors have different ranks, as long as certain rules are followed.

To address this issue, we should:

  1. Update the documentation for matmul to explicitly state that broadcasting requires the input tensors to have the same rank.
  2. Provide examples showing how to use unsqueeze to manually adjust the rank of tensors before performing matmul.
  3. Explain the technical reasons behind this limitation (Rust's type system not supporting const expressions needed to return max rank).
  4. Consider automatically performing unsqueeze in certain cases to improve API usability, while being mindful of potential unintended consequences.

Additionally, we should enhance the Burn Book by linking relevant API documentation pages to provide users with quick access to exact function signatures and usage examples. This could be automated using pre-processor plug-ins (see #1197).

By improving the documentation and clearly communicating the limitations and workarounds, we can prevent user confusion and create a more frictionless API experience, even if full PyTorch-style broadcasting is not currently feasible in Burn.

antimora avatar Mar 21 '24 17:03 antimora

This isn't just about matmul, but all element-wise operations can be broadcasted: mul, add, sub, div, you name it. But of course, both LHS and RHS should have the same rank.

nathanielsimard avatar Mar 22 '24 14:03 nathanielsimard

CCing @laggui , @ashdtu

antimora avatar Apr 15 '24 20:04 antimora

Right now the docs for tensors basically only mention this:

With Burn, you can focus more on what the model should do, rather than on how to do it. We take the responsibility of making your code run as fast as possible during training as well as inference. The same principles apply to broadcasting; all operations support broadcasting unless specified otherwise.

We should definitely improve that to be more explicit.

For example, adding a small section explaining the broadcasting semantics which are supported in Burn.

laggui avatar Apr 16 '24 12:04 laggui

I think we should start by tackling 1, 2 and 3.

  1. Consider automatically performing unsqueeze in certain cases to improve API usability, while being mindful of potential unintended consequences.

Could be captured in another issue/PR. I'd start with properly documenting the current broadcasting semantics.

Btw, as opposed to other matmul implementations (e.g, numpy and torch), we currently support broadcasting for leading dimensions with matmul while others don't apply. For example, a @ b where a has shape [2, 2, 1] and b is [1, 2, 2] works with Burn (dim 0 broadcasted), but other implementations don't go that far since matmul is different than element-wise operations (where a mul/div/mul/etc would broadcast for the same tensors (a, b)). I guess this is an opiniated decision, which should also be reflected in the docs.

laggui avatar Apr 16 '24 14:04 laggui

@laggui I think torch has the same behavior; it's batch matrix multiplication where the batch dimension can be broadcasted.

nathanielsimard avatar Apr 17 '24 11:04 nathanielsimard

@laggui I think torch has the same behavior; it's batch matrix multiplication where the batch dimension can be broadcasted.

You're right! At least, for element-wise ops but not for matmul

>>> torch.matmul(torch.empty(2,2,1), torch.empty(1,2,2)).shape
RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2, 1] but got: [2, 2].

Same with numpy

>>> np.matmul(np.empty((2,2,1)), np.empty((1,2,2))).shape
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 1)

Whereas element-wise multiplication does broadcast

>>> torch.mul(torch.empty(2,2,1), torch.empty(1,2,2)).shape
torch.Size([2, 2, 2])

On the other hand, this works with Burn:

let tensor_1 = Tensor::<B, 3>::ones([2, 2, 2], &device);
let tensor_2 = Tensor::<B, 3>::ones([1, 2, 2], &device);
let dims1 = tensor_1.dims();
let dims2 = tensor_2.dims();
let tensor_3 = tensor_1.matmul(tensor_2);
println!("{:?} @ {:?} = {:?}", dims1, dims2, tensor_3.dims());
[2, 2, 2] @ [1, 2, 2] = [2, 2, 2]

laggui avatar Apr 17 '24 12:04 laggui

@laggui ho that's funny, since it works with the LibTorch bindings we are using. I guess the behavior is different from the Python torch.

nathanielsimard avatar Apr 17 '24 12:04 nathanielsimard