RNNLogic icon indicating copy to clipboard operation
RNNLogic copied to clipboard

Case Study of Generated Logic Rules

Open LuMflowers opened this issue 3 years ago • 3 comments

Hi, I'm interested in how to get the generated logic rules in Table 4 and Table 7 as your paper shows. And I can't find the relations in the dataset of FB15k-237 as what you have mentioned in Table 4 and Table 7.

LuMflowers avatar Aug 08 '21 17:08 LuMflowers

The relation names of FB15k-237 are very long. So we have renamed them in the paper by understanding their meaning.

The following code can be used to print rules from trained models, with various constraints:

import gc
import copy
gc.enable()

import os
import sys
from sys import *
#from random import *
from collections import defaultdict
import torch
import torch.nn as nn
from torch import *
from torch.nn import *
from torch.optim import *
from random import shuffle
from random import randint
import time
import datetime
import json
import torch.nn.functional as F
from model import *

relation2id = dict()
id2relation = dict()

with open(f'dataset/FB15k-237/relations.dict') as fin:
	relation2id = dict()
	for line in fin:
		rid, relation = line.strip().split('\t')
		relation2id[relation] = int(rid)

		rel = ""
		cnt = 0
		for c in reversed(relation):
			if c == '/':
				cnt += 1
				if cnt == 2 and len(rel) >= 30:
					break
				if cnt >= 3:
					break
			rel = c + rel

		id2relation[int(rid)] = rel

R = len(relation2id)

mov = R
inv = [0] * 2 * R

for i in range(R):
	inv[i + mov] = i
	inv[i] = i + mov
	id2relation[i + mov] = "!" + id2relation[i]

a = torch.load(sys.argv[1])
r = int(a['r'])

def has_revlink(rule):
	for i in range(len(rule) - 1):
		if rule[i] == inv[rule[i + 1]]:
			return True
	return False


def contains_r(rule):
	r = int(a['r'])
	ret = r in set(map(int, rule))
	return ret

def prt(p, n=10):
	for _i in range(n):
		i = p[_i]
		r = id2relation[int(a['r'])]
		path = map(lambda x : id2relation[int(x)], a['rules'][i])
		val = a['predictor']['rule_weight_raw'][i]


		print(f"&$\\gets$&$",end='')
		Ltr = "XUVW"
		for i, r in enumerate(path):
			r = r.replace('_', '\\_')
			if r[0] != '!':
				print(f"{Ltr[i]}\\relarr{{{r}}}",end="")
			else:
				print(f"{Ltr[i]}\\relarrl{{{r[1:]}}}",end="")

		print("Y$\\\\")



weight = a['predictor']['rule_weight_raw']


print("Relation:", id2relation[int(a['r'])])

print("general:")
p = sorted(range(len(a['rules'])), 
	key=lambda i : (weight[i]),
	reverse=True)
prt(p)

print("no self:")
p = sorted(range(len(a['rules'])), 
	key=lambda i : (not contains_r(a['rules'][i]), not has_revlink(a['rules'][i]), weight[i]),
	reverse=True)
prt(p,n=40)

print("revlink:")
p = sorted(range(len(a['rules'])), 
	key=lambda i : (has_revlink(a['rules'][i]), weight[i]),
	reverse=True)
prt(p,n=40)

immortalCO avatar Aug 09 '21 15:08 immortalCO

When I run the above code, I encounter an error " a = torch.load(sys.argv[1]) IndexError: list index out of range". How to solve this problem?

LuMflowers avatar Aug 12 '21 07:08 LuMflowers

Hello,

sys.argv[1] is the path to the model file, that is learnt during the train time. set it to ''./workspace/model_0.pth' to see the rules for relation 0 (r=0) and so on for the other relations. Also, the rules generated by this code are in the form of pdf code. Give the output of this file as the input to latex file. I hope this helps.

navdeepkjohal avatar May 08 '22 17:05 navdeepkjohal