candle icon indicating copy to clipboard operation
candle copied to clipboard

Poor performance in back propagation

Open alex opened this issue 8 months ago • 5 comments

(All numbers are measured on an M1 Max, with samply -- a sampling CPU profiler)

The performance numbers I'll share are for the following model:

const N_EMBD: usize = 32;
const N_HIDDEN: usize = 128;

let model = Sequential::new()
    // Embedding layer
    .add(candle_nn::embedding(
        num_hero_tokens.into(),
        N_EMBD,
        vb.push_prefix("embedding"),
    )?)
    // First hidden layer
    .add(FlattenConsecutive::new(5))
    .add(candle_nn::linear_no_bias(
        N_EMBD * 5,
        N_HIDDEN,
        vb.push_prefix("linear[0]"),
    )?)
    .add(TransposedBatchNorm::new(candle_nn::batch_norm(
        N_HIDDEN,
        candle_nn::BatchNormConfig::default(),
        vb.push_prefix(format!("batch_norm[0]")),
    )?))
    .add(|t: &Tensor| t.tanh())
    // Second hidden layer
    .add(FlattenConsecutive::new(2))
    .add(|t: &Tensor| t.squeeze(1))
    .add(candle_nn::linear_no_bias(
        N_HIDDEN * 2,
        N_HIDDEN,
        vb.push_prefix("linear[1]"),
    )?)
    .add(candle_nn::batch_norm(
        N_HIDDEN,
        candle_nn::BatchNormConfig::default(),
        vb.push_prefix(format!("batch_norm[1]")),
    )?)
    .add(|t: &Tensor| t.tanh())
    // Output layer
    .add(candle_nn::linear(
        N_HIDDEN,
        1,
        vb.push_prefix("linear[output]"),
    )?)
    // Sigmoid ensures that probabilities are between 0 and 1
    .add(candle_nn::Activation::Sigmoid)
    .add(|t: &Tensor| t.squeeze(1));

And training is done with the following code:

fn train(
    dev: &candle_core::Device,
    vm: &candle_nn::VarMap,
    model: &impl ModuleT,
    x: &Tensor,
    y: &Tensor,
) -> anyhow::Result<()> {
    let mut opt = candle_nn::AdamW::new_lr(vm.all_vars(), 0.01)?;
    let mut rng = rand::rng();
    for i in 0..TRAINING_EPOCHS {
        let idx = Tensor::from_iter(
            (0..TRAINING_BATCH_SIZE).map(|_| rng.random_range(0..x.dims()[0]) as u32),
            dev,
        )?;
        let x_batch = x.index_select(&idx, 0)?;
        let y_batch = y.index_select(&idx, 0)?;

        let logits = model.forward_t(&x_batch, true)?;
        let loss = candle_nn::loss::binary_cross_entropy_with_logit(&logits, &y_batch)
            .context("cross_entropy")?;

        if i % 1000 == 0 {
            println!("{i:7}/{TRAINING_EPOCHS}: {}", loss.to_scalar::<f32>()?);
        }

        opt.backward_step(&loss)?;
    }

    Ok(())
}

Running this on a Cpu device I see that:

  • 95% of time is spent in backward_step
    • 89% in Tensor::backward
    • 6% in AdamW::step
  • 4% of time is spent in forward_t

I'm not much of an ML-hand, but speaking to a friend with far more expertise, he tells me that:

Very roughly, backprop is typically 2x the FLOPs of the forward pass, and between 1-2x the number of kernels/operators. So if it's much more than 2x the time, that's suspicious

Breaking down Tensor::backward:

  • 57% in Tensor::add
  • 15% in Storage::reduce_op
  • 6% in Arc::drop_slow
  • 5% in GradStore::or_insert
  • 3.5% in Tensor::matmul
  • etc.

I also ran with the Metal backend (and did not experience a significant speedup relative to CPU), and obtained the following rough breakdown:

  • 60% of time in backprop
    • 20% GradStore::or_insert
  • 33% of time in AdamW::step
  • 6% in forward_t

However, these numbers should be considered less reliable: because the numbers are captured with a sampling CPU profiler, time spent is really in synchronization points between the CPU/GPU or userspace/kernel, so these may not be a proper reflection of where computation time is spent.

