bevfusion
bevfusion copied to clipboard
How to change the lidar backbone to pointpillars
I'm running the fusion model. And I want to change the lidar backbone from voxelnet to pointpillars. I modified the configs, but it didn't work.
I hope to receive your help.
Could you provide more details about why "it didn't work"?
I modified the https://github.com/mit-han-lab/bevfusion/blob/main/configs/nuscenes/det/transfusion/secfpn/camera%2Blidar/default.yaml.
lidar:
voxelize_reduce: false
voxelize:
max_num_points: 20
point_cloud_range: ${point_cloud_range}
voxel_size: ${voxel_size}
max_voxels: ${max_voxels}
backbone:
type: PointPillarsEncoder
pts_voxel_encoder:
type: PillarFeatureNet
in_channels: 5
feat_channels: [64, 64]
with_distance: false
point_cloud_range: ${point_cloud_range}
voxel_size: ${voxel_size}
norm_cfg:
type: BN1d
eps: 1.0e-3
momentum: 0.01
pts_middle_encoder:
type: PointPillarsScatter
in_channels: 64
output_shape: [512, 512]
I changed the lidar backbone and added conv2d and interpolate in https://github.com/mit-han-lab/bevfusion/blob/main/mmdet3d/models/backbones/pillar_encoder.py#L244 to resize the tensor to fit the fusion model.
@BACKBONES.register_module()
class PointPillarsEncoder(nn.Module):
def __init__(
self,
pts_voxel_encoder: Dict[str, Any],
pts_middle_encoder: Dict[str, Any],
**kwargs,
):
super().__init__()
self.pts_voxel_encoder = build_backbone(pts_voxel_encoder)
self.pts_middle_encoder = build_backbone(pts_middle_encoder)
self.conv1 = nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1) # added
def forward(self, feats, coords, batch_size, sizes):
# feats: [[30000, 20, 5]] coords: [30000, 4] sizes: [30000]
batch_size = batch_size.item()
x = self.pts_voxel_encoder(feats, sizes, coords)
x = self.pts_middle_encoder(x, coords, batch_size) # [1, 64, 512, 512]
# import pdb;pdb.set_trace()
x = self.conv1(x)
x = F.interpolate(x, size=(180, 180), mode='bilinear', align_corners=False)
return x
It can run, but got 0 mAP. I know my way is too rough, so I'm thinking how to improve it.
Can you share some experience? Thanks very much!
hello,did you solve this problem? Can you send me the relevant configuration? Thank you very much @wyf0414
@wyf0414 hi , I had this problem too, could you please tell me if you've solved it now? Thanks.
@kentang-mit hi, when I use the config file : configs/nuscenes/det/transfusion/secfpn/lidar/pointpillars.yaml to get the pointpillars backbone model. There us a part of the code that confuses me.
`@BACKBONES.register_module()
class PointPillarsScatter(nn.Module):
def init(self, in_channels=64, output_shape=(512, 512), **kwargs):
"""
Point Pillar's Scatter.
Converts learned features from dense tensor to sparse pseudo image. This replaces SECOND's
second.pytorch.voxelnet.SparseMiddleExtractor.
:param output_shape: ([int]: 4). Required output shape of features.
:param num_input_features:
super().__init__()
self.in_channels = in_channels
self.output_shape = output_shape
self.nx = output_shape[0]
self.ny = output_shape[1]
def extra_repr(self):
return (
f"in_channels={self.in_channels}, output_shape={tuple(self.output_shape)}"
)
def forward(self, voxel_features, coords, batch_size):
# batch_canvas will be the final output.
batch_canvas = []
for batch_itt in range(batch_size):
# Create the canvas for this sample
canvas = torch.zeros(
self.in_channels,
self.nx * self.ny,
dtype=voxel_features.dtype,
device=voxel_features.device,
)
# Only include non-empty pillars
batch_mask = coords[:, 0] == batch_itt
this_coords = coords[batch_mask, :]
# modified -> xyz coords
indices = this_coords[:, 1] * self.ny + this_coords[:, 2]
indices = indices.type(torch.long)
voxels = voxel_features[batch_mask, :]
voxels = voxels.t()
# Now scatter the blob back to the canvas.
canvas[:, indices] = voxels
# Append to a list for later stacking.
batch_canvas.append(canvas)
# Stack to 3-dim tensor (batch-size, nchannels, nrows*ncols)
batch_canvas = torch.stack(batch_canvas, 0)
# Undo the column stacking to final 4-dim tensor
batch_canvas = batch_canvas.view(batch_size, self.in_channels, self.nx, self.ny)
return batch_canvas`
As manipulated in the code, the resulting batch_canvas is: [batch_size, self.in_channels, W,H]. Shouldn't the correct dimension be [batch, in_channels, H, W]?
@zyh321 @LINSJGM Hey, guys. Sorry for late. I've solved this problem. It took much time for me, a beginner. Changes are as follows.
For convenience, you can use the "configs.yaml" in any runs/run-*****.
Move it to "configs/" dictionary.
Rename as "my_config.yaml"
So you can train by CUDA_VISIBLE_DEVICES=1 torchpack dist-run -np 1 python tools/train.py configs/my_config.yaml --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth --load_from pretrained/lidar-only-det.pth
In "fuser/conv.py", the camera feature [B, 80, 180, 180] and lidar feature [B, 256, 180, 180] are concat and conv. But in "backbones/pillar_encoder.py", the lidar feature is [B, 64, 512, 512] finally. We need the lidar feature to be [B, 256, 180, 180].
The lidar feature is [B, 64, 512, 512] because the grid size is [512, 512] and the point_cloud_range is [-51.2, -51.2, -5, 51.2, 51.2, 3] and output_shape is [512, 512] and voxel_size is [0.2, 0.2, 8] We need to change these parameters to be the same as the official configs.yaml. There are mainly two options. Option A: voxel_size: [0.075, 0.075, 8] point_cloud_range: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0] grid_size: [1440, 1440, 1] # 54x2/0.075=1440 output_shape: [1440, 1440] out_size_factor: 8 # 1440/8=180 Option B: voxel_size: [0.15, 0.15, 8] point_cloud_range: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0] grid_size: [720, 720, 1] # 54x2/0.15=720 output_shape: [720, 720] out_size_factor: 4 # 720/4=180
Then we are going to modify "my_config.yaml"
First, replace the lidar backbone as pointpillars at line 939--994.
lidar: backbone: pts_middle_encoder: in_channels: 64 output_shape: - 720 - 720 type: PointPillarsScatter pts_voxel_encoder: feat_channels: - 64 - 64 in_channels: 5 norm_cfg: eps: 0.001 momentum: 0.01 type: BN1d point_cloud_range: - -54.0 - -54.0 - -5.0 - 54.0 - 54.0 - 3.0 type: PillarFeatureNet voxel_size: - 0.15 - 0.15 - 8 with_distance: false type: PointPillarsEncoder voxelize: max_num_points: 20 max_voxels: - 30000 - 60000 point_cloud_range: - -54.0 - -54.0 - -5.0 - 54.0 - 54.0 - 3.0 voxel_size: - 0.15 - 0.15 - 8 voxelize_reduce: false
Second, change all the "voxel_size" to "0.15 0.15 8" change all the "point_cloud_range" to "-54.0 -54.0 -5.0 54.0 54.0 3.0" change all the "output_shape" to "720 720" change all the "grid_size" to "720 720 1" change all the "out_size_factor" to "4"
Finally, you should to add some convs in "backbones/pillar_encoder.py" to change the lidar feature from [B, 64, 720, 720] to [B, 256, 180, 180]
Then you can train a BEVFusion with pointpillars as lidar backbone. Good luck!