mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Feature] More accurate reductions for low precision types

Open awni opened this issue 1 year ago • 0 comments

Our reductions are quite naive and can be less accurate particularly in lower precision (mx.float16).

NumPy and PyTorch (MPS) seem to use more sophisticated reductions. E.g. for NumPy:

For floating point numbers the numerical precision of sum (and np.add.reduce) is in general limited by directly adding each number individually to the result causing rounding errors in every step. However, often numpy will use a numerically better approach (partial pairwise summation) leading to improved precision in many use-cases. This improved precision is always provided when no axis is given.

See #483 for a little more discussion.

awni avatar Jan 18 '24 14:01 awni