jax
jax copied to clipboard
Custom gpu ops partitioning
PR to update tutorial to demonstrate the use of custom_partitioning instead of xmap when writing a custom primitive in JAX
General comment: should we perhaps choose a simpler op/ops for this? RMSNorm needs a lot of CUDA code which is not directly relevant here.
People don't need to read the CUDA code here. If we took something simpler, we won't be able to show all features, like multi-output.
If we find something simpler that still show all features, that would be great. But I don't have an idea. But this should be a follow up PR I think.
General comment: should we perhaps choose a simpler op/ops for this? RMSNorm needs a lot of CUDA code which is not directly relevant here.
People don't need to read the CUDA code here. If we took something simpler, we won't be able to show all features, like multi-output.
If we find something simpler that still show all features, that would be great. But I don't have an idea. But this should be a follow up PR I think.
@jaro-sevcik had similar perception I believe ... any ideas here?
If we took something simpler, we won't be able to show all features, like multi-output.
What are some of the other features we want to demonstrate?
If we find something simpler that still show all features, that would be great. But I don't have an idea. But this should be a follow up PR I think.
Yeah, I agree this could be done in a follow up.
@keshavb96 is the PR ready for another review?
Yes, I think it should be good to go
Can you squash the commits, please? I suspect that could be the reason for the import/copybara check failing.
@superbobry It still seems like copybara is failing despite squashing all the commits
It looks like Copybara is unhappy that some files are missing the Apache license header. Can you make sure all files have one, please?
Can you squash the commits, please?
Done!