burn
burn copied to clipboard
Move `full` and `pad` functions to `BaseTensor` trait and make `bool` a proper tensor element type
Currently, the full function is implemented separately for each tensor type. To make it work for bool tensors as well, it was moved to the BaseTensor trait. However, this introduced some issues:
- The
booltype does not implement theElementtrait, which is required by functions inBaseTensor. - The
Elementtrait depends on numeric traits likeToPrimitive,Zero, andOne, whichbooldoes not implement. - Calling
elem()to convert values to the element type is a temporary workaround and should be cleaned up.
To properly address this, we should:
- Create our own versions of the
Zero,One, andToPrimitivetraits that are not coupled to numeric types. - Remove the dependency on
num_traitsin the baseElementtrait. - Make the
Elementtrait not require numeric traits. - Implement the new
Elementtrait forbool. - Move the
fullfunction toBaseTensoronceboolproperly implementsElement.
This will allow bool to be treated as a proper tensor element type, simplify the implementation of full and other functions in BaseTensor, and remove the need for workarounds like calling elem().
For now, the pad function changes should be moved to the numeric tensor implementation to unblock the current PR. The full function move and bool element type cleanup can be handled separately in this ticket.
bool_full implementation:
fn bool_full<const D: usize>(
shape: Shape<D>,
value: bool,
device: &Device<B>,
) -> BoolTensor<B, D> {
B::int_equal_elem(
B::int_zeros(shape, device),
if value { 0.elem() } else { 1.elem() },
)
}
We need full working for bool tensor because ConstantOfShape can initialize with bool values and we can't use full because of this error:
Compiling onnx-tests v0.14.0 (/Users/dilshod/Projects/burn/crates/burn-import/onnx-tests)
error[E0277]: the trait bound `burn::tensor::Bool: Numeric<B>` is not satisfied
--> /Users/dilshod/Projects/burn/target/debug/build/onnx-tests-50e78da95971b883/out/model/constant_of_shape_full_like.rs:53:37
|
53 | let constantofshape3_out1 = Tensor::full(shape3_out1, true, &*self.device);
| ^^^^^^^^^^^^ the trait `Numeric<B>` is not implemented for `burn::tensor::Bool`
|
= help: the following other types implement trait `Numeric<B>`:
burn::tensor::Float
burn::tensor::Int
note: required by a bound in `burn_tensor::tensor::api::numeric::<impl Tensor<B, D, K>>::full`
--> /Users/dilshod/Projects/burn/crates/burn-tensor/src/tensor/api/numeric.rs:13:8
|
13 | K: Numeric<B>,
| ^^^^^^^^^^ required by this bound in `burn_tensor::tensor::api::numeric::<impl Tensor<B, D, K>>::full`
...
112 | pub fn full<S: Into<Shape<D>>, E: ElementConversion>(
| ---- required by a bound in this associated function
For more information about this error, try `rustc --explain E0277`.
error: could not compile `onnx-tests` (test "onnx_tests") due to 1 previous error
[constant_of_shape]%
Workaround to get full for bool tensors:
// All true
Tensor::<Backend, 3, Int>::ones(shape, &device).bool();
// All false
Tensor::<Backend, 3, Int>::zeros(shape, &device).bool();
@laggui , you mentioned you were working in this area. Do you think we are close making it work?
@laggui , you mentioned you were working in this area. Do you think we are close making it work?
Bool is now an element type :) but the methods mentioned like full have not yet been refactored.
So this issue is partially completed.