djl
djl copied to clipboard
MXNet train error
VERSION:
Block: public static SequentialBlock build() { SequentialBlock resNet = new SequentialBlock();
resNet.add(Conv2d.builder()
.setKernelShape(new Shape(3, 3)) // 9个图片数据合成1个数据
.setFilters(FILTERS_NUM) // 生成多少changle
.optStride(new Shape(1, 1)) // 步长
.optPadding(new Shape(1, 1)) // 是否扩展
.optBias(true)
.build())
.add(BatchNorm.builder()
.optEpsilon(1e-5f)
.optMomentum(BATCH_NORM_MOMENTUM)
.build())
.add(Activation.reluBlock());
for (int i = 0; i < NUM_REST_BLOCKS; i++) {
resNet.add(residualBlock());
}
ParallelBlock parallelBlock = new ParallelBlock(
list -> {
NDList policy = list.get(0);
NDList value = list.get(1);
return new NDList(policy.singletonOrThrow(), value.singletonOrThrow());
},
Arrays.asList(policyBlock(), valueBlock()));
resNet.add(parallelBlock);
return resNet;
}
private static Block policyBlock() {
SequentialBlock policyNet = new SequentialBlock();
policyNet.add(Conv2d.builder()
.setKernelShape(new Shape(1, 1))
.setFilters(2)
.optStride(new Shape(1, 1))
.optPadding(new Shape(0, 0))
.optBias(true)
.build())
.add(BatchNorm.builder()
.optEpsilon(1e-5f)
.optMomentum(BATCH_NORM_MOMENTUM)
.build())
.add(Activation.reluBlock())
.add(Blocks.batchFlattenBlock())
.add(Linear.builder().setUnits(OUT_SIZE).build())
.add(new LambdaBlock(RNBuild::softMax));
return policyNet;
}
public static NDList softMax(NDList arrays) {
return new NDList(arrays.singletonOrThrow().softmax(1));
}
private static Block valueBlock() {
SequentialBlock valueNet = new SequentialBlock();
valueNet.add(Conv2d.builder()
.setKernelShape(new Shape(1, 1))
.setFilters(2)
.optStride(new Shape(1, 1))
.optPadding(new Shape(0, 0))
.optBias(true)
.build())
.add(BatchNorm.builder()
.optEpsilon(1e-5f)
.optMomentum(BATCH_NORM_MOMENTUM)
.build())
.add(Activation.reluBlock())
.add(Blocks.batchFlattenBlock())
.add(Linear.builder().setUnits(FILTERS_NUM).build())
.add(Activation.reluBlock())
.add(Linear.builder().setUnits(1).build())
.add(Activation.tanhBlock());
return valueNet;
}
private static Block residualBlock() {
SequentialBlock shortcut = new SequentialBlock();
shortcut.add(Blocks.identityBlock());
SequentialBlock resUnit = new SequentialBlock();
resUnit.add(Conv2d.builder()
.setKernelShape(new Shape(3, 3))
.setFilters(FILTERS_NUM)
.optStride(new Shape(1, 1))
.optPadding(new Shape(1, 1))
.optBias(true)
.build())
.add(BatchNorm.builder()
.optEpsilon(1e-5f)
.optMomentum(BATCH_NORM_MOMENTUM)
.build())
.add(Activation::relu)
.add(Conv2d.builder()
.setKernelShape(new Shape(3, 3))
.setFilters(FILTERS_NUM)
.optStride(new Shape(1, 1))
.optPadding(new Shape(1, 1))
.optBias(false)
.build())
.add(BatchNorm.builder()
.optEpsilon(1e-5f)
.optMomentum(BATCH_NORM_MOMENTUM)
.build());
return new ParallelBlock(
list -> {
NDList unit = list.get(0);
NDList parallel = list.get(1);
return new NDList(
unit.singletonOrThrow()
.add(parallel.singletonOrThrow())
.getNDArrayInternal()
.relu());
},
Arrays.asList(resUnit, shortcut));
}
train error: ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: AGInfo: :IsNone(*output): Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with autograd. Stack trace: File "C:\source\mxnet\src\imperative\imperative.cc", line 261
at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1942)
at ai.djl.mxnet.jna.JnaUtils.imperativeInvoke(JnaUtils.java:521)
at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:60)
Can you share your training code as well? It's hard for us to debug with just the block/model definition.
Can you also share the full stack trace, or at least where the error originates from your code?
Please see the attachment 《TestDjl.zip》
灰色&月光 @.***
------------------ 原始邮件 ------------------ 发件人: "deepjavalibrary/djl" @.>; 发送时间: 2023年3月7日(星期二) 凌晨2:44 @.>; @.@.>; 主题: Re: [deepjavalibrary/djl] MXNet train error (Issue #2444)
Can you share your training code as well? It's hard for us to debug with just the block/model definition.
Can you also share the full stack trace, or at least where the error originates from your code?
— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>