diffusion_policy
diffusion_policy copied to clipboard
bug: Batchnorms not replaced by Groupnorm when using `share_rgb_model`
When using share_rgb_model to share the weights across different image inputs, the Batchnorms are not replaced by groupnorms in the MultiImageObsEncoder .
This results in very unstable training and reduced performance, as mentioned in the paper.
Can be fixed as follows:
# handle sharing vision backbone
if share_rgb_model:
assert isinstance(rgb_model, nn.Module)
key_model_map['rgb'] = rgb_model
if use_group_norm:
key_model_map['rgb'] = replace_submodules(
root_module=key_model_map['rgb'],
predicate=lambda x: isinstance(x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features//16,
num_channels=x.num_features)
)