relationPrediction
relationPrediction copied to clipboard
Retrieve node embeddings from checkpoint
Hi, thanks for your amazing work. I'm trying to retrieve node embedding from GAT and CONV checkpoints. What the right way to do that? I'm trying something like this:
def get_embedding(args, unique_entities):
model_conv = SpKBGATConvOnly(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
args.drop_GAT, args.drop_conv, args.alpha, args.alpha_conv,
args.nheads_GAT, args.out_channels)
model_conv.load_state_dict(torch.load(
'{0}conv/trained_{1}.pth'.format(args.output_folder, args.epochs_conv - 1)), strict=False)
model_conv.cuda()
model_conv.eval()
with torch.no_grad():
preds = model_conv(Corpus_, Corpus_.train_adj_matrix, unique_entities)
print(preds.size())
But I0m not sure about preds = model_conv(Corpus_, Corpus_.train_adj_matrix, unique_entities)
.
Thanks in advance
I think to find a possible solution (but not sure about it since I need all entities embedding but I'm using train indices):
def get_embeddings():
fl = args.data + "/2hop.pickle"
with open(fl, 'rb') as handle:
node_neighbors_2hop = pickle.load(handle)
current_batch_2hop_indices = Corpus_.get_batch_nhop_neighbors_all(args, Corpus_.unique_entities_train, node_neighbors_2hop)
if CUDA:
current_batch_2hop_indices = Variable(
torch.LongTensor(current_batch_2hop_indices)).cuda()
train_indices = Variable(
torch.LongTensor(Corpus_.train_indices)).cuda()
model_gat = SpKBGATModified(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
args.drop_GAT, args.alpha, args.nheads_GAT)
model_gat.load_state_dict(torch.load(
'{}/trained_{}.pth'.format(args.output_folder, args.epochs_gat - 1)), strict=False)
model_gat.cuda()
model_gat.eval()
with torch.no_grad():
entity_embed, relation_embed = model_gat(
Corpus_, Corpus_.train_adj_matrix, train_indices, current_batch_2hop_indices)
return entity_embed, relation_embed