mmcv
mmcv copied to clipboard
[Bug] ms_deform_attn_forward_cuda do not support BF16
Prerequisite
- [X] I have searched Issues and Discussions but cannot get the expected help.
- [X] The bug has not been fixed in the latest version(https://github.com/open-mmlab/mmcv).
Environment
OrderedDict([('sys.platform', 'linux'), ('Python', '3.10.12 (main, Jul 5 2023, 18:54:27) [GCC 11.2.0]'), ('CUDA available', True), ('numpy_random_seed', 2147483648), ('GPU 0', 'NVIDIA GeForce RTX 4090'), ('CUDA_HOME', '/usr/local/cuda'), ('NVCC', 'Cuda compilation tools, release 12.2, V12.2.91'), ('GCC', 'gcc (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0'), ('PyTorch', '2.0.1'), ('PyTorch compiling details', 'PyTorch built with:\n - GCC 9.3\n - C++ Version: 201703\n - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - LAPACK is enabled (usually provided by MKL)\n - NNPACK is enabled\n - CPU capability usage: AVX2\n - CUDA Runtime 11.8\n - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_37,code=compute_37\n - CuDNN 8.7\n - Magma 2.6.1\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n'), ('TorchVision', '0.15.2'), ('OpenCV', '4.8.0'), ('MMEngine', '0.8.2'), ('MMCV', '2.0.1'), ('MMCV Compiler', 'GCC 9.3'), ('MMCV CUDA Compiler', '11.8')])
Reproduces the problem - code sample
Add the following sections to Mask2Former Config :
[mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py](mmseg::mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py)
optim_wrapper = dict(
_delete_=True,
type='AmpOptimWrapper',
dtype='bfloat16',
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg=dict(
custom_keys=custom_keys,
norm_decay_mult=0.0))
Reproduces the problem - command or script
tools/train.py configs/mask2former/mask2former_swin-b-in22k-384x384-pre_8xb2-160k_ade20k-640x640.py --amp
Reproduces the problem - error message
RuntimeError : " ms_deform_attn_forward_cuda " not implemented for ' BFloat16 '
Additional information
Do we consider adding a conditional check once the BFloat16 is not implemented with CUDA ops, then fallback to native PyTorch implementation?
I have built a temporary fix but it is ugly and complicated :
import copy
import math
import warnings
from typing import Optional, no_type_check
import mmengine
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Union
from mmdet.models.task_modules import MlvlPointGenerator
from mmdet.utils import ConfigType, OptMultiConfig
from mmcv.cnn import ConvModule
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch
from mmdet.models import MSDeformAttnPixelDecoder, Mask2FormerTransformerEncoder, DeformableDetrTransformerEncoderLayer, \
SinePositionalEncoding
from mmengine.model import BaseModule, constant_init, xavier_init
from mmengine.registry import MODELS
from mmengine.utils import deprecated_api_warning
from torch.autograd.function import Function, once_differentiable
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
from mmcv.ops import MultiScaleDeformableAttention
from mmengine.registry import MODELS
from mmdet.registry import MODELS as MMDET_MODELS
from torch.nn import ModuleList
from mmcv.cnn import Conv2d, ConvModule
@MODELS.register_module()
class TorchMultiScaleDeformableAttention(MultiScaleDeformableAttention):
def __init__(self, **kwargs):
# cfg = copy.deepcopy(kwargs)
self.force_not_using_cuda_ops = force_not_using_cuda_ops
super().__init__(**kwargs)
@no_type_check
@deprecated_api_warning({'residual': 'identity'},
cls_name='MultiScaleDeformableAttention')
def forward(self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
identity: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
reference_points: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.Tensor] = None,
level_start_index: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""Forward Function of MultiScaleDeformAttention.
Args:
query (torch.Tensor): Query of Transformer with shape
(num_query, bs, embed_dims).
key (torch.Tensor): The key tensor with shape
`(num_key, bs, embed_dims)`.
value (torch.Tensor): The value tensor with shape
`(num_key, bs, embed_dims)`.
identity (torch.Tensor): The tensor used for addition, with the
same shape as `query`. Default None. If None,
`query` will be used.
query_pos (torch.Tensor): The positional encoding for `query`.
Default: None.
key_padding_mask (torch.Tensor): ByteTensor for `query`, with
shape [bs, num_key].
reference_points (torch.Tensor): The normalized reference
points with shape (bs, num_query, num_levels, 2),
all elements is range in [0, 1], top-left (0,0),
bottom-right (1, 1), including padding area.
or (N, Length_{query}, num_levels, 4), add
additional two dimensions is (w, h) to
form reference boxes.
spatial_shapes (torch.Tensor): Spatial shape of features in
different levels. With shape (num_levels, 2),
last dimension represents (h, w).
level_start_index (torch.Tensor): The start index of each level.
A tensor has shape ``(num_levels, )`` and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
Returns:
torch.Tensor: forwarded results with shape
[num_query, bs, embed_dims].
"""
if value is None:
value = query
if identity is None:
identity = query
if query_pos is not None:
query = query + query_pos
if not self.batch_first:
# change to (bs, num_query ,embed_dims)
query = query.permute(1, 0, 2)
value = value.permute(1, 0, 2)
bs, num_query, _ = query.shape
bs, num_value, _ = value.shape
assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
value = self.value_proj(value)
if key_padding_mask is not None:
value = value.masked_fill(key_padding_mask[..., None], 0.0)
value = value.view(bs, num_value, self.num_heads, -1)
sampling_offsets = self.sampling_offsets(query).view(
bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
attention_weights = self.attention_weights(query).view(
bs, num_query, self.num_heads, self.num_levels * self.num_points)
attention_weights = attention_weights.softmax(-1)
attention_weights = attention_weights.view(bs, num_query,
self.num_heads,
self.num_levels,
self.num_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack(
[spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets \
/ offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.num_points \
* reference_points[:, :, None, :, None, 2:] \
* 0.5
else:
raise ValueError(
f'Last dim of reference_points must be'
f' 2 or 4, but get {reference_points.shape[-1]} instead.')
if ((IS_CUDA_AVAILABLE and value.is_cuda)
or (IS_MLU_AVAILABLE and value.is_mlu)) and not self.force_not_using_cuda_ops:
output = MultiScaleDeformableAttnFunction.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
output = self.output_proj(output)
if not self.batch_first:
# (num_query, bs ,embed_dims)
output = output.permute(1, 0, 2)
return self.dropout(output) + identity
class TorchDeformableDetrTransformerEncoderLayer(DeformableDetrTransformerEncoderLayer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# self._init_layers()
def _init_layers(self) -> None:
"""Initialize self_attn, ffn, and norms."""
super()._init_layers()
self.self_attn = TorchMultiScaleDeformableAttention(**self.self_attn_cfg)
class TorchMask2FormerTransformerEncoder(Mask2FormerTransformerEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._init_layers()
def _init_layers(self) -> None:
"""Initialize encoder layers."""
self.layers = ModuleList([
TorchDeformableDetrTransformerEncoderLayer(**self.layer_cfg)
for _ in range(self.num_layers)
])
self.embed_dims = self.layers[0].embed_dims
# pass
@MMDET_MODELS.register_module()
class TorchMSDeformAttnPixelDecoder(MSDeformAttnPixelDecoder):
def __init__(self, *args, **kwargs):
encoder_cfg = kwargs.get('encoder', None)
global force_not_using_cuda_ops
force_not_using_cuda_ops = kwargs['encoder']['layer_cfg']['self_attn_cfg'].pop('force_not_using_cuda_ops', None)
super().__init__(*args, **kwargs)
self.encoder = TorchMask2FormerTransformerEncoder(**encoder_cfg)
self.init_weights()