mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] nan is clipped to the upper

Open Redempt1onzzZZ opened this issue 1 year ago • 3 comments

Describe the bug a = mx.array([float('nan')]) mx.clip(a,-1,1) = array([-1], dtype=float32)

To Reproduce

Include code snippet

import mlx.core as mx

a = mx.array([float('nan')])
print(a)
b = mx.clip(a,-1,1)
print(b)

image

Expected behavior mx.clip(a,-1,1) = array([nan], dtype=float32)

Desktop (please complete the following information):

  • OS Version: MacOS 14.2.1
  • Version 0.7.0

Redempt1onzzZZ avatar Jan 16 '24 11:01 Redempt1onzzZZ

I think this is a bug in max and min 🤔 but I'm not sure if we should fix it. It will be tricky to fix in the general case without doing extra work to handle NaNs...

mx.set_default_device(mx.cpu)                                                  
a = mx.array([float("nan")])                                                   
print(mx.minimum(a, 1))
array([1], dtype=float32)

This happens because:

>>> 3 < float("nan")
False
>>> 3 > float("nan")
False

And the condition for the minimum is x < y ? x : y which always takes y when x is NaN.

awni avatar Jan 16 '24 15:01 awni

it seems NaN also disrupt mx.where. This is the bug I encountered when implementing the softplus :

def softplus(x, beta=1, threshold=20):
    scaled_x = beta * x
    mask = scaled_x > threshold

    return mx.where(mask, x, 1/beta * mx.log1p(mx.exp(x)))

when I launch :

softplus(mx.array([1, 210]))

I get array([1.31326, nan], dtype=float32). Although the mask is correct, having a NaN messes with mx.where.

alxndrTL avatar Jan 16 '24 17:01 alxndrTL

@alxndrTL your comment is a different problem from this issue. You are taking and exp(210) which is inf which causes a NaN when you use where. I would not use such a large value in an exponential if it can be helped.

You should use a mx.logaddexo(0, x) instead of mx.log1p(mx.exp(x))

awni avatar Jan 17 '24 03:01 awni

It seems the clipping no longer occurs for nan:

>>> a = mx.array([float('nan')])
>>> mx.clip(a, -1, 1)
array([nan], dtype=float32)

and

>>> mx.set_default_device(mx.cpu)   
>>> a = mx.array([float("nan")]) 
>>> print(mx.minimum(a, 1))
array([nan], dtype=float32)

Version:

>>> mx.__version__
'0.16.0.dev20240723+6768c6a5'

Perhaps this is due to updates in maximum and minimum in ops.cpp via https://github.com/ml-explore/mlx/pull/871/files#diff-f1a5f4c64261d0588116760d0811fd474fdba3a0fdfe976a40b23d26de5e78db

plpxsk avatar Jul 23 '24 18:07 plpxsk

Cool. I guess we can close this! Thanks for noticing it.

awni avatar Jul 23 '24 18:07 awni