grok-1
grok-1 copied to clipboard
TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1,) for operand shape (0,)
I have 8 RTX-4090 cards, and running grok-1 fails. Reproducing steps:
- Clone the grok-1
- Install the requirements:Install the requirements:
pip install -r requirements.txt
- Download the Hugging Face weights:
git clone https://huggingface.co/xai-org/grok-1
- The local mesh config change in run.py file
local_mesh_config=(1, 1)
- Run the run.py file
python3 run.py
The error below:
Traceback (most recent call last): File "/app/grok-1/run.py", line 87, in <module> main() File "/app/grok-1/run.py", line 82, in main print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) File "/app/grok-1/runners.py", line 597, in sample_from_model next(server) File "/app/grok-1/runners.py", line 481, in run rngs, last_output, memory, settings = self.prefill_memory( File "/usr/local/lib/python3.10/dist-packages/haiku/_src/multi_transform.py", line 314, in apply_fn return f.apply(params, None, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/haiku/_src/transform.py", line 183, in apply_fn out, state = f.apply(params, None, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/haiku/_src/transform.py", line 456, in apply_fn out = f(*args, **kwargs) File "/app/grok-1/runners.py", line 352, in hk_prefill_memory settings = jax.tree_map( File "/app/grok-1/runners.py", line 353, in <lambda> lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0), TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1,) for operand shape (0,).