djl
djl copied to clipboard
I implement a ssd. It can run without Exception, but zhe gradients are allways zeros
Description
I implement a ssd. It can run without Exception, but zhe gradients are allways zeros. djl version is 0.12.0
Expected Behavior
there should be some gradients which is not all zeros.
Error Message
no error message
How to Reproduce?
below is my code. `try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
NDList results = net.forward(parameterStore, X, true);
if(firstTime == 0) {
requiresGrad(net.getParameters());
}
MultiBoxTarget mt = MultiBoxTarget.builder().build();
NDList inputs = new NDList(results.get(0));
inputs.add(Y.get(0));
inputs.add(results.get(1).transpose(0,2,1));
NDList outputs = mt.target(inputs);
NDArray l = calc_loss(cls_loss,bbox_loss,new NDList(results.get(1)),new NDList(outputs.get(2)),new NDList(results.get(2)),new NDList(outputs.get(0)),new NDList(outputs.get(1)));
gc.backward(l);
acc_sum += cls_eval(new NDList(results.get(1)),new NDList(outputs.get(2)));
n += outputs.get(2).size();
mae_sum += bbox_eval(new NDList(results.get(2)),new NDList(outputs.get(0)),new NDList(outputs.get(1)));
m += outputs.get(0).size();
}`
What have you tried to solve it?
- I have tried to use version 0.11.0 ; I was very happy, beacuse one day, it works very well suddenly. After happy some day, it works badly(all gradients are allways zeros).
Environment Info
IDE: eclipse jdk: jdk11
@hyandell can you give me a answer?
Apologies @kameronren - I'm a part of the AWS OSPO and am not a domain expert on this project. The project has a Slack channel at deepjavalibrary.slack.com, that might be a good place to raise this. I'll also mention this on there.
@zachgk Can you take a look?
As a first suggestion, I recommend using the Trainer. It is the main training session API and should make it easier to get everything right. You can refer to our SSD example at https://github.com/deepjavalibrary/djl/blob/master/examples/src/main/java/ai/djl/examples/training/TrainPikachu.java to see what that might look like.
Otherwise, the only thought I have on a first glance is that you may need to move the requiresGrad call to before you create the gradient collector. I also don't see how you are initializing the parameters or where the code uses the gradients is.