Continuing training based on checkpoint using torch tile
Description
Saving and loading the model and optimizer state before resuming training changes the behavior.
How to reproduce
MWE:
import torch
from aihwkit.optim.analog_optimizer import AnalogSGD
from aihwkit.simulator.configs import TorchInferenceRPUConfig
from aihwkit.simulator.configs.utils import (
BoundManagementType,
NoiseManagementType,
)
from aihwkit.nn.conversion import convert_to_analog
def train_linear_regression(reload: bool):
def generate_toy_data(num_samples=100):
torch.manual_seed(0)
X = 2 * torch.rand(num_samples, 1)
y = 4 + 3 * X + torch.rand(num_samples, 1)
return X, y
def mean_squared_error(y_true, y_pred):
return torch.mean((y_true - y_pred) ** 2)
torch.manual_seed(0)
num_epochs = 1000
learning_rate = 0.001
X, y = generate_toy_data()
model = torch.nn.Linear(1, 1)
rpu_config = TorchInferenceRPUConfig()
rpu_config.forward.bound_management = BoundManagementType.NONE
rpu_config.forward.noise_management = NoiseManagementType.NONE
rpu_config.forward.out_noise = 0.0
rpu_config.pre_post.input_range.enable = True
rpu_config.pre_post.input_range.init_value = 3.0
rpu_config.forward.is_perfect = True
rpu_config.pre_post.input_range.enable = True
rpu_config.pre_post.input_range.init_from_data = 1000
rpu_config.pre_post.input_range.learn_input_range = False
rpu_config.pre_post.input_range.decay = 0.0
model = convert_to_analog(model, rpu_config)
optimizer = AnalogSGD(params=model.parameters(), lr=learning_rate)
losses = []
for epoch in range(num_epochs):
predictions = model.forward(X)
loss = mean_squared_error(y, predictions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if epoch == 500 and reload:
sd = model.state_dict()
optimizer_sd = optimizer.state_dict()
model = torch.nn.Linear(1, 1)
model = convert_to_analog(model, rpu_config)
optimizer = AnalogSGD(params=model.parameters(), lr=learning_rate)
model.load_state_dict(sd)
optimizer.load_state_dict(optimizer_sd)
return torch.tensor(losses)
if __name__ == "__main__":
losses_false = train_linear_regression(reload=False)
losses_true = train_linear_regression(reload=True)
assert torch.allclose(losses_false, losses_true, atol=1e-4)
Expected behavior
Losses should be the same.
@maljoras I have identified the source of the bug, but I can't really solve it in an elegant fashion.
What causes the bug:
In the __init__ of the TorchInferenceTile (see here) if I pass ignore_analog_state=False the test passes.
This is because a new tile is not created and therefore new Parameter objects are not created. When a new Parameter object is created, the corresponding parameter in the optimizer now points to the wrong memory. This is why the MWE above works when I call optimizer = AnalogSGD(params=model.parameters(), lr=learning_rate) after model.load_state_dict(sd).
What is the best way to solve this?
The problem is that the torch tile might bypass important tile loading code. You should make sure that this code is called when setting the tile, see aihwkit.simulator.tile.module.py line 120. Originally it was, maybe something has changed since.
Just for reference, this is the code:
def __setstate__(self, state: Dict) -> None:
# pylint: disable=no-member
if hasattr(super(Module, self), "__setstate__"):
# The TileWrapper is handling all the attributes
super(Module, self).__setstate__(state)
else:
Module.__setstate__(self, state)
# update parameter IDs
for name in self._parameters: # type: ignore
self._parameters[name] = getattr(self, name) # type: ignore
for name in self._buffers:
self._buffers[name] = getattr(self, name)
for name in self._modules:
self._modules[name] = getattr(self, name)
So, this line super(Module, self).__setstate__(state) creates the new tile, including a brand-new Parameter object that is then populated with the weights. The code below that is also called, but when we loop over the _parameters of self which is the TorchInferenceTile, we don't have any weights, since the weight parameters are in the tile, i.e. self.tile. Maybe that is the issue?
Yes, I think so, too. For that reason I defined all parameters on the Tile Module level and not on the SimulatorTile (self.tile) level even for TorchInferenceTile. Has that changed?
For the torch tile it was always like that I think. How should we proceed?
@jubueche Is this only the case for init_from_data=True? Could it be that the counter is not saved, so that the init is done again after loading?
@maljoras That is also the case (I made a private PR for that) but here I explicitly avoided that by using 1000 epochs and setting the init_from_data to 1000 as well.