mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Inconsistency in compile with kwargs

Open awni opened this issue 1 year ago • 2 comments

The fact that some of the following work but some don't seems inconsistent and unexpected. Filling this here mostly so I don't forget about it.

import mlx.core as mx

@mx.compile
def fun(x, y=None):
    if y is not None:
        return x + y
    else:
        return x + 1

fun(mx.array(1.0)) # ok
fun(mx.array(1.0), mx.array(2.0)) # ok
fun(mx.array(1.0), None) # exception
fun(mx.array(1.0), y=None) #exception

awni avatar Mar 12 '24 23:03 awni

Second time writing this comment sorry for any notification spam, and let me know if this is the wrong place for it, it seems related enough, but if its the wrong place let me know and I'll delete.

I've run into another edge case with mx.compile and custom dataclasses and Im not super certain why the behavior is occurring:

import mlx.core as mx
from collections import namedtuple

exampleClass = namedtuple('Example', ['x','y'])
example_tuple = exampleClass(x=0,y=1)

def foo(mytuple):
    return mytuple[0] + mytuple[1] 

print(foo(mx.array([0,1]))) # works outputs array(1, dtype=int32)
print(foo(example_tuple)) # works outputs 1

compiled_foo = mx.compile(foo)

print(compiled_foo(mx.array([0,1]))) # outputs array(1, dtype=int32)
print(compiled_foo(example_tuple)) # outputs None (?huh?)

romanoneg avatar Mar 15 '24 02:03 romanoneg

In the named example type you are not doing any array options so compiling through that doesn't make sense. (The 0 and 1 never get cast to mx.array. You can fix it by doing:

exampleClass = namedtuple('Example', ['x','y'])
example_tuple = exampleClass(x=mx.array(0),y=mx.array(1))

def foo(mytuple):
    return mytuple[0] + mytuple[1]

compiled_foo = mx.compile(foo)
print(compiled_foo(example_tuple))

We should have better error messaging (or find a way to support it).

awni avatar Mar 15 '24 13:03 awni