transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Suppress `reset_parameters` of `torch.nn.Linear,Conv2d...` inside `no_init_weights`

Open YouJiacheng opened this issue 3 years ago • 0 comments

Feature request

torch.nn.Linear,Conv2d... will call self.reset_parameters() inside their __init__. I'd like to make reset_parameters be a no-op inside no_init_weights context manager.

Motivation

no_init_weights is used in from_pretrained to speed up loading large models. However, torch-built-in modules like torch.nn.Linear are heavily used in models of transformers, while its weights initialization cannot be disabled by no_init_weights. And in the doc string of no_init_weights, it should "globally disable weight initialization".

Your contribution

possible implementation

class SupportsResetParameters(Protocol):
    def reset_parameters(self): ...

@contextmanager
def no_init(module_classes: Iterable[Type[SupportsResetParameters]]):
    saved = {m: vars(m).get('reset_parameters') for m in module_classes}
    def no_op(_): pass
    for m in saved: m.reset_parameters = no_op # Iterable can only be safely iterated through once
    try:
        yield
    finally:
        for m, init in saved.items():
            del m.reset_parameters
            if init is not None:
                m.reset_parameters = init

TORCH_BUILT_IN_MODULES = [nn.Linear, nn.Conv2d, ...]

@contextmanager
def no_init_weights():
    """
    Context manager to globally disable weight initialization to speed up loading large models.
    """
    global _init_weights
    saved = _init_weights
    _init_weights = False
    try:
        with no_init(TORCH_BUILT_IN_MODULES):
            yield
    finally:
        _init_weights = saved

YouJiacheng avatar Aug 06 '22 11:08 YouJiacheng