RNNLogic
RNNLogic copied to clipboard
Case Study of Generated Logic Rules
Hi, I'm interested in how to get the generated logic rules in Table 4 and Table 7 as your paper shows. And I can't find the relations in the dataset of FB15k-237 as what you have mentioned in Table 4 and Table 7.
The relation names of FB15k-237 are very long. So we have renamed them in the paper by understanding their meaning.
The following code can be used to print rules from trained models, with various constraints:
import gc
import copy
gc.enable()
import os
import sys
from sys import *
#from random import *
from collections import defaultdict
import torch
import torch.nn as nn
from torch import *
from torch.nn import *
from torch.optim import *
from random import shuffle
from random import randint
import time
import datetime
import json
import torch.nn.functional as F
from model import *
relation2id = dict()
id2relation = dict()
with open(f'dataset/FB15k-237/relations.dict') as fin:
relation2id = dict()
for line in fin:
rid, relation = line.strip().split('\t')
relation2id[relation] = int(rid)
rel = ""
cnt = 0
for c in reversed(relation):
if c == '/':
cnt += 1
if cnt == 2 and len(rel) >= 30:
break
if cnt >= 3:
break
rel = c + rel
id2relation[int(rid)] = rel
R = len(relation2id)
mov = R
inv = [0] * 2 * R
for i in range(R):
inv[i + mov] = i
inv[i] = i + mov
id2relation[i + mov] = "!" + id2relation[i]
a = torch.load(sys.argv[1])
r = int(a['r'])
def has_revlink(rule):
for i in range(len(rule) - 1):
if rule[i] == inv[rule[i + 1]]:
return True
return False
def contains_r(rule):
r = int(a['r'])
ret = r in set(map(int, rule))
return ret
def prt(p, n=10):
for _i in range(n):
i = p[_i]
r = id2relation[int(a['r'])]
path = map(lambda x : id2relation[int(x)], a['rules'][i])
val = a['predictor']['rule_weight_raw'][i]
print(f"&$\\gets$&$",end='')
Ltr = "XUVW"
for i, r in enumerate(path):
r = r.replace('_', '\\_')
if r[0] != '!':
print(f"{Ltr[i]}\\relarr{{{r}}}",end="")
else:
print(f"{Ltr[i]}\\relarrl{{{r[1:]}}}",end="")
print("Y$\\\\")
weight = a['predictor']['rule_weight_raw']
print("Relation:", id2relation[int(a['r'])])
print("general:")
p = sorted(range(len(a['rules'])),
key=lambda i : (weight[i]),
reverse=True)
prt(p)
print("no self:")
p = sorted(range(len(a['rules'])),
key=lambda i : (not contains_r(a['rules'][i]), not has_revlink(a['rules'][i]), weight[i]),
reverse=True)
prt(p,n=40)
print("revlink:")
p = sorted(range(len(a['rules'])),
key=lambda i : (has_revlink(a['rules'][i]), weight[i]),
reverse=True)
prt(p,n=40)
When I run the above code, I encounter an error " a = torch.load(sys.argv[1]) IndexError: list index out of range". How to solve this problem?
Hello,
sys.argv[1] is the path to the model file, that is learnt during the train time. set it to ''./workspace/model_0.pth' to see the rules for relation 0 (r=0) and so on for the other relations. Also, the rules generated by this code are in the form of pdf code. Give the output of this file as the input to latex file. I hope this helps.