ComputeLibrary icon indicating copy to clipboard operation
ComputeLibrary copied to clipboard

Wrong result of NESoftmaxLayer with multi-dimensions

Open daoxian opened this issue 2 years ago • 1 comments

Output of 'strings libarm_compute.so | grep arm_compute_version': arm_compute_version=v22.02 Build options: {'Werror': '1', 'debug': '1', 'neon': '1', 'opencl': '0', 'os': 'linux', 'arch': 'arm64-v8.2-a-sve', 'benchmark_tests': '1', 'validation_tests': '1'} Git hash=unknown

Platform: Arm64

Operating System: Ubuntu 20.04.3 LTS input data: float32 array of [4, 5, 8, 2, 9, 15] input shape: 2x3x1x1 softmax axis=1

ComputeLibrary's Softmax output (by Tensor.print()):

 0.00490169 4.53978e-05
   0.267623 2.26022e-06
   0.727475    0.999952

pytorch、numpy、tensorflow、etc... Softmax output:

tensor([[[[1.7148e-02]],

         [[4.6613e-02]],

         [[9.3624e-01]]],


        [[[2.2547e-06]],

         [[2.4726e-03]],

         [[9.9753e-01]]]])

Below is the test code:

#include "arm_compute/runtime/NEON/NEFunctions.h"
#include "arm_compute/core/Types.h"
#include "utils/Utils.h"

using namespace arm_compute;
using namespace utils;

class NEONSOFTMAXExample : public Example
{
public:
    bool do_setup(int argc, char **argv) override
    {
        ARM_COMPUTE_UNUSED(argc);
        ARM_COMPUTE_UNUSED(argv);

        softmax = std::make_unique<NESoftmaxLayer>();

        TensorShape src_shape(0);
        src_shape.set(0, 2, false, true);
        src_shape.set(1, 3, false, true);
        src_shape.set(2, 1, false, true);
        src_shape.set(3, 1, false, true);

        src.allocator()->init(TensorInfo(src_shape, 1, DataType::F32));

        TensorShape out_shape_softmax(0);
        out_shape_softmax.set(0, 2, false, true);
        out_shape_softmax.set(1, 3, false, true);
        out_shape_softmax.set(2, 1, false, true);
        out_shape_softmax.set(3, 1, false, true);
        out_softmax.allocator()->init(TensorInfo(out_shape_softmax, 1, DataType::F32));

        softmax->configure(&src, &out_softmax, 1.0f, 1);

        src.allocator()->allocate();
        out_softmax.allocator()->allocate();
        std::vector<float> vec = {4, 5, 8, 2, 9, 15};
        fill_tensor_vector(src,vec);

        return true;
    }
    void do_run() override
    {
        src.print(std::cout);
        softmax->run();
        out_softmax.print(std::cout);
    }

private:
    Tensor src{};
    Tensor out_softmax{};
    std::unique_ptr<NESoftmaxLayer>        softmax{};
};

/** Main program for softmax test
 *
 */
int main(int argc, char **argv)
{
    return utils::run_example<NEONSOFTMAXExample>(argc, argv);
}


daoxian avatar Aug 09 '22 09:08 daoxian

Hasn't anyone else encountered this problem? Any help will be appreciated! @GeorgeARM @morgolock

daoxian avatar Aug 10 '22 03:08 daoxian

I've found the reason: acl softmax axis should be counted in a reverse order of numpy/pytorch ...

daoxian avatar Dec 30 '22 12:12 daoxian