pytorch_bayesian_unet
pytorch_bayesian_unet copied to clipboard
Migrate to PyTorch. Re-implementation of Bayesian Convolutional Neural Networks (BCNNs)
BCNNs
This is PyTorch re-implementation for Bayesian Convolutional Neural Networks.
(Chainer implementation is available: bayesian_unet)
In this project, we assume the following two scenarios, especially for medical imaging.
- Two-dimensional segmentation / regression with the 2D U-Net. (e.g., 2D x-ray, laparoscopic images, and CT slices)
- Three-dimensional segmentation / regression with the 3D U-Net. (e.g., 3D CT volumes)

This is a part of following works.
@article{hiasa2019automated,
title={Automated Muscle Segmentation from Clinical CT using Bayesian U-Net for Personalized Musculoskeletal Modeling},
author={Hiasa, Yuta and Otake, Yoshito and Takao, Masaki and Ogawa, Takeshi and Sugano, Nobuhiko and Sato, Yoshinobu},
journal={IEEE Transactions on Medical Imaging},
year={2019 (in press)},
doi={10.1109/TMI.2019.2940555},
}
@article{sakamoto2019bayesian,
title={Bayesian Segmentation of Hip and Thigh Muscles in Metal Artifact-Contaminated CT using Convolutional Neural Network-Enhanced Normalized Metal Artifact Reduction},
author={Sakamoto, Mitsuki and Hiasa, Yuta and Otake, Yoshito and Takao, Masaki and Suzuki, Yuki and Sugano, Nobuhiko and Sato, Yoshinobu},
journal={Journal of Signal Processing Systems},
year={2019 (in press)}
doi={10.1007/s11265-019-01507-z},
}
@inproceedings{hiasa2018surgical,
title={Surgical tools segmentation in laparoscopic images using convolutional neural networks with uncertainty estimation and semi-supervised learning},
author={Hiasa, Y and Otake, Y and Nakatani, S and Harada, H and Kanaji, S and Kakeji, Y and Sato, Yoshinobu},
booktitle={Proc. International Conference of Computer Assisted Radiology and Surgery},
pages={14--15},
year={2018}
}
Requirements
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
- PyTorch 1.4
Getting started
Installation
- Install PyTorch and dependencies from https://pytorch.org/
- Install Pytorch Trainer
git clone https://github.com/yuta-hi/pytorch-trainer
cd pytorch-trainer
python setup.py install
- For other requirements, see requirements.txt.
- Install from this repository
git clone https://github.com/yuta-hi/pytorch_bayesian_unet
cd pytorch_bayesian_unet
python setup.py install
Examples
The data set we used are medical images and it is difficult to share due to ethical issues. Thus, we prepared the following examples using synthetic or public data set.
Curve regression
Approximation of the function $y = x\sin x$, which reproduces the previous work [Y. Gal, et al.]. Training and validation data sets were sampled within the range of [-5, 5]. On the other hand, test data set was sampled within the range of [-10, 10]. The predicted variance on test data set showed high values for the unseen samples, but not for the distribution of training data set.
python examples/curve_regression/train_and_test_epistemic.py
python examples/curve_regression/train_and_test_epistemic.py --test_on_test

python examples/curve_regression/train_and_test_epistemic_aleatoric.py
python examples/curve_regression/train_and_test_epistemic_aleatoric.py --test_on_test

MNIST classification
Ten digits classification. A subset of samples was used for the training data set. In the default setting, 1,000 samples are used for training and 1,000 samples are used for validation. The distribution of predicted variance for correct and wrong predictions on the test data set (10,000 samples) were visualized.
python examples/mnist_classification/train_and_test_epistemic.py
python examples/mnist_classification/train_and_test_epistemic.py --test_on_test

EndoVis segmentation
Segmentation of surgical instruments from laparoscopic images. Data set is downloaded from https://endovissub-instrument.grand-challenge.org/ . Training and test data sets consist 160 and 140 images, respectively.
python examples/miccai_endovis_segmentation/preprocess.py # download the dataset and convert label format
python examples/miccai_endovis_segmentation/train_and_test_epistemic.py
python examples/miccai_endovis_segmentation/train_and_test_epistemic.py --test_on_test


