pykeen icon indicating copy to clipboard operation
pykeen copied to clipboard

Training loop does not update relation representations when continuing training

Open lukas-schwab opened this issue 1 year ago • 2 comments

Describe the bug

When looking at the representations of TransE before and after training multiple times using the training loop with continue_training=True, only the entity representation changes. In contrast the relation representation does not change. It would be expected, that the relation representation changes just like the entity representation.

I believe this problem might exist for other models as well (e. g. DistMult)

How to reproduce

from pykeen.models import TransE, DistMult
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop
from torch.optim import Adam
from tqdm import tqdm
from typing import List
import pykeen.nn

reps = []

training_triples_factory = Nations().training

model = TransE(
    triples_factory=training_triples_factory,
    embedding_dim=2,
    #random_seed=1235,
)

optimizer = Adam(params=model.get_grad_params())
training_loop = SLCWATrainingLoop(
    model=model,
    triples_factory=training_triples_factory,
    optimizer=optimizer,
)

_ = training_loop.train(
    triples_factory=training_triples_factory,
    batch_size=32,
    num_epochs=1,
    use_tqdm=False,
    use_tqdm_batch=False
)

n = 10

for i in tqdm(range(1, n)):
    # Continue training only seems to work for entity embeddings. Relation embeddings don't change when using continue training. This might be a bug in PyKEEN

    loss = training_loop.train(
        triples_factory=training_triples_factory,
        num_epochs=i,
        batch_size=32,
        continue_training=True,
        use_tqdm=False,
        use_tqdm_batch=False
    )

    # saving the representation after training an epoch
    reps.append((model.entity_representations[0](indices=None).detach().numpy(), model.relation_representations[0](indices=None).detach().numpy()))
   

If we now look at the saved relation representations in reps we can see that they don't change over time. They are always equal to the first representation in the list. You can try this by comparing the relation representations like this:

print(reps[0][1] == [reps[i][1] for i in range(1, len(reps))])

All of the items will be True. We'd expect to see differences and therefore False for most of these values. You can see a perfect example of what it should look like when doing the same for the entity representation:

print(reps[0][0] == [reps[i][0] for i in range(1, len(reps))])

Environment

Key Value
OS posix
Platform Linux
Release 5.15.0-91-generic
Time Thu Jan 11 10:36:10 2024
Python 3.11.4
PyKEEN 1.10.1
PyKEEN Hash UNHASHED
PyKEEN Branch
PyTorch 2.1.2+cu121
CUDA Available? true
CUDA Version 12.1
cuDNN Version 8902

Additional information

No response

Issue Template Checks

  • [X] This is not a feature request (use a different issue template if it is)
  • [X] This is not a question (use the discussions forum instead)
  • [X] I've read the text explaining why including environment information is important and understand if I omit this information that my issue will be dismissed

lukas-schwab avatar Jan 11 '24 11:01 lukas-schwab

Hi @lukas-schwab ,

thanks for reporting the issue. I can reproduce it locally, although I have not yet have the time to dive deeper into why this happens.

Interestingly, it only seems to happen with two separate .train calls; if I record weights over multiple epochs of a single training run, they change (as expected):

from collections import defaultdict
from typing import Any

import torch

from pykeen.pipeline import pipeline
from pykeen.training.callbacks import TrainingCallback


class WeightRecorderCallback(TrainingCallback):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weights = defaultdict(list)

    def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None:
        for name, tensor in self.model.named_parameters():
            self.weights[name].append(tensor.detach().clone().cpu())


callback = WeightRecorderCallback()
result = pipeline(dataset="nations", model="Transe", training_kwargs=dict(callbacks=[callback]))
print(
    {
        key: [torch.allclose(weights[0], weights[i]) for i in range(len(weights))]
        for key, weights in callback.weights.items()
    }
)
# {
#   'entity_representations.0._embeddings.weight': [True, False, False, False, False], 
#   'relation_representations.0._embeddings.weight': [True, False, False, False, False]
# }

mberr avatar Jan 12 '24 17:01 mberr

Interestingly, it only seems to happen with two separate .train calls

Yes, it's quite strange. Took me a while to get convinced that the library was at fault here and not me.

Fortunately the code you provided is the proper solution to what I was actually trying to achieve. So thank you for posting that and keep up the good work!

lukas-schwab avatar Jan 12 '24 18:01 lukas-schwab