mlx
mlx copied to clipboard
[bug/performance] Random state is updated even when unused
import mlx.core as mx
def fun():
for _ in range(1000):
mx.random.randint(1, 10)
fun()
print(mx.random.randint(0, 10, shape=(32, 32)))
Evaluating the last line causes 1k split kernels to run since the internal random state is updated in fun. I am not sure there is an easy fix for that.. but it would be nice to avoid running those superfluous 1k kernels since they are never used.