Jianing Ye

Results 3 issues of Jianing Ye

I followed the instructions in the [tutorial#scan-over-layers](https://flax.readthedocs.io/en/latest/nnx_basics.html#scan-over-layers) to build a network with multiple layers with `nnx.vmap` and to forward with `nnx.scan`. However, doing so reults in loss of precision in...

### Describe the bug I am trying to put the buffer on CPU to alleviate the memory shortage on GPU, but encountered a large speed degradation. As shown in `readme.md`,...

bug

I am using mctx to implement MuZero, where actions are selected at each node via a forward pass of a large neural network. Initially, the main bottleneck in terms of...