[torch] Add Canonicalize Pattern for embedding op
Converts PrimConvertOp followed by Embedding -> Embedding followed by PrimConvertOp. We don't need to cast the entire matrix; just the output of the embedding op.
Issue: https://github.com/iree-org/iree/issues/17226#issuecomment-2089655421
I can see benefit to this optimization however this is more avoiding the compilation issue we have been encountering rather than preventing the crash.
Note this could also be a pessimization: if you have your embeddings as f32, gather them, and convert to f16 you really want the conversion to fold into the embeddings so you aren't shipping (and doing the memory transactions) on f32 if you don't need those bits. This may get taken care of later in the pipeline but it's important to note that there are some massive implications of things like this (it's always better to hoist narrowing operations and sink widening operations, almost never the opposite).
To avoid hurting performance we should only perform the swap during the widening case, otherwise we are potentially loading more data just to truncate back down whereas there is benefit to truncating overall.
There's a tradeoff between memory and compute; doing this might take more memory but is less compute-intensive, whereas the one suggested might be compute-intensive since we are not able to fuse both kernels at the backend. I will add the check to perform a swap only during the widening case.
To avoid hurting performance we should only perform the swap during the widening case, otherwise we are potentially loading more data just to truncate back down whereas there is benefit to truncating overall.
I've made the necessary changes. Please review.
Hi @pashu123, it seems the PR has been there for quite some time. Can you please update it in order to get merged?
Hi @pashu123, it seems the PR has been there for quite some time. Can you please update it in order to get merged?
It’s not needed. I’ll close the PR.