paxml
paxml copied to clipboard
Perform gradient clipping on global batch when using gradient accumulation
Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using ShardedStaticAccumulator
. Note that this refactor allows us to maintain support for enable_skip_step_on_gradient_anomalies
and requires x+1
grad norm calculations per global batch when using ShardedStaticAccumulator
with x
subbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.
This PR should be taken together with the corresponding Praxis PR.