sagemaker-debugger icon indicating copy to clipboard operation
sagemaker-debugger copied to clipboard

[WIP]Use get_name in forward hook

Open NihalHarish opened this issue 5 years ago • 1 comments

Description of bug:

  • Users can wrap their modules with helper classes like DataParallelCriterion or DataParallel
custom_loss_module = CustomLossModule()
parallel_custom_loss_module = DataParallelCriterion(custom_loss_module)
  • The smdebug hook register each module like this:
        for name, submodule in module.named_modules():
            assert submodule not in self.module_set, f"Don't register module={module} twice"
            submodule._module_name = name
            self.module_set.add(submodule)
        module._module_name = module._get_name()
        self.module_set.add(module)

The problem with the above line is that the _module_name attribute is attached to the parallel_custom_loss_module and not the nested custom_loss_module.

  • We would have instead needed to do:
submodule.module._module_name = name
  • When the call to the forward hook is made, the helper module will internally call:
    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)

We should instead simply use: module._get_name() in the forward_hook

Style and formatting:

I have run pre-commit install to ensure that auto-formatting happens with every commit.

Issue number, if available

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

NihalHarish avatar Oct 30 '20 01:10 NihalHarish

Codecov Report

Merging #393 into master will decrease coverage by 2.85%. The diff coverage is 0.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #393      +/-   ##
==========================================
- Coverage   85.49%   82.63%   -2.86%     
==========================================
  Files          86       86              
  Lines        6514     6520       +6     
==========================================
- Hits         5569     5388     -181     
- Misses        945     1132     +187     
Impacted Files Coverage Δ
smdebug/pytorch/hook.py 0.00% <0.00%> (-82.41%) :arrow_down:
smdebug/pytorch/__init__.py 0.00% <0.00%> (-100.00%) :arrow_down:
smdebug/pytorch/singleton_utils.py 0.00% <0.00%> (-100.00%) :arrow_down:
smdebug/pytorch/collection.py 0.00% <0.00%> (-90.00%) :arrow_down:
smdebug/rules/action/stop_training_action.py 56.45% <0.00%> (-20.97%) :arrow_down:
smdebug/pytorch/utils.py 0.00% <0.00%> (-18.52%) :arrow_down:
smdebug/rules/req_tensors.py 79.16% <0.00%> (-11.12%) :arrow_down:
smdebug/core/tfevent/util.py 92.00% <0.00%> (-8.00%) :arrow_down:
smdebug/tensorflow/callable_cache.py 78.26% <0.00%> (-4.35%) :arrow_down:
... and 4 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update be862bd...3c9b1ad. Read the comment docs.

codecov-io avatar Oct 30 '20 01:10 codecov-io