decision-forests
decision-forests copied to clipboard
tfdf.model_plotter.plot_model() is broken for GradientBoostedTreesModel and CartModel
I am using tfdf 0.2.4 and can successfully train a model and plot it using the plot_model()
function.
model = tfdf.keras.RandomForestModel()
model.fit(train_ds)
model.compile(metrics=["accuracy"])
evaluation = model.evaluate(test_ds)
with open("model.html", "w") as html_file:
html_file.write(tfdf.model_plotter.plot_model(model, tree_idx=0, max_depth=10))
For my current task I get a decision tree graph consisting of two decision nodes and tree outputs. The key line in the generated HTML file seems to be this one:
display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "PROBABILITY", "distribution": [0.006622516556291391, 0.695364238410596, 0.2781456953642384, 0.019867549668874173], "num_examples": 151.0}, "condition": {"type": "CATEGORICAL_IS_IN", "attribute": "product_group", "mask": ["DIY"]}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 0.0, 1.0, 0.0], "num_examples": 42.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.009174311926605505, 0.963302752293578, 0.0, 0.027522935779816515], "num_examples": 109.0}, "condition": {"type": "NUMERICAL_IS_HIGHER_THAN", "attribute": "height", "threshold": 42.0}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 1.0, 0.0, 0.0], "num_examples": 103.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.16666666666666666, 0.3333333333333333, 0.0, 0.5], "num_examples": 6.0}}]}]}, "#tree_plot_24de9183c1d54e6b8c963d372b714bc0")
If I use exactly the same code but replace the RandomForestModel
with a GradientBoostedTreesModel
I only get one decision and two outputs:
display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "REGRESSION", "value": -0.09703703969717026, "num_examples": 135.0, "standard_deviation": 0.08574694002066838}, "condition": {"type": "NUMERICAL_IS_HIGHER_THAN", "attribute": "length", "threshold": 227.0}, "children": [{"value": {"type": "REGRESSION", "value": -0.020000001415610313, "num_examples": 5.0, "standard_deviation": 0.4}}, {"value": {"type": "REGRESSION", "value": -0.10000000149011612, "num_examples": 130.0, "standard_deviation": 0.0}}]}, "#tree_plot_73421ac8ea9a47a88761b7441afab47c")
This can't be right since the inferences of the GradientBoostedTreesModel
are perfect (100% correct, thanks!) and that requires to take more features into account that the length od the classified object. Additionally
The model summary is below. (I have replaced some sensitive feature names). I am not really an expert but if I read the summary correctly than the decision tree should have a depth of 5 and 26 to 27 nodes. On the other hand I would have expected more noees to show for the RandomForestModel, too. ¯_(ツ)_/¯
If there is any additional information I can provide please let me know.
Model: "gradient_boosted_trees_model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "GRADIENT_BOOSTED_TREES"
Task: CLASSIFICATION
Label: "__LABEL"
Input Features (11):
parcel_count
ft_ot_text
girth
height
length
product_group
tipping_risk
shipping_mode
volume
weight
width
No weights
Variable Importance: MEAN_MIN_DEPTH:
1. "parcel_count" 3.890381 ################
2. "__LABEL" 3.890381 ################
3. "shipping_mode" 3.889688 ###############
4. "girth" 3.541476 #############
5. "ft_ot_text" 3.516287 #############
6. "width" 3.287958 ###########
7. "volume" 3.184331 ##########
8. "length" 3.039927 #########
9. "height" 2.885094 ########
10. "tipping_risk" 2.538273 ######
11. "weight" 2.267620 ####
12. "product_group" 1.719362
Variable Importance: NUM_AS_ROOT:
1. "product_group" 616.000000 ################
2. "height" 183.000000 ####
3. "weight" 172.000000 ####
4. "length" 117.000000 ##
5. "width" 47.000000
6. "volume" 41.000000
7. "tipping_risk" 24.000000
Variable Importance: NUM_NODES:
1. "weight" 2592.000000 ################
2. "tipping_risk" 2367.000000 ##############
3. "volume" 1318.000000 ########
4. "height" 1271.000000 #######
5. "product_group" 1195.000000 #######
6. "width" 1062.000000 ######
7. "girth" 968.000000 #####
8. "ft_ot_text" 730.000000 ####
9. "length" 689.000000 ####
10. "shipping_mode" 5.000000
Variable Importance: SUM_SCORE:
1. "product_group" 212.827222 ################
2. "height" 17.159601 #
3. "weight" 3.552953
4. "tipping_risk" 2.266512
5. "length" 1.447021
6. "volume" 0.999544
7. "girth" 0.891605
8. "width" 0.525099
9. "ft_ot_text" 0.106717
10. "shipping_mode" 0.000000
Loss: MULTINOMIAL_LOG_LIKELIHOOD
Validation loss value: 2.87221e-06
Number of trees per iteration: 4
Node format: NOT_SET
Number of trees: 1200
Total number of nodes: 25594
Number of nodes by tree:
Count: 1200 Average: 21.3283 StdDev: 3.18991
Min: 3 Max: 27 Ignored: 0
----------------------------------------------
[ 3, 4) 2 0.17% 0.17%
[ 4, 5) 0 0.00% 0.17%
[ 5, 6) 2 0.17% 0.33%
[ 6, 8) 0 0.00% 0.33%
[ 8, 9) 0 0.00% 0.33%
[ 9, 10) 0 0.00% 0.33%
[ 10, 11) 0 0.00% 0.33%
[ 11, 13) 8 0.67% 1.00%
[ 13, 14) 12 1.00% 2.00%
[ 14, 15) 0 0.00% 2.00%
[ 15, 16) 21 1.75% 3.75% #
[ 16, 18) 73 6.08% 9.83% ##
[ 18, 19) 0 0.00% 9.83%
[ 19, 20) 262 21.83% 31.67% #######
[ 20, 21) 0 0.00% 31.67%
[ 21, 23) 372 31.00% 62.67% ##########
[ 23, 24) 199 16.58% 79.25% #####
[ 24, 25) 0 0.00% 79.25%
[ 25, 26) 156 13.00% 92.25% ####
[ 26, 27] 93 7.75% 100.00% ###
Depth by leafs:
Count: 13397 Average: 3.9155 StdDev: 1.0663
Min: 1 Max: 5 Ignored: 0
----------------------------------------------
[ 1, 2) 178 1.33% 1.33%
[ 2, 3) 1354 10.11% 11.44% ###
[ 3, 4) 3100 23.14% 34.57% ######
[ 4, 5) 3555 26.54% 61.11% #######
[ 5, 5] 5210 38.89% 100.00% ##########
Number of training obs by leaf:
Count: 13397 Average: 12.0923 StdDev: 18.5167
Min: 5 Max: 130 Ignored: 0
----------------------------------------------
[ 5, 11) 11675 87.15% 87.15% ##########
[ 11, 17) 419 3.13% 90.27%
[ 17, 23) 42 0.31% 90.59%
[ 23, 30) 40 0.30% 90.89%
[ 30, 36) 63 0.47% 91.36%
[ 36, 42) 7 0.05% 91.41%
[ 42, 49) 1 0.01% 91.42%
[ 49, 55) 40 0.30% 91.71%
[ 55, 61) 158 1.18% 92.89%
[ 61, 68) 320 2.39% 95.28%
[ 68, 74) 53 0.40% 95.68%
[ 74, 80) 306 2.28% 97.96%
[ 80, 86) 226 1.69% 99.65%
[ 86, 93) 27 0.20% 99.85%
[ 93, 99) 16 0.12% 99.97%
[ 99, 105) 2 0.01% 99.99%
[ 105, 112) 1 0.01% 99.99%
[ 112, 118) 0 0.00% 99.99%
[ 118, 124) 0 0.00% 99.99%
[ 124, 130] 1 0.01% 100.00%
Attribute in nodes:
2592 : weight [NUMERICAL]
2367 : tipping_risk [NUMERICAL]
1318 : volume [NUMERICAL]
1271 : height [NUMERICAL]
1195 : product_group [CATEGORICAL]
1062 : width [NUMERICAL]
968 : girth [NUMERICAL]
730 : ft_ot_text [CATEGORICAL]
689 : length [NUMERICAL]
5 : shipping_mode [CATEGORICAL]
Attribute in nodes with depth <= 0:
616 : product_group [CATEGORICAL]
183 : height [NUMERICAL]
172 : weight [NUMERICAL]
117 : length [NUMERICAL]
47 : width [NUMERICAL]
41 : volume [NUMERICAL]
24 : tipping_risk [NUMERICAL]
Attribute in nodes with depth <= 1:
709 : weight [NUMERICAL]
627 : product_group [CATEGORICAL]
468 : height [NUMERICAL]
457 : tipping_risk [NUMERICAL]
378 : length [NUMERICAL]
314 : volume [NUMERICAL]
218 : width [NUMERICAL]
156 : girth [NUMERICAL]
95 : ft_ot_text [CATEGORICAL]
Attribute in nodes with depth <= 2:
1550 : weight [NUMERICAL]
1225 : tipping_risk [NUMERICAL]
767 : volume [NUMERICAL]
741 : product_group [CATEGORICAL]
675 : height [NUMERICAL]
479 : length [NUMERICAL]
437 : width [NUMERICAL]
361 : girth [NUMERICAL]
277 : ft_ot_text [CATEGORICAL]
Attribute in nodes with depth <= 3:
2342 : weight [NUMERICAL]
1860 : tipping_risk [NUMERICAL]
1077 : volume [NUMERICAL]
937 : product_group [CATEGORICAL]
927 : height [NUMERICAL]
778 : girth [NUMERICAL]
734 : width [NUMERICAL]
601 : length [NUMERICAL]
336 : ft_ot_text [CATEGORICAL]
Attribute in nodes with depth <= 5:
2592 : weight [NUMERICAL]
2367 : tipping_risk [NUMERICAL]
1318 : volume [NUMERICAL]
1271 : height [NUMERICAL]
1195 : product_group [CATEGORICAL]
1062 : width [NUMERICAL]
968 : girth [NUMERICAL]
730 : ft_ot_text [CATEGORICAL]
689 : length [NUMERICAL]
5 : shipping_mode [CATEGORICAL]
Condition type in nodes:
10267 : HigherCondition
1930 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
616 : ContainsBitmapCondition
584 : HigherCondition
Condition type in nodes with depth <= 1:
2700 : HigherCondition
722 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
5494 : HigherCondition
1018 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
8319 : HigherCondition
1273 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
10267 : HigherCondition
1930 : ContainsBitmapCondition
None
CartModel has a similar problem of showing only one decision but at least the mouseover is working.
display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "PROBABILITY", "distribution": [0.007407407407407408, 0.7333333333333333, 0.24444444444444444, 0.014814814814814815], "num_examples": 135.0}, "condition": {"type": "CATEGORICAL_IS_IN", "attribute": "product_Group", "mask": ["DIY"]}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 0.0, 1.0, 0.0], "num_examples": 33.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.00980392156862745, 0.9705882352941176, 0.0, 0.0196078431372549], "num_examples": 102.0}}]}, "#tree_plot_e7010c332612435caae222c9a1230050")
Hi, I'm not sure I correctly understand the problem just yet, but let me summarize what I think is going on.
The GradientBoostedTrees model you're building has Number of trees: 1200
i.e. it consists of 1200 trees. You inspect the first tree of this collection using tfdf.model_plotter.plot_model(model, tree_idx=0, max_depth=10)
(this is what tree_idx
does). This tree alone might not be great, but this is expected - all 1200 trees together give great performance, not a single tree.
For CART, there is indeed just a single tree - but for most problems, CART models do not perform as well as Random Forests or Gradient Boosted Trees.
Ahh, okay. Did not read the manual properly and misinterpreted the tree_idx parameter.
I had noticed that the missing class distribution bars are for the gradient boosted trees. Is that intentional?
Can you please clarify what you mean with "missing class distribution bars"?
Closing this as stale