Poor performance in back propagation
(All numbers are measured on an M1 Max, with samply -- a sampling CPU profiler)
The performance numbers I'll share are for the following model:
const N_EMBD: usize = 32;
const N_HIDDEN: usize = 128;
let model = Sequential::new()
// Embedding layer
.add(candle_nn::embedding(
num_hero_tokens.into(),
N_EMBD,
vb.push_prefix("embedding"),
)?)
// First hidden layer
.add(FlattenConsecutive::new(5))
.add(candle_nn::linear_no_bias(
N_EMBD * 5,
N_HIDDEN,
vb.push_prefix("linear[0]"),
)?)
.add(TransposedBatchNorm::new(candle_nn::batch_norm(
N_HIDDEN,
candle_nn::BatchNormConfig::default(),
vb.push_prefix(format!("batch_norm[0]")),
)?))
.add(|t: &Tensor| t.tanh())
// Second hidden layer
.add(FlattenConsecutive::new(2))
.add(|t: &Tensor| t.squeeze(1))
.add(candle_nn::linear_no_bias(
N_HIDDEN * 2,
N_HIDDEN,
vb.push_prefix("linear[1]"),
)?)
.add(candle_nn::batch_norm(
N_HIDDEN,
candle_nn::BatchNormConfig::default(),
vb.push_prefix(format!("batch_norm[1]")),
)?)
.add(|t: &Tensor| t.tanh())
// Output layer
.add(candle_nn::linear(
N_HIDDEN,
1,
vb.push_prefix("linear[output]"),
)?)
// Sigmoid ensures that probabilities are between 0 and 1
.add(candle_nn::Activation::Sigmoid)
.add(|t: &Tensor| t.squeeze(1));
And training is done with the following code:
fn train(
dev: &candle_core::Device,
vm: &candle_nn::VarMap,
model: &impl ModuleT,
x: &Tensor,
y: &Tensor,
) -> anyhow::Result<()> {
let mut opt = candle_nn::AdamW::new_lr(vm.all_vars(), 0.01)?;
let mut rng = rand::rng();
for i in 0..TRAINING_EPOCHS {
let idx = Tensor::from_iter(
(0..TRAINING_BATCH_SIZE).map(|_| rng.random_range(0..x.dims()[0]) as u32),
dev,
)?;
let x_batch = x.index_select(&idx, 0)?;
let y_batch = y.index_select(&idx, 0)?;
let logits = model.forward_t(&x_batch, true)?;
let loss = candle_nn::loss::binary_cross_entropy_with_logit(&logits, &y_batch)
.context("cross_entropy")?;
if i % 1000 == 0 {
println!("{i:7}/{TRAINING_EPOCHS}: {}", loss.to_scalar::<f32>()?);
}
opt.backward_step(&loss)?;
}
Ok(())
}
Running this on a Cpu device I see that:
- 95% of time is spent in
backward_step- 89% in
Tensor::backward - 6% in
AdamW::step
- 89% in
- 4% of time is spent in
forward_t
I'm not much of an ML-hand, but speaking to a friend with far more expertise, he tells me that:
Very roughly, backprop is typically 2x the FLOPs of the forward pass, and between 1-2x the number of kernels/operators. So if it's much more than 2x the time, that's suspicious
Breaking down Tensor::backward:
- 57% in
Tensor::add - 15% in
Storage::reduce_op - 6% in
Arc::drop_slow - 5% in
GradStore::or_insert - 3.5% in
Tensor::matmul - etc.
I also ran with the Metal backend (and did not experience a significant speedup relative to CPU), and obtained the following rough breakdown:
- 60% of time in backprop
- 20%
GradStore::or_insert
- 20%
- 33% of time in
AdamW::step - 6% in
forward_t
However, these numbers should be considered less reliable: because the numbers are captured with a sampling CPU profiler, time spent is really in synchronization points between the CPU/GPU or userspace/kernel, so these may not be a proper reflection of where computation time is spent.
With all that said, the conclusion that backwards is too much seems inescapable, and I want to offer a few observations:
- There is no way to reuse a
GradStore(zeroing out the gradientTensors, instead of allocating them fresh on eachTensor::backwardscall). This leads to additional allocations (I'm far from an expert, but my understanding is that many GPU allocators are quite slow, so avoiding tons of extra allocation traffic is desirable.) - Because
Tensoris immutable, all of theTensor::addinbackwardsreally add up (pun not intended): You're going to haveaddproportional to the number of ops in back propagation (because everyone has to add to the accmumulated gradient), and in the current design each one of those is an allocation (and then a free!) and not an in-place op. AdamW::steptime feels excessive in my Metal measurements, but since it didn't reproduce under CPU, I'm ignoring it for now.
I'm not proposing any specific actions, as some of these cut right to the core of candle's design decisions and I don't want to get ahead of myself.
Happy to share more of the code if that'd be helpful, also happy to run any additional tests or experiments. Cheers!
the tensor is the problem
you dont need it for backprop
#[test]
fn its_a_neural_network() {
// 1. replace matrix multiplication with angle composition: Wx+b → [|W||x|, θW+θx]
// create input and weight geometric numbers
let input = Geonum {
length: 2.0,
angle: 0.5,
blade: 1,
};
let weight = Geonum {
length: 1.5,
angle: 0.3,
blade: 1,
};
let bias = Geonum {
length: 0.5,
angle: 0.0,
blade: 0, // scalar (grade 0) - bias is a pure magnitude without direction
};
// traditional neural network: output = activation(Wx + b)
// with geometric numbers, we directly compose lengths and angles
// compute layer output using forward_pass method
let output = input.forward_pass(&weight, &bias);
// apply activation function using activate method with Activation enum
let activated = output.activate(Activation::ReLU);
// 2. eliminate backpropagation matrix chain rule with reverse angle adjustment
// traditional backpropagation requires matrix operations through the network
// with geometric numbers, we can directly adjust angles and lengths
// compute error gradient (simplified)
let target = Geonum {
length: 3.0,
angle: 1.0,
blade: 1,
};
let error = Geonum {
length: (target.length - activated.length).abs(),
angle: target.angle - activated.angle,
blade: 0, // scalar (grade 0) - error magnitude is a pure scalar value
};
// update weights via direct angle and length adjustments
let learning_rate = 0.1;
let _updated_weight = Geonum {
length: weight.length + learning_rate * error.length * input.length,
angle: weight.angle + learning_rate * error.angle,
blade: 1,
};
// 3. demonstrate activation functions as angle threshold operations
// use the built-in activate method for sigmoid activation with Activation enum
let sigmoid_output = output.activate(Activation::Sigmoid);
// 4. measure performance: neural network operations with geometric numbers
// are O(n) vs O(n²) for traditional networks
assert!(
sigmoid_output.length > 0.0,
"activation should produce non-zero output"
);
}
and you dont need them in general: tensor_test.rs
run the tests yourselves :
git clone https://github.com/mxfactorial/geonum.git
cd geonum
cargo test -- --show-output
youre welcome to extend test suites to reveal feature gaps
https://crates.io/crates/geonum enables replacing O(k^n) tensors with O(1) multivectors
I'm not sure that "use a completely different project" is a super productive reply :-) The goal here is to find ways to improve candle's performance.
you can be sure "use a completely different project" doesnt appear in the reply
tensor is an O(k^n) module
https://crates.io/crates/geonum tests empirically prove its an O(1) crate
replacing the module with a crate is an O(k^n) to O(1) way to "to improve candle's performance"
Thanks @alex for the feedback, great to have a detailed analysis of the current limitations of candle and how we can improve on this, hopefully there will be a few updates that should help soon.
@LaurentMazare great! if there's particular issues/PRs to follow along with, that'd be helpful.