Image synthesis with adversarial training
Aerial-to-Map translation. This example focuses on how the adversarial training affects uncertainty behavior. This is mainly followed the previous work [P. Isola, et al.]. In this example, the generator is replaced to Bayesian U-Net for uncertainty estimates. And, spectral normalization [T. Miyato et al.] is applied to the patch discriminator for stabilizing the optimization.
cd examples/map_synthesis
python preprocess.py # download and normalize the dataset
python train_and_test_pix2pix.py --out logs/pix2pix
Note that this is under construction.
Heatmap regression
On going.
Usage
Please follow the description to define these objects.
- datasets
- data augmentor
- data normalizer
- model
- visualizer
- validator
- inferencer
- (optional) singularity image
Setup datasets
You can define your own dataset like below. PNG, JPG, BMP and meta image format (MHD, MHA) are supported.
- [case #1] 2D images
from pytorch_bcnn.datasets import ImageDataset
data_root = './data'
patients = ['ID0', 'ID1', 'ID2'] # NOTE: 3 patients
class_list = ['background', 'liver', 'tumor']
augmentor = None # NOTE: please set if you have..
normalizer = None # NOTE: please set if you have..
dtypes = OrderedDict({
'image': np.float32,
'label': np.int64, # NOTE: if categorical label
# 'mask': np.uint8, # NOTE: please set if you have..
})
filenames = OrderedDict({
'image': '{root}/{patient}/*_image.mhd',
'label': '{root}/{patient}/*_label.mhd',
# 'mask' : '{root}/{patient}/*_mask.mhd', # NOTE: please set if you have..
})
dataset = ImageDataset(data_root, patients, classes=class_list,
dtypes=dtypes, filenames=filenames, augmentor=augmentor, normalizer=normalizer)
- [case #2] 3D volumes
from pytorch_bcnn.datasets import VolumeDataset
...
dataset = VolumeDataset(data_root, patients, classes=class_list,
dtypes=dtypes, filenames=filenames, augmentor=augmentor, normalizer=normalizer)
- [case #3] Custom dataset
from pytorch_bcnn.datasets import BaseDataset
class CustomDataset(BaseDataset):
...
raise NotImplementedError()
Setup data augmentor
You can use the data augmentor based on geometric transformation, which has stochastic behavior.
from pytorch_bcnn.data.augmentor import DataAugmentor
from pytorch_bcnn.data.augmentor import Crop2D, Flip2D, Affine2D
from pytorch_bcnn.data.augmentor import Crop3D, Flip3D, Affine3D
augmentor = DataAugmentor()
augmentor.add(Crop2D(size=(300,400)))
augmentor.add(Flip2D(axis=1))
augmentor.add(Affine2D(rotation=15.,
translate=(10.,10.),
shear=0.25,
zoom=(0.8, 1.2),
keep_aspect_ratio=True,
fill_mode=('nearest', 'constant'),
cval=(0.,0.),
interp_order=(3,0)))
augmentor.summary('augment.json')
Setup data normalizer
You can use the data normalizer based on intensity transformation.
from pytorch_bcnn.data.normalizer import Normalizer
from pytorch_bcnn.data.normalizer import Clip2D, Subtract2D, Divide2D, Quantize2D
from pytorch_bcnn.data.normalizer import Clip3D, Subtract3D, Divide3D, Quantize3D
normalizer = Normalizer()
normalizer.add(Clip2D((-150, 350)))
normalizer.add(Quantize2D(8))
normalizer.add(Subtract2D(0.5))
normalizer.add(Divide2D(1./255.))
normalizer.summary('norm.json')
Setup model
To see the computational graph of UNet, please ```click here```.

- [case #1] Segmentation
from pytorch_bcnn.models import BayesianUNet
from pytorch_bcnn.links import Classifier
predictor = BayesianUNet(ndim=2,
in_channels=1,
out_channels=3,
nlayer=5,
nfilter=32)
lossfun = partial(softmax_cross_entropy,
normalize=False, class_weight=class_weight)
model = Classifier(predictor,
lossfun=lossfun)
- [case #2] Regression
from pytorch_bcnn.links import Regressor
from pytorch_bcnn.functions.loss import sigmoid_soft_cross_entropy
from torch.nn.functional as F
...
lossfun = F.mse_loss
# lossfun = sigmoid_soft_cross_entropy # NOTE: if you want..
model = Regressor(predictor,
lossfun=lossfun)
- [case #3] Other problems (e.g., multi-task)
from pytorch_bcnn.models import UNetBase
class MultiTaskUNet(UNetBase):
def __init__(self,
ndim,
in_channels,
foo, # TODO
bar, # TODO
nfilter=32,
nlayer=5,
conv_param=_default_conv_param,
pool_param=_default_pool_param,
upconv_param=_default_upconv_param,
norm_param=_default_norm_param,
activation_param=_default_activation_param,
dropout_param=_default_dropout_param,
residual=False,
):
super(UNet, self).__init__(
ndim,
in_channels,
nfilter,
nlayer,
conv_param,
pool_param,
upconv_param,
norm_param,
activation_param,
dropout_param,
residual,)
self._foo = foo
self._bar = bar
pass # TODO: foo, bar
def forward(self, x):
h = super().forward(x)
# TODO: foo, bar
raise NotImplementedError('foo is bar..')
Setup visualizer
- [case #1] 2D segmentation
from pytorch_bcnn.visualizer import ImageVisualizer
transforms = {
'x': lambda x: x,
'y': lambda x: np.argmax(x, axis=0),
't': lambda x: x,
}
_cmap = np.array([
[0,0,0], # NOTE: background (black)
[1,0,0], # liver (red)
[0,1,0]]) # tumor (green)
cmaps = {
'x': None,
'y': _cmap,
't': _cmap,
}
clims = {
'x': (0., 255.),
'y': None,
't': None,
}
visualizer = ImageVisualizer(transforms=transforms,
cmaps=cmaps,
clims=clims)
- [case #2] 2D regression
from pytorch_bcnn.visualizer import ImageVisualizer
import matplotlib.pyplot as plt
def alpha_blend(heatmaps, cmap='jet'):
assert heatmaps.ndim == 3
ch, w, h = heatmaps.shape
ret = np.zeros((3, w, h))
mapper = plt.get_cmap(cmap, ch)
for i in range(ch):
color = np.ones((3, w, h)) \
* np.asarray(mapper(i)[:3]).reshape(-1,1,1)
ret += (color * heatmaps[i])
return ret
transforms = {
'x': None,
'y': lambda x: alpha_blend(sigmoid(x)),
't': lambda x: alpha_blend(x),
}
clims = {
'x': (0., 255.),
'y': (0., 1.),
't': (0., 1.),
}
cmaps = None
visualizer = ImageVisualizer(transforms=transforms,
cmaps=cmaps,
clims=clims)
To visualize 3D volumes, you can pass the volume renderer to the transforms
as described above.
Setup validator
from pytorch_bcnn.extensions import Validator
...
valid_file = 'iter_{.updater.iteration:08}.png'
n_vis = 20 # NOTE: number of samples for visualization
trainer.extend(Validator(valid_iter, model, valid_file,
visualizer=visualizer, n_vis=n_vis,
device=device))
Setup inferencer
- [case #1] Segmentation / Classification
from pytorch_bcnn.links import MCSampler
from pytorch_bcnn.inference import Inferencer
import torch
mc_iteration = 50
model = MCSampler(predictor, # NOTE: e.g., BayesianUNet
mc_iteration=mc_iteration,
activation=partial(torch.softmax, dim=1),
reduce_mean=partial(torch.argmax, dim=1),
reduce_var=partial(torch.mean, dim=1))
infer = Inferencer(test_iter, model, device=device)
estimated_labels, predicted_variances = infer.run()
(Optional) Setup singularity image
cd recipe
make all