rotary-embedding-torch icon indicating copy to clipboard operation
rotary-embedding-torch copied to clipboard

`torch.cat` failes in `apply_rotary_emb` when `freqs.shape[-1] == t.shape[-1]`, and `start_index = 0`

Open mattaltberg opened this issue 1 year ago • 1 comments

Two issues into one, as they seemingly come from the same function.

Right now, if I torch.jit.trace a module that uses rotate_queries_or_keys(), I hit the following TracerWarning:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'

Adding to this, if you try to use torch.neuron.trace to trace the same module, I've been hitting this error:

MultiheadAttentionRoPE_8/aten_slice_14/StridedSlice", begin=[0, 0, 0, 0], end=[1, 8, 2, 0], strides=[1, 1, 1, 1]) // an internal invariant was violdated while typechecking your program [21:33:36] /opt/brazil-pkg-cache/packages/DmlcTvm/DmlcTvm-1.18.2.0/AL2_x86_64/generic-flavor/src/src/relay/op/tensor/transform.cc:2111: Check failed: begin_v < end_v (0 vs. 0) : strided_slice get empty slice at axis 3

The tracing works when I comment out the lines that use the rotate_queries_or_keys call.

UPDATE

I did more debugging into this issue, and it looks like it occurs when t_left and t_right end up with a 0-length axis. For example, I have a starting t tensor of shape [1, 8, 2, 80], and a freqs shape of [x, 80]. This means that rot_dim == t.shape[-1]. After the slicing, I end up with:

t.shape = [1, 8, 2, 80]
t_left.shape = [1, 8, 2, 0]
t_right.shape = [1, 8, 2, 0]

This ends up breaking the torch.cat neuron trace posted above.

mattaltberg avatar Feb 29 '24 22:02 mattaltberg

UPDATE 2

The fix was to only run torch.cat when t_left or t_right had all non-zeros in their Size tensors

mattaltberg avatar Mar 01 '24 20:03 mattaltberg