functorch icon indicating copy to clipboard operation
functorch copied to clipboard

[WIP,POC] Faster functional modules

Open vmoens opened this issue 3 years ago • 1 comments

Proposes a new method to load weights in FunctionalModule and FunctionalModuleWithBuffers.

A map module <-> param_name <-> param_value is created and used to set attributes.

Test: The following test runs twice as fast on CPU than current implementation:

if __name__ == "__main__":
    # module with high param allocation cost but few operations
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 1),
        torch.nn.Linear(1, 1),
        torch.nn.Sequential(
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
            torch.nn.Linear(1, 1),
        )
    )

    fnet, params = make_functional(net)
    x = torch.randn(1)
    print(timeit.timeit("fnet(params, x)", globals={"fnet": fnet, "x": x, "params": params}, number=10000))
    # 1.7 sec with new, 3.8 with old
    

    # the implementation supports serialization
    import tempfile
    with tempfile.NamedTemporaryFile() as file:
        torch.save(fnet2, file.name)
        loaded_fnet = torch.load(file.name)
        assert torch.isclose(fnet2(params, x), loaded_fnet(params, x))

Other metrics: On torchrl's DDPG, the new in a full forward-backard pass, the old implementation of _swap_state takes approx. 20% of the runtime with small neural nets (2 layers MLP with 256 cells) on CPU. The new implementation takes approx. 6% of runtime.

vmoens avatar Jul 24 '22 08:07 vmoens