optax
optax copied to clipboard
Passing arguments to train multiple models in parallel
Hi,
I want to perform a gridsearch over different arguments to train multiple models in parallel using optax and flax. My initial idea is to pass an array of learning rates to an initialization function using vmap but it results in a side effect transformation error.
What is the best way to pass a list of arguments and can this be solved? The issue seems to be related to the adamw optimizer which I believe modifies the learning rate parameter?
I have attached a reduced example of my code:
def calculate_loss_acc(state, params, batch):
data_input, labels = batch
logits = state.apply_fn(params, data_input)
loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
acc = jnp.mean(jnp.argmax(logits, -1) == labels)
return loss, acc
@jax.jit # Jit the function for efficiency
def train_step(state, batch):
# Gradient function
grad_fn = jax.value_and_grad(calculate_loss_acc, # Function to calculate the loss
argnums=1, # Parameters are second argument of the function
has_aux=True # Function has additional outputs, here accuracy
)
# Determine gradients for current model, parameters and batch
(loss, acc), grads = grad_fn(state, state.params, batch)
# Perform parameter update with gradients and optimizer
state = state.apply_gradients(grads=grads)
# Return state and any other value we might want
return state, loss, acc
def initialization(model, learning_rate, input_size, seed):
rng = jax.random.PRNGKey(seed)
rng, init_rng = jax.random.split(rng)
dummy_input = jax.random.normal(init_rng, (8, input_size)) # Batch size 8, input size 2
params = model.init(init_rng, dummy_input)
model.apply(params, dummy_input)
optimizer = optax.adamw(learning_rate=learning_rate)
model_state = train_state.TrainState.create(apply_fn=model.apply,
params=params,
tx=optimizer)
return model_state
@hydra.main(version_base=None, config_name="main", config_path="config")
def main(cfg) -> None:
seed = 0
num_epochs = 1
input_size = 194
output_size = 97
learning_rates = jnp.array([0.01, 0.1])
train_dataloader, test_dataloader = get_dataloaders(cfg)
model = FCNN_2(num_hidden=1000,
num_outputs=output_size,
activation = cfg.model.parameters.activation)
parallel_init_fn = jax.vmap(initialization, in_axes=(None, 0, None, None))
parallel_train_step_fn = jax.vmap(train_step, in_axes=(0, None))
params = parallel_init_fn(model, learning_rates, input_size, seed)
for epoch in range(num_epochs):
#Run training on epoch
for batch in train_dataloader:
params, loss, acc = parallel_train_step_fn(params, batch)
print(loss)
Hello @kclauw,
- What error do you get exactly?
- Why are you saying that the issue is with adamw? Adamw does not modify the learning rate internally. Have you tried with sgd and did that produce the same error?
Thanks for reaching out
Hi,
Thanks
- I am still learning Jax coming from Pytorch but my understanding of the error is that something is changing the value of the learning rate parameter in the initialization function:
params, loss, acc = parallel_train_step_fn(params, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(int32[], weak_type=True)>with<BatchTrace(level=1/0)> with
val = Array([0, 0], dtype=int32, weak_type=True)
batch_dim = 0, BatchTrace(level=1/0)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
- The code works when using a fixed value. However, when using the learning rate passed by vmap it gives the error. Changing to SGD did not resolve this issue. Based on this, I figured the optax optimizer might be changing the learning rate.
- I passed a list of seeds as argument to initialization which is not used by the optimizer. This works fine so the issue seems to only happen when passing the learning rate parameter in combination with any optix optimizer.
I looked at the code of adamw:
def adamw(
learning_rate: base.ScalarOrSchedule,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: float = 1e-4,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
*,
nesterov: bool = False,
) -> base.GradientTransformation:
return combine.chain(
transform.scale_by_adam(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype,
nesterov=nesterov,
),
transform.add_decayed_weights(weight_decay, mask),
transform.scale_by_learning_rate(learning_rate),
)
def scale_by_learning_rate(
learning_rate: base.ScalarOrSchedule,
*,
flip_sign: bool = True,
) -> base.GradientTransformation:
m = -1 if flip_sign else 1
if callable(learning_rate):
return scale_by_schedule(lambda count: m * learning_rate(count))
return scale(m * learning_rate)
The problem is due to adamw (and SGD etc) changing the learning rate via transform.scale_by_learning_rate(learning_rate) see (scale(m * learning_rate).
What would be the best way to deal with having to pass arguments that will change during vmap? if this is even possible? I figure this will also become a problem when passing weight decay arguments.
When dealing with parameters that change during vmap
, like learning rates or weight decay values, you can use partial function application or closures. This allows you to fix certain arguments while leaving others flexible. For instance, you can create a function that takes only the parameters that remain constant during vmap
, then partially apply it with the varying parameters within the vmap
loop. This ensures that only the necessary parameters are passed through vmap
, avoiding unexpected tracer errors.
Hello @kclauw,
Sorry for the delayed answer.
- It could help if you would make the example minimal to reproduce the same error (some dependencies are not defined in what you sent). Also you may try to trace the error as suggested just to be sure. It's not clear to me yet if this is really the learning rate that is the culprit here.
- If this is truly the learning rate, one quick workaround would be to use
optax.inject_hyperparams
. So you would instantiate the optimizer asopt= optax.inject_hyperparams(optax.adamw)(learning_rate=1.)
outside the vmap and in the vmap you would call the init of the optimizerstate = opt.init(params)
. In the resulting state, you would be able to change the learning rate chosenstate = optax.tree_util.tree_set(state, learning_rate=your_learning_rate)
. The optimizer would then run with the learning rate you chose in the vmap. Happy to try out to be sure but I'd need a minimal example for that.