Horace He

Results 242 comments of Horace He

@vadimkantorov You can either save the input to the function or the relu mask. For fp32 it's better to save the mask, but for fp16 it's about equal.

> Can't the mask be recovered again from the torch.sign applied to saved_output? You could have a gradient formula that called torch.sign on saved_output, but that would require saving the...

@vadimkantorov At least from my benchmarks, I think torch.compile will just give as good/better results as your manual impl? ``` import torch import torch.nn.functional as F def bench(f, name=None, iters=100,...

It might be worth checking out timm models with channels last. Iirc, one of the big places where we see this kind of pattern messing up Inductor is outer-dim reductions,...

This has the same root cause as https://github.com/pytorch/pytorch/issues/53678, except that the issue in that case was previously hidden, but now it's not. @bhosmer

Yeah, the minifier doesn't really work with symbolic shapes right now.

@L-Reichardt Can you post the exact operation you want compiled? In my experience, torch.compile already works fairly well on many of these KeOps-like patterns.

I was suggesting that you might be able to replace the KeOps kernel by just writing it directly in PyTorch. Unfortunately, I don't think we support an operation like `argkmin`...