burn icon indicating copy to clipboard operation
burn copied to clipboard

Training Issues (NaN) When Migrating from PyTorch

Open oiwn opened this issue 10 months ago • 17 comments

First, I want to express appreciation for the Burn framework - it's a great step toward bringing ML capabilities to Rust. I'm working on migrating several small PyTorch models to Rust, but I've encountered some issues with BCE loss calculation and training behavior.

When migrating a simple binary classifier from PyTorch to Burn, I'm seeing significant differences in:

  • BCE loss calculation
  • Training behaviour
  • Final model accuracy (46% with Burn, 92% with pytorch)

The model architecture is identical between frameworks (input->64->32->1 with ReLU/Sigmoid).

Reproducible Example

I've created a minimal example repository: [burn-problems]

Key test cases demonstrate the issues:

BCE Loss Test

# Python/PyTorch
Test Case 1 - Perfect predictions:
Predictions: tensor([1., 0., 1., 0.])
Targets:     tensor([1., 0., 1., 0.])
Loss:        0.00000000

Test Case 2 - Wrong predictions:
Predictions: tensor([0., 1., 0., 1.])
Targets:     tensor([1., 0., 1., 0.])
Loss:        100.00000000

Test Case 3 - Uncertain predictions:
Predictions: tensor([0.5000, 0.5000, 0.5000, 0.5000])
Targets:     tensor([1., 0., 1., 0.])
Loss:        0.69314718
// Rust/Burn
Test Case 1 - Perfect predictions:
Predictions: tensor([1.0000, 0.0000, 1.0000, 0.0000])
Targets:     tensor([1, 0, 1, 0])
Loss:        NaN

Test Case 2 - Wrong predictions:
Predictions: tensor([0.0000, 1.0000, 0.0000, 1.0000])
Targets:     tensor([1, 0, 1, 0])
Loss:        inf

Test Case 3 - Uncertain predictions:
Predictions: tensor([0.5000, 0.5000, 0.5000, 0.5000])
Targets:     tensor([1, 0, 1, 0])
Loss:        0.69314718

Training Results

  • PyTorch achieves 93.36% accuracy
  • Burn implementation stops at ~46% accuracy with NaN loss
Model:
DemoClassifierModel {
  input_layer: Linear {d_input: 20, d_output: 64, bias: true, params: 1344}
  hidden_layer1: Linear {d_input: 64, d_output: 32, bias: true, params: 2080}
  output_layer: Linear {d_input: 32, d_output: 1, bias: true, params: 33}
  activation: Relu
  sigmoid: Sigmoid
  params: 3457
}
Total Epochs: 20


| Split | Metric   | Min.     | Epoch    | Max.     | Epoch    |
|-------|----------|----------|----------|----------|----------|
| Train | Accuracy | 46.113   | 1        | 46.113   | 20       |
| Train | Loss     | 0.279    | 10       | NaN      | 20       |
| Valid | Accuracy | 45.766   | 1        | 45.766   | 20       |
| Valid | Loss     | 0.282    | 10       | NaN      | 20       |

Noticed that accuracy is same as zero labels distribution in training set:

Train distribution: total=14690, zeroes=6742 (45.9%), ones=7948 (54.1%)

The critical section appears to be in BCE loss calculation:

// Current implementation
let loss = BinaryCrossEntropyLossConfig::new()
    .init(&output.device())
    .forward(output.clone().squeeze(1), targets.clone());

// Alternative attempt
// let loss = BinaryCrossEntropyLossConfig::new()
//     .init(&output.device())
//     .forward(output.clone(), targets.clone().reshape([batch_size, 1]));

Questions

  1. Recommended way to handle tensor shapes for BCE loss in Burn?
  2. Are there any known issues with batch dimension handling that could cause these discrepancies?
  3. Should the loss calculation approach differ from PyTorch's implementation?

oiwn avatar Jan 24 '25 10:01 oiwn

Thanks for flagging this!

I believe this is due to the current implementation of the BCE loss for tensor.log() which results in -inf for values of 0.0. We need to clamp the values to make sure we don't have this issue.

Should have a PR to fix this soon.

laggui avatar Jan 24 '25 14:01 laggui

@laggui Latest version of Burn unable to compile with backend:Wgpu:

error[E0275]: overflow evaluating the requirement `wgpu_core::validation::NumericType: Sync`                                                                                               
    |                                                                                                                                                                                      
    = help: consider increasing the recursion limit by adding a `#![recursion_limit = "256"]` attribute to your crate (`burn_problems`)                                                    
