djl icon indicating copy to clipboard operation
djl copied to clipboard

MXNet train error

Open luosoy opened this issue 2 years ago • 2 comments
trafficstars

VERSION: <groupId>ai.djl.mxnet</groupId> <artifactId>mxnet-engine</artifactId> 0.21.0

ai.djl.mxnet mxnet-native-mkl win-x86_64 1.9.1

Block: public static SequentialBlock build() { SequentialBlock resNet = new SequentialBlock();

    resNet.add(Conv2d.builder()
                    .setKernelShape(new Shape(3, 3))   // 9个图片数据合成1个数据
                    .setFilters(FILTERS_NUM)           // 生成多少changle
                    .optStride(new Shape(1, 1))        // 步长
                    .optPadding(new Shape(1, 1))       // 是否扩展
                    .optBias(true)
                    .build())
            .add(BatchNorm.builder()
                    .optEpsilon(1e-5f)
                    .optMomentum(BATCH_NORM_MOMENTUM)
                    .build())
            .add(Activation.reluBlock());

    for (int i = 0; i < NUM_REST_BLOCKS; i++) {
        resNet.add(residualBlock());
    }

    ParallelBlock parallelBlock = new ParallelBlock(
            list -> {
                NDList policy = list.get(0);
                NDList value = list.get(1);
                return new NDList(policy.singletonOrThrow(), value.singletonOrThrow());
            },
            Arrays.asList(policyBlock(), valueBlock()));

    resNet.add(parallelBlock);

    return resNet;
}

private static Block policyBlock() {
    SequentialBlock policyNet = new SequentialBlock();

    policyNet.add(Conv2d.builder()
                    .setKernelShape(new Shape(1, 1))
                    .setFilters(2)
                    .optStride(new Shape(1, 1))
                    .optPadding(new Shape(0, 0))
                    .optBias(true)
                    .build())
            .add(BatchNorm.builder()
                    .optEpsilon(1e-5f)
                    .optMomentum(BATCH_NORM_MOMENTUM)
                    .build())
            .add(Activation.reluBlock())
            .add(Blocks.batchFlattenBlock())
            .add(Linear.builder().setUnits(OUT_SIZE).build())
            .add(new LambdaBlock(RNBuild::softMax));

    return policyNet;
}

public static NDList softMax(NDList arrays) {
    return new NDList(arrays.singletonOrThrow().softmax(1));
}

private static Block valueBlock() {
    SequentialBlock valueNet = new SequentialBlock();
    valueNet.add(Conv2d.builder()
                    .setKernelShape(new Shape(1, 1))
                    .setFilters(2)
                    .optStride(new Shape(1, 1))
                    .optPadding(new Shape(0, 0))
                    .optBias(true)
                    .build())
            .add(BatchNorm.builder()
                    .optEpsilon(1e-5f)
                    .optMomentum(BATCH_NORM_MOMENTUM)
                    .build())
            .add(Activation.reluBlock())
            .add(Blocks.batchFlattenBlock())
            .add(Linear.builder().setUnits(FILTERS_NUM).build())
            .add(Activation.reluBlock())
            .add(Linear.builder().setUnits(1).build())
            .add(Activation.tanhBlock());

    return valueNet;
}


private static Block residualBlock() {
    SequentialBlock shortcut = new SequentialBlock();
    shortcut.add(Blocks.identityBlock());

    SequentialBlock resUnit = new SequentialBlock();
    resUnit.add(Conv2d.builder()
                    .setKernelShape(new Shape(3, 3))
                    .setFilters(FILTERS_NUM)
                    .optStride(new Shape(1, 1))
                    .optPadding(new Shape(1, 1))
                    .optBias(true)
                    .build())
            .add(BatchNorm.builder()
                    .optEpsilon(1e-5f)
                    .optMomentum(BATCH_NORM_MOMENTUM)
                    .build())
            .add(Activation::relu)
            .add(Conv2d.builder()
                    .setKernelShape(new Shape(3, 3))
                    .setFilters(FILTERS_NUM)
                    .optStride(new Shape(1, 1))
                    .optPadding(new Shape(1, 1))
                    .optBias(false)
                    .build())
            .add(BatchNorm.builder()
                    .optEpsilon(1e-5f)
                    .optMomentum(BATCH_NORM_MOMENTUM)
                    .build());

    return new ParallelBlock(
            list -> {
                NDList unit = list.get(0);
                NDList parallel = list.get(1);
                return new NDList(
                        unit.singletonOrThrow()
                                .add(parallel.singletonOrThrow())
                                .getNDArrayInternal()
                                .relu());
            },
            Arrays.asList(resUnit, shortcut));

}

train error: ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: AGInfo: :IsNone(*output): Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with autograd. Stack trace: File "C:\source\mxnet\src\imperative\imperative.cc", line 261

at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1942)
at ai.djl.mxnet.jna.JnaUtils.imperativeInvoke(JnaUtils.java:521)
at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:60)

luosoy avatar Mar 06 '23 03:03 luosoy

Can you share your training code as well? It's hard for us to debug with just the block/model definition.

Can you also share the full stack trace, or at least where the error originates from your code?

siddvenk avatar Mar 06 '23 18:03 siddvenk

Please see the attachment 《TestDjl.zip》

灰色&月光 @.***

 

------------------ 原始邮件 ------------------ 发件人: "deepjavalibrary/djl" @.>; 发送时间: 2023年3月7日(星期二) 凌晨2:44 @.>; @.@.>; 主题: Re: [deepjavalibrary/djl] MXNet train error (Issue #2444)

Can you share your training code as well? It's hard for us to debug with just the block/model definition.

Can you also share the full stack trace, or at least where the error originates from your code?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

luosoy avatar Mar 07 '23 01:03 luosoy