functorch icon indicating copy to clipboard operation
functorch copied to clipboard

[Feature Request] Support `different` randomness settings to train an ensemble of models with TorchOpt

Open Benjamin-eecs opened this issue 1 year ago • 3 comments

Motivation

We recently used TorchOpt as a functional optimizer API mentioned in functorch parallel training example to achieve batchable optimization training small neural networks on one GPU with functorch.vmap.

With TorchOpt, we can mimic the jax implementation to use vmap on the init function: JAX:

def init_fn(input_shape, seed):
    rng = jr.PRNGKey(seed)                                     # jr = jax.random
    dummy_input = jnp.ones((1, *input_shape))
    params = classifier_fns.init(rng, dummy_input)['params']   # do shape inference
    optimizer_def = optim.Adam(learning_rate=1e-3)
    optimizer = optimizer_def.create(params)
    return optimizer
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
model_states = parallel_init_fn((2,), seeds)

TorchOpt + functorch:

def init_fn(model_idx):
    _, weights = functorch.make_functional(MLPClassifier().to(DEVICE))
    opt_state = torchopt.adam(lr=0.2).init(weights)
    return weights, opt_state
parallel_init_fn = functorch.vmap(init_fn, randomness='same') # only 'same' works
batched_weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))

instead of combine_state_for_ensemble

def init_fn(num_models):
    models = [MLPClassifier().to(DEVICE) for _ in range(num_models)]
    _, params, _ = combine_state_for_ensemble(models)
    return params
batched_weights = init_fn(num_models=2)

However, any other randomness setting in functorch.vmap(init_fn) threw a bug (i.e. if randomness='different').

Traceback (most recent call last):
  File "parallel_train_torchopt.py", line 196, in <module>
    functorch_original.test_parallel_train_step_fn(num_models=2)
  File "parallel_train_torchopt.py", line 136, in test_parallel_train_step_fn
    weights, opt_state = parallel_init_fn(torch.ones(num_models, 1))
  File "/home/benjamin/miniconda3/envs/torchopt/lib/python3.8/site-packages/functorch/_src/vmap.py", line 365, in wrapped
    batched_outputs = func(*batched_inputs, **kwargs)
  File "parallel_train_torchopt.py", line 109, in init_fn
    _, weights = make_functional(MLPClassifier().to(DEVICE))
  File "parallel_train_torchopt.py", line 49, in __init__
    self.fc1 = nn.Linear(2, self.hidden_dim)
  File "/home/benjamin/miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 101, in __init__
    self.reset_parameters()
  File "/home/benjamin/miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 107, in reset_parameters
    init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  File "/home/benjamin/miniconda3/envs/torchopt/lib/python3.8/site-packages/torch/nn/init.py", line 412, in kaiming_uniform_
    return tensor.uniform_(-bound, bound)
RuntimeError: vmap: Cannot ask for different inplace randomness on an unbatched tensor. This will appear like same randomness. If this is necessary for your usage, please file an issue with functorch.

functorch.vmap(init_fn, randomness='same') gives identical inits for each net in the ensemble, which is not desirable if we want to train ensembles averaging across random seeds, therefore functorch.vmap(init_fn) supporting different randomness settings is a needed feature in this kind of usage.

cc @waterhorse1 @JieRen98 @XuehaiPan

Solution

https://github.com/metaopt/TorchOpt/pull/32 can be runned with functorch.vmap(init_fn, randomness='different').

Resource

Checklist

  • [x] I have checked that there is no similar issue in the repo. There is a one #909 to improve combine_state_for_ensemble for initialization of an ensemble of models and related issue #782 to ask for implemention for this usage but my request is more on giving an specific usage that requires this feature.

Benjamin-eecs avatar Aug 06 '22 10:08 Benjamin-eecs

Thanks for the thorough writeup @Benjamin-eecs! To check a couple things, is this a current bottleneck for your examples? We had been under the assumption that the training would be much more expensive than the training but that may not be true (or it may be fair that we're losing out on performance by not doing this)

We're also currently looking at different ways that JAX libraries build neural nets and this is a great axis I hadn't thought of before. It looks like you might be using Flax or Haiku and I was wondering if you had tried this with Equinox at all?

cc @zou3519 This seems to be the same thing that the federated learning people were asking for. I forget if we got clear answer for them

samdow avatar Aug 09 '22 17:08 samdow

Hi there @samdow , thanks for your quick and detailed feedback.

is this a current bottleneck for your examples?

I think I can call it bottleneck in some way, we can definitely initialize the ensemble of models and optimizers using for-loop. But our TorchOpt example mainly wants to show that we can support functorch.vmap for both initialization and training for ensemble of models. Also, in our specific usage where we want to repeat the same training process with different seeds or hyperparamters using functorch.vmap, we think it would be better if user can write it in a functional way to code the init_fn as a function of list of seeds or list of hyperparameters. But for now, we can only initialize a set of models with same weights.

It looks like you might be using Flax or Haiku and I was wondering if you had tried this with Equinox at all?

I am not sure I fully understood, the Jax code snippet I showed in the writeup just to present that our TorchOpt example change the functorch example into Jax-style with extra optimizer such as adam other than sgd.

Benjamin-eecs avatar Aug 11 '22 09:08 Benjamin-eecs

Also, in our specific usage where we want to repeat the same training process with different seeds or hyperparamters using functorch.vmap, we think it would be better if user can write it in a functional way to code the init_fn as a function of list of seeds or list of hyperparameters.

To be clear, if this is the end goal, it will probably always be easier to write this as a for loop. Most of the hyperparameters are scalar values and right now we can't vmap over lists or tensors of scalar values (1D tensors that we vmap over are going to be treated as scalar tensors instead of scalars).

As an example, if we had an ensemble of models like the ones in the TorchOpt PR but where the hidden dimension was being changed:

MLP(nn.Module)
    def __init__(self, hidden_dim=32, n_classes=2):
        ...
        self.fc1 = nn.Linear(2, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
  ...

we would never be able to vmap over different values for hidden_dim

the Jax code snippet I showed in the writeup just to present that our TorchOpt example

I see! Thanks for that clarification. I saw classifier_fns.init in the code snippet and assumed 😄

samdow avatar Aug 15 '22 13:08 samdow