dex-lang icon indicating copy to clipboard operation
dex-lang copied to clipboard

Handling of Nat at the Jax-Dex boundary

Open axch opened this issue 2 years ago • 3 comments

Suppose we wish to export a Dex function that takes a Nat argument to Jax:

dex_iota = primitive(dex.eval(r"\(size:Nat). for i:(Fin size). ordinal i"))

If we just call it, it works ok:

dex_iota(5)
> [0, 1, 2, 3, 4]

But if we jit it first, it shows us a type error:

jax.jit(dex_iota)(5)
E         RuntimeError: dtype mismatch in arg 0: expected uint32, got int32

We should probably pick one of these behaviors and stick with it (though it issue may have to do with Jax's notion of weak types).

axch avatar Sep 02 '22 17:09 axch

Yeah weak types sound relevant here!

apaszke avatar Sep 02 '22 20:09 apaszke

I assume what we actually want is to behave like any other Jax primitive. Do we know clearly enough what behavior that is?

axch avatar Sep 02 '22 20:09 axch

Not sure... We should check in with others. Jake will know for sure.

apaszke avatar Sep 03 '22 06:09 apaszke