Horace He
Horace He
Haha, this is kind of a dumb issue, if I understand correctly. Basically, `make_functional` returns a "functional" `nn.Module` (that doesn't actually have any params on it). But... PyTorch still thinks...
Yeah it's removed on main I believe - this was a hack we used to have to work around nvfuser limitations. But that's been fixed now.
Kinda :) We have something called AOTAutograd that can be used similarly to `jax.jit` (although certainly not as mature). Specifically, you can use `from functorch.compile import memory_efficient_fusion` So, you can...
Mmmm, so... this is probably the wrong way to compile through vmap :P It might technically work today, but it's very much not intended usage. In general we cannot (currently)...
Unfortunately, we don't support kwargs properly right now with AOTAutograd (particularly with static_argnums). So, it works if you remove the `p=0.3` from the function.
@mohamad-amin Your colab link needs to be shared. The library is still currently in rapid development, so I wouldn't be shocked if there were fixes between 1.10 and nightly.
@ain-soph Well, you're only computing the explicit empirical NTK in this example :P You may find computing the implicit NTK harder with existing PyTorch APIs. That being said, in many...
I think this is the same issue as https://github.com/pytorch/pytorch/issues/124423
`aten::transpose` doesn't have a non-overloaded version, so you just need to add a specific overload: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4577 So, in this case ``` @register_decomposition([aten.transpose.int], decompositions) def transpose(x, dim0: int, dim1: int): ndim...
Yeah, will also add a docstring to it.