burn
burn copied to clipboard
Wgpu backend of image-classification-web example cannot work
When I choose the wgpu backend, I get errors in the console.
After I disable the autotune feature of burn-wgpu, the wgpu backend still cannot work. The live demo works fine, so I think it's not my device's problem.
I tried printing the tensor converted from the input f32 slice in the console. It seems that the tensor's data has been corrupted from the very beginning.
@nathanielsimard @louisfd just be aware of this jit related bug on WASM/WebGPU
I would be great if we had WebGPU tests on CI #810
Is this bug caused by Pool2dEagerKernel? @nathanielsimard @louisfd
[START_KERNEL_COMPILATION]
name: burn_jit::kernel::pool::pool2d_shader::Pool2dEagerKernel<
burn_jit::kernel::pool::max_pool2d::MaxPool<
f32,
>,
cubecl_wgpu::runtime::WgpuRuntime,
f32,
>
cube_dim: (16, 16, 1)
shared_memory: 0 bytes
source:
```
@group(0)
@binding(0)
var<storage, read_write> input_0_global: array<f32>;
@group(0)
@binding(1)
var<storage, read_write> output_0_global: array<f32>;
@group(0)
@binding(2)
var<storage, read_write> info: array<u32>;
@group(0)
@binding(3)
var<storage, read_write> scalars_uint: array<u32, 6>;
const WORKGROUP_SIZE_X = 16u;
const WORKGROUP_SIZE_Y = 16u;
const WORKGROUP_SIZE_Z = 1u;
@compute
@workgroup_size(16, 16, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let rank: u32 = info[0];
let rank_2: u32 = rank * 2u;
var l_0_0: u32;
var l_0_1: u32;
var l_0_2: u32;
var l_0_3: u32;
var l_0_4: u32;
var l_0_5: u32;
var l_0_6: u32;
var l_0_7: u32;
var l_0_8: u32;
var l_0_9: u32;
var l_0_10: u32;
var l_0_11: u32;
var l_0_12: u32;
var l_0_13: u32;
var l_0_14: u32;
var l_0_15: u32;
var l_0_16: u32;
var l_0_17: u32;
var l_0_18: u32;
var l_0_19: u32;
var l_0_20: u32;
var l_0_21: u32;
var l_0_22: u32;
var l_0_23: u32;
var l_0_24: u32;
var l_0_25: f32;
var l_0_26: u32;
var l_0_27: u32;
var l_0_28: u32;
var l_0_29: u32;
var l_0_30: u32;
var l_0_31: u32;
var l_0_32: bool;
var l_0_33: bool;
var l_0_34: bool;
var l_0_35: u32;
var l_0_36: u32;
var l_0_37: f32;
l_0_0 = info[(0u * rank_2) + 0u + 1u];
l_0_1 = info[(0u * rank_2) + 1u + 1u];
l_0_2 = info[(0u * rank_2) + 2u + 1u];
l_0_3 = info[(0u * rank_2) + 3u + 1u];
l_0_4 = info[(0u * rank_2) + rank + 2u + 1u];
l_0_5 = info[(0u * rank_2) + rank + 3u + 1u];
l_0_6 = info[(0u * rank_2) + rank + 2u + 1u];
l_0_7 = info[(0u * rank_2) + rank + 3u + 1u];
l_0_8 = info[(1u * rank_2) + 0u + 1u];
l_0_9 = info[(1u * rank_2) + 1u + 1u];
l_0_10 = info[(1u * rank_2) + 2u + 1u];
l_0_11 = info[(1u * rank_2) + 3u + 1u];
l_0_12 = info[(1u * rank_2) + rank + 0u + 1u];
l_0_13 = info[(1u * rank_2) + rank + 1u + 1u];
l_0_14 = info[(1u * rank_2) + rank + 2u + 1u];
l_0_15 = info[(1u * rank_2) + rank + 3u + 1u];
l_0_16 = id / l_0_8;
l_0_16 = l_0_16 % l_0_12;
l_0_17 = id / l_0_9;
l_0_17 = l_0_17 % l_0_13;
l_0_18 = id / l_0_10;
l_0_18 = l_0_18 % l_0_14;
l_0_19 = id / l_0_11;
l_0_19 = l_0_19 % l_0_15;
l_0_35 = l_0_6 + scalars_uint[4];
l_0_36 = l_0_7 + scalars_uint[5];
l_0_27 = l_0_16 * l_0_0;
l_0_28 = l_0_17 * l_0_1;
l_0_37 = f32(-340282350000000000000000000000000000000f);
l_0_20 = l_0_18 * scalars_uint[0];
l_0_22 = 0u * scalars_uint[2];
l_0_20 = l_0_20 + l_0_22;
l_0_32 = l_0_20 >= scalars_uint[4];
l_0_34 = l_0_20 < l_0_35;
l_0_32 = l_0_32 && l_0_34;
if l_0_32 {
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 0u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 1u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 2u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
}
l_0_20 = l_0_18 * scalars_uint[0];
l_0_22 = 1u * scalars_uint[2];
l_0_20 = l_0_20 + l_0_22;
l_0_32 = l_0_20 >= scalars_uint[4];
l_0_34 = l_0_20 < l_0_35;
l_0_32 = l_0_32 && l_0_34;
if l_0_32 {
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 0u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 1u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 2u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
}
l_0_20 = l_0_18 * scalars_uint[0];
l_0_22 = 2u * scalars_uint[2];
l_0_20 = l_0_20 + l_0_22;
l_0_32 = l_0_20 >= scalars_uint[4];
l_0_34 = l_0_20 < l_0_35;
l_0_32 = l_0_32 && l_0_34;
if l_0_32 {
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 0u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 1u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
l_0_21 = l_0_19 * scalars_uint[1];
l_0_22 = 2u * scalars_uint[3];
l_0_21 = l_0_21 + l_0_22;
l_0_33 = l_0_21 >= scalars_uint[5];
l_0_34 = l_0_21 < l_0_36;
l_0_33 = l_0_33 && l_0_34;
if l_0_33 {
var l_2_0: bool;
l_0_23 = l_0_20 - scalars_uint[4];
l_0_24 = l_0_21 - scalars_uint[5];
l_0_29 = l_0_23 * l_0_2;
l_0_31 = u32(l_0_29);
l_0_31 = l_0_31 + l_0_24;
l_0_30 = l_0_24 * l_0_3;
l_0_26 = u32(l_0_27);
l_0_26 = l_0_26 + l_0_28;
l_0_26 = l_0_26 + l_0_29;
l_0_26 = l_0_26 + l_0_30;
l_0_25 = f32(input_0_global[l_0_26]);
l_2_0 = l_0_25 > l_0_37;
if l_2_0 {
l_0_37 = f32(l_0_25);
}
}
}
output_0_global[id] = f32(l_0_37);
}
```
[END_KERNEL_COMPILATION]
This bug is caused by the limited default precision of Rust's display of f32 values. It should be an easy fix.
I changed the code from https://github.com/tracel-ai/cubecl/blob/cfe0b0204380cbd0931f478194a053a6ac35d1cb/crates/cubecl-wgpu/src/compiler/wgsl/base.rs#L262-L263 :
FloatKind::F32 => f.write_fmt(format_args!("{}f", *val as f32)),
FloatKind::F64 => f.write_fmt(format_args!("{}f", { *val })),
to:
FloatKind::F32 => f.write_fmt(format_args!("{:.9}f", *val as f32)),
FloatKind::F64 => f.write_fmt(format_args!("{:.17}f", { *val })),
It can fix the bug.
I thought maybe it was due to incompatible versions between the main branch and the examples, but I have the same issue after downloading the 0.13.2 release out-of-the-box then cd examples/image-classification-web then ./build-for-web.sh. In fact, it would not even compile without further adding the features=["autotune"] to the burn-wgpu crate dependency manually before compiling.
Then I tried with 0.13.1 release as well as the 0.13.0 release, and this produces the same issue as described in the first place (also, the Candle backend does not work either, as per issue #1034 ).
I am just not able to reproduce the image-classification-web with any of 0.13+ version (didn't look at earlier versions), except for the ndarray backend which works seemlessly.
@antimora do you by any chance kept around the original repo you used to make your published version work? Would be very helpful to diff it and check what's wrong.
For instance, how did you make the Candle backend work without the AvgPool2d op in the first place? Did you switched to another model than squeezenet for the sake of this example? Also, how did you manage to get the wgpu backend to load? (If I am correct, the above solution suggested by @wcshds only apply to 0.14+ versions that include cubecl dependency, correct?)
By the way, thanks for sharing this great project :)
In fact, I've tried all the release versions of Burn since 10.0.0 on my device, but the wgpu backend in the image-classification-example hasn't worked. However, after the small modifications I mentioned above, now I can run this example successfully on the main branch.
Reproduction steps:
# Clone repo
git clone [email protected]:tracel-ai/burn.git
# get into repo
cd burn
# Change cubecl dependency to revision that include the suggested bugfix (https://github.com/tracel-ai/cubecl/commit/32feabc5140170d45d4365a56106db930ed79a33)
# For reproduction purposes, here I use the sd utility: (https://github.com/chmln/sd), but one can just change it manually in the Cargo.toml for both cubecl AND cubecl-common
sd '(cubecl.* rev =).*(\})' '$1 "32feabc5140170d45d4365a56106db930ed79a33" $2' Cargo.toml
# cd into the relevant example
cd examples/image-classification-web
# compile the example
./build-for-web.sh
# run the server
./run-server.sh
RESULTS:
✅ NdArray backend: working (slower than 0.13.2 version by an order of magnitude, but still okay)
❗ Candle backend: backend LOADS correctly but cannot do inference, Candle does not support excluding pad count in pooling
❌ Wgpu backend: cannot load, An home directory should exist
Did I miss something here?
@Jonarod You also need to disable the use of the autotune feature for burn-wgpu.
As for Candle, I don't know how to make it work either.
Hey @wcshds 👋
I was going to update the cubecl dep to the latest to include some fixes, but some wgpu tests are failing with your merged PR.
failures:
tests::jit::gradients::tests::should_update_tensor_when_grad_replace
tests::jit::kernel::bernoulli::tests::number_of_1_proportional_to_prob
tests::jit::kernel::bernoulli::tests::runs_test
tests::jit::kernel::normal::tests::empirical_mean_close_to_expectation
tests::jit::kernel::normal::tests::normal_respects_68_95_99_rule
tests::jit::kernel::uniform::tests::at_least_one_value_per_bin_int_uniform
tests::jit::kernel::uniform::tests::at_least_one_value_per_bin_uniform
tests::jit::kernel::uniform::tests::runs_test
tests::jit_fusion::gradients::tests::should_update_tensor_when_grad_replace
Tried to understand the problem in this issue but I think I'm missing a bit of context.. could you explain why the precision change was required?
/edit: see PRs #2159 #2158 as reference.
Worked like a charm.
Thanks for your help.
I submitted corresponding PR that should solve this issue.
If you read this while PR is not merged to main, basically the solution, as suggested by @wcshds is to:
-
change
cubeclandcubecl-commonrevisions to32feabc5140170d45d4365a56106db930ed79a33in the burn's root'sCargo.toml -
remove
burn-wgpu'sfeatures = [ "autotune" ]fromexamples/image-classification-web/Cargo.toml
I think this can be closed now.
@laggui This is because I saw in the browser console that -3.40282347E+38f32 was rounded to -340282350000000000000000000000000000000f, which caused a WGSL compilation error, so I believe this is an issue with Rust's default display precision. It's strange that the test failed, but changing the precision to 13 decimal places solved the problem.
FloatKind::F32 => f.write_fmt(format_args!("{:.13}f", *val as f32)),
I originally thought that a precision of 9 decimal places would be sufficient for f32.
Weird 😅
Not sure if the fix is the proper way to address this or if it's just a patch for a more specific issue. It doesn't seem to happen anywhere else 🤔
/edit: fyi, we have decided to revert the changes applied to the precision for now. The current workaround is at least documented for users to try while we investigate why this happens in this specific example.
It looks like this issue has been resolved in cubecl 130.