candle
candle copied to clipboard
Error while training multi class classification model having resnet50.
trafficstars
I want to make a simple multi class classification using resnet50 model but receiving errors. I am receiving error on step adam.backward_step(&loss).unwrap()
let num_classes = 2;
let vm = VarMap::new();
let vs = VarBuilder::from_varmap(&vm, DType::F32, &dev);
let model = seq();
let model = model.add(resnet50(2, vs).unwrap());
let mut adam = candle_nn::optim::AdamW::new(vm.all_vars(), ParamsAdamW::default()).unwrap();
let (a, b) = dataset.get_batch(0); //a.shape = [5, 224, 224, 3] b.shape = [1,5], a and b are vectors
let x_train = Tensor::new(a, &dev)
.unwrap()
.transpose(1, 3)
.unwrap()
.to_dtype(DType::F32)
.unwrap();
println!("{:?}", x_train.shape()); //x_train.shape = [5, 3, 224, 224]
let y_train = Tensor::new(b, &dev).unwrap().transpose(0, 1).unwrap();
println!("{:?}", y_train.shape()); //y_train.shape = [5, 1]
let y_pred = softmax(&model.forward(&x_train).unwrap(), 1).unwrap();
println!("{:?}", y_pred.shape()); //y_pred.shape = [5, 2]
let loss = sparse_categorical_cross_entropy(&y_pred, &y_train, num_classes).unwrap();
println!("{:?}", loss.clone()); //Tensor[16.030241: f32]
println!("{:?}", loss.clone().to_scalar::<f32>().unwrap()); //16.030241
adam.backward_step(&loss).unwrap();
The detailed error is given below
thread 'main' panicked at src\main.rs:51:31:
called `Result::unwrap()` on an `Err` value: WithBacktrace { inner: Msg("backward not supported for maxpool2d if ksize (3, 3) != stride (2, 2)"), backtrace: Backtrace [{ fn: "std::backtrace_rs::backtrace::dbghelp::trace", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\..\..\backtrace\src\backtrace\dbghelp.rs", line: 131 }, { fn: "std::backtrace_rs::backtrace::trace_unsynchronized", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\..\..\backtrace\src\backtrace\mod.rs", line: 66 }, { fn: "std::backtrace::Backtrace::create", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\backtrace.rs", line: 331 }, { fn: "std::backtrace::Backtrace::capture", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\backtrace.rs", line: 296 }, { fn: "enum2$<candle_core::error::Error>::bt", file: "C:\Users\Dell\.cargo\git\checkouts\candle-0c2b4fa9e5801351\3318fe3\candle-core\src\error.rs", line: 227 }, { fn: "candle_core::tensor::Tensor::backward", file: "C:\Users\Dell\.cargo\git\checkouts\candle-0c2b4fa9e5801351\3318fe3\candle-core\src\backprop.rs", line: 337 }, { fn: "candle_nn::optim::Optimizer::backward_step<candle_nn::optim::AdamW>", file: "C:\Users\Dell\.cargo\git\checkouts\candle-0c2b4fa9e5801351\3318fe3\candle-nn\src\optim.rs", line: 21 }, { fn: "image_preprocessing::main", file: ".\src\main.rs", line: 51 }, { fn: "core::ops::function::FnOnce::call_once<void (*)(),tuple$<> >", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\ops\function.rs", line: 250 }, { fn: "core::hint::black_box", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\hint.rs", line: 286 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace<void (*)(),tuple$<> >", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\std\src\sys_common\backtrace.rs", line: 155 }, { fn: "std::rt::lang_start::closure$0<tuple$<> >", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\std\src\rt.rs", line: 166 }, { fn: "std::rt::lang_start_internal::closure$2", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\panicking.rs", line: 552 }, { fn: "std::panicking::try", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\panicking.rs", line: 516 }, { fn: "std::panic::catch_unwind", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\rt.rs", line: 148 }, { fn: "std::rt::lang_start<tuple$<> >", file: "/rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\std\src\rt.rs", line: 165 }, { fn: "main" }, { fn: "invoke_main", file: "D:\a\_work\1\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl", line: 78 }, { fn: "__scrt_common_main_seh", file: "D:\a\_work\1\s\src\vctools\crt\vcstartup\src\startup\exe_common.inl", line: 288 }, { fn: "BaseThreadInitThunk" }, { fn: "RtlUserThreadStart" }] }
stack backtrace:
0: std::panicking::begin_panic_handler
at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\std\src\panicking.rs:645
1: core::panicking::panic_fmt
at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\core\src\panicking.rs:72
2: core::result::unwrap_failed
at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library\core\src\result.rs:1649
3: enum2$<core::result::Result<tuple$<>,enum2$<candle_core::error::Error> > >::unwrap<tuple$<>,enum2$<candle_core::error::Error> >
at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\result.rs:1073
4: image_preprocessing::main
at .\src\main.rs:51
5: core::ops::function::FnOnce::call_once<void (*)(),tuple$<> >
at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\ops\function.rs:250
6: core::hint::black_box
at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce\library\core\src\hint.rs:286
Can anyone suggest me what am I doing wrong ?
The error message mentions "backward not supported for maxpool2d if ksize (3, 3) != stride (2, 2)", the resnet model uses a maxpool2d step in a way that is not supported by our backward pass. Hopefully we could fix the backward step to be compatible with such maxpool2d usage at some point but this hasn't been done yet.