aihwkit icon indicating copy to clipboard operation
aihwkit copied to clipboard

Continuing training based on checkpoint using torch tile

Open jubueche opened this issue 1 year ago • 7 comments

Description

Saving and loading the model and optimizer state before resuming training changes the behavior.

How to reproduce

MWE:

import torch
from aihwkit.optim.analog_optimizer import AnalogSGD
from aihwkit.simulator.configs import TorchInferenceRPUConfig
from aihwkit.simulator.configs.utils import (
    BoundManagementType,
    NoiseManagementType,
)
from aihwkit.nn.conversion import convert_to_analog


def train_linear_regression(reload: bool):
    def generate_toy_data(num_samples=100):
        torch.manual_seed(0)
        X = 2 * torch.rand(num_samples, 1)
        y = 4 + 3 * X + torch.rand(num_samples, 1)
        return X, y

    def mean_squared_error(y_true, y_pred):
        return torch.mean((y_true - y_pred) ** 2)

    torch.manual_seed(0)
    num_epochs = 1000
    learning_rate = 0.001
    X, y = generate_toy_data()
    model = torch.nn.Linear(1, 1)

    rpu_config = TorchInferenceRPUConfig()
    rpu_config.forward.bound_management = BoundManagementType.NONE
    rpu_config.forward.noise_management = NoiseManagementType.NONE
    rpu_config.forward.out_noise = 0.0
    rpu_config.pre_post.input_range.enable = True
    rpu_config.pre_post.input_range.init_value = 3.0
    rpu_config.forward.is_perfect = True
    rpu_config.pre_post.input_range.enable = True
    rpu_config.pre_post.input_range.init_from_data = 1000
    rpu_config.pre_post.input_range.learn_input_range = False
    rpu_config.pre_post.input_range.decay = 0.0

    model = convert_to_analog(model, rpu_config)
    optimizer = AnalogSGD(params=model.parameters(), lr=learning_rate)

    losses = []
    for epoch in range(num_epochs):
        predictions = model.forward(X)
        loss = mean_squared_error(y, predictions)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if epoch == 500 and reload:
            sd = model.state_dict()
            optimizer_sd = optimizer.state_dict()
            model = torch.nn.Linear(1, 1)
            model = convert_to_analog(model, rpu_config)
            optimizer = AnalogSGD(params=model.parameters(), lr=learning_rate)
            model.load_state_dict(sd)
            optimizer.load_state_dict(optimizer_sd)
    return torch.tensor(losses)

if __name__ == "__main__":
    losses_false = train_linear_regression(reload=False)
    losses_true = train_linear_regression(reload=True)
    assert torch.allclose(losses_false, losses_true, atol=1e-4)

Expected behavior

Losses should be the same.

jubueche avatar Jan 26 '24 15:01 jubueche

@maljoras I have identified the source of the bug, but I can't really solve it in an elegant fashion. What causes the bug: In the __init__ of the TorchInferenceTile (see here) if I pass ignore_analog_state=False the test passes. This is because a new tile is not created and therefore new Parameter objects are not created. When a new Parameter object is created, the corresponding parameter in the optimizer now points to the wrong memory. This is why the MWE above works when I call optimizer = AnalogSGD(params=model.parameters(), lr=learning_rate) after model.load_state_dict(sd). What is the best way to solve this?

jubueche avatar Jan 26 '24 15:01 jubueche

The problem is that the torch tile might bypass important tile loading code. You should make sure that this code is called when setting the tile, see aihwkit.simulator.tile.module.py line 120. Originally it was, maybe something has changed since.

maljoras avatar Jan 26 '24 17:01 maljoras

Just for reference, this is the code:

def __setstate__(self, state: Dict) -> None:
    # pylint: disable=no-member

    if hasattr(super(Module, self), "__setstate__"):
        # The TileWrapper is handling all the attributes
        super(Module, self).__setstate__(state)
    else:
        Module.__setstate__(self, state)

    # update parameter IDs
    for name in self._parameters:  # type: ignore
        self._parameters[name] = getattr(self, name)  # type: ignore

    for name in self._buffers:
        self._buffers[name] = getattr(self, name)

    for name in self._modules:
        self._modules[name] = getattr(self, name)

So, this line super(Module, self).__setstate__(state) creates the new tile, including a brand-new Parameter object that is then populated with the weights. The code below that is also called, but when we loop over the _parameters of self which is the TorchInferenceTile, we don't have any weights, since the weight parameters are in the tile, i.e. self.tile. Maybe that is the issue?

jubueche avatar Jan 29 '24 09:01 jubueche

Yes, I think so, too. For that reason I defined all parameters on the Tile Module level and not on the SimulatorTile (self.tile) level even for TorchInferenceTile. Has that changed?

maljoras avatar Jan 29 '24 10:01 maljoras

For the torch tile it was always like that I think. How should we proceed?

jubueche avatar Jan 29 '24 10:01 jubueche

@jubueche Is this only the case for init_from_data=True? Could it be that the counter is not saved, so that the init is done again after loading?

maljoras avatar Feb 12 '24 16:02 maljoras

@maljoras That is also the case (I made a private PR for that) but here I explicitly avoided that by using 1000 epochs and setting the init_from_data to 1000 as well.

jubueche avatar Feb 12 '24 16:02 jubueche