tnt icon indicating copy to clipboard operation
tnt copied to clipboard

share `apply_strategy` method between autounit and autopredictunit

Open JKSenthil opened this issue 1 year ago • 2 comments

Summary:

Context:

Both AutoUnit and AutoPredictUnit use the same code block to apply the strategy on the module and check for any incompatibilties:

if strategy:
    if isinstance(strategy, str):
        strategy = _convert_str_to_strategy(strategy)
    if isinstance(strategy, DDPStrategy):
        if torch_compile_params and strategy.static_graph is True:
            # https://dev-discuss.pytorch.org/t/torchdynamo-update-9-making-ddp-work-with-torchdynamo/860
            raise RuntimeError(
                "Torch compile requires DDPStrategy's static_graph to be False"
            )
        module = prepare_ddp(module, self.device, strategy)
    elif isinstance(strategy, FSDPStrategy):
        if swa_params:
            raise RuntimeError(
                "Stochastic Weight Averaging is currently not supported with the FSDP strategy"
            )
        # as stated here https://pytorch.org/get-started/pytorch-2.0/
        rank_zero_warn(
            "We recommend setting FSDPStrategy's use_original_params to True when using torch compile."
        )
        module = prepare_fsdp(module, self.device, strategy)
else:
    module = module.to(self.device)

If changes are made to this logic, they must be made in both of those classes, which can be easily missed

This Diff

Creates helper function _apply_strategy_and_check(...) to apply the strategy on the module and calls this function in both AutoUnit and AutoPredictUnit (other name suggestions are also welcome)

Differential Revision: D48612629

JKSenthil avatar Aug 23 '23 18:08 JKSenthil