dex-lang
dex-lang copied to clipboard
Handling of Nat at the Jax-Dex boundary
Suppose we wish to export a Dex function that takes a Nat
argument to Jax:
dex_iota = primitive(dex.eval(r"\(size:Nat). for i:(Fin size). ordinal i"))
If we just call it, it works ok:
dex_iota(5)
> [0, 1, 2, 3, 4]
But if we jit
it first, it shows us a type error:
jax.jit(dex_iota)(5)
E RuntimeError: dtype mismatch in arg 0: expected uint32, got int32
We should probably pick one of these behaviors and stick with it (though it issue may have to do with Jax's notion of weak types).
Yeah weak types sound relevant here!
I assume what we actually want is to behave like any other Jax primitive. Do we know clearly enough what behavior that is?
Not sure... We should check in with others. Jake will know for sure.