[FEATURE]: Confusing variable names in LowLevelZeroOptimizer class
Describe the feature
Describe the feature
I would like to suggest a feature that updates the variable names related to Mixed Precision and Zero Optimizer in the LowLevelZeroOptimizer class.
Is your feature request related to a problem? Please describe.
The current variable names in the LowLevelZeroOptimizer class assume that forced_dtype will be FP16. However, if forced_dtype is not given (default value), the optimizer operates as a regular Zero Optimizer without Mixed Precision. This can be confusing for users and may lead to incorrect assumptions about the code's behavior.
Describe the solution you'd like
-
Add a check for forced_dtype to determine if it is FP16 (
torch.cuda.HalfTensor,torch.cuda.BFloat16Tensor) or None. -
Update the variable names to better reflect their purpose and to group them:
2.1.
param_groups:_fp16_param_groupsstores all parameter group information (it may not even be FP16). In contrast,_fp32_flat_param_groups_of_current_rankis a sharded flat parameter group. How about renaming those variables to something likefull_optim_state_dict,flatten_sharded_optim_state_dictlike the PyTorchFSDPstate dict variable names? (ref: https://pytorch.org/docs/stable/fsdp.html) https://github.com/hpcaitech/ColossalAI/blob/2e16f842a9e5b1fb54e7e41070e9d2bb5cd64d7c/colossalai/zero/sharded_optim/low_level_optim.py#L92 https://github.com/hpcaitech/ColossalAI/blob/2e16f842a9e5b1fb54e7e41070e9d2bb5cd64d7c/colossalai/zero/sharded_optim/low_level_optim.py#L932.2.
fp16_param,fp16_avg_grads: If there is no forced_dtype or if it is not FP16, the current names are invalid. Please suggest more suitable names.2.3. Methods of ParamStore: The current bookkeeping doesn't seem to have anything to do with FP16. Therefore, it may be better to rename the relevant methods to something more generic that reflects their actual purpose. https://github.com/hpcaitech/ColossalAI/blob/2e16f842a9e5b1fb54e7e41070e9d2bb5cd64d7c/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
I'm posting this in hopes that this framework will evolve further and no one will have to go through the confusion I experienced. If you give me feedback, I will make a suitable PR. thank you 😀