pykeen
pykeen copied to clipboard
Training loop does not update relation representations when continuing training
Describe the bug
When looking at the representations of TransE before and after training multiple times using the training loop with continue_training=True, only the entity representation changes. In contrast the relation representation does not change. It would be expected, that the relation representation changes just like the entity representation.
I believe this problem might exist for other models as well (e. g. DistMult)
How to reproduce
from pykeen.models import TransE, DistMult
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop
from torch.optim import Adam
from tqdm import tqdm
from typing import List
import pykeen.nn
reps = []
training_triples_factory = Nations().training
model = TransE(
triples_factory=training_triples_factory,
embedding_dim=2,
#random_seed=1235,
)
optimizer = Adam(params=model.get_grad_params())
training_loop = SLCWATrainingLoop(
model=model,
triples_factory=training_triples_factory,
optimizer=optimizer,
)
_ = training_loop.train(
triples_factory=training_triples_factory,
batch_size=32,
num_epochs=1,
use_tqdm=False,
use_tqdm_batch=False
)
n = 10
for i in tqdm(range(1, n)):
# Continue training only seems to work for entity embeddings. Relation embeddings don't change when using continue training. This might be a bug in PyKEEN
loss = training_loop.train(
triples_factory=training_triples_factory,
num_epochs=i,
batch_size=32,
continue_training=True,
use_tqdm=False,
use_tqdm_batch=False
)
# saving the representation after training an epoch
reps.append((model.entity_representations[0](indices=None).detach().numpy(), model.relation_representations[0](indices=None).detach().numpy()))
If we now look at the saved relation representations in reps we can see that they don't change over time. They are always equal to the first representation in the list. You can try this by comparing the relation representations like this:
print(reps[0][1] == [reps[i][1] for i in range(1, len(reps))])
All of the items will be True. We'd expect to see differences and therefore False for most of these values. You can see a perfect example of what it should look like when doing the same for the entity representation:
print(reps[0][0] == [reps[i][0] for i in range(1, len(reps))])
Environment
| Key | Value |
|---|---|
| OS | posix |
| Platform | Linux |
| Release | 5.15.0-91-generic |
| Time | Thu Jan 11 10:36:10 2024 |
| Python | 3.11.4 |
| PyKEEN | 1.10.1 |
| PyKEEN Hash | UNHASHED |
| PyKEEN Branch | |
| PyTorch | 2.1.2+cu121 |
| CUDA Available? | true |
| CUDA Version | 12.1 |
| cuDNN Version | 8902 |
Additional information
No response
Issue Template Checks
- [X] This is not a feature request (use a different issue template if it is)
- [X] This is not a question (use the discussions forum instead)
- [X] I've read the text explaining why including environment information is important and understand if I omit this information that my issue will be dismissed
Hi @lukas-schwab ,
thanks for reporting the issue. I can reproduce it locally, although I have not yet have the time to dive deeper into why this happens.
Interestingly, it only seems to happen with two separate .train calls; if I record weights over multiple epochs of a single training run, they change (as expected):
from collections import defaultdict
from typing import Any
import torch
from pykeen.pipeline import pipeline
from pykeen.training.callbacks import TrainingCallback
class WeightRecorderCallback(TrainingCallback):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weights = defaultdict(list)
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None:
for name, tensor in self.model.named_parameters():
self.weights[name].append(tensor.detach().clone().cpu())
callback = WeightRecorderCallback()
result = pipeline(dataset="nations", model="Transe", training_kwargs=dict(callbacks=[callback]))
print(
{
key: [torch.allclose(weights[0], weights[i]) for i in range(len(weights))]
for key, weights in callback.weights.items()
}
)
# {
# 'entity_representations.0._embeddings.weight': [True, False, False, False, False],
# 'relation_representations.0._embeddings.weight': [True, False, False, False, False]
# }
Interestingly, it only seems to happen with two separate .train calls
Yes, it's quite strange. Took me a while to get convinced that the library was at fault here and not me.
Fortunately the code you provided is the proper solution to what I was actually trying to achieve. So thank you for posting that and keep up the good work!