nobrainer
nobrainer copied to clipboard
Add attention-unet to the models
https://arxiv.org/pdf/1804.03999.pdf or https://github.com/sabeenlohawala/tissue_labeling/blob/6a1ab8466ba6629a1d9d4e553793a26d72f2b60b/TissueLabeling/models/attention_unet.py (this is in pytorch)
- Code for paper https://github.com/ozan-oktay/Attention-Gated-Networks (in pytorch)
- the model as proposed in paper (with gating attention) is at https://github.com/ozan-oktay/Attention-Gated-Networks/blob/master/models/networks/unet_grid_attention_3D.py
- print summary as follows
from torchinfo import summary
from models.networks.unet_grid_attention_3D import unet_grid_attention_3D
input_shape = (128, 128, 128)
batch_size = 1
image_channels = 3
summary(
unet_grid_attention_3D(),
input_size=(batch_size, image_channels, *input_shape),
col_names=["input_size", "output_size", "num_params"],
# depth=5,
)
- The closes tensorflow/keras implementation of this paper is at https://github.com/robinvvinod/unet
- however, it uses inception blocks. Probably, get this running first and them replace inception blocks with unet cnv blocks (need to look into this further)
- print summary as follows
from network import network
import tensorflow as tf
from tensorflow.keras.layers import Input
input_dimensions = (128, 128, 128, 1)
input_img = Input(input_dimensions)
model = network(input_img)
model.summary()
- [ ] add tests and update
__init__.py