functorch
functorch copied to clipboard
vmap should accept a dim_size=None argument
vmap should accept a dim_size=None argument where the user is allowed to specify the size of the dimension being vmapped over. Should behave similarly to JAX's axis_name argument.
The net effect of this is that one should be able to vmap over functions that do not take Tensors as input!
def f():
return torch.tensor(1.)
result = vmap(f, dim_size=5)()
assert torch.allclose(result, torch.tensor([1., 1., 1., 1., 1.]))
We should also investigate if there are other things that the axis_size arg in JAX provides.