pytorch_geometric
pytorch_geometric copied to clipboard
deg argument in DegreeScalerAggregation prevents from using w/o the training set
🚀 The feature, motivation and pitch
When training a model with PNA layers, there is a need to pass deg
tensor that sets the parameters in the DegreeScalerAggregation
self.avg_deg: Dict[str, float] = {
'lin': float((bin_degrees * deg).sum()) / num_nodes,
'log': float(((bin_degrees + 1).log() * deg).sum()) / num_nodes,
'exp': float((bin_degrees.exp() * deg).sum()) / num_nodes,
}
These are neither registered buffers nor trainable params so model.save() does nothing to them, and these need to be recalculated upon every model upload from the checkpoint. This requires that every time I want to use the saved model I need to have a training set on hand.
Alternatives
avg_deg_exp
shuold be removed as it's never used. deg_lin and deg_log should be either trainable or buffered. This changes will require passing a param to the constructor (train_delta) and also modifying the PNA conv layer implementation.
if train_delta:
logging.info('Delta parameters are trainable.')
self.avg_deg_lin = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_lin]))
self.avg_deg_log = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_log]))
else:
self.register_buffer('avg_deg_lin', torch.Tensor([self.init_avg_deg_lin]))
self.register_buffer('avg_deg_log', torch.Tensor([self.init_avg_deg_log]))
Additional context
The current implementation is not allowing me to adopt the native PyG code, and I have to keep a separate implementation of these two classes in my local project.
Thanks for reporting. I am not totally sure I understand the problem. Are you talking about PyTorch Lightning integration?
No, it's not about PyTorch Lightning integration?
Simply, when I save the checkpoint, PNA conv layers don't save this dictionary
self.avg_deg: Dict[str, float] = {
'lin': float((bin_degrees * deg).sum()) / num_nodes,
'log': float(((bin_degrees + 1).log() * deg).sum()) / num_nodes,
'exp': float((bin_degrees.exp() * deg).sum()) / num_nodes,
}
in the DegreeScalerAggregation
object. So when I load the model from the checkpoint, I have to init it with the same deg
tensor that I used for training a model (either by recomputing it or loading it from a saved object). But if avg_deg_lin
and avg_deg_log
are registered parameters they will be saved as to the checkpoint file, and when I load the model from the checkpoint, the training values will be restored.
So rather than passing a deg
tensor, mybe it's better to pass two values avg_deg_lin
and avg_deg_log
and then treat them as either trainable, or registered parameters.
Do you have a minimal example to reproduce? Using torch.save(model)
will store avg_deg
as well, and using model.load_state_dict()
would require either input argument to be present for initializing the model in the first place (in this case either deg
, or avg_deg_lin
and avg_deg_log
). Sorry for my confusion, looks like I am missing something.
I have this kind of an idea in mind, but requires changes to pna_conv
so I'm not sure if it's a desirable approach:
def __init__(
self,
aggr: Union[str, List[str], Aggregation],
scaler: Union[str, List[str]],
deg_lin: float,
deg_log: float,
#deg: Tensor,
train_delta: bool=False,
aggr_kwargs: Optional[List[Dict[str, Any]]] = None,
):
...
self.init_avg_deg_lin = deg_lin
self.init_avg_deg_log = deg_log
if train_delta:
logging.info('Delta parameters are trainable.')
self.avg_deg_lin = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_lin]))
self.avg_deg_log = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_log]))
#self.avg_deg_exp = torch.nn.Parameter(torch.Tensor([self.init_avg_deg_exp]))
else:
self.register_buffer('avg_deg_lin', torch.Tensor([self.init_avg_deg_lin]))
self.register_buffer('avg_deg_log', torch.Tensor([self.init_avg_deg_log]))
#self.register_buffer('avg_deg_exp', torch.Tensor([self.init_avg_deg_exp]))
I am not a big fan of this TBH since it forces users to actually compute avg_deg_log
which adds another layer of complexity. Happy though to make this values trainable though :)
Right, this approach requires calculating the stats prior the object creation. I guess I can live with that :) but having trainable deltas
would be nice ...
I am happy to include that. Do you wanna send a PR for this?
Yeah, I will send a PR. thx.
Not sure why this issue is still open, but: after upgrading from 2.0.4
to 2.5.2
, I can no longer load previously trained weights of a PNA layer due to
Missing key(s) in state_dict: "gnn.convs.0.aggr_module.avg_deg_lin"
I couldn't find any mention in the docs of there being a breaking change in terms of loading previously saved checkpoints.
The only possible workaround for this is to manually add avg_deg_lin
to your state dict. In general, we cannot necessarily promise backward compatibility of modules between different PyG versions.