Add a `dtype` option for `load_from_checkpoint`
Description & Motivation
Hi,
It would be nice to have a dtype argument for load_from_checkpoint along with the very cool map_location. It would allow the user to set the dtype automatically via load_from_checkpoint without the manual need of doing .to(dtype).
Pitch
The fix is should be pretty straightforward. Currently we have:
def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Self:
...
It would just get one more argument:
def load_from_checkpoint(
cls,
checkpoint_path: Union[_PATH, IO],
map_location: _MAP_LOCATION_TYPE = None,
type:torch.dtype=None,
hparams_file: Optional[_PATH] = None,
strict: Optional[bool] = None,
**kwargs: Any,
) -> Self:
...
Given this actually points to _load_from_checkpoint https://github.com/Lightning-AI/pytorch-lightning/blob/3d398240d2f62f2ad05e9eff557d2d5cb44f235c/src/lightning/pytorch/core/saving.py#L53
We can add our extra dtype argument here too and can easily change this line https://github.com/Lightning-AI/pytorch-lightning/blob/3d398240d2f62f2ad05e9eff557d2d5cb44f235c/src/lightning/pytorch/core/saving.py#L99
to
return model.to(dtype).to(device)
Hope this is taken into consideration :)
@Borda @awaelchli @lantiga
Alternatives
No response
Additional context
No response