SimKGC icon indicating copy to clipboard operation
SimKGC copied to clipboard

How to use a trained model to obtain entity embeddings

Open xga0 opened this issue 1 year ago • 1 comments

Thank you for sharing the code!

After we have trained a model, how to use this trained model to obtain entity embeddings? Should we use the following function:

    def predict_by_entities(self, entity_exs) -> torch.tensor:
        examples = []
        for entity_ex in entity_exs:
            examples.append(Example(head_id='', relation='',
                                    tail_id=entity_ex.entity_id))
        data_loader = torch.utils.data.DataLoader(
            Dataset(path='', examples=examples, task=args.task),
            num_workers=2,
            batch_size=max(args.batch_size, 1024),
            collate_fn=collate,
            shuffle=False)

        ent_tensor_list = []
        for idx, batch_dict in enumerate(tqdm.tqdm(data_loader)):
            batch_dict['only_ent_embedding'] = True
            if self.use_cuda:
                batch_dict = move_to_cuda(batch_dict)
            outputs = self.model(**batch_dict)
            ent_tensor_list.append(outputs['ent_vectors'])

        return torch.cat(ent_tensor_list, dim=0)

Also, is the vectors.json file mentioned in this issue generated by the following line?

entity_tensor = predictor.predict_by_entities(entity_dict.entity_exs)

Thank you!

xga0 avatar Aug 10 '23 20:08 xga0

Yes, that's correct.

intfloat avatar Aug 11 '23 05:08 intfloat