yggdrasil-decision-forests icon indicating copy to clipboard operation
yggdrasil-decision-forests copied to clipboard

Using weights creates a confusing confusion matrix and inaccurate accuracy score

Open CodingDoug opened this issue 6 days ago • 0 comments

I'm using RandomForestLearner to train a 10-class categorization model using roughly 15000 examples and 12 features. My example set is imbalanced in terms of category distribution, so I need to use class-based weighting to boost the under-represented classes.

I'm post-processing my dataset with weights computed from the entire set:

for row in rows:
    row["weight"] = count / (len(category_counts) * category_counts[row["category"]])

The resulting model is effective, but the confusion matrix is confusing. Here part of the output from model.describe():

Confusion Table:
truth\prediction
          n     p     e     h     c     b     t     a     d     s
    n1504.9912.17848.717155.255933.717610.384580.384581.794711.281931.79471
    p10.05111525.09     0     0     00.670074     0     0     04.69052
    e23.4119     01517.09     0     0     0     0     0     0     0
    h31.82856.3657     01502.31     0     0     0     0     0     0
    c332.748     0     0     01170.78     0     0     024.64812.324
    b28.527828.5278     0     0     01483.44     0     0     0     0
    t55.6807     0     0     0     0     01484.82     0     0     0
    a418.407     0     0     0     0     0     01122.09     0     0
    d303.432     0     0     023.3409     0     046.68181143.723.3409
    s252.08284.0273     0     028.0091     0     0     028.00911148.37
Total: 15405

Basically unreadable. Here it is again from model.self_evaluation():

accuracy: 0.883005
confusion matrix:
    label (row) \ prediction (col)
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |          |        n |        p |        e |        h |        c |        b |        t |        a |        d |        s |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        n |  1504.99 |  12.1784 |  8.71715 |  5.25593 |  3.71761 |  0.38458 |  0.38458 |  1.79471 |  1.28193 |  1.79471 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        p |  10.0511 |  1525.09 |        0 |        0 |        0 | 0.670074 |        0 |        0 |        0 |  4.69052 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        e |  23.4119 |        0 |  1517.09 |        0 |        0 |        0 |        0 |        0 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        h |  31.8285 |   6.3657 |        0 |  1502.31 |        0 |        0 |        0 |        0 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        c |  332.748 |        0 |        0 |        0 |  1170.78 |        0 |        0 |        0 |   24.648 |   12.324 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        b |  28.5278 |  28.5278 |        0 |        0 |        0 |  1483.44 |        0 |        0 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        t |  55.6807 |        0 |        0 |        0 |        0 |        0 |  1484.82 |        0 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        a |  418.407 |        0 |        0 |        0 |        0 |        0 |        0 |  1122.09 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        d |  303.432 |        0 |        0 |        0 |  23.3409 |        0 |        0 |  46.6818 |   1143.7 |  23.3409 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        s |  252.082 |  84.0273 |        0 |        0 |  28.0091 |        0 |        0 |        0 |  28.0091 |  1148.37 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
loss: 1.2135
num examples: 15405
num examples (weighted): 15405

Without weights, the confusion matrix prints integers, as I would expect. With weights, it's these floating point numbers that don't make much sense. Also I believe the accuracy number is incorrect. If I run predictions against the model using the same training dataset, I compute only 186 of 15405 incorrect predictions (1.2%).

CodingDoug avatar Jul 02 '24 14:07 CodingDoug