candle icon indicating copy to clipboard operation
candle copied to clipboard

Error while training multi class classification model having resnet50.

Open bhavaysavaliya opened this issue 1 year ago • 1 comments
trafficstars

I want to make a simple multi class classification using resnet50 model but receiving errors. I am receiving error on step adam.backward_step(&loss).unwrap()

    let num_classes = 2;
    let vm = VarMap::new();
    let vs = VarBuilder::from_varmap(&vm, DType::F32, &dev);

    let model = seq();
    let model = model.add(resnet50(2, vs).unwrap());
    let mut adam = candle_nn::optim::AdamW::new(vm.all_vars(), ParamsAdamW::default()).unwrap();
    let (a, b) = dataset.get_batch(0); //a.shape = [5, 224, 224, 3] b.shape = [1,5], a and b are vectors
    let x_train = Tensor::new(a, &dev) 
        .unwrap()
        .transpose(1, 3)
        .unwrap()
        .to_dtype(DType::F32)
        .unwrap();
    println!("{:?}", x_train.shape()); //x_train.shape = [5, 3, 224, 224]
    let y_train = Tensor::new(b, &dev).unwrap().transpose(0, 1).unwrap(); 
    println!("{:?}", y_train.shape()); //y_train.shape = [5, 1]
    let y_pred = softmax(&model.forward(&x_train).unwrap(), 1).unwrap();
    println!("{:?}", y_pred.shape()); //y_pred.shape = [5, 2]
    let loss = sparse_categorical_cross_entropy(&y_pred, &y_train, num_classes).unwrap();
    println!("{:?}", loss.clone()); //Tensor[16.030241: f32]
    println!("{:?}", loss.clone().to_scalar::<f32>().unwrap()); //16.030241
    adam.backward_step(&loss).unwrap();

The detailed error is given below

thread 'main' panicked at src\main.rs:51:31:
called `Result::unwrap()` on an `Err` value: WithBacktrace { inner: Msg("backward not supported for maxpool2d if ksize (3, 3) != stride (2, 2)"), backtrace: Backtrace [{ fn: "std::backtrace_rs::backtrace::dbghelp::trace", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\..\..\backtrace\src\backtrace\dbghelp.rs", line: 131 }, { fn: "std::backtrace_rs::backtrace::trace_unsynchronized", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\..\..\backtrace\src\backtrace\mod.rs", line: 66 }, { fn: "std::backtrace::Backtrace::create", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\backtrace.rs", line: 331 }, { fn: "std::backtrace::Backtrace::capture", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\backtrace.rs", line: 296 }, { fn: "enum2$<candle_core::error::Error>::bt", file: "C:\Users\Dell\.cargo\git\checkouts\candle-0c2b4fa9e5801351\3318fe3\candle-core\src\error.rs", line: 227 }, { fn: "candle_core::tensor::Tensor::backward", file: "C:\Users\Dell\.cargo\git\checkouts\candle-0c2b4fa9e5801351\3318fe3\candle-core\src\backprop.rs", line: 337 }, { fn: "candle_nn::optim::Optimizer::backward_step<candle_nn::optim::AdamW>", file: "C:\Users\Dell\.cargo\git\checkouts\candle-0c2b4fa9e5801351\3318fe3\candle-nn\src\optim.rs", line: 21 }, { fn: "image_preprocessing::main", file: ".\src\main.rs", line: 51 }, { fn: "core::ops::function::FnOnce::call_once<void (*)(),tuple$<> >", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\ops\function.rs", line: 250 }, { fn: "core::hint::black_box", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\hint.rs", line: 286 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace<void (*)(),tuple$<> >", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\std\src\sys_common\backtrace.rs", line: 155 }, { fn: "std::rt::lang_start::closure$0<tuple$<> >", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\std\src\rt.rs", line: 166 }, { fn: "std::rt::lang_start_internal::closure$2", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\panicking.rs", line: 552 }, { fn: "std::panicking::try", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\panicking.rs", line: 516 }, { fn: "std::panic::catch_unwind", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\rt.rs", line: 148 }, { fn: "std::rt::lang_start<tuple$<> >", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\std\src\rt.rs", line: 165 }, { fn: "main" }, { fn: "invoke_main", file: "D:\a\_work\1\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl", line: 78 }, { fn: "__scrt_common_main_seh", file: "D:\a\_work\1\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl", line: 288 }, { fn: "BaseThreadInitThunk" }, { fn: "RtlUserThreadStart" }] }
stack backtrace:
   0: std::panicking::begin_panic_handler
             at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\panicking.rs:645
   1: core::panicking::panic_fmt
             at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\core\src\panicking.rs:72
   2: core::result::unwrap_failed
             at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\core\src\result.rs:1649
   3: enum2$<core::result::Result<tuple$<>,enum2$<candle_core::error::Error> > >::unwrap<tuple$<>,enum2$<candle_core::error::Error> >
             at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\result.rs:1073
   4: image_preprocessing::main
             at .\src\main.rs:51
   5: core::ops::function::FnOnce::call_once<void (*)(),tuple$<> >
             at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\ops\function.rs:250
   6: core::hint::black_box
             at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\hint.rs:286

Can anyone suggest me what am I doing wrong ?

bhavaysavaliya avatar Mar 15 '24 05:03 bhavaysavaliya

The error message mentions "backward not supported for maxpool2d if ksize (3, 3) != stride (2, 2)", the resnet model uses a maxpool2d step in a way that is not supported by our backward pass. Hopefully we could fix the backward step to be compatible with such maxpool2d usage at some point but this hasn't been done yet.

LaurentMazare avatar Mar 16 '24 21:03 LaurentMazare