DecisionTree.jl icon indicating copy to clipboard operation
DecisionTree.jl copied to clipboard

Input checking

Open cstjean opened this issue 8 years ago • 5 comments

This is scary:

tree = fit!(DecisionTreeRegressor(), [1.0 2; 3 4], [10, 24.0])
predict(tree, [])
> 17.0

apply_tree also accepts (and ignores) extra values in the feature_vector without complaining.

cstjean avatar Nov 02 '16 14:11 cstjean

Yeah, this is an issue. Back to your example, note that the tree generated is actually a leaf, and so there is no decision to be made based on input features:

print_tree(tree.root)
Feature 0, Threshold nothing
L-> 17.0 : 0/2
R-> nothing : 1/1

But if we forced it to 1 sample per leaf, it would complain upon prediction with no input:

tree2 = fit!(DecisionTreeRegressor( maxlabels=1 ), [1.0 2; 3 4], [10, 24.0])

print_tree(tree2.root)
Feature 1, Threshold 3.0
L-> 10.0 : 1/1
R-> 24.0 : 1/1

predict(tree2, [])
ERROR: BoundsError: attempt to access 0-element Array{Any,1} at index [1]

Now if we want to do input feature checking, then we'd need add this metadata to the tree model, and add a new variable to the Node type, which would be replicated down all the subsequent Nodes. My concern is around model size; as is, the trees are very bloated, and added a new var would help on that front.

What do you think?

bensadeghi avatar Nov 16 '16 03:11 bensadeghi

I think that it's easy to fix for the ScikitLearn interface: we can just add an input_width field to the models, and check once in predict. For apply_tree, it would suck to add an extra 4 bytes to every node and leaf in the tree, just for input checking, so I'm not sure that there's a solution. Is that what you're saying?

My concern is around model size; as is, the trees are very bloated, and added a new var would help on that front.

cstjean avatar Nov 16 '16 13:11 cstjean

With regards to bloat, I assume you're referring to leaf.values? We should store the value counts there, instead of storing all values. It would make apply_tree_proba much faster.

cstjean avatar Nov 16 '16 13:11 cstjean

I'm still hesitant to add a new field to the Node type. If this issue is handled in SKL.jl, then it's ok.

And yes, the bloated models need to be addressed with label value counts, using something like Dict(All, Int). This is something I've been meaning to do for some time, but just haven't gotten around to it... a bunch of functions would need to be reworked.

I've resolved #39 and #41 for now, and will submit to METADATA. Will need to get back to this issue a bit later.

Thanks for all the help!!

bensadeghi avatar Dec 10 '16 07:12 bensadeghi

I'm still hesitant to add a new field to the Node type. If this issue is handled in SKL.jl, then it's ok.

I feel the same. I'll patch up SKL this week.

cstjean avatar Dec 10 '16 10:12 cstjean