functorch
functorch copied to clipboard
[AOT Autograd] 'aten.transpose' has no overload name 'default'
This decomposition:
@register_decomposition([aten.transpose], decompositions)
def transpose(x, dim0: int, dim1: int):
ndim = x.ndimension()
dims = list(range(ndim))
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return torch.permute(x, dims)
Gives this error:
../functorch/functorch/_src/decompositions.py:29: in decomposition_decorator
tree_map(add_op_to_table, aten_op)
../pytorch/torch/utils/_pytree.py:179: in tree_map
return tree_unflatten([fn(i) for i in flat_args], spec)
../pytorch/torch/utils/_pytree.py:179: in <listcomp>
return tree_unflatten([fn(i) for i in flat_args], spec)
../functorch/functorch/_src/decompositions.py:24: in add_op_to_table
op_overload = aten_op.default
../pytorch/torch/_ops.py:133: in __getattr__
raise AttributeError(
E AttributeError: The underlying op of 'aten.transpose' has no overload name 'default'
cc @Chillee
aten::transpose doesn't have a non-overloaded version, so you just need to add a specific overload: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml#L4577
So, in this case
@register_decomposition([aten.transpose.int], decompositions)
def transpose(x, dim0: int, dim1: int):
ndim = x.ndimension()
dims = list(range(ndim))
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
return torch.permute(x, dims)
should work.
Will add a better error message for this case.
Yup, that worked. Thanks!
Yeah, will also add a docstring to it.