mlx
mlx copied to clipboard
[BUG] CUDA random crashes on large sizes
import mlx.core as mx
a = mx.random.uniform(shape=(1024, 1024, 1024, 3))
mx.eval(a)
Fails with:
RuntimeError: cudaGraphAddKernelNode(&node, graph_, NULL, 0, ¶ms) failed: invalid argument