tinychat-tutorial icon indicating copy to clipboard operation
tinychat-tutorial copied to clipboard

Improve checking to detect `NaN` case

Open insop opened this issue 1 year ago • 1 comments

The current check_two_equal may not able to check the NaN case. We can improve it by using std::isnan.

When the following incorrect SIMD implementation runs on M1 environment, the test still passes. The cause of this issue is that diff becomes NaN and the final check won't detect NaN case.

The fix is to add a check for NaN.

Here is an incorrect implementation of simd_programming; however, it passes the current test.

-------- Sanity check of simd_programming implementation: Passed! --------
Section, Total time(ms), Average time(ms), Count, GOPs
simd_programming, 1082.963989, 108.295998, 10, 2.420616

            for (int q = 0; q < num_block; q++) {
                // load 32x4bit (16 bytes) weight
                const uint8x16_t w0 = vld1q_u8(w_start);
                w_start += 16;

                /*
                   We will accelerate the program using ARM Intrinsics. You can check the documentation of operations
                   at: https://developer.arm.com/architectures/instruction-sets/intrinsics
                */
                // TODO: decode the lower and upper half of the weights as int8x16_t
                // Hint:
                // (1) use `vandq_u8` with the mask_low4bit to get the lower half
                // (2) use `vshrq_n_u8` to right shift 4 bits and get the upper half
                // (3) use `vreinterpretq_s8_u8` to interpret the  vector as int8
                // lowbit mask
                const uint8x16_t mask_low4bit = vdupq_n_u8(0xf);

                // ADD correct code here

                // TODO: apply zero_point to weights and convert the range from (0, 15) to (-8, 7)
                // Hint: using `vsubq_s8` to the lower-half and upper-half vectors of weights
                const int8x16_t offsets = vdupq_n_s8(8);

                // ADD correct code here

                // load 32 8-bit activation
                const int8x16_t a0 = vld1q_s8(a_start);
                const int8x16_t a1 = vld1q_s8(a_start + 16);
                a_start += 32;

                // TODO: perform dot product and store the result into the intermediate sum, int_sum0
                // Hint: use `vdotq_s32` to compute sumv0 = a0 * lower-half weights + a1 * upper-half weights
                // int32x4 vector to store intermediate sum
                int32x4_t int_sum0;

                // HERE is an incorrect implementation of the code
                int_sum0 = vdupq_n_s32(0);

                sumv0 = vdotq_s32(int_sum0, w0_low, a0);
                sumv0 = vdotq_s32(int_sum0, w0_high, a1);

                float s_0 = *s_a++ * *s_w++;
                sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0);
            }
            C->data_ptr[row * n + col] = vaddvq_f32(sumv0);

insop avatar Jan 01 '24 23:01 insop

PTAL, @meenchen , @RaymondWang0. Thank you.

insop avatar Jan 02 '24 19:01 insop