dice-embeddings icon indicating copy to clipboard operation
dice-embeddings copied to clipboard

Kolmogorov-Arnold Networks (KANs) over triple embeddings

Open Demirrr opened this issue 1 year ago • 3 comments

Integrate Kolmogorov-Arnold Networks (KANs) model for forward_triples

KAN(s,p,o) -> [0,1] forward triples KAN(s,p) -> [0,....,] KvsAll

Demirrr avatar May 08 '24 18:05 Demirrr

Original code shared by the authors https://github.com/KindXiaoming/pykan/blob/master/kan/KAN.py

Demirrr avatar May 08 '24 18:05 Demirrr

The previous link is broken. Here is the new one => https://github.com/KindXiaoming/pykan/blob/ecde4ec3274d3bef1ad737479cf126aed38ab530/kan/KANLayer.py#L8

Demirrr avatar May 05 '25 07:05 Demirrr

@Amgad-Abdallah-Mahmoud has already experience in KAN.

KANKGE can be implemented by inheriting from BaseKGE as shown below


class KANKGE(BaseKGE):
    def __init__(self, args):
        super().__init__(args)
        self.name = 'KANKGE'

See the full implementation of DistMult (a simple KGE model).


class DistMult(BaseKGE):
    """Embedding Entities and Relations for Learning and Inference in Knowledge Bases (https://arxiv.org/abs/1412.6575)"""

    def __init__(self, args):
        super().__init__(args)
        self.name = 'DistMult'

    def k_vs_all_score(self, emb_h: torch.FloatTensor, emb_r: torch.FloatTensor, emb_E: torch.FloatTensor):
        return torch.mm(self.hidden_dropout(self.hidden_normalizer(emb_h * emb_r)), emb_E.transpose(1, 0))

    def forward_k_vs_all(self, x: torch.LongTensor):
        emb_head, emb_rel = self.get_head_relation_representation(x)
        return self.k_vs_all_score(emb_h=emb_head, emb_r=emb_rel, emb_E=self.entity_embeddings.weight)

    def forward_k_vs_sample(self, x: torch.LongTensor, target_entity_idx: torch.LongTensor):
        return torch.einsum('bd, bkd -> bk', hr, t)


    def score(self, h, r, t):
        return (self.hidden_dropout(self.hidden_normalizer(h * r)) * t).sum(dim=1)

Is all clear @Amgad-Abdallah-Mahmoud ?

Demirrr avatar May 05 '25 07:05 Demirrr