mlx
mlx copied to clipboard
[Feature] More accurate reductions for low precision types
Our reductions are quite naive and can be less accurate particularly in lower precision (mx.float16).
NumPy and PyTorch (MPS) seem to use more sophisticated reductions. E.g. for NumPy:
For floating point numbers the numerical precision of sum (and np.add.reduce) is in general limited by directly adding each number individually to the result causing rounding errors in every step. However, often numpy will use a numerically better approach (partial pairwise summation) leading to improved precision in many use-cases. This improved precision is always provided when no axis is given.
See #483 for a little more discussion.