mmengine
mmengine copied to clipboard
Supports mps device IndexType
Motivation
When doing Inference on MMDet CoDINO model, the inference fails raising assertion error in mps device. Specifically this line: https://github.com/open-mmlab/mmengine/blob/85c83ba61689907fb1775713622b1b146d82277b/mmengine/structures/instance_data.py#L175
The issue is that there are no types defined for tensors loaded in mps and mps doesn’t support the following:
AttributeError: module 'torch.mps' has no attribute 'BoolTensor’
AttributeError: module 'torch.mps' has no attribute 'LongTensor’
If I define
elif get_device() == 'mps':
BoolTypeTensor = Union[torch.BoolTensor, torch.Tensor]
LongTypeTensor = Union[torch.LongTensor, torch.Tensor]
it throws this error:
File "../site-packages/mmengine/structures/instance_data.py", line 207, in __getitem__
assert len(item) == len(self), 'The shape of the ' \
AssertionError: The shape of the input(BoolTensor) 257 does not match the shape of the indexed tensor in results_field 300 at first dimension.
So a simple fix is to add torch.Tensor to IndexType types.
Thanks!!