tinychat-tutorial
tinychat-tutorial copied to clipboard
Improve checking to detect `NaN` case
The current check_two_equal may not able to check the NaN case. We can improve it by using std::isnan.
When the following incorrect SIMD implementation runs on M1 environment, the test still passes. The cause of this issue is that diff becomes NaN and the final check won't detect NaN case.
The fix is to add a check for NaN.
Here is an incorrect implementation of simd_programming; however, it passes the current test.
-------- Sanity check of simd_programming implementation: Passed! --------
Section, Total time(ms), Average time(ms), Count, GOPs
simd_programming, 1082.963989, 108.295998, 10, 2.420616
for (int q = 0; q < num_block; q++) {
// load 32x4bit (16 bytes) weight
const uint8x16_t w0 = vld1q_u8(w_start);
w_start += 16;
/*
We will accelerate the program using ARM Intrinsics. You can check the documentation of operations
at: https://developer.arm.com/architectures/instruction-sets/intrinsics
*/
// TODO: decode the lower and upper half of the weights as int8x16_t
// Hint:
// (1) use `vandq_u8` with the mask_low4bit to get the lower half
// (2) use `vshrq_n_u8` to right shift 4 bits and get the upper half
// (3) use `vreinterpretq_s8_u8` to interpret the vector as int8
// lowbit mask
const uint8x16_t mask_low4bit = vdupq_n_u8(0xf);
// ADD correct code here
// TODO: apply zero_point to weights and convert the range from (0, 15) to (-8, 7)
// Hint: using `vsubq_s8` to the lower-half and upper-half vectors of weights
const int8x16_t offsets = vdupq_n_s8(8);
// ADD correct code here
// load 32 8-bit activation
const int8x16_t a0 = vld1q_s8(a_start);
const int8x16_t a1 = vld1q_s8(a_start + 16);
a_start += 32;
// TODO: perform dot product and store the result into the intermediate sum, int_sum0
// Hint: use `vdotq_s32` to compute sumv0 = a0 * lower-half weights + a1 * upper-half weights
// int32x4 vector to store intermediate sum
int32x4_t int_sum0;
// HERE is an incorrect implementation of the code
int_sum0 = vdupq_n_s32(0);
sumv0 = vdotq_s32(int_sum0, w0_low, a0);
sumv0 = vdotq_s32(int_sum0, w0_high, a1);
float s_0 = *s_a++ * *s_w++;
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0);
}
C->data_ptr[row * n + col] = vaddvq_f32(sumv0);
PTAL, @meenchen , @RaymondWang0. Thank you.