maxtext
maxtext copied to clipboard
Add custom remat poicy
Users can use this remat policy by specifying custom
in the remat_policy
field and then specifying labeled tensors they would like to be saved into HBM from the forward pass in the remat_saved_tensors
field. Existing labeled tensors include query_proj
, key_proj
, value_proj
(https://github.com/google/maxtext/blob/main/MaxText/layers.py#L332-L341).
For example, in order to save just the key and value projections, the fields in base.yml would be modified as following:
remat_policy: 'custom'
remat_saved_tensors: ['key_proj', 'value_proj']
Users can label their own tensors to save and experiment with the compute and memory trade-offs.