gemma icon indicating copy to clipboard operation
gemma copied to clipboard

ValueError: Incompatible shapes for broadcasting

Open radna0 opened this issue 3 months ago • 2 comments

  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/text/_sampler.py", line 311, in sample
    init_state = _prefill.prefill(
                 ^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/text/_prefill.py", line 110, in prefill
    out = model.apply(
          ^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/kauldron/utils/train_property.py", line 141, in decorated
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/utils/_jax_utils.py", line 96, in decorated
    output = fn(*bound_args.args, **bound_args.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/kauldron/typing/type_check.py", line 270, in _reraise_with_shape_info
    retval = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_transformer.py", line 247, in __call__
    x, new_cache = self._apply_attention(inputs, cache)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_transformer.py", line 292, in _apply_attention
    layer_cache, x = block(
                     ^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_modules.py", line 467, in __call__
    cache, attn_output = self.attn(
                         ^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/gemma/gm/nn/_modules.py", line 277, in __call__
    padded_logits = jnp.where((jnp.expand_dims(attn_mask, -2)), logits, K_MASK)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2821, in where
    return util._where(condition, x, y)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 311, in _where
    condition, x_arr, y_arr = _broadcast_arrays(condition, x, y)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/kojoe/miniconda3/envs/vllm/lib/python3.12/site-packages/jax/_src/numpy/util.py", line 264, in _broadcast_arrays
    result_shape = lax.broadcast_shapes(*shapes)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Incompatible shapes for broadcasting: shapes=[(1, 1447, 1, 5234), (1, 1447, 8, 4096), ()]
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.```

radna0 avatar Sep 13 '25 20:09 radna0

Noticed that attn_mask has shape (1, 1447, 1, 5234) while logits are (1, 1447, 8, 4096). Looks like a broadcasting mismatch between the mask and attention logits — possibly the mask is aligned with vocab dimension instead of head dimension. Worth checking how the attention mask is being constructed/applied.

Roaimkhan avatar Sep 15 '25 15:09 Roaimkhan

Hi, I'd like to help to fix this issue. I’m currently investigating the shape mismatch in the attention mask construction, specifically why attn_mask has shape (1, 1447, 1, 5234) while logits are (1, 1447, 8, 4096).

If no one is working on this, I can prepare a PR.

Aditya-Ware-ds avatar Nov 24 '25 14:11 Aditya-Ware-ds

Following up — happy to prepare a fix if this issue is unassigned.

Aditya-Ware-ds avatar Nov 30 '25 18:11 Aditya-Ware-ds