note: required because it appears within the type `wgpu_core::validation::InterfaceVar`                                                                                                    
   --> .cargo/registry/src/index.crates.io-6f17d22bba15001f/wgpu-core-24.0.0/src/validation.rs:109:12
    |
109 | pub struct InterfaceVar {
    |            ^^^^^^^^^^^^
note: required because it appears within the type `wgpu_core::validation::Varying`
   --> .cargo/registry/src/index.crates.io-6f17d22bba15001f/wgpu-core-24.0.0/src/validation.rs:136:6
    |
136 | enum Varying {
    |      ^^^^^^^
note: required because it appears within the type `PhantomData<wgpu_core::validation::Varying>`
   --> .rustup/toolchains/stable-aarch64-apple-darwin/lib/rustlib/src/rust/library/core/src/marker.rs:753:12
    |
753 | pub struct PhantomData<T: ?Sized>;

....

With backend::NdArray i got same results - 46% probably issues is not in BCELoss =(

https://github.com/oiwn/burn-problems/commit/c36f56a6ff3b2969d798cc3c25e4fb70120103f6

Model:                                                                                                                                                                                     
DemoClassifierModel {                                                                                                                                                                      
  input_layer: Linear {d_input: 20, d_output: 64, bias: true, params: 1344}                                                                                                                
  hidden_layer1: Linear {d_input: 64, d_output: 32, bias: true, params: 2080}                                                                                                              
  output_layer: Linear {d_input: 32, d_output: 1, bias: true, params: 33}                                                                                                                  
  activation: Relu                                                                                                                                                                         
  sigmoid: Sigmoid                                                                                                                                                                         
  params: 3457                                                                                                                                                                             
}                                                                                                                                                                                          
Total Epochs: 20                                                                                                                                                                           
                                                                                                                                                                                           
                                                                                                                                                                                           
| Split | Metric   | Min.     | Epoch    | Max.     | Epoch    |                                                                                                                           
|-------|----------|----------|----------|----------|----------|                                                                                                                           
| Train | Loss     | 0.265    | 13       | NaN      | 20       |                                                                                                                           
| Train | Accuracy | 46.052   | 1        | 46.052   | 20       |                                                                                                                           
| Valid | Loss     | 0.258    | 13       | NaN      | 20       |                                                                                                                           
| Valid | Accuracy | 46.011   | 1        | 46.011   | 20       |  

oiwn avatar Jan 24 '25 17:01 oiwn

Yeah we realized this the other day with the upgrade to wgpu 0.24.0.. see this discord convo for reference.

This seems to stem from new complex types in wgpu. As a temporary fix you can actually follow the compiler's help: increase the recursion limit (default is 128). You probably don't need to double it to 256, something around 140 should work.

With backend::NdArray i got same results - 46% probably issues is not in BCELoss =(

I haven't actually tested the whole thing, just isolated the bce loss bug initially 😅 Seems weird that your loss still NaNs 🤔 I'll check it out

/edit: just took a quick glance, looks like it's actually coming from the first linear layer parameters becoming NaN at some point. I'm assuming you validated the input data?

laggui avatar Jan 24 '25 17:01 laggui

@laggui Input data are identical for PyTorch and Burn.

Pytorch:

=== Training Data Check (Python) ===

Dataset sizes:
Train: 14690 Test: 3673

Feature statistics (train) (first 3):

Feature 0 (feature1):
Mean:   0.0029
StdDev: 1.0035
Min:    -2.1987
Max:    5.7738

Feature 1 (feature2):
Mean:   -0.0012
StdDev: 0.9764
Min:    -1.3040
Max:    17.6505

Feature 2 (feature3):
Mean:   0.0044
StdDev: 1.0126
Min:    -0.7085
Max:    7.9158


Burn:

=== Training Data Check (Rust) ===

Dataset sizes:
Train: 14690 Test: 3673

Feature statistics (train) (first 3):

Feature 0:
Mean:   -0.0034
StdDev: 1.0025
Min:    -2.1987
Max:    5.7738

Feature 1:
Mean:   0.0006
StdDev: 0.9915
Min:    -1.3040
Max:    21.7539

Feature 2:
Mean:   0.0012
StdDev: 1.0048
Min:    -0.7085
Max:    7.9158

oiwn avatar Jan 25 '25 14:01 oiwn

Some slight variations, but as long as the inputs during training don't have some weird values that deviate then I don't think that will be the issue.

I'll reopen this issue but it doesn't seem to be specific to the BCE loss anymore.

laggui avatar Jan 27 '25 13:01 laggui

It might also be a difference in the training configuration. If you have a higher learning rate or are missing weight decay, it might lead to unstable training, resulting in NaN values, which render the model useless.

nathanielsimard avatar Jan 28 '25 14:01 nathanielsimard

It might also be a difference in the training configuration. If you have a higher learning rate or are missing weight decay, it might lead to unstable training, resulting in NaN values, which render the model useless.

Tried different learning rates with no luck. There is strange correlation between training accuracy (46%) and amount of zeroes in training set labels.

oiwn avatar Jan 29 '25 12:01 oiwn

There is strange correlation between training accuracy (46%) and amount of zeroes in training set labels.

The model has not converged as expected if you're getting NaNs during training, so the solution that gives a 46% accuracy probably defaulted to a simple heuristic that always predicts the same label. And so it is incorrect for all the zero labels if it always predicts one 🙂

Regarding the cause for the NaNs, not quite sure at first glance. Would have to spend a bit of time to investigate.

laggui avatar Jan 29 '25 16:01 laggui

@laggui thank you!

oiwn avatar Jan 31 '25 16:01 oiwn

Hey all,

I seem to have the same issue. Every now and then and seemingly at random I get nan as loss.

Here is an example that caused the loss to be nan:

loss_d_real contains nan Tensor {
  data:
[NaN],
  shape:  [1],
  device:  DefaultDevice,
  backend:  "autodiff<fusion<jit<wgpu<wgsl>>>>",
  kind:  "Float",
  dtype:  "f32",
}. it was made from Tensor {
  data:
[0.9999989, 0.9999994, 1.0, 0.9999983],
  shape:  [4],
  device:  DefaultDevice,
  backend:  "autodiff<fusion<jit<wgpu<wgsl>>>>",
  kind:  "Float",
  dtype:  "f32",
} and Tensor {
  data:
[1, 1, 1, 1],
  shape:  [4],
  device:  DefaultDevice,
  backend:  "autodiff<fusion<jit<wgpu<wgsl>>>>",
  kind:  "Int",
  dtype:  "i32",
} with bce loss

VirtualNonsense avatar Apr 09 '25 16:04 VirtualNonsense

use burn::{
    backend::{Autodiff, Wgpu},
    nn::loss::BinaryCrossEntropyLossConfig,
    tensor::{Int, Tensor, TensorData},
};

#[test]
fn test_binary_cross_entropy_preds_almost_correct() {
    type MyBackend = Wgpu<f32, i32>;
    type MyAutodiffBackend = Autodiff<MyBackend>;

    let device = burn::backend::wgpu::WgpuDevice::default();
    let preds = Tensor::<MyAutodiffBackend, 1>::from_floats(
        [0.9999989, 0.9999994, 1.0, 0.9999983],
        &device,
    );
    let targets =
        Tensor::<MyAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 1, 1, 1]), &device);
    let bce = BinaryCrossEntropyLossConfig::new().init(&device);
    let loss_actual = bce.forward(preds, targets).into_data();

    let loss_expected = TensorData::from([8.49366756483505e-07]);
    loss_actual.assert_approx_eq(&loss_expected, 7);
}

I adjusted one of your testcases to reproduce this behavior.

edit: I cross checked it with pytorch and adjusted the test with the expected value

VirtualNonsense avatar Apr 09 '25 16:04 VirtualNonsense

edit: I cross checked it with pytorch and adjusted the test with the expected value

I think the problem is in BCELoss calculation, it show NaN/inf on border cases where pytorch clamp them into 0 and 100.

oiwn avatar Apr 14 '25 04:04 oiwn

use burn::{ backend::{Autodiff, Wgpu}, nn::loss::BinaryCrossEntropyLossConfig, tensor::{Int, Tensor, TensorData}, };

#[test] fn test_binary_cross_entropy_preds_almost_correct() { type MyBackend = Wgpu<f32, i32>; type MyAutodiffBackend = Autodiff<MyBackend>;

let device = burn::backend::wgpu::WgpuDevice::default();
let preds = Tensor::<MyAutodiffBackend, 1>::from_floats(
    [0.9999989, 0.9999994, 1.0, 0.9999983],
    &device,
);
let targets =
    Tensor::<MyAutodiffBackend, 1, Int>::from_data(TensorData::from([1, 1, 1, 1]), &device);
let bce = BinaryCrossEntropyLossConfig::new().init(&device);
let loss_actual = bce.forward(preds, targets).into_data();

let loss_expected = TensorData::from([8.49366756483505e-07]);
loss_actual.assert_approx_eq(&loss_expected, 7);

} I adjusted one of your testcases to reproduce this behavior.

