keras-unet-collection
keras-unet-collection copied to clipboard
filter_num in TransUNet
Hi,
In TransUNet, filter_num
shows the number of filters for down and upsampling levels, right? However, If I use 3 filters filter_num=[64, 128, 128]
instead of the default filter_num=[64, 128, 128, 512]
, the number of parameters of the network increases, and I get an OOM error. Is this a bug or am I missing something?
If you do this on GPUs, then a possible reason is that your configuration is too big. [64, 128, 128]
--> [64, 128, 128, 256]
adds a lot of weights.
The problem is that I go from [64, 128, 128, 256]
--> [64, 128, 128]
and the network gets larger! This doesn't make sense since I reduce the number of layers but the network gets bigger.
Have you been able to reproduce this issue?
@parniash Would you mind sharing your code? I don't think the network would get bigger.
If you compile these two models (filter_num
is the difference):
model = models.transunet_2d((input_height, input_width, 3), filter_num=[32, 64, 128, 256], n_labels=n_classes, embed_dim=600, num_mlp=1000, num_heads=4, num_transformer=4, activation='ReLU', mlp_activation='GELU', output_activation='Sigmoid', batch_norm=True, pool=True, unpool='bilinear', backbone='ResNet50V2', weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='transunet')
vs
model = models.transunet_2d((input_height, input_width, 3), filter_num=[64, 128, 256], n_labels=n_classes, embed_dim=600, num_mlp=1000, num_heads=4, num_transformer=4, activation='ReLU', mlp_activation='GELU', output_activation='Sigmoid', batch_norm=True, pool=True, unpool='bilinear', backbone='ResNet50V2', weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='transunet')
The last one has more parameters. I don't understand why.
If I plug in: (128, 128, 3)
, n_labels=2
Model 1: Total params: 30,508,410
Model 2: Total params: 29,812,794
So your second configuration is smaller, there is no problem.
I feel that you are commenting on the number of trainable parameters---the second one has more trainable params, because its output head is connected to 64 channels.
Try dig into your configurations with model.summary()
. The total size of a model can be reduced in many ways.