neural-backed-decision-trees
neural-backed-decision-trees copied to clipboard
How to train with new dataset
Hi, I'm trying to use gen_train_eval_nopretrained.sh to train with a new dataset I implemented. However, in main.py there is this line of code tree = Tree.create_from_args(args, classes=trainset.classes) the error I got is FileNotFound at nbdt/hierarchies/mydataset/graph-induced.json and it seems that this line of code requires generated hierarchy? But I can only generate hierarchy after I've trained the model. So I'm a bit confused on what to do.
@XAVILLA Ah, good question! Do you have a regular pre-trained model for that dataset? The most stable way to train is to:
- Train a regular model
- Generate a hierarchy from the trained regular model
- Train an NBDT model using this generated hierarchy
There are tidbits of code that will allow you to fine-tune in step 3 instead of retraining or skip steps 1+2, but they're not as stable, so I don't recommend.