edit: I cross checked it with pytorch and adjusted the test with the expected value

Is this test case a reproducible failure on your machine? Just tried it locally and it passes 🤔

laggui avatar Apr 14 '25 12:04 laggui

Currently it is.

thread 'test_binary_cross_entropy_preds_almost_correct' panicked at crates\sketchy_pix2pix\tests\bce_test.rs:23:17:
Tensors are not approx eq:
  => Position 0: NaN != 0.000000849366756483505 | difference NaN > tolerance 0.00000010000000000000004
stack backtrace:
   0: std::panicking::begin_panic_handler
             at /rustc/4eb161250e340c8f48f66e2b929ef4a5bed7c181/library\std\src\panicking.rs:692
   1: core::panicking::panic_fmt
             at /rustc/4eb161250e340c8f48f66e2b929ef4a5bed7c181/library\core\src\panicking.rs:75
   2: burn_tensor::tensor::data::TensorData::assert_approx_eq_diff
             at C:\Users\user\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\burn-tensor-0.16.1\src\tensor\data.rs:594
   3: burn_tensor::tensor::data::TensorData::assert_approx_eq
             at C:\Users\user\.cargo\registry\src\index.crates.io-1949cf8c6b5b557f\burn-tensor-0.16.1\src\tensor\data.rs:434
   4: bce_test::test_binary_cross_entropy_preds_almost_correct
             at .\tests\bce_test.rs:23
   5: bce_test::test_binary_cross_entropy_preds_almost_correct::closure$0
             at .\tests\bce_test.rs:8
   6: core::ops::function::FnOnce::call_once<bce_test::test_binary_cross_entropy_preds_almost_correct::closure_env$0,tuple$<> >
             at C:\Users\user\.rustup\toolchains\stable-x86_64-pc-windows-msvc\lib\rustlib\src\rust\library\core\src\ops\function.rs:250
   7: core::ops::function::FnOnce::call_once
             at /rustc/4eb161250e340c8f48f66e2b929ef4a5bed7c181/library\core\src\ops\function.rs:250
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.


