functorch icon indicating copy to clipboard operation
functorch copied to clipboard

vmap should accept a dim_size=None argument

Open zou3519 opened this issue 3 years ago • 0 comments

vmap should accept a dim_size=None argument where the user is allowed to specify the size of the dimension being vmapped over. Should behave similarly to JAX's axis_name argument.

The net effect of this is that one should be able to vmap over functions that do not take Tensors as input!

def f():
  return torch.tensor(1.)

result = vmap(f, dim_size=5)()
assert torch.allclose(result, torch.tensor([1., 1., 1., 1., 1.]))

We should also investigate if there are other things that the axis_size arg in JAX provides.

zou3519 avatar Dec 09 '22 18:12 zou3519