The question about function create_compressed_model():RuntimeError: CUDA error: device-side assert triggered
🐛 Describe the bug
Hi, I want to ask about the quantization of the non-neural network parts in the neural network. I want to realize the quantization of the PointNet-based network to deal with classification tasks
The other part of the network are just some conv1d and relu function. Then I modified the train function with changed dataset and evaluation method. However the error occurred . Can you help me fix this problem? This is the error information:
Traceback (most recent call last):
File "/home/lalala/examples/torch/classification/quant_train_cls.py", line 874, in <module>
main(sys.argv[1:])
File "/home/lalala/examples/torch/classification/quant_train_cls.py", line 151, in main
start_worker(main_worker, config)
File "/home/lalala/examples/torch/common/execution.py", line 114, in start_worker
mp.spawn(main_worker, nprocs=config.ngpus_per_node, args=(config,))
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 197, in start_processes
while not context.join():
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 158, in join
raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException:
-- Process 1 terminated with the following error:
Traceback (most recent call last):
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
fn(i, *args)
File "/home/lalala/examples/torch/classification/quant_train_cls.py", line 267, in main_worker
compression_ctrl, model = create_compressed_model(
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/telemetry/decorator.py", line 72, in wrapped
retval = fn(*args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/model_creation.py", line 134, in create_compressed_model
compressed_model = builder.apply_to(nncf_network)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/compression_method_api.py", line 124, in apply_to
transformed_model = transformer.transform(transformation_layout)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/model_transformer.py", line 78, in transform
model.nncf.rebuild_graph()
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/nncf_network.py", line 555, in rebuild_graph
compressed_traced_graph = builder.build_dynamic_graph(
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/graph/graph_builder.py", line 53, in build_dynamic_graph
return tracer.trace_graph(model, context_to_use, as_eval, trace_parameters)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/graph_tracer.py", line 53, in trace_graph
self.custom_forward_fn(model)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/graph_tracer.py", line 96, in default_dummy_forward_fn
retval = model(*args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/nncf_network.py", line 1004, in __call__
return ORIGINAL_CALL(self, *args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/nncf_network.py", line 1036, in forward
retval = wrap_module_call(self.nncf._original_unbound_forward)(self, *args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 154, in wrapped
retval = module_call(self, *args, **kwargs)
File "/home/lalala/examples/torch/classification/models/pointnet2_cls.py", line 39, in forward
l1_xyz, l1_points = self.sa1(xyz, norm)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 154, in wrapped
retval = module_call(self, *args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/lalala/examples/torch/classification/models/pointnet2_utils.py", line 197, in forward
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
File "/home/lalala/examples/torch/classification/models/pointnet2_utils.py", line 131, in sample_and_group
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
File "/home/lalala/examples/torch/classification/models/pointnet2_utils.py", line 59, in index_points
new_points = points[batch_indices, idx, :]
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 98, in wrapped
result = _execute_op(op_address, operator_info, operator, ctx, *args, **kwargs)
File "/home/lalala/anaconda3/envs/quantization/lib/python3.9/site-packages/nncf/torch/dynamic_graph/wrappers.py", line 179, in _execute_op
result = operator(*args, **kwargs)
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Environment
nccf==2.10.0 pytorch==2.2.2 python==3.9
Minimal Reproducible Example
The non-neural network parts of the network is as follow:
def square_distance(src, dst):
Calculate Euclid distance between each two points.
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dis
def index_points(points, idx):
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint):
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3]
idx = query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points):
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
# print(new_points.shape)
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0] # max at dimension of nsample
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
HI @zbnlala
Could you please provide a full script to reproduce the issue?
HI @AlexanderDokuchaev. Sure!
First, my code is modified from examples/torch/classification and put the train.py into this dir.
Note that, for train.py you just need to focus on the function main_worker().
Moreover, the modified dataset code is ModelNetDataLoader.py placed in examples/torch/classification/data_utils(the dir is created by myself). From now the code must run on the modelnet dataset, I do not know how to run create_compressed_model() without the dataloader. Therefore how to get dataset is here.
the model architecture is described in pointnet2_cls.py and pointnet2_utils.py placed in
examples/torch/classification/models.
And then, the json is pointnet_v2_classification_int8.json placed in examples/torch/classification/configs/quantization.
Finally , run the script will cause the bug
CUDA_VISIBLE_DEVICES="0,1" NCCL_P2P_DISABLE=1 python train.py -m test --config configs/quantization/pointnet_v2_classification_int8.json --log-dir=../../results/quantization/pointnet_v2_int8/ --multiprocessing-distributed
However, if I run script with extra --cpu-only, there will no bug.(CUDA version 12.3)
train.py
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# # limitations under the License.
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import os.path as osp
import sys
import time
import warnings
from copy import deepcopy
from functools import partial
from pathlib import Path
from shutil import copyfile
from typing import Any
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch import nn
from torch.backends import cudnn
from torch.cuda.amp.autocast_mode import autocast
from torch.nn.modules.loss import _Loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets
from torchvision import models
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision.models import InceptionOutputs
from examples.common.paths import configure_paths
from examples.common.sample_config import SampleConfig
from examples.common.sample_config import create_sample_config
from examples.torch.common.argparser import get_common_argument_parser
from examples.torch.common.argparser import parse_args
from examples.torch.common.example_logger import logger
from examples.torch.common.execution import ExecutionMode
from examples.torch.common.execution import get_execution_mode
from examples.torch.common.execution import prepare_model_for_execution
from examples.torch.common.execution import set_seed
from examples.torch.common.execution import start_worker
from examples.torch.common.export import export_model
from examples.torch.common.model_loader import COMPRESSION_STATE_ATTR
from examples.torch.common.model_loader import MODEL_STATE_ATTR
from examples.torch.common.model_loader import extract_model_and_compression_states
from examples.torch.common.model_loader import load_model
from examples.torch.common.model_loader import load_resuming_checkpoint
from examples.torch.common.optimizer import get_parameter_groups
from examples.torch.common.optimizer import make_optimizer
from examples.torch.common.utils import MockDataset
from examples.torch.common.utils import NullContextManager
from examples.torch.common.utils import configure_device
from examples.torch.common.utils import configure_logging
from examples.torch.common.utils import create_code_snapshot
from examples.torch.common.utils import get_run_name
from examples.torch.common.utils import is_pretrained_model_requested
from examples.torch.common.utils import is_staged_quantization
from examples.torch.common.utils import make_additional_checkpoints
from examples.torch.common.utils import print_args
from examples.torch.common.utils import write_metrics
from nncf.api.compression import CompressionStage
from nncf.common.accuracy_aware_training import create_accuracy_aware_training_loop
from nncf.common.utils.tensorboard import prepare_for_tensorboard
from nncf.config.utils import is_accuracy_aware_training
from nncf.torch import create_compressed_model
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.initialization import default_criterion_fn
from nncf.torch.initialization import register_default_init_args
from nncf.torch.structures import ExecutionParameters
from nncf.torch.utils import is_main_process
from nncf.torch.utils import safe_thread_call
from examples.torch.classification.models.pointnet2_cls import pointnetv2_cls
from examples.torch.common import restricted_pickle_module
from data_utils.ModelNetDataLoader import ModelNetDataLoader
model_names = sorted(
name
for name, val in models.__dict__.items()
if name.islower() and not name.startswith("__") and callable(val)
)
def get_argument_parser():
parser = get_common_argument_parser()
parser.add_argument(
"--dataset",
help="Dataset to use.",
choices=["imagenet", "cifar100", "cifar10"],
default=None,
)
parser.add_argument(
"--local_rank",
default=None,
)
parser.add_argument(
"--test-every-n-epochs",
default=1,
type=int,
help="Enables running validation every given number of epochs",
)
parser.add_argument(
"--mixed-precision",
dest="mixed_precision",
help="Enables torch.cuda.amp autocasting during training and validation steps",
action="store_true",
)
return parser
def main(argv):
parser = get_argument_parser()
args = parse_args(parser, argv)
config = create_sample_config(args, parser)
if config.dist_url == "env://":
config.update_from_env()
configure_paths(config, get_run_name(config))
copyfile(args.config, osp.join(config.log_dir, "config.json"))
source_root = Path(__file__).absolute().parents[2] # nncf root
create_code_snapshot(source_root, osp.join(config.log_dir, "snapshot.tar.gz"))
if config.seed is not None:
warnings.warn(
"You have chosen to seed training. "
"This will turn on the CUDNN deterministic setting, "
"which can slow down your training considerably! "
"You may see unexpected behavior when restarting "
"from checkpoints."
)
config.execution_mode = get_execution_mode(config)
if config.metrics_dump is not None:
write_metrics(0, config.metrics_dump)
if not is_staged_quantization(config):
start_worker(main_worker, config)
else:
from examples.torch.classification.staged_quantization_worker import (
staged_quantization_main_worker,
)
start_worker(staged_quantization_main_worker, config)
def inception_criterion_fn(
model_outputs: Any, target: Any, criterion: _Loss
) -> torch.Tensor:
# From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
output, aux_outputs = model_outputs
loss1 = criterion(output, target)
loss2 = criterion(aux_outputs, target)
return loss1 + 0.4 * loss2
def main_worker(current_gpu, config: SampleConfig):
configure_device(current_gpu, config)
if is_main_process():
configure_logging(logger, config)
print_args(config)
else:
config.tb = None
set_seed(config)
# define loss function (criterion)
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(config.device)
model_name = config["model"]
train_criterion_fn = (
inception_criterion_fn if "inception" in model_name else default_criterion_fn
)
train_loader = train_sampler = val_loader = None
resuming_checkpoint_path = config.resuming_checkpoint_path
nncf_config = config.nncf_config
pretrained = is_pretrained_model_requested(config)
is_export_only = "export" in config.mode and (
"train" not in config.mode and "test" not in config.mode
)
if is_export_only:
assert pretrained or (resuming_checkpoint_path is not None)
else:
# Data loading code
# train_dataset, val_dataset = create_datasets(config)
train_dataset, val_dataset = pointnet_dataset(config)
train_loader, train_sampler, val_loader, init_loader = create_data_loaders(
config, train_dataset, val_dataset
)
def train_steps_fn(loader, model, optimizer, compression_ctrl, train_steps):
train_epoch(
loader,
model,
criterion,
train_criterion_fn,
optimizer,
compression_ctrl,
0,
config,
train_iters=train_steps,
log_training_info=False,
)
def validate_model_fn(model, eval_loader):
instance_acc,class_acc = validate(
eval_loader, model, criterion, config
)
return instance_acc,class_acc
def model_eval_fn(model):
acc,_ = validate(val_loader, model, criterion, config)
return acc
execution_params = ExecutionParameters(config.cpu_only, config.current_gpu)
nncf_config = register_default_init_args(
nncf_config,
init_loader,
criterion=criterion,
criterion_fn=train_criterion_fn,
train_steps_fn=train_steps_fn,
validate_fn=lambda *x: validate_model_fn(*x)[::2],
autoq_eval_fn=lambda *x: validate_model_fn(*x)[1],
val_loader=val_loader,
model_eval_fn=model_eval_fn,
device=config.device,
execution_parameters=execution_params,
)
# create model
num_classes=config.get("num_classes", 1000)
model_params=config.get("model_params")
weights_path=config.get("weights")
load_model_fn = partial(pointnetv2_cls, num_class=num_classes,pretrained=pretrained,load_path=weights_path)
model = safe_thread_call(load_model_fn)
if not pretrained and weights_path is not None:
# Check if provided path is a url and download the checkpoint if yes
sd = torch.load(weights_path, map_location="cpu", pickle_module=restricted_pickle_module)
sd=sd["model_state_dict"]
if MODEL_STATE_ATTR in sd:
sd = sd[MODEL_STATE_ATTR]
load_state(model, sd, is_resume=False)
model.to(config.device)
if "train" in config.mode and is_accuracy_aware_training(config):
uncompressed_model_accuracy = model_eval_fn(model)
resuming_checkpoint = None
if resuming_checkpoint_path is not None:
resuming_checkpoint = load_resuming_checkpoint(resuming_checkpoint_path)
model_state_dict, compression_state = extract_model_and_compression_states(
resuming_checkpoint
)
compression_ctrl, model = create_compressed_model(
model, nncf_config, compression_state
)
def train(
config,
compression_ctrl,
model,
criterion,
criterion_fn,
lr_scheduler,
model_name,
optimizer,
train_loader,
train_sampler,
val_loader,
best_acc1=0,
):
best_compression_stage = CompressionStage.UNCOMPRESSED
for epoch in range(config.start_epoch, config.epochs):
# update compression scheduler state at the begin of the epoch
compression_ctrl.scheduler.epoch_step()
if config.distributed:
train_sampler.set_epoch(epoch)
# train for one epoch
train_epoch(
train_loader,
model,
criterion,
criterion_fn,
optimizer,
compression_ctrl,
epoch,
config,
)
# Learning rate scheduling should be applied after optimizer’s update
lr_scheduler.step(
epoch if not isinstance(lr_scheduler, ReduceLROnPlateau) else best_acc1
)
# compute compression algo statistics
statistics = compression_ctrl.statistics()
acc1 = best_acc1
best_instance_acc = 0.0
best_class_acc = 0.0
if epoch % config.test_every_n_epochs == 0:
# evaluate on validation set
instance_acc, class_acc = validate(val_loader, model, criterion, config, epoch=epoch)
if (instance_acc >= best_instance_acc):
best_instance_acc = instance_acc
best_epoch = epoch + 1
if (class_acc >= best_class_acc):
best_class_acc = class_acc
logger.info('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
logger.info('Best Instance Accuracy: %f, Class Accuracy: %f' % (best_instance_acc, best_class_acc))
compression_stage = compression_ctrl.compression_stage()
# remember best acc@1, considering compression stage. If current acc@1 less then the best acc@1, checkpoint
# still can be best if current compression stage is larger than the best one. Compression stages in ascending
# order: UNCOMPRESSED, PARTIALLY_COMPRESSED, FULLY_COMPRESSED.
is_best_by_accuracy = (
acc1 > best_acc1 and compression_stage == best_compression_stage
)
is_best = is_best_by_accuracy or compression_stage > best_compression_stage
if is_best:
best_acc1 = acc1
best_compression_stage = max(compression_stage, best_compression_stage)
if is_main_process():
logger.info(statistics.to_str())
if config.metrics_dump is not None:
acc = best_acc1 / 100
write_metrics(acc, config.metrics_dump)
checkpoint_path = osp.join(
config.checkpoint_save_dir, get_run_name(config) + "_last.pth"
)
checkpoint = {
"epoch": epoch + 1,
"arch": model_name,
MODEL_STATE_ATTR: model.state_dict(),
COMPRESSION_STATE_ATTR: compression_ctrl.get_compression_state(),
"best_acc1": best_acc1,
"acc1": acc1,
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, checkpoint_path)
make_additional_checkpoints(checkpoint_path, is_best, epoch + 1, config)
for key, value in prepare_for_tensorboard(statistics).items():
config.tb.add_scalar(
"compression/statistics/{0}".format(key),
value,
len(train_loader) * epoch,
)
def get_dataset(dataset_config, config, transform, is_train):
if dataset_config == "imagenet":
prefix = "train" if is_train else "val"
return datasets.ImageFolder(osp.join(config.dataset_dir, prefix), transform)
# For testing purposes
num_images = config.get("num_mock_images", 1000)
if dataset_config == "mock_32x32":
return MockDataset(
img_size=(32, 32), transform=transform, num_images=num_images
)
if dataset_config == "mock_299x299":
return MockDataset(
img_size=(299, 299), transform=transform, num_images=num_images
)
return create_cifar(config, dataset_config, is_train, transform)
def create_cifar(config, dataset_config, is_train, transform):
create_cifar_fn = None
if dataset_config in ["cifar100", "cifar100_224x224"]:
create_cifar_fn = partial(
CIFAR100, config.dataset_dir, train=is_train, transform=transform
)
if dataset_config == "cifar10":
create_cifar_fn = partial(
CIFAR10, config.dataset_dir, train=is_train, transform=transform
)
if create_cifar_fn:
return safe_thread_call(
partial(create_cifar_fn, download=True),
partial(create_cifar_fn, download=False),
)
return None
def create_datasets(config):
dataset_config = config.dataset if config.dataset is not None else "imagenet"
dataset_config = dataset_config.lower()
assert dataset_config in [
"imagenet",
"cifar100",
"cifar10",
"cifar100_224x224",
"mock_32x32",
"mock_299x299",
], "Unknown dataset option"
if dataset_config == "imagenet":
normalize = transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
)
elif dataset_config in ["cifar100", "cifar100_224x224"]:
normalize = transforms.Normalize(
mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2761)
)
elif dataset_config == "cifar10":
normalize = transforms.Normalize(
mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)
)
elif dataset_config in ["mock_32x32", "mock_299x299"]:
normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
input_info = FillerInputInfo.from_nncf_config(config)
image_size = input_info.elements[0].shape[-1]
size = int(image_size / 0.875)
if dataset_config in ["cifar10", "cifar100_224x224", "cifar100"]:
list_val_transforms = [transforms.ToTensor(), normalize]
if dataset_config == "cifar100_224x224":
list_val_transforms.insert(0, transforms.Resize(image_size))
val_transform = transforms.Compose(list_val_transforms)
list_train_transforms = [
transforms.RandomCrop(image_size, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
if dataset_config == "cifar100_224x224":
list_train_transforms.insert(0, transforms.Resize(image_size))
train_transforms = transforms.Compose(list_train_transforms)
elif dataset_config in ["mock_32x32", "mock_299x299"]:
val_transform = transforms.Compose(
[
transforms.Resize(size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
normalize,
]
)
train_transforms = transforms.Compose(
[
transforms.Resize(size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
normalize,
]
)
else:
val_transform = transforms.Compose(
[
transforms.Resize(size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
normalize,
]
)
train_transforms = transforms.Compose(
[
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
val_dataset = get_dataset(dataset_config, config, val_transform, is_train=False)
train_dataset = get_dataset(dataset_config, config, train_transforms, is_train=True)
return train_dataset, val_dataset
def pointnet_dataset(config):
data_path=config.get("dataset_dir")
num_point = config.get("num_point")
use_uniform_sample = config.get("use_uniform_sample")
use_normals = config.get("use_normals")
num_category = config.get("num_classes")
process_data=config.get("process_data")
train_dataset = ModelNetDataLoader(root=data_path, args=None, split='train',
process_data=process_data,num_point=num_point,
use_uniform_sample=use_uniform_sample,
use_normals=use_normals,num_category=num_category)
val_dataset = ModelNetDataLoader(root=data_path, args=None, split='test',
process_data=process_data,num_point=num_point,
use_uniform_sample=use_uniform_sample,
use_normals=use_normals,num_category=num_category)
return train_dataset, val_dataset
def create_data_loaders(config, train_dataset, val_dataset):
pin_memory = config.execution_mode != ExecutionMode.CPU_ONLY
# When using a single GPU per process and per
# DistributedDataParallel, we need to divide the batch size
# ourselves based on the total number of GPUs we have
batch_size = int(config.batch_size)
workers = int(config.workers)
batch_size_val = (
int(config.batch_size_val)
if config.batch_size_val is not None
else int(config.batch_size)
)
if config.execution_mode == ExecutionMode.MULTIPROCESSING_DISTRIBUTED:
batch_size //= config.ngpus_per_node
batch_size_val //= config.ngpus_per_node
workers //= config.ngpus_per_node
val_sampler = torch.utils.data.SequentialSampler(val_dataset)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size_val,
shuffle=False,
num_workers=workers,
pin_memory=pin_memory,
sampler=val_sampler,
drop_last=False,
)
train_sampler = None
if config.distributed:
sampler_seed = 0 if config.seed is None else config.seed
dist_sampler_shuffle = config.seed is None
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset, seed=sampler_seed, shuffle=dist_sampler_shuffle
)
train_shuffle = train_sampler is None and config.seed is None
def create_train_data_loader(batch_size_):
return torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size_,
shuffle=train_shuffle,
num_workers=workers,
pin_memory=pin_memory,
sampler=train_sampler,
drop_last=True,
)
train_loader = create_train_data_loader(batch_size)
if config.batch_size_init:
init_loader = create_train_data_loader(config.batch_size_init)
else:
init_loader = deepcopy(train_loader)
return train_loader, train_sampler, val_loader, init_loader
def train_epoch(
train_loader,
model,
criterion,
criterion_fn,
optimizer,
compression_ctrl,
epoch,
config,
train_iters=None,
log_training_info=True,
):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
compression_losses = AverageMeter()
criterion_losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
if train_iters is None:
train_iters = len(train_loader)
compression_scheduler = compression_ctrl.scheduler
casting = autocast if config.mixed_precision else NullContextManager
# switch to train mode
model.train()
end = time.time()
for i, (input_, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
compression_scheduler.step()
input_ = input_.to(config.device)
target = target.to(config.device)
# compute output
with casting():
output = model(input_)
criterion_loss = criterion_fn(output, target, criterion)
# compute compression loss
compression_loss = compression_ctrl.loss()
loss = criterion_loss + compression_loss
if isinstance(output, InceptionOutputs):
output = output.logits
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), input_.size(0))
comp_loss_val = (
compression_loss.item()
if isinstance(compression_loss, torch.Tensor)
else compression_loss
)
compression_losses.update(comp_loss_val, input_.size(0))
criterion_losses.update(criterion_loss.item(), input_.size(0))
top1.update(acc1, input_.size(0))
top5.update(acc5, input_.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % config.print_freq == 0 and log_training_info:
logger.info(
"{rank}: "
"Epoch: [{0}][{1}/{2}] "
"Lr: {3:.3} "
"Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) "
"Data: {data_time.val:.3f} ({data_time.avg:.3f}) "
"CE_loss: {ce_loss.val:.4f} ({ce_loss.avg:.4f}) "
"CR_loss: {cr_loss.val:.4f} ({cr_loss.avg:.4f}) "
"Loss: {loss.val:.4f} ({loss.avg:.4f}) "
"Acc@1: {top1.val:.3f} ({top1.avg:.3f}) "
"Acc@5: {top5.val:.3f} ({top5.avg:.3f})".format(
epoch,
i,
len(train_loader),
get_lr(optimizer),
batch_time=batch_time,
data_time=data_time,
ce_loss=criterion_losses,
cr_loss=compression_losses,
loss=losses,
top1=top1,
top5=top5,
rank=(
"{}:".format(config.rank)
if config.multiprocessing_distributed
else ""
),
)
)
if is_main_process() and log_training_info:
global_step = train_iters * epoch
config.tb.add_scalar(
"train/learning_rate", get_lr(optimizer), i + global_step
)
config.tb.add_scalar(
"train/criterion_loss", criterion_losses.val, i + global_step
)
config.tb.add_scalar(
"train/compression_loss", compression_losses.val, i + global_step
)
config.tb.add_scalar("train/loss", losses.val, i + global_step)
config.tb.add_scalar("train/top1", top1.val, i + global_step)
config.tb.add_scalar("train/top5", top5.val, i + global_step)
statistics = compression_ctrl.statistics(quickly_collected_only=True)
for stat_name, stat_value in prepare_for_tensorboard(statistics).items():
config.tb.add_scalar(
"train/statistics/{}".format(stat_name), stat_value, i + global_step
)
if i >= train_iters:
break
def validate(val_loader, model, criterion, config, epoch=0, log_validation_info=True):
mean_correct = []
num_classes=config.get("num_classes")
class_acc = np.zeros((num_classes, 3))
# switch to evaluate mode
model.eval()
casting = autocast if config.mixed_precision else NullContextManager
with torch.no_grad():
for j, (points, target) in tqdm(enumerate(val_loader),total=len(val_loader)):
points, target = points.to(config.device), target.to(config.device)
with casting():
pred, _ = model(points)
# print(pred)
pred_choice = pred.data.max(1)[1]
for cat in np.unique(target.cpu()):
classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
class_acc[cat, 1] += 1
correct = pred_choice.eq(target.long().data).cpu().sum()
mean_correct.append(correct.item() / float(points.size()[0]))
class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
class_acc = np.mean(class_acc[:, 2])
instance_acc = np.mean(mean_correct)
return instance_acc,class_acc
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.val = None
self.avg = None
self.sum = None
self.count = None
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).sum(0, keepdim=True)
res.append(correct_k.float().mul_(100.0 / batch_size).item())
return res
def get_lr(optimizer):
return optimizer.param_groups[0]["lr"]
if __name__ == "__main__":
main(sys.argv[1:])
ModelNetDataLoader.py
import os
import numpy as np
import warnings
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def farthest_point_sample(point, npoint):
"""
Input:
xyz: pointcloud data, [N, D]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [npoint, D]
"""
N, D = point.shape
xyz = point[:,:3]
centroids = np.zeros((npoint,))
distance = np.ones((N,)) * 1e10
farthest = np.random.randint(0, N)
for i in range(npoint):
centroids[i] = farthest
centroid = xyz[farthest, :]
dist = np.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = np.argmax(distance, -1)
point = point[centroids.astype(np.int32)]
return point
class ModelNetDataLoader(Dataset):
def __init__(self, root, args, split='train', process_data=False,num_point=1024,use_uniform_sample=None,use_normals=False,num_category=40):
self.root = root
if args is not None:
self.npoints = args.num_point
self.uniform = args.use_uniform_sample
self.use_normals = args.use_normals
self.num_category = args.num_category
else:
self.npoints = num_point
self.uniform = use_uniform_sample
self.use_normals = use_normals
self.num_category = num_category
self.process_data = process_data
if self.num_category == 10:
self.catfile = os.path.join(self.root, 'modelnet10_shape_names.txt')
else:
self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')
self.cat = [line.rstrip() for line in open(self.catfile)]
self.classes = dict(zip(self.cat, range(len(self.cat))))
shape_ids = {}
if self.num_category == 10:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet10_test.txt'))]
else:
shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
assert (split == 'train' or split == 'test')
shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
in range(len(shape_ids[split]))]
print('The size of %s data is %d' % (split, len(self.datapath)))
if self.uniform:
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, split, self.npoints))
else:
self.save_path = os.path.join(root, 'modelnet%d_%s_%dpts.dat' % (self.num_category, split, self.npoints))
if self.process_data:
if not os.path.exists(self.save_path):
print('Processing data %s (only running in the first time)...' % self.save_path)
self.list_of_points = [None] * len(self.datapath)
self.list_of_labels = [None] * len(self.datapath)
for index in tqdm(range(len(self.datapath)), total=len(self.datapath)):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
cls = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
self.list_of_points[index] = point_set
self.list_of_labels[index] = cls
with open(self.save_path, 'wb') as f:
pickle.dump([self.list_of_points, self.list_of_labels], f)
else:
print('Load processed data from %s...' % self.save_path)
with open(self.save_path, 'rb') as f:
self.list_of_points, self.list_of_labels = pickle.load(f)
def __len__(self):
return len(self.datapath)
def _get_item(self, index):
if self.process_data:
point_set, label = self.list_of_points[index], self.list_of_labels[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
label = np.array([cls]).astype(np.int32)
point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
if self.uniform:
point_set = farthest_point_sample(point_set, self.npoints)
else:
point_set = point_set[0:self.npoints, :]
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
if not self.use_normals:
point_set = point_set[:, 0:3]
return point_set, label[0]
def __getitem__(self, index):
return self._get_item(index)
if __name__ == '__main__':
import torch
data = ModelNetDataLoader('/data/modelnet40_normal_resampled/', split='train')
DataLoader = torch.utils.data.DataLoader(data, batch_size=12, shuffle=True)
for point, label in DataLoader:
print(point.shape)
print(label.shape)
pointnet2_utils.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
#1
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
# Modified: take the farthest point to (0, 0, 0) as the first sample rather than random sampling
# farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
farthest = torch.argmax(torch.square(xyz).sum(dim=2), dim=1)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
cent = xyz[batch_indices, farthest, :]
if cent.shape[-1]!=3:
print("cent.shape",cent.shape)
print("xyz.shape",xyz.shape)
cent = cent.view(B, 1, 3)
dist = torch.sum((xyz - cent) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
#2
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
#3
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3]
idx = query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
# simply concatnate the xyz coordinate and point feature
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
#4
class PointNetSetAbstraction(nn.Module):
def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
super(PointNetSetAbstraction, self).__init__()
self.npoint = npoint
self.radius = radius
self.nsample = nsample
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.group_all = group_all
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
if self.group_all:
new_xyz, new_points = sample_and_group_all(xyz, points)
else:
new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
# new_xyz: sampled points position data, [B, npoint, C]
# new_points: sampled points data, [B, npoint, nsample, C+D]
# print(new_points.shape)
new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
new_points = torch.max(new_points, 2)[0] # max at dimension of nsample
new_xyz = new_xyz.permute(0, 2, 1)
return new_xyz, new_points
class PointNetSetAbstractionMsg(nn.Module):
def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
super(PointNetSetAbstractionMsg, self).__init__()
self.npoint = npoint
self.radius_list = radius_list
self.nsample_list = nsample_list
self.conv_blocks = nn.ModuleList()
self.bn_blocks = nn.ModuleList()
for i in range(len(mlp_list)):
convs = nn.ModuleList()
bns = nn.ModuleList()
last_channel = in_channel + 3
for out_channel in mlp_list[i]:
convs.append(nn.Conv2d(last_channel, out_channel, 1))
bns.append(nn.BatchNorm2d(out_channel))
last_channel = out_channel
self.conv_blocks.append(convs)
self.bn_blocks.append(bns)
def forward(self, xyz, points):
"""
Input:
xyz: input points position data, [B, C, N]
points: input points data, [B, D, N]
Return:
new_xyz: sampled points position data, [B, C, S]
new_points_concat: sample points feature data, [B, D', S]
"""
xyz = xyz.permute(0, 2, 1)
if points is not None:
points = points.permute(0, 2, 1)
B, N, C = xyz.shape
S = self.npoint
new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
new_points_list = []
for i, radius in enumerate(self.radius_list):
K = self.nsample_list[i]
group_idx = query_ball_point(radius, K, xyz, new_xyz)
grouped_xyz = index_points(xyz, group_idx)
grouped_xyz -= new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, group_idx)
grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
else:
grouped_points = grouped_xyz
grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S]
for j in range(len(self.conv_blocks[i])):
conv = self.conv_blocks[i][j]
bn = self.bn_blocks[i][j]
grouped_points = F.relu(bn(conv(grouped_points)))
new_points = torch.max(grouped_points, 2)[0] # [B, D', S]
new_points_list.append(new_points)
new_xyz = new_xyz.permute(0, 2, 1)
new_points_concat = torch.cat(new_points_list, dim=1)
return new_xyz, new_points_concat
class PointNetFeaturePropagation(nn.Module):
def __init__(self, in_channel, mlp):
super(PointNetFeaturePropagation, self).__init__()
self.mlp_convs = nn.ModuleList()
self.mlp_bns = nn.ModuleList()
last_channel = in_channel
for out_channel in mlp:
self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
self.mlp_bns.append(nn.BatchNorm1d(out_channel))
last_channel = out_channel
def forward(self, xyz1, xyz2, points1, points2):
"""
Input:
xyz1: input points position data, [B, C, N]
xyz2: sampled input points position data, [B, C, S]
points1: input points data, [B, D, N]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
xyz1 = xyz1.permute(0, 2, 1)
xyz2 = xyz2.permute(0, 2, 1)
points2 = points2.permute(0, 2, 1)
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
if points1 is not None:
points1 = points1.permute(0, 2, 1)
new_points = torch.cat([points1, interpolated_points], dim=-1)
else:
new_points = interpolated_points
new_points = new_points.permute(0, 2, 1)
for i, conv in enumerate(self.mlp_convs):
bn = self.mlp_bns[i]
new_points = F.relu(bn(conv(new_points)))
return new_points
pointnet2_cls.py
import torch.nn as nn
import torch
import torch.nn.functional as F
from .pointnet2_utils import PointNetSetAbstraction
import logging
class get_model(nn.Module):
def __init__(self,num_class,normal_channel=True,pretrained=False,load_path=None):
super(get_model, self).__init__()
in_channel = 6 if normal_channel else 3
self.normal_channel = normal_channel
self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
self.sa2 = PointNetSetAbstraction(npoint=256, radius=0.4, nsample=64, in_channel=128+3, mlp=[128, 128, 256], group_all=False)
self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256+3, mlp=[256, 512, 1024], group_all=True)
self.fc1 = nn.Linear(1024, 512)
self.bn1 = nn.BatchNorm1d(512)
self.drop1 = nn.Dropout(0.4)
self.fc2 = nn.Linear(512, 256)
self.bn2 = nn.BatchNorm1d(256)
self.drop2 = nn.Dropout(0.4)
self.fc3 = nn.Linear(256, num_class)
if pretrained:
if load_path is not None:
self.load_state_dict(torch.load(load_path)['model_state_dict'])
logging.info("=> done loading ")
else:
logging.info("=> no provided checkpoint path ")
def forward(self, xyz):
B, _, _ = xyz.shape
xyz = xyz.transpose(2, 1)
print(xyz.shape)
if self.normal_channel:
norm = xyz[:, 3:, :]
xyz = xyz[:, :3, :]
else:
norm = None
l1_xyz, l1_points = self.sa1(xyz, norm)
# print(l1_xyz.shape, l1_points.shape)
l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
# print(l2_xyz.shape, l2_points.shape)
l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
# print(l3_xyz.shape, l3_points.shape)
x = l3_points.view(B, 1024)
# print(x)
x = self.drop1(F.relu(self.bn1(self.fc1(x))))
x = self.drop2(F.relu(self.bn2(self.fc2(x))))
x = self.fc3(x)
# x = F.log_softmax(x, -1)
return x, l3_points
def pointnetv2_cls(num_class,normal_channel=False,pretrained=False,load_path=None):
return get_model(num_class,normal_channel=normal_channel,pretrained=pretrained,load_path=load_path)
class get_cls_loss(nn.Module):
def __init__(self):
super(get_cls_loss, self).__init__()
def forward(self, pred, target, trans_feat):
total_loss = F.nll_loss(pred, target)
return total_loss
pointnet_v2_classification_int8.json
{
"model": "pointnetv2",
"pretrained": false,
"dataset_dir":"/data1/modelnet/modelnet40_normal_resampled",
"input_info": {
"sample_size": [1, 1024, 3]
},
"num_classes": 40,
"batch_size" : 64,
"epochs": 80,
"optimizer": {
"type": "Adam",
"base_lr": 0.00001,
"schedule_type": "multistep",
"steps": [
5
]
},
"compression": {
"algorithm": "quantization",
"initializer": {
"range": {
"num_init_samples": 2560
}
}
}
}
@zbnlala unfortunately this issue is not reproduced.
Could you check that issue is not reproduced without nncf?
@zbnlala , do you have any feedback?
Closed issue as no answer for a long time period. Feel free to re-open it if needed.