jax
jax copied to clipboard
No replication rule for `pmax`
Description
pmax
under shard_map
fails with
NotImplementedError: No replication rule for pmax. As a workaround, pass the `check_rep=False` argument to `shard_map`.
The workaround does indeed work.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='reiners-mbp.lan', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:55:06 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6020', machine='arm64')