maxtext icon indicating copy to clipboard operation
maxtext copied to clipboard

Add custom remat poicy

Open raymondzouu opened this issue 1 year ago • 0 comments

Users can use this remat policy by specifying custom in the remat_policy field and then specifying labeled tensors they would like to be saved into HBM from the forward pass in the remat_saved_tensors field. Existing labeled tensors include query_proj, key_proj, value_proj (https://github.com/google/maxtext/blob/main/MaxText/layers.py#L332-L341).

For example, in order to save just the key and value projections, the fields in base.yml would be modified as following:

remat_policy: 'custom'
remat_saved_tensors: ['key_proj', 'value_proj']

Users can label their own tensors to save and experiment with the compute and memory trade-offs.

raymondzouu avatar Jul 18 '23 16:07 raymondzouu