xla
xla copied to clipboard
AVX512 quantization (cast from float to uint8) returns wrong results
This is the issue I reported originally at https://github.com/tensorflow/tensorflow/issues/49944
The problem is that the values returned by the float to uint8 cast are always in the wrong order. I.e. the code here: https://github.com/openxla/xla/blob/3ae086875f9a42f1a5491f1270253f623668473f/xla/tsl/framework/fixedpoint/TypeCastingAVX512.h#L85-L112
This can best be seen by using a test with inputs such that the output should be an ordered sequence of 120 numbers which I broke down into lines of 8.
AVX512BW path:
0 1 2 3 16 17 18 19
32 33 34 35 48 49 50 51
4 5 6 7 20 21 22 23
36 37 38 39 52 53 54 55
8 9 10 11 24 25 26 27
40 41 42 43 56 57 58 59
12 13 14 15 28 29 30 31
44 45 46 47 60 61 62 63
64 65 66 67 68 69 70 71
72 73 74 75 76 77 78 79
80 81 82 83 84 85 86 87
88 89 90 91 92 93 94 95
96 97 98 99 100 101 102 103
104 105 106 107 108 109 110 111
112 113 114 115 116 117 118 119
Fallback:
36 37 38 39 32 33 34 35
52 53 54 55 48 49 50 51
4 5 6 7 0 1 2 3
20 21 22 23 16 17 18 19
44 45 46 47 40 41 42 43
60 61 62 63 56 57 58 59
12 13 14 15 8 9 10 11
28 29 30 31 24 25 26 27
64 65 66 67 68 69 70 71
72 73 74 75 76 77 78 79
80 81 82 83 84 85 86 87
88 89 90 91 92 93 94 95
96 97 98 99 100 101 102 103
104 105 106 107 108 109 110 111
112 113 114 115 116 117 118 119
See https://github.com/tensorflow/tensorflow/issues/49944#issuecomment-2178594002 for a solution to this specific code although other code paths are likely affected too