transformers
transformers copied to clipboard
Suppress `reset_parameters` of `torch.nn.Linear,Conv2d...` inside `no_init_weights`
Feature request
torch.nn.Linear,Conv2d... will call self.reset_parameters() inside their __init__.
I'd like to make reset_parameters be a no-op inside no_init_weights context manager.
Motivation
no_init_weights is used in from_pretrained to speed up loading large models.
However, torch-built-in modules like torch.nn.Linear are heavily used in models of transformers, while its weights initialization cannot be disabled by no_init_weights.
And in the doc string of no_init_weights, it should "globally disable weight initialization".
Your contribution
possible implementation
class SupportsResetParameters(Protocol):
def reset_parameters(self): ...
@contextmanager
def no_init(module_classes: Iterable[Type[SupportsResetParameters]]):
saved = {m: vars(m).get('reset_parameters') for m in module_classes}
def no_op(_): pass
for m in saved: m.reset_parameters = no_op # Iterable can only be safely iterated through once
try:
yield
finally:
for m, init in saved.items():
del m.reset_parameters
if init is not None:
m.reset_parameters = init
TORCH_BUILT_IN_MODULES = [nn.Linear, nn.Conv2d, ...]
@contextmanager
def no_init_weights():
"""
Context manager to globally disable weight initialization to speed up loading large models.
"""
global _init_weights
saved = _init_weights
_init_weights = False
try:
with no_init(TORCH_BUILT_IN_MODULES):
yield
finally:
_init_weights = saved