burn
burn copied to clipboard
Training Issues (NaN) When Migrating from PyTorch
First, I want to express appreciation for the Burn framework - it's a great step toward bringing ML capabilities to Rust. I'm working on migrating several small PyTorch models to Rust, but I've encountered some issues with BCE loss calculation and training behavior.
When migrating a simple binary classifier from PyTorch to Burn, I'm seeing significant differences in:
- BCE loss calculation
- Training behaviour
- Final model accuracy (46% with Burn, 92% with pytorch)
The model architecture is identical between frameworks (input->64->32->1 with ReLU/Sigmoid).
Reproducible Example
I've created a minimal example repository: [burn-problems]
Key test cases demonstrate the issues:
BCE Loss Test
# Python/PyTorch
Test Case 1 - Perfect predictions:
Predictions: tensor([1., 0., 1., 0.])
Targets: tensor([1., 0., 1., 0.])
Loss: 0.00000000
Test Case 2 - Wrong predictions:
Predictions: tensor([0., 1., 0., 1.])
Targets: tensor([1., 0., 1., 0.])
Loss: 100.00000000
Test Case 3 - Uncertain predictions:
Predictions: tensor([0.5000, 0.5000, 0.5000, 0.5000])
Targets: tensor([1., 0., 1., 0.])
Loss: 0.69314718
// Rust/Burn
Test Case 1 - Perfect predictions:
Predictions: tensor([1.0000, 0.0000, 1.0000, 0.0000])
Targets: tensor([1, 0, 1, 0])
Loss: NaN
Test Case 2 - Wrong predictions:
Predictions: tensor([0.0000, 1.0000, 0.0000, 1.0000])
Targets: tensor([1, 0, 1, 0])
Loss: inf
Test Case 3 - Uncertain predictions:
Predictions: tensor([0.5000, 0.5000, 0.5000, 0.5000])
Targets: tensor([1, 0, 1, 0])
Loss: 0.69314718
Training Results
- PyTorch achieves 93.36% accuracy
- Burn implementation stops at ~46% accuracy with NaN loss
Model:
DemoClassifierModel {
input_layer: Linear {d_input: 20, d_output: 64, bias: true, params: 1344}
hidden_layer1: Linear {d_input: 64, d_output: 32, bias: true, params: 2080}
output_layer: Linear {d_input: 32, d_output: 1, bias: true, params: 33}
activation: Relu
sigmoid: Sigmoid
params: 3457
}
Total Epochs: 20
| Split | Metric | Min. | Epoch | Max. | Epoch |
|-------|----------|----------|----------|----------|----------|
| Train | Accuracy | 46.113 | 1 | 46.113 | 20 |
| Train | Loss | 0.279 | 10 | NaN | 20 |
| Valid | Accuracy | 45.766 | 1 | 45.766 | 20 |
| Valid | Loss | 0.282 | 10 | NaN | 20 |
Noticed that accuracy is same as zero labels distribution in training set:
Train distribution: total=14690, zeroes=6742 (45.9%), ones=7948 (54.1%)
The critical section appears to be in BCE loss calculation:
// Current implementation
let loss = BinaryCrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone().squeeze(1), targets.clone());
// Alternative attempt
// let loss = BinaryCrossEntropyLossConfig::new()
// .init(&output.device())
// .forward(output.clone(), targets.clone().reshape([batch_size, 1]));
Questions
- Recommended way to handle tensor shapes for BCE loss in Burn?
- Are there any known issues with batch dimension handling that could cause these discrepancies?
- Should the loss calculation approach differ from PyTorch's implementation?
Thanks for flagging this!
I believe this is due to the current implementation of the BCE loss for tensor.log() which results in -inf for values of 0.0. We need to clamp the values to make sure we don't have this issue.
Should have a PR to fix this soon.
@laggui Latest version of Burn unable to compile with backend:Wgpu:
error[E0275]: overflow evaluating the requirement `wgpu_core::validation::NumericType: Sync`
|
= help: consider increasing the recursion limit by adding a `#![recursion_limit = "256"]` attribute to your crate (`burn_problems`)
note: required because it appears within the type `wgpu_core::validation::InterfaceVar`
--> .cargo/registry/src/index.crates.io-6f17d22bba15001f/wgpu-core-24.0.0/src/validation.rs:109:12
|
109 | pub struct InterfaceVar {
| ^^^^^^^^^^^^
note: required because it appears within the type `wgpu_core::validation::Varying`
--> .cargo/registry/src/index.crates.io-6f17d22bba15001f/wgpu-core-24.0.0/src/validation.rs:136:6
|
136 | enum Varying {
| ^^^^^^^
note: required because it appears within the type `PhantomData<wgpu_core::validation::Varying>`
--> .rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/core/src/marker.rs:753:12
|
753 | pub struct PhantomData<T: ?Sized>;
....
With backend::NdArray i got same results - 46% probably issues is not in BCELoss =(
https://github.com/oiwn/burn-problems/commit/c36f56a6ff3b2969d798cc3c25e4fb70120103f6
Model:
DemoClassifierModel {
input_layer: Linear {d_input: 20, d_output: 64, bias: true, params: 1344}
hidden_layer1: Linear {d_input: 64, d_output: 32, bias: true, params: 2080}
output_layer: Linear {d_input: 32, d_output: 1, bias: true, params: 33}
activation: Relu
sigmoid: Sigmoid
params: 3457
}
Total Epochs: 20
| Split | Metric | Min. | Epoch | Max. | Epoch |
|-------|----------|----------|----------|----------|----------|
| Train | Loss | 0.265 | 13 | NaN | 20 |
| Train | Accuracy | 46.052 | 1 | 46.052 | 20 |
| Valid | Loss | 0.258 | 13 | NaN | 20 |
| Valid | Accuracy | 46.011 | 1 | 46.011 | 20 |
Yeah we realized this the other day with the upgrade to wgpu 0.24.0.. see this discord convo for reference.
This seems to stem from new complex types in wgpu. As a temporary fix you can actually follow the compiler's help: increase the recursion limit (default is 128). You probably don't need to double it to 256, something around 140 should work.
With backend::NdArray i got same results - 46% probably issues is not in BCELoss =(
I haven't actually tested the whole thing, just isolated the bce loss bug initially 😅 Seems weird that your loss still NaNs 🤔 I'll check it out
/edit: just took a quick glance, looks like it's actually coming from the first linear layer parameters becoming NaN at some point. I'm assuming you validated the input data?
@laggui Input data are identical for PyTorch and Burn.
Pytorch:
=== Training Data Check (Python) ===
Dataset sizes:
Train: 14690 Test: 3673
Feature statistics (train) (first 3):
Feature 0 (feature1):
Mean: 0.0029
StdDev: 1.0035
Min: -2.1987
Max: 5.7738
Feature 1 (feature2):
Mean: -0.0012
StdDev: 0.9764
Min: -1.3040
Max: 17.6505
Feature 2 (feature3):
Mean: 0.0044
StdDev: 1.0126
Min: -0.7085
Max: 7.9158
Burn:
=== Training Data Check (Rust) ===
Dataset sizes:
Train: 14690 Test: 3673
Feature statistics (train) (first 3):
Feature 0:
Mean: -0.0034
StdDev: 1.0025
Min: -2.1987
Max: 5.7738
Feature 1:
Mean: 0.0006
StdDev: 0.9915
Min: -1.3040
Max: 21.7539
Feature 2:
Mean: 0.0012
StdDev: 1.0048
Min: -0.7085
Max: 7.9158
Some slight variations, but as long as the inputs during training don't have some weird values that deviate then I don't think that will be the issue.
I'll reopen this issue but it doesn't seem to be specific to the BCE loss anymore.
It might also be a difference in the training configuration. If you have a higher learning rate or are missing weight decay, it might lead to unstable training, resulting in NaN values, which render the model useless.
It might also be a difference in the training configuration. If you have a higher learning rate or are missing weight decay, it might lead to unstable training, resulting in NaN values, which render the model useless.
Tried different learning rates with no luck. There is strange correlation between training accuracy (46%) and amount of zeroes in training set labels.
There is strange correlation between training accuracy (46%) and amount of zeroes in training set labels.
The model has not converged as expected if you're getting NaNs during training, so the solution that gives a 46% accuracy probably defaulted to a simple heuristic that always predicts the same label. And so it is incorrect for all the zero labels if it always predicts one 🙂
Regarding the cause for the NaNs, not quite sure at first glance. Would have to spend a bit of time to investigate.
@laggui thank you!
Hey all,
I seem to have the same issue. Every now and then and seemingly at random I get nan as loss.
Here is an example that caused the loss to be nan:
loss_d_real contains nan Tensor {
data:
[NaN],
shape: [1],
device: DefaultDevice,
backend: "autodiff<fusion<jit<wgpu<wgsl>>>>",
kind: "Float",
dtype: "f32",
}. it was made from Tensor {
data:
[0.9999989, 0.9999994, 1.0, 0.9999983],
shape: [4],
device: DefaultDevice,
backend: "autodiff<fusion<jit<wgpu<wgsl>>>>",
kind: "Float",
dtype: "f32",
} and Tensor {
data:
[1, 1, 1, 1],
shape: [4],
device: DefaultDevice,
backend: "autodiff<fusion<jit<wgpu<wgsl>>>>",
kind: "Int",
dtype: "i32",
} with bce loss
use burn::{
backend::{Autodiff, Wgpu},
nn::loss::BinaryCrossEntropyLossConfig,
tensor::{Int, Tensor, TensorData},
};
#[test]
fn test_binary_cross_entropy_preds_almost_correct() {
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;
let device = burn::backend::wgpu::WgpuDevice::default();
let preds = Tensor::<MyAutodiffBackend, 1>::from_floats(
[0.9999989, 0.9999994, 1.0, 0.9999983],
&device,
);
let targets =
Tensor::<MyAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 1, 1, 1]), &device);
let bce = BinaryCrossEntropyLossConfig::new().init(&device);
let loss_actual = bce.forward(preds, targets).into_data();
let loss_expected = TensorData::from([8.49366756483505e-07]);
loss_actual.assert_approx_eq(&loss_expected, 7);
}
I adjusted one of your testcases to reproduce this behavior.
edit: I cross checked it with pytorch and adjusted the test with the expected value
edit: I cross checked it with pytorch and adjusted the test with the expected value
I think the problem is in BCELoss calculation, it show NaN/inf on border cases where pytorch clamp them into 0 and 100.
use burn::{ backend::{Autodiff, Wgpu}, nn::loss::BinaryCrossEntropyLossConfig, tensor::{Int, Tensor, TensorData}, };
#[test] fn test_binary_cross_entropy_preds_almost_correct() { type MyBackend = Wgpu<f32, i32>; type MyAutodiffBackend = Autodiff<MyBackend>;
let device = burn::backend::wgpu::WgpuDevice::default(); let preds = Tensor::<MyAutodiffBackend, 1>::from_floats( [0.9999989, 0.9999994, 1.0, 0.9999983], &device, ); let targets = Tensor::<MyAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 1, 1, 1]), &device); let bce = BinaryCrossEntropyLossConfig::new().init(&device); let loss_actual = bce.forward(preds, targets).into_data(); let loss_expected = TensorData::from([8.49366756483505e-07]); loss_actual.assert_approx_eq(&loss_expected, 7);} I adjusted one of your testcases to reproduce this behavior.
edit: I cross checked it with pytorch and adjusted the test with the expected value
Is this test case a reproducible failure on your machine? Just tried it locally and it passes 🤔
Currently it is.
thread 'test_binary_cross_entropy_preds_almost_correct' panicked at crates\sketchy_pix2pix\tests\bce_test.rs:23:17:
Tensors are not approx eq:
=> Position 0: NaN != 0.000000849366756483505 | difference NaN > tolerance 0.00000010000000000000004
stack backtrace:
0: std::panicking::begin_panic_handler
at /rustc/4eb161250e340c8f48f66e2b929ef4a5bed7c181/library\std\src\panicking.rs:692
1: core::panicking::panic_fmt
at /rustc/4eb161250e340c8f48f66e2b929ef4a5bed7c181/library\core\src\panicking.rs:75
2: burn_tensor::tensor::data::TensorData::assert_approx_eq_diff
at C:\Users\user\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\burn-tensor-0.16.1\src\tensor\data.rs:594
3: burn_tensor::tensor::data::TensorData::assert_approx_eq
at C:\Users\user\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\burn-tensor-0.16.1\src\tensor\data.rs:434
4: bce_test::test_binary_cross_entropy_preds_almost_correct
at .\tests\bce_test.rs:23
5: bce_test::test_binary_cross_entropy_preds_almost_correct::closure$0
at .\tests\bce_test.rs:8
6: core::ops::function::FnOnce::call_once<bce_test::test_binary_cross_entropy_preds_almost_correct::closure_env$0,tuple$<> >
at C:\Users\user\.rustup\toolchains\stable-x86_64-pc-windows-msvc\lib\rustlib\src\rust\library\core\src\ops\function.rs:250
7: core::ops::function::FnOnce::call_once
at /rustc/4eb161250e340c8f48f66e2b929ef4a5bed7c181/library\core\src\ops\function.rs:250
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
failures:
test_binary_cross_entropy_preds_almost_correct
test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 1.22s
However, I had days this was not issues.
Since this is seems to be related to the graphics card here is some information about the driver and card im using
nvidia-smi
Mon Apr 14 15:28:36 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 572.83 Driver Version: 572.83 CUDA Version: 12.8 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Driver-Model | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 2060 ... WDDM | 00000000:01:00.0 Off | N/A |
| N/A 59C P8 5W / 65W | 0MiB / 6144MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Let me know if I can help to get to the bottom of this issue.
Ahhh ok I can reproduce on 0.16
---- test_binary_cross_entropy_preds_almost_correct stdout ----
thread 'test_binary_cross_entropy_preds_almost_correct' panicked at src\main.rs:23:17:
Tensors are not approx eq:
=> Position 0: NaN != 0.000000849366756483505 | difference NaN > tolerance 0.00000010000000000000004
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
I was testing on the latest changes on main, which seems to have fixed the underlying issue
test test_binary_cross_entropy_preds_almost_correct ... ok
test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.41s
If you use the main branch instead of the last released version does it work?
With updated brun (to version in main):
burn = { git = "https://github.com/tracel-ai/burn.git", branch="main", features = ["wgpu", "ndarray", "train"] }
| Split | Metric | Min. | Epoch | Max. | Epoch |
|-------|----------|----------|----------|----------|----------|
| Train | Accuracy | 45.861 | 1 | 45.861 | 20 |
| Train | Loss | 0.267 | 13 | NaN | 20 |
| Valid | Accuracy | 46.774 | 1 | 46.774 | 20 |
| Valid | Loss | 0.264 | 13 | NaN | 20 |
Ahhh ok I can reproduce on 0.16
---- test_binary_cross_entropy_preds_almost_correct stdout ---- thread 'test_binary_cross_entropy_preds_almost_correct' panicked at src\main.rs:23:17: Tensors are not approx eq: => Position 0: NaN != 0.000000849366756483505 | difference NaN > tolerance 0.00000010000000000000004 note: run with `RUST_BACKTRACE=1` environment variable to display a backtraceI was testing on the latest changes on main, which seems to have fixed the underlying issue
test test_binary_cross_entropy_preds_almost_correct ... ok test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.41sIf you use the main branch instead of the last released version does it work?
Sorry for the delay! This particular case was resolved by using the main branch. I'll restart the training program and see if it still affects that process.
Edit: The issue seems to be gone!
Closing this since the issue has been resolved.