burn icon indicating copy to clipboard operation
burn copied to clipboard

Move `full` and `pad` functions to `BaseTensor` trait and make `bool` a proper tensor element type

Open antimora opened this issue 1 year ago • 10 comments

Currently, the full function is implemented separately for each tensor type. To make it work for bool tensors as well, it was moved to the BaseTensor trait. However, this introduced some issues:

  1. The bool type does not implement the Element trait, which is required by functions in BaseTensor.
  2. The Element trait depends on numeric traits like ToPrimitive, Zero, and One, which bool does not implement.
  3. Calling elem() to convert values to the element type is a temporary workaround and should be cleaned up.

To properly address this, we should:

  1. Create our own versions of the Zero, One, and ToPrimitive traits that are not coupled to numeric types.
  2. Remove the dependency on num_traits in the base Element trait.
  3. Make the Element trait not require numeric traits.
  4. Implement the new Element trait for bool.
  5. Move the full function to BaseTensor once bool properly implements Element.

This will allow bool to be treated as a proper tensor element type, simplify the implementation of full and other functions in BaseTensor, and remove the need for workarounds like calling elem().

For now, the pad function changes should be moved to the numeric tensor implementation to unblock the current PR. The full function move and bool element type cleanup can be handled separately in this ticket.

antimora avatar Mar 26 '24 22:03 antimora

bool_full implementation:

    fn bool_full<const D: usize>(
        shape: Shape<D>,
        value: bool,
        device: &Device<B>,
    ) -> BoolTensor<B, D> {
        B::int_equal_elem(
            B::int_zeros(shape, device),
            if value { 0.elem() } else { 1.elem() },
        )
    }

antimora avatar Mar 27 '24 15:03 antimora

We need full working for bool tensor because ConstantOfShape can initialize with bool values and we can't use full because of this error:

   Compiling onnx-tests v0.14.0 (/Users/dilshod/Projects/burn/crates/burn-import/onnx-tests)
error[E0277]: the trait bound `burn::tensor::Bool: Numeric<B>` is not satisfied
   --> /Users/dilshod/Projects/burn/target/debug/build/onnx-tests-50e78da95971b883/out/model/constant_of_shape_full_like.rs:53:37
    |
53  |         let constantofshape3_out1 = Tensor::full(shape3_out1, true, &*self.device);
    |                                     ^^^^^^^^^^^^ the trait `Numeric<B>` is not implemented for `burn::tensor::Bool`
    |
    = help: the following other types implement trait `Numeric<B>`:
              burn::tensor::Float
              burn::tensor::Int
note: required by a bound in `burn_tensor::tensor::api::numeric::<impl Tensor<B, D, K>>::full`
   --> /Users/dilshod/Projects/burn/crates/burn-tensor/src/tensor/api/numeric.rs:13:8
    |
13  |     K: Numeric<B>,
    |        ^^^^^^^^^^ required by this bound in `burn_tensor::tensor::api::numeric::<impl Tensor<B, D, K>>::full`
...
112 |     pub fn full<S: Into<Shape<D>>, E: ElementConversion>(
    |            ---- required by a bound in this associated function

For more information about this error, try `rustc --explain E0277`.
error: could not compile `onnx-tests` (test "onnx_tests") due to 1 previous error
[constant_of_shape]%

antimora avatar Jul 06 '24 19:07 antimora

Workaround to get full for bool tensors:

// All true
Tensor::<Backend, 3, Int>::ones(shape, &device).bool();

// All false
Tensor::<Backend, 3, Int>::zeros(shape, &device).bool();

antimora avatar Jul 06 '24 22:07 antimora

@laggui , you mentioned you were working in this area. Do you think we are close making it work?

antimora avatar Jul 06 '24 22:07 antimora

@laggui , you mentioned you were working in this area. Do you think we are close making it work?

Bool is now an element type :) but the methods mentioned like full have not yet been refactored.

So this issue is partially completed.

laggui avatar Jul 07 '24 12:07 laggui