paxml icon indicating copy to clipboard operation
paxml copied to clipboard

Perform gradient clipping on global batch when using gradient accumulation

Open ashors1 opened this issue 1 year ago • 3 comments

Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using ShardedStaticAccumulator. Note that this refactor allows us to maintain support for enable_skip_step_on_gradient_anomalies and requires x+1 grad norm calculations per global batch when using ShardedStaticAccumulator with x subbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.

This PR should be taken together with the corresponding Praxis PR.

ashors1 avatar Feb 14 '23 18:02 ashors1