burn icon indicating copy to clipboard operation
burn copied to clipboard

After using `slice_assign`, gradient descent cannot track the model parameters

Open wcshds opened this issue 1 year ago • 6 comments

I found that after using slice_assign in the loss function, gradient descent cannot track the model parameters. I believe this is the main reason why the loss becomes NaN after the first iteration when I apply my implementation of CTC loss to the CRNN model.

pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 1> {
    let device = input.device();
    let [d1, d2, d3] = input.dims();

    let input2 = Tensor::empty(input.shape(), &device);
    input2
        .clone()
        .slice_assign([0..d1, 0..d2, 0..d3], input.clone());

    input2.mean()
}

wcshds avatar Dec 25 '23 04:12 wcshds

The actual reason for my implementation of CTC loss becoming NaN after iterations is that the logarithm of zero is taken, not due to slice_assign. Now, I am not certain if there is a bug in slice_assign... perhaps more investigation is needed.

wcshds avatar Dec 27 '23 08:12 wcshds

@nathanielsimard I'm trying to train a CRNN model from scratch, but after a day of training, there's still no sign of convergence in the model. Then I noticed that only the parameters of the last layer were being updated. Here is the minimal reproducible example:

use burn::{
    backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
    module::Module,
    nn::{Linear, LinearConfig, Lstm, LstmConfig},
    optim::{AdamConfig, GradientsParams, Optimizer},
    record::{FullPrecisionSettings, PrettyJsonFileRecorder},
    tensor::{
        backend::{AutodiffBackend, Backend},
        Tensor,
    },
};

fn main() {
    run::<Autodiff<NdArray>>(NdArrayDevice::Cpu);
}

fn run<B: AutodiffBackend>(device: B::Device) {
    let mut model = Model::<B>::new(&device);
    let mut optim = AdamConfig::new().init();
    let pfr = PrettyJsonFileRecorder::<FullPrecisionSettings>::new();

    for iteration in 0..51 {
        let input = Tensor::random(
            [2, 10, 5],
            burn::tensor::Distribution::Uniform(-1.0, 1.0),
            &device,
        );
        let output = model.forward(input);
        let loss = output.mean();

        println!(
            "[Train - Iteration {}] Loss {:.5}",
            iteration,
            loss.clone().into_scalar()
        );

        let grads = loss.backward();
        let grads = GradientsParams::from_grads(grads, &model);

        model = optim.step(0.001, model, grads);

        if iteration % 10 == 0 {
            model
                .clone()
                .lstm
                .save_file(format!("./lstm-{:02}", iteration), &pfr)
                .unwrap();
            model
                .clone()
                .linear
                .save_file(format!("./linear-{:02}", iteration), &pfr)
                .unwrap();
        }
    }
}

#[derive(Module, Debug)]
pub struct Model<B: Backend> {
    lstm: Lstm<B>,
    linear: Linear<B>,
}

impl<B: Backend> Model<B> {
    pub fn new(device: &B::Device) -> Self {
        Self {
            lstm: LstmConfig::new(5, 10, true).init(device),
            linear: LinearConfig::new(10, 20).init(device),
        }
    }

    pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 2> {
        let (_, x) = self.lstm.forward(input, None);
        let [batch_size, seq_length, d_hidden] = x.dims();
        let x = x.reshape([batch_size * seq_length, d_hidden]);
        let x = self.linear.forward(x);

        x
    }
}

After some investigation, I believe this is due to the use of slice_assign in the implementation of LSTM. After replacing it with Tensor::cat instead of Tensor::slice_assign, the parameters of LSTM can be updated correctly.

pub fn forward(
    &self,
    batched_input: Tensor<B, 3>,
    state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
) -> (Tensor<B, 3>, Tensor<B, 3>) {
    let [batch_size, seq_length, _] = batched_input.shape().dims;
    let device = &batched_input.device();

    let (mut cell_state, mut hidden_state) = match state {
        Some((cell_state, hidden_state)) => (cell_state, hidden_state),
        None => (
            Tensor::zeros([batch_size, self.d_hidden], device),
            Tensor::zeros([batch_size, self.d_hidden], device),
        ),
    };

    let mut batched_cell_state_vec = Vec::with_capacity(seq_length);
    let mut batched_hidden_state_vec = Vec::with_capacity(seq_length);

    for input_t in batched_input.iter_dim(1) {
        let input_t = input_t.squeeze(1);
        // f(orget)g(ate) tensors
        let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate);
        let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state

        // i(nput)g(ate) tensors
        let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate);
        let add_values = activation::sigmoid(biased_ig_input_sum);

        // o(output)g(ate) tensors
        let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate);
        let output_values = activation::sigmoid(biased_og_input_sum);

        // c(ell)g(ate) tensors
        let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate);
        let candidate_cell_values = biased_cg_input_sum.tanh();

        cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
        hidden_state = output_values * cell_state.clone().tanh();

        let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]];

        let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape);
        let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape);

        // store the state for this timestep
        batched_cell_state_vec.push(unsqueezed_cell_state);
        batched_hidden_state_vec.push(unsqueezed_hidden_state);
    }

    let batched_cell_state = Tensor::cat(batched_cell_state_vec, 1);
    let batched_hidden_state = Tensor::cat(batched_hidden_state_vec, 1);

    (batched_cell_state, batched_hidden_state)
}

wcshds avatar Jan 10 '24 16:01 wcshds

Thanks @wcshds for the example with LSTM, it will help us in fixing this.

nathanielsimard avatar Jan 16 '24 15:01 nathanielsimard

I found that after using slice_assign in the loss function, gradient descent cannot track the model parameters. I believe this is the main reason why the loss becomes NaN after the first iteration when I apply my implementation of CTC loss to the CRNN model.

pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 1> {
    let device = input.device();
    let [d1, d2, d3] = input.dims();

    let input2 = Tensor::empty(input.shape(), &device);
    input2
        .clone()
        .slice_assign([0..d1, 0..d2, 0..d3], input.clone());

    input2.mean()
}

I found the actual problem in this code. Slice assign doesn't actually mutate any data in input2, it returns a new tensor handle that should be used afterward:

pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 1> {
    let device = input.device();
    let [d1, d2, d3] = input.dims();

    let input2 = Tensor::empty(input.shape(), &device);
    let x = input2.slice_assign([0..d1, 0..d2, 0..d3], input);

    x.mean()
}

There are no mutable operation in the tensor API, every operation returns the result that should be used!

Though it doesn't explain the bug with LSTM.

nathanielsimard avatar Jan 16 '24 15:01 nathanielsimard

There are no mutable operation in the tensor API, every operation returns the result that should be used!

Thank you very much for pointing out the actual problem! I often forget that.

wcshds avatar Jan 16 '24 16:01 wcshds

@wcshds I added some tests comparing slice_assign and cat backwards in #1146 but I can't find a bug

louisfd avatar Jan 17 '24 21:01 louisfd