OpenDiT
OpenDiT copied to clipboard
fp32 and Zero2
I notice the training code stops me from using fp32 and Zero2, which can be further attributed to update_ema() in train_utils.py
if param.data.dtype != torch.float32 and isinstance(optimizer, LowLevelZeroOptimizer):
param_id = id(param)
master_param = optimizer._param_store.working_to_master_param[param_id]
param_data = master_param.data
Is there any problem if using fp32 and Zero2 while updating ema?
I met these issues when adapting flash_attn to xformers
zero does not support fp32 now. you are recommended to use ddp for fp32