SimKGC
SimKGC copied to clipboard
How to use a trained model to obtain entity embeddings
Thank you for sharing the code!
After we have trained a model, how to use this trained model to obtain entity embeddings? Should we use the following function:
def predict_by_entities(self, entity_exs) -> torch.tensor:
examples = []
for entity_ex in entity_exs:
examples.append(Example(head_id='', relation='',
tail_id=entity_ex.entity_id))
data_loader = torch.utils.data.DataLoader(
Dataset(path='', examples=examples, task=args.task),
num_workers=2,
batch_size=max(args.batch_size, 1024),
collate_fn=collate,
shuffle=False)
ent_tensor_list = []
for idx, batch_dict in enumerate(tqdm.tqdm(data_loader)):
batch_dict['only_ent_embedding'] = True
if self.use_cuda:
batch_dict = move_to_cuda(batch_dict)
outputs = self.model(**batch_dict)
ent_tensor_list.append(outputs['ent_vectors'])
return torch.cat(ent_tensor_list, dim=0)
Also, is the vectors.json file mentioned in this issue generated by the following line?
entity_tensor = predictor.predict_by_entities(entity_dict.entity_exs)
Thank you!
Yes, that's correct.