tnt
tnt copied to clipboard
share `apply_strategy` method between autounit and autopredictunit
Summary:
Context:
Both AutoUnit
and AutoPredictUnit
use the same code block to apply the strategy on the module and check for any incompatibilties:
if strategy:
if isinstance(strategy, str):
strategy = _convert_str_to_strategy(strategy)
if isinstance(strategy, DDPStrategy):
if torch_compile_params and strategy.static_graph is True:
# https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
raise RuntimeError(
"Torch compile requires DDPStrategy's static_graph to be False"
)
module = prepare_ddp(module, self.device, strategy)
elif isinstance(strategy, FSDPStrategy):
if swa_params:
raise RuntimeError(
"Stochastic Weight Averaging is currently not supported with the FSDP strategy"
)
# as stated here https://pytorch.org/get-started/pytorch-2.0/
rank_zero_warn(
"We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
)
module = prepare_fsdp(module, self.device, strategy)
else:
module = module.to(self.device)
If changes are made to this logic, they must be made in both of those classes, which can be easily missed
This Diff
Creates helper function _apply_strategy_and_check(...)
to apply the strategy on the module and calls this function in both AutoUnit
and AutoPredictUnit
(other name suggestions are also welcome)
Differential Revision: D48612629