relationPrediction icon indicating copy to clipboard operation
relationPrediction copied to clipboard

Retrieve node embeddings from checkpoint

Open matteomedioli opened this issue 3 years ago • 1 comments

Hi, thanks for your amazing work. I'm trying to retrieve node embedding from GAT and CONV checkpoints. What the right way to do that? I'm trying something like this:

def get_embedding(args, unique_entities):
    model_conv = SpKBGATConvOnly(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
                                 args.drop_GAT, args.drop_conv, args.alpha, args.alpha_conv,
                                 args.nheads_GAT, args.out_channels)
    model_conv.load_state_dict(torch.load(
        '{0}conv/trained_{1}.pth'.format(args.output_folder, args.epochs_conv - 1)), strict=False)

    model_conv.cuda()
    model_conv.eval()
    with torch.no_grad():
        preds = model_conv(Corpus_, Corpus_.train_adj_matrix, unique_entities)
        print(preds.size())

But I0m not sure about preds = model_conv(Corpus_, Corpus_.train_adj_matrix, unique_entities). Thanks in advance

matteomedioli avatar Aug 17 '21 14:08 matteomedioli

I think to find a possible solution (but not sure about it since I need all entities embedding but I'm using train indices):

def get_embeddings():
    fl = args.data + "/2hop.pickle"
    with open(fl, 'rb') as handle:
        node_neighbors_2hop = pickle.load(handle)

    current_batch_2hop_indices = Corpus_.get_batch_nhop_neighbors_all(args, Corpus_.unique_entities_train, node_neighbors_2hop)
    if CUDA:
        current_batch_2hop_indices = Variable(
            torch.LongTensor(current_batch_2hop_indices)).cuda()
        train_indices = Variable(
                    torch.LongTensor(Corpus_.train_indices)).cuda()

    model_gat = SpKBGATModified(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
                                args.drop_GAT, args.alpha, args.nheads_GAT)
    model_gat.load_state_dict(torch.load(
        '{}/trained_{}.pth'.format(args.output_folder, args.epochs_gat - 1)), strict=False)
    model_gat.cuda()
    model_gat.eval()
    with torch.no_grad():
        entity_embed, relation_embed = model_gat(
            Corpus_, Corpus_.train_adj_matrix, train_indices, current_batch_2hop_indices)
    return entity_embed, relation_embed

matteomedioli avatar Aug 25 '21 15:08 matteomedioli