RoPE implemetation, Torch.polar op not available in MIL requesting support
🌱 Describe your Feature Request
- having support for converting torch.polar to MIL will allow using RoPE embeddings which is fairly common now days currently get this issue:RuntimeError: PyTorch convert function for op 'polar' not implemented.
How can this feature be used?
In Stable diffusion and LLM with Rope
FWIW: In some of our internal experiments we realized that computing the RoPE sinusoidals wasn't very optimal (especially using complex numbers which the stack decomposes into tensors anyway), so we decided to pre-compute them in the torch Module's __init__:
# shape: [max_sequence_len, head_dim]
self.register_buffer("cos_cached", torch.tensor(np.cos(emb, dtype=np.float32)))
self.register_buffer("sin_cached", torch.tensor(np.sin(emb, dtype=np.float32)))
and then apply them as follows:
cos, sin = self.cos_cached[token_indices], self.sin_cached[token_indices]
rope_embedded = x * cos + rotate_half(x) * sin
where rotate_half is something like:
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.concat((-x2, x1), dim=-1)
@junpeiz since you have some context on the complex numbers stuff, how difficult would it be to support torch.polar for convenience?
Shouldn't be very difficult, as the complex result construction is straightforward abs⋅cos(angle)+abs⋅sin(angle)⋅j.
@nighting0le01 Thank you for filing this feature request! We will add it based on priority and bandwidth. Meanwhile, feel free to try it on your side by following how add supports complex: https://github.com/apple/coremltools/blob/0e292a072452db19d1e64b687a372c0c54704a90/coremltools/converters/mil/frontend/torch/ops.py#L917
(More examples could be found if you search complex in that file)