neural-backed-decision-trees icon indicating copy to clipboard operation
neural-backed-decision-trees copied to clipboard

How to train with new dataset

Open XAVILLA opened this issue 3 years ago • 1 comments

Hi, I'm trying to use gen_train_eval_nopretrained.sh to train with a new dataset I implemented. However, in main.py there is this line of code tree = Tree.create_from_args(args, classes=trainset.classes) the error I got is FileNotFound at nbdt/hierarchies/mydataset/graph-induced.json and it seems that this line of code requires generated hierarchy? But I can only generate hierarchy after I've trained the model. So I'm a bit confused on what to do.

XAVILLA avatar Aug 30 '21 07:08 XAVILLA

@XAVILLA Ah, good question! Do you have a regular pre-trained model for that dataset? The most stable way to train is to:

  1. Train a regular model
  2. Generate a hierarchy from the trained regular model
  3. Train an NBDT model using this generated hierarchy

There are tidbits of code that will allow you to fine-tune in step 3 instead of retraining or skip steps 1+2, but they're not as stable, so I don't recommend.

alvinwan avatar Aug 31 '21 19:08 alvinwan