functorch
functorch copied to clipboard
[WIP,POC] Faster functional modules
Proposes a new method to load weights in FunctionalModule and FunctionalModuleWithBuffers.
A map module <-> param_name <-> param_value is created and used to set attributes.
Test: The following test runs twice as fast on CPU than current implementation:
if __name__ == "__main__":
# module with high param allocation cost but few operations
net = torch.nn.Sequential(
torch.nn.Linear(1, 1),
torch.nn.Linear(1, 1),
torch.nn.Sequential(
torch.nn.Linear(1, 1),
torch.nn.Linear(1, 1),
torch.nn.Linear(1, 1),
torch.nn.Linear(1, 1),
)
)
fnet, params = make_functional(net)
x = torch.randn(1)
print(timeit.timeit("fnet(params, x)", globals={"fnet": fnet, "x": x, "params": params}, number=10000))
# 1.7 sec with new, 3.8 with old
# the implementation supports serialization
import tempfile
with tempfile.NamedTemporaryFile() as file:
torch.save(fnet2, file.name)
loaded_fnet = torch.load(file.name)
assert torch.isclose(fnet2(params, x), loaded_fnet(params, x))
Other metrics:
On torchrl's DDPG, the new in a full forward-backard pass, the old implementation of _swap_state takes approx. 20% of the runtime with small neural nets (2 layers MLP with 256 cells) on CPU. The new implementation takes approx. 6% of runtime.