mmengine
mmengine copied to clipboard
gy77/add freeze hook
Motivation
Motivation:
- Freeze some parameters of the model when training the model.
Goal:
- Specify the epoch to freeze the specified network layer.
- Available for all downstream repositories.
Modification
Add FreezeHook and FreezeHook unit tests.
Use cases
- Network layers matching
freeze_layers
are freeze beforefreeze_iter/freeze_epoch
starts. - Network layers matching
unfreeze_layers
are freeze beforeunfreeze_iter/unfreeze_epoch
starts. -
freeze_layers/unfreeze_layers
matches network layers via regular expression - The index of
iter/epoch
starts at 0, with epoch=0 for the first epoch. -
unfreeze_iter
,unfreeze_epoch
andunfreeze_layers
are optional. Iffreeze_epoch/freeze_iter
is not None,unfreeze_layers
must not be None. - Only one of
freeze_iter
andfreeze_epoch
can be set, as well asunfreeze_iter
andunfreeze_epoch
.
ImageClassifier(
(backbone):ResNet(
...
(layer1):Sequential(...)
(layer2):Sequential(...)
(layer3):Sequential(...)
(layer4):Sequential(...)
)
(neck):GlobalAveragePooling2d(...)
(head):Linear(...)
)
- Freeze the parameters of backbone before the start of 1st training epoch.
custom_hooks = [
...
dict(
type="FreezeHook",
freeze_layers="backbone.*",
freeze_epoch=0)
]
- Freeze the layer1 and layer2 parameters in the backbone before the start of 10th training epoch.
custom_hooks = [
...
dict(
type="FreezeHook",
freeze_layers="backbone.layer1.*|backbone.layer2.*",
freeze_epoch=10)
]
- Freeze the parameters of backbone before the start of 1st training epoch. Unfreeze the parameters of the the backbone before the start of 10th training epoch.
custom_hooks = [
...
dict(
type="FreezeHook",
freeze_layers="backbone.*",
freeze_epoch=0,
unfreeze_layers="backbone.*",
unfreeze_epoch=9)
]
- The
verbose
parameter is used to determine whether to print therequires_grad
variable for each model layer.
custom_hooks = [
...
dict(
type="FreezeHook",
freeze_layers="backbone.*",
freeze_epoch=1,
verbose=True)
]
mmengine - INFO - backbone.conv1.weight requires_grad: True
mmengine - INFO - backbone.bn1.weight requires_grad: True
...
mmengine - INFO - head.light_head.weight requires_grad: True
mmengine - INFO - head.light_head.bias requires_grad: True