grok-1 icon indicating copy to clipboard operation
grok-1 copied to clipboard

TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1,) for operand shape (0,)

Open jifa513 opened this issue 9 months ago • 1 comments

I have 8 RTX-4090 cards, and running grok-1 fails. Reproducing steps:

  1. Clone the grok-1
  2. Install the requirements:Install the requirements: pip install -r requirements.txt
  3. Download the Hugging Face weights: git clone https://huggingface.co/xai-org/grok-1
  4. The local mesh config change in run.py file local_mesh_config=(1, 1)
  5. Run the run.py file python3 run.py

The error below: Traceback (most recent call last): File "/app/grok-1/run.py", line 87, in <module> main() File "/app/grok-1/run.py", line 82, in main print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01)) File "/app/grok-1/runners.py", line 597, in sample_from_model next(server) File "/app/grok-1/runners.py", line 481, in run rngs, last_output, memory, settings = self.prefill_memory( File "/usr/local/lib/python3.10/dist-packages/haiku/_src/multi_transform.py", line 314, in apply_fn return f.apply(params, None, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/haiku/_src/transform.py", line 183, in apply_fn out, state = f.apply(params, None, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/haiku/_src/transform.py", line 456, in apply_fn out = f(*args, **kwargs) File "/app/grok-1/runners.py", line 352, in hk_prefill_memory settings = jax.tree_map( File "/app/grok-1/runners.py", line 353, in <lambda> lambda o, v: jax.lax.dynamic_update_index_in_dim(o, v, i, axis=0), TypeError: dynamic_update_slice update shape must be smaller than operand shape, got update shape (1,) for operand shape (0,).

jifa513 avatar May 14 '24 12:05 jifa513