tutorials icon indicating copy to clipboard operation
tutorials copied to clipboard

Fix (+suggestions) to instantiation of LocalNet in 3D registration tutorial

Open brudfors opened this issue 10 months ago • 1 comments

The LocalNet in the 3D registration tutorial is currently instantiated as:

model = LocalNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=3,
    num_channel_initial=32,
    extract_levels=[3],
    out_activation=None,
    out_kernel_initializer="zeros",
).to(device)

However, the number of extraction levels should be set as extract_levels=[3, 2, 1, 0] in order for the average over feauture maps to be takes over all possible resolutions. With extract_levels=[3] the average is only taken over the coarsest (deepest) feature map, meaning a very coarse displacement field is predicted by the network, impacting performance negatively.

I would further recommend that out_kernel_initializer is set to "kaiming_uniform", as this is the default value in the LocalNet constructor (not "zeros"), and we have seen instabilities during training when using "zeros". An additional proposal would be to use an additional layer in the network, and adjust num_channel_initial, accordingly.

All in all, the instantiation would become:

model = LocalNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=3,
    num_channel_initial=16,
    extract_levels=[4, 3, 2, 1, 0],
    out_activation=None,
    out_kernel_initializer="kaiming_uniform",
).to(device)

Or simplified (using defaults):

model = LocalNet(
    spatial_dims=3,
    in_channels=2,
    out_channels=3,
    num_channel_initial=16,
    extract_levels=[4, 3, 2, 1, 0],
).to(device)

brudfors avatar Aug 11 '23 15:08 brudfors