algorithmic-efficiency
algorithmic-efficiency copied to clipboard
Models will always be initialized without dropout layers in self-tuning ruleset
In submission_runner.py, if we are in the self-tuning rules, the hyperparameters argument to train_once will always be None.
Then in this code snippet
dropout_rate = None
aux_dropout_rate = None
if hasattr(hyperparameters, 'dropout_rate'):
dropout_rate = hyperparameters.dropout_rate
if hasattr(hyperparameters, 'aux_dropout_rate'):
aux_dropout_rate = hyperparameters.aux_dropout_rate
model_params, model_state = workload.init_model_fn(
model_init_rng, dropout_rate, aux_dropout_rate)
workload.init_model_fn will always get None for dropout_rate and aux_dropout_rate, so Dropout layers won't ever be added to the model.
Although submissions could call workload.init_model_fn again themselves to make use of its side effect of setting workload._model, this is awkward and also challenging for workloads near the memory limit since it involves superfluously reconstructing model_params again on device.
Our current API has 2 dropout related limitations:
Currently, in the external tuning ruleset we read the dropout value from the hparam config and pass it to the model initialization functions. In the self-tuning ruleset there exist no convenient way to specify the dropout value in the model initialization. Furthermore, there is no way to change the dropout value during training. Having a workload function to change the dropout value that submitters can call will remove both of these limitations.
Some considerations about changing the dropout implementation.
Current situation
The dropout probability value is provided as a hyperparameter in the JSON search space. It is then used in submission_runner.py as follows:
model_params, model_state = workload.init_model_fn(
model_init_rng, dropout_rate, aux_dropout_rate)
After initializing the model, we torch.compile it and initialize the optimizer.
Current limitations
- Self tuning submissions cannot specify a dropout probability value
- It's not possible to change dropout during training
How can we address these problems?
I can see several possibilities, some require major changes, some are less disruptive.
(A) extend the submission module API to provide initial dropout value ⭐
A submission should provide a function model_init_hyperparams that returns hyperparameters used in initialization, such as dropout. Something like get_batch_size for dropout. This would address (1) but not (2),
(B) re-init and re-compile the model
We could add a change_dropout method to each workload, for the submission to call. When triggered, it re-initializes the model with the new dropout probability. However, in torch we would also have to recompile the model, which is something that currently happens in submission_runner, not inside the submitter's code. It's also non-trivial to keep the old parameters and initialize a new model in torch, without incurring in an OOM error, because of this double temporary storage.
(C) pass dropout to the model fwd call
Not trivial, need to modify all model implementations.
Conclusion
My suggested option is (A), but I am happy to discuss!
Hi Niccolo, are there any open things here left to discuss regarding the plan? I think we agreed on (A) in the eng meeting?
Hey Priya! Nothing left to discuss, just lagging behind! Will submit a PR by this week! Sorry for the delay.
As discussed offline, I have implemented a fix in: #851
Update
After thorough discussion, we have now moved to option (C). I am posting a documentation of our changes here for consistency with previous discussion.
Overview
We introduce dropout_rate in the model call (fwd pass), and modify the model implementations to support a dynamic dropout_rate value.
We remove dropout_rate from the model initialization as well. We considered keeping both, but opted for this option to avoid confusion and keep the code cleaner. In this way, submitters are responsible for passing dropout_rate themselves inside the submission. This design choice is also motivated by consistency with how label_smoothing is used: submitters take care of passing it directly to model_fn, and so shall be for dropout_rate.
These two changes require modifying models.py and **workload.py implementations for all frameworks and workloads but imagenet_resnet, which does not employ dropout.
To be consistent with the original implementation, we define a default dropout_rate value in the model fwd pass. The same value is read from models.py and set as default in the workload's model_fn method.
PyTorch
- When possible, we replace
torch.nn.Dropoutwithtorch.nn.functional.dropout - When
torch.nn.Sequentialis used in combination withtorch.nn.Dropout, we replacetorch.nn.DropoutwithCustomDropoutandtorch.nn.SequentialwithSequentialWithDropout. We implement these new modules inalgoperf.pytorch_utils. Some workload-specific model classes are also modified to allow usage in combination with these modules. - We implement tests to check that the new implementations are functionally equivalent to the original ones: https://github.com/mlcommons/algorithmic-efficiency/commit/d8e39b0da371abbcf311ce1d09e06439bd5a0eec
JAX
Priya implemented a custom dropout layer, which supports changes to dropout_rate during training: https://github.com/mlcommons/algorithmic-efficiency/blob/05bff916dee7de6852afc6d95e2564ad57aa77ef/algoperf/jax_utils.py#L13
Dropped aux_dropout_rate
We get rid of aux_dropout_rate, and replace it with dropout_rate.