Wilson Yan
Wilson Yan
Hi, I have a simple script below that compares runtimes for `pmap` vs `pjit`. I expected that the runtime for `pjit` with full data parallelism would be the same for...
Hi, Nice repo! Is this line a bug? Since I think `batch['images']` is `N x B x H x W x C`, so the indices should be shift up by...
### Description I have this simple script below that tests `jnp.cumsum` when sharding along the same axis it is summing over. Tested on a machine with 8 40GB A100s. ```python...