Horace He
Horace He
@vadimkantorov You can either save the input to the function or the relu mask. For fp32 it's better to save the mask, but for fp16 it's about equal.
> Can't the mask be recovered again from the torch.sign applied to saved_output? You could have a gradient formula that called torch.sign on saved_output, but that would require saving the...
@vadimkantorov At least from my benchmarks, I think torch.compile will just give as good/better results as your manual impl? ``` import torch import torch.nn.functional as F def bench(f, name=None, iters=100,...
It might be worth checking out timm models with channels last. Iirc, one of the big places where we see this kind of pattern messing up Inductor is outer-dim reductions,...
This has the same root cause as https://github.com/pytorch/pytorch/issues/53678, except that the issue in that case was previously hidden, but now it's not. @bhosmer
Assigning this to @wconstab for hook support.
Yeah, the minifier doesn't really work with symbolic shapes right now.
@L-Reichardt Can you post the exact operation you want compiled? In my experience, torch.compile already works fairly well on many of these KeOps-like patterns.
I was suggesting that you might be able to replace the KeOps kernel by just writing it directly in PyTorch. Unfortunately, I don't think we support an operation like `argkmin`...
Yeah we don't properly fuse topk today either :)