tta_wrapper
tta_wrapper copied to clipboard
Test Time image Augmentation (TTA) wrapper for Keras model.
TTA wrapper
Test time augmnentation wrapper for keras image segmentation and classification models.
Description
How it works?
Wrapper add augmentation layers to your Keras model like this:
Input
| # input image; shape 1, H, W, C
/ / / \ \ \ # duplicate image for augmentation; shape N, H, W, C
| | | | | | # apply augmentations (flips, rotation, shifts)
your Keras model
| | | | | | # reverse transformations
\ \ \ / / / # merge predictions (mean, max, gmean)
| # output mask; shape 1, H, W, C
Output
Arguments
h_flip- bool, horizontal flip augmentationv_flip- bool, vertical flip augmentationrotataion- list, allowable angles - 90, 180, 270h_shift- list of int, horizontal shift augmentation in pixelsv_shift- list of int, vertical shift augmentation in pixelsadd- list of int/float, additive factor (aug_image = image + factor)mul- list of int/float, additive factor (aug_image = image * factor)contrast- list of int/float, contrast adjustment factor (aug_image = (image - mean) * factor + mean)merge- one of 'mean', 'gmean' and 'max' - mode of merging augmented predictions together
Constraints
- model has to have 1
inputand 1output - inference
batch_size == 1 - image
height == widthifrotationaugmentation is used
Installation
- PyPI package:
$ pip install tta-wrapper
- Latest version:
$ pip install git+https://github.com/qubvel/tta_wrapper/
Example
from keras.models import load_model
from tta_wrapper import tta_segmentation
model = load_model('path/to/model.h5')
tta_model = tta_segmentation(model, h_flip=True, rotation=(90, 270),
h_shift=(-5, 5), merge='mean')
y = tta_model.predict(x)