flax
flax copied to clipboard
`DynamicScale` behaves unexpected when computing per-sample gradients with `vmap`.
When running jax.vmap
, e.g. to compute per-sample gradients, the fin_steps
and scale
attributes of DynamicScale
might become arrays, leading to an error in the next step during training if not handled manually. The thrown TypeError
does not directly hint at the actual problem of a non-scalar scale
attribute.
System information
- jax==0.4.28 and flax==0.8.5
Problem you have encountered:
Due to self.scale
becoming an array in the output of the first vmap
call, the loss_wrapper
also starts to return an array instead of a scalar inside of DynamicScale
.
What you expected to happen:
The scale
and fin_steps
attributes should be either averaged or enforced to be scalars and thus not cause the TypeError
.
Logs, error messages, etc:
File ~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:132, in DynamicScale.value_and_grad.<locals>.grad_fn_wrapper(*args)
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:131) def grad_fn_wrapper(*args):
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:132) aux, grad = grad_fn(*args)
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:133) aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:135) grad = jax.tree_util.tree_map(
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:136) lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad
~/miniforge3/envs/test/lib/python3.12/site-packages/flax/training/dynamic_scale.py:137) )
TypeError: Gradient only defined for scalar-output functions. Output had shape: (32,).
Steps to reproduce:
from typing import Sequence
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import dynamic_scale
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
def cross_entropy_loss(params, model, image, label):
"""Loss function used for training."""
logits = model.apply({"params": params}, image)
loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, label))
return loss, logits
model = MLP([12, 8, 4])
input = jnp.ones((32, 10))
labels = jnp.ones((32,), dtype=int)
variables = model.init(jax.random.key(0), input)
output = model.apply(variables, input)
ds = dynamic_scale.DynamicScale()
# 1st batch
ds, is_fin, (loss, logits), per_sample_grads = jax.vmap(
ds.value_and_grad(cross_entropy_loss, has_aux=True),
in_axes=(None, None, 0, 0),
)(variables["params"], model, input, labels)
# 2nd batch
ds, is_fin, (loss, logits), per_sample_grads = jax.vmap(
ds.value_and_grad(cross_entropy_loss, has_aux=True),
in_axes=(None, None, 0, 0),
)(variables["params"], model, input, labels)
Can be fixed manually with ds = ds.replace(fin_steps=ds.fin_steps.mean(), scale=ds.scale.mean())
after each step.
Should be handled automatically / enforced within DynamicScale
IMO.