pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Add a `dtype` option for `load_from_checkpoint`

Open arijit-hub opened this issue 11 months ago • 0 comments

Description & Motivation

Hi, It would be nice to have a dtype argument for load_from_checkpoint along with the very cool map_location. It would allow the user to set the dtype automatically via load_from_checkpoint without the manual need of doing .to(dtype).

Pitch

The fix is should be pretty straightforward. Currently we have:

def load_from_checkpoint(
        cls,
        checkpoint_path: Union[_PATH, IO],
        map_location: _MAP_LOCATION_TYPE = None,
        hparams_file: Optional[_PATH] = None,
        strict: Optional[bool] = None,
        **kwargs: Any,
    ) -> Self:
...

It would just get one more argument:

def load_from_checkpoint(
        cls,
        checkpoint_path: Union[_PATH, IO],
        map_location: _MAP_LOCATION_TYPE = None,
        type:torch.dtype=None,
        hparams_file: Optional[_PATH] = None,
        strict: Optional[bool] = None,
        **kwargs: Any,
    ) -> Self:
...

Given this actually points to _load_from_checkpoint https://github.com/Lightning-AI/pytorch-lightning/blob/3d398240d2f62f2ad05e9eff557d2d5cb44f235c/src/lightning/pytorch/core/saving.py#L53

We can add our extra dtype argument here too and can easily change this line https://github.com/Lightning-AI/pytorch-lightning/blob/3d398240d2f62f2ad05e9eff557d2d5cb44f235c/src/lightning/pytorch/core/saving.py#L99

to

return model.to(dtype).to(device)

Hope this is taken into consideration :)

@Borda @awaelchli @lantiga

Alternatives

No response

Additional context

No response

arijit-hub avatar May 17 '25 19:05 arijit-hub