jax icon indicating copy to clipboard operation
jax copied to clipboard

No replication rule for `pmax`

Open reinerp opened this issue 11 months ago • 0 comments

Description

pmax under shard_map fails with

NotImplementedError: No replication rule for pmax. As a workaround, pass the `check_rep=False` argument to `shard_map`.

The workaround does indeed work.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.10.13 (main, Sep 11 2023, 08:16:02) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='reiners-mbp.lan', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:55:06 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T6020', machine='arm64')

reinerp avatar Mar 20 '24 14:03 reinerp