With all that said, the conclusion that backwards is too much seems inescapable, and I want to offer a few observations:

  • There is no way to reuse a GradStore (zeroing out the gradient Tensors, instead of allocating them fresh on each Tensor::backwards call). This leads to additional allocations (I'm far from an expert, but my understanding is that many GPU allocators are quite slow, so avoiding tons of extra allocation traffic is desirable.)
  • Because Tensor is immutable, all of the Tensor::add in backwards really add up (pun not intended): You're going to have add proportional to the number of ops in back propagation (because everyone has to add to the accmumulated gradient), and in the current design each one of those is an allocation (and then a free!) and not an in-place op.
  • AdamW::step time feels excessive in my Metal measurements, but since it didn't reproduce under CPU, I'm ignoring it for now.

I'm not proposing any specific actions, as some of these cut right to the core of candle's design decisions and I don't want to get ahead of myself.

Happy to share more of the code if that'd be helpful, also happy to run any additional tests or experiments. Cheers!

alex avatar Apr 22 '25 03:04 alex

the tensor is the problem

you dont need it for backprop

machine_learning_test.rs:

#[test]
fn its_a_neural_network() {
    // 1. replace matrix multiplication with angle composition: Wx+b → [|W||x|, θW+θx]

    // create input and weight geometric numbers
    let input = Geonum {
        length: 2.0,
        angle: 0.5,
        blade: 1,
    };
    let weight = Geonum {
        length: 1.5,
        angle: 0.3,
        blade: 1,
    };
    let bias = Geonum {
        length: 0.5,
        angle: 0.0,
        blade: 0, // scalar (grade 0) - bias is a pure magnitude without direction
    };

    // traditional neural network: output = activation(Wx + b)
    // with geometric numbers, we directly compose lengths and angles

    // compute layer output using forward_pass method
    let output = input.forward_pass(&weight, &bias);

    // apply activation function using activate method with Activation enum
    let activated = output.activate(Activation::ReLU);

    // 2. eliminate backpropagation matrix chain rule with reverse angle adjustment

    // traditional backpropagation requires matrix operations through the network
    // with geometric numbers, we can directly adjust angles and lengths

    // compute error gradient (simplified)
    let target = Geonum {
        length: 3.0,
        angle: 1.0,
        blade: 1,
    };
    let error = Geonum {
        length: (target.length - activated.length).abs(),
        angle: target.angle - activated.angle,
        blade: 0, // scalar (grade 0) - error magnitude is a pure scalar value
    };

    // update weights via direct angle and length adjustments
    let learning_rate = 0.1;
    let _updated_weight = Geonum {
        length: weight.length + learning_rate * error.length * input.length,
        angle: weight.angle + learning_rate * error.angle,
        blade: 1,
    };

    // 3. demonstrate activation functions as angle threshold operations

    // use the built-in activate method for sigmoid activation with Activation enum
    let sigmoid_output = output.activate(Activation::Sigmoid);

    // 4. measure performance: neural network operations with geometric numbers
    // are O(n) vs O(n²) for traditional networks
    assert!(
        sigmoid_output.length > 0.0,
        "activation should produce non-zero output"
    );
}

and you dont need them in general: tensor_test.rs

run the tests yourselves :

git clone https://github.com/mxfactorial/geonum.git
cd geonum
cargo test -- --show-output

youre welcome to extend test suites to reveal feature gaps

https://crates.io/crates/geonum enables replacing O(k^n) tensors with O(1) multivectors

mxfactorial avatar Apr 24 '25 04:04 mxfactorial

I'm not sure that "use a completely different project" is a super productive reply :-) The goal here is to find ways to improve candle's performance.

alex avatar Apr 24 '25 10:04 alex

you can be sure "use a completely different project" doesnt appear in the reply

tensor is an O(k^n) module

https://crates.io/crates/geonum tests empirically prove its an O(1) crate

replacing the module with a crate is an O(k^n) to O(1) way to "to improve candle's performance"

mxfactorial avatar Apr 24 '25 12:04 mxfactorial

Thanks @alex for the feedback, great to have a detailed analysis of the current limitations of candle and how we can improve on this, hopefully there will be a few updates that should help soon.

LaurentMazare avatar Apr 24 '25 12:04 LaurentMazare

@LaurentMazare great! if there's particular issues/PRs to follow along with, that'd be helpful.

alex avatar Apr 24 '25 21:04 alex