mmengine icon indicating copy to clipboard operation
mmengine copied to clipboard

Supports mps device IndexType

Open bibekyess opened this issue 1 year ago • 1 comments

Motivation

When doing Inference on MMDet CoDINO model, the inference fails raising assertion error in mps device. Specifically this line: https://github.com/open-mmlab/mmengine/blob/85c83ba61689907fb1775713622b1b146d82277b/mmengine/structures/instance_data.py#L175

The issue is that there are no types defined for tensors loaded in mps and mps doesn’t support the following:

AttributeError: module 'torch.mps' has no attribute 'BoolTensor’
AttributeError: module 'torch.mps' has no attribute 'LongTensor’

If I define

elif get_device() == 'mps':
    BoolTypeTensor = Union[torch.BoolTensor, torch.Tensor]
    LongTypeTensor = Union[torch.LongTensor, torch.Tensor]

it throws this error:

File "../site-packages/mmengine/structures/instance_data.py", line 207, in __getitem__
    assert len(item) == len(self), 'The shape of the ' \
AssertionError: The shape of the input(BoolTensor) 257 does not match the shape of the indexed tensor in results_field 300 at first dimension.

So a simple fix is to add torch.Tensor to IndexType types. Thanks!!

bibekyess avatar Sep 05 '24 02:09 bibekyess

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Sep 05 '24 02:09 CLAassistant