pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

deg argument in DegreeScalerAggregation prevents from using w/o the training set

Open pgniewko opened this issue 2 years ago • 4 comments

🚀 The feature, motivation and pitch

When training a model with PNA layers, there is a need to pass deg tensor that sets the parameters in the DegreeScalerAggregation

 self.avg_deg: Dict[str, float] = {
            'lin': float((bin_degrees * deg).sum()) / num_nodes,
            'log': float(((bin_degrees + 1).log() * deg).sum()) / num_nodes,
            'exp': float((bin_degrees.exp() * deg).sum()) / num_nodes,
        }

These are neither registered buffers nor trainable params so model.save() does nothing to them, and these need to be recalculated upon every model upload from the checkpoint. This requires that every time I want to use the saved model I need to have a training set on hand.

Alternatives

avg_deg_exp shuold be removed as it's never used. deg_lin and deg_log should be either trainable or buffered. This changes will require passing a param to the constructor (train_delta) and also modifying the PNA conv layer implementation.

if train_delta:
            logging.info('Delta parameters are trainable.')
            self.avg_deg_lin = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_lin]))
            self.avg_deg_log = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_log]))
        else:
            self.register_buffer('avg_deg_lin', torch.Tensor([self.init_avg_deg_lin]))
            self.register_buffer('avg_deg_log', torch.Tensor([self.init_avg_deg_log]))

Additional context

The current implementation is not allowing me to adopt the native PyG code, and I have to keep a separate implementation of these two classes in my local project.

pgniewko avatar Sep 21 '22 07:09 pgniewko

Thanks for reporting. I am not totally sure I understand the problem. Are you talking about PyTorch Lightning integration?

rusty1s avatar Sep 21 '22 10:09 rusty1s

No, it's not about PyTorch Lightning integration?

Simply, when I save the checkpoint, PNA conv layers don't save this dictionary

self.avg_deg: Dict[str, float] = {
            'lin': float((bin_degrees * deg).sum()) / num_nodes,
            'log': float(((bin_degrees + 1).log() * deg).sum()) / num_nodes,
            'exp': float((bin_degrees.exp() * deg).sum()) / num_nodes,
        }

in the DegreeScalerAggregation object. So when I load the model from the checkpoint, I have to init it with the same deg tensor that I used for training a model (either by recomputing it or loading it from a saved object). But if avg_deg_lin and avg_deg_log are registered parameters they will be saved as to the checkpoint file, and when I load the model from the checkpoint, the training values will be restored.

pgniewko avatar Sep 22 '22 04:09 pgniewko

So rather than passing a deg tensor, mybe it's better to pass two values avg_deg_lin and avg_deg_log and then treat them as either trainable, or registered parameters.

pgniewko avatar Sep 22 '22 04:09 pgniewko

Do you have a minimal example to reproduce? Using torch.save(model) will store avg_deg as well, and using model.load_state_dict() would require either input argument to be present for initializing the model in the first place (in this case either deg, or avg_deg_lin and avg_deg_log). Sorry for my confusion, looks like I am missing something.

rusty1s avatar Sep 22 '22 07:09 rusty1s

I have this kind of an idea in mind, but requires changes to pna_conv so I'm not sure if it's a desirable approach:

    def __init__(
        self,
        aggr: Union[str, List[str], Aggregation],
        scaler: Union[str, List[str]],
        deg_lin: float,
        deg_log: float,
        #deg: Tensor,
        train_delta: bool=False,
        aggr_kwargs: Optional[List[Dict[str, Any]]] = None,
    ):
...
        self.init_avg_deg_lin = deg_lin
        self.init_avg_deg_log = deg_log
        if train_delta:
            logging.info('Delta parameters are trainable.')
            self.avg_deg_lin = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_lin]))
            self.avg_deg_log = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_log]))
            #self.avg_deg_exp = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_exp]))
        else:
            self.register_buffer('avg_deg_lin', torch.Tensor([self.init_avg_deg_lin]))
            self.register_buffer('avg_deg_log', torch.Tensor([self.init_avg_deg_log]))
            #self.register_buffer('avg_deg_exp', torch.Tensor([self.init_avg_deg_exp]))

pgniewko avatar Oct 06 '22 00:10 pgniewko

I am not a big fan of this TBH since it forces users to actually compute avg_deg_log which adds another layer of complexity. Happy though to make this values trainable though :)

rusty1s avatar Oct 06 '22 12:10 rusty1s

Right, this approach requires calculating the stats prior the object creation. I guess I can live with that :) but having trainable deltas would be nice ...

pgniewko avatar Oct 11 '22 22:10 pgniewko

I am happy to include that. Do you wanna send a PR for this?

rusty1s avatar Oct 13 '22 12:10 rusty1s

Yeah, I will send a PR. thx.

pgniewko avatar Oct 14 '22 17:10 pgniewko

Not sure why this issue is still open, but: after upgrading from 2.0.4 to 2.5.2, I can no longer load previously trained weights of a PNA layer due to

Missing key(s) in state_dict: "gnn.convs.0.aggr_module.avg_deg_lin"

I couldn't find any mention in the docs of there being a breaking change in terms of loading previously saved checkpoints.

kmaziarz avatar Jun 06 '24 16:06 kmaziarz

The only possible workaround for this is to manually add avg_deg_lin to your state dict. In general, we cannot necessarily promise backward compatibility of modules between different PyG versions.

rusty1s avatar Jun 14 '24 06:06 rusty1s