failures:
    test_binary_cross_entropy_preds_almost_correct

test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 1.22s

However, I had days this was not issues.

Since this is seems to be related to the graphics card here is some information about the driver and card im using

nvidia-smi
Mon Apr 14 15:28:36 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 572.83                 Driver Version: 572.83         CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 2060 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   59C    P8              5W /   65W |       0MiB /   6144MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Let me know if I can help to get to the bottom of this issue.

VirtualNonsense avatar Apr 14 '25 13:04 VirtualNonsense

Ahhh ok I can reproduce on 0.16

---- test_binary_cross_entropy_preds_almost_correct stdout ----

thread 'test_binary_cross_entropy_preds_almost_correct' panicked at src\main.rs:23:17:
Tensors are not approx eq:
  => Position 0: NaN != 0.000000849366756483505 | difference NaN > tolerance 0.00000010000000000000004
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

I was testing on the latest changes on main, which seems to have fixed the underlying issue

test test_binary_cross_entropy_preds_almost_correct ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.41s

If you use the main branch instead of the last released version does it work?

laggui avatar Apr 14 '25 13:04 laggui

With updated brun (to version in main):

burn = { git = "https://github.com/tracel-ai/burn.git", branch="main", features = ["wgpu", "ndarray", "train"] }
| Split | Metric   | Min.     | Epoch    | Max.     | Epoch    |
|-------|----------|----------|----------|----------|----------|
| Train | Accuracy | 45.861   | 1        | 45.861   | 20       |
| Train | Loss     | 0.267    | 13       | NaN      | 20       |
| Valid | Accuracy | 46.774   | 1        | 46.774   | 20       |
| Valid | Loss     | 0.264    | 13       | NaN      | 20       |

oiwn avatar Apr 15 '25 06:04 oiwn

Ahhh ok I can reproduce on 0.16

---- test_binary_cross_entropy_preds_almost_correct stdout ----

thread 'test_binary_cross_entropy_preds_almost_correct' panicked at src\main.rs:23:17:
Tensors are not approx eq:
  => Position 0: NaN != 0.000000849366756483505 | difference NaN > tolerance 0.00000010000000000000004
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

I was testing on the latest changes on main, which seems to have fixed the underlying issue

test test_binary_cross_entropy_preds_almost_correct ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.41s

If you use the main branch instead of the last released version does it work?

Sorry for the delay! This particular case was resolved by using the main branch. I'll restart the training program and see if it still affects that process.

Edit: The issue seems to be gone!

VirtualNonsense avatar Apr 15 '25 06:04 VirtualNonsense

Closing this since the issue has been resolved.

antimora avatar Oct 01 '25 00:10 antimora