alpa
alpa copied to clipboard
Does not support this kind of Gather
Please describe the bug
I attempted to use alpa, and the quick start example works fine. However, I encountered an error when implementing it with my own model. My model (including the optimizer) is purely written in JAX, without using Flax or Optax. Is alpa bound with Flax or Optax?
2023-06-30 07:29:19.767640: F external/org_tensorflow/tensorflow/compiler/xla/service/spmd/auto_sharding.cc:883] Check failed: operand_dim < ins->operand(0)->shape().rank() (2 vs. 2)Does not support this kind of Gather.
Please describe the expected behavior
System information and environment
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker): docker
- Python version: 3.8.10
- CUDA version: 11.8
- NCCL version:
- cupy version:
- GPU model and memory:
- Alpa version: 0.2.3
- TensorFlow version:
- JAX version: 0.3.22
To Reproduce Steps to reproduce the behavior: 1. 2. 3. 4. See error
Screenshots If applicable, add screenshots to help explain your problem.
Code snippet to reproduce the problem
Additional information Add any other context about the problem here or include any logs that would be helpful to diagnose the problem.
I got into the same trouble! How did you solve it?
I encountered the same problem, has your problem been solved?