ColossalAI icon indicating copy to clipboard operation
ColossalAI copied to clipboard

[BUG]: activation checkpoint function can't pass kwargs

Open Gy-Lu opened this issue 3 years ago • 2 comments

🐛 Describe the bug

See the checkpoint function and how it used in CheckpointModule. Now the only keyword arg can be passed in checkpoint function is use_reentrant, and can't be passed into _forward.

the def of checkpoint function

def checkpoint(function, activation_offload, *args, use_reentrant: bool = True):

its usage in CheckpointModule

    def forward(self, *args, **kwargs):
        if self._use_checkpoint:
            return checkpoint(self._forward, self._offload, *args, **kwargs)
        else:
            return self._forward(*args, **kwargs)

Environment

No response

Gy-Lu avatar Oct 26 '22 03:10 Gy-Lu

Hmm, It seems that keyword args in _forward of CheckpointModule is not supported.

Gy-Lu avatar Oct 26 '22 03:10 Gy-Lu

cc @Cypher30

super-dainiu avatar Oct 27 '22 05:10 super-dainiu

We have updated a lot. This issue was closed due to inactivity. Thanks.

binmakeswell avatar Apr 13 '23 04:04 binmakeswell