LightGBM icon indicating copy to clipboard operation
LightGBM copied to clipboard

I cannot replicate LightGBM's `refit` behaviours when using a custom loss function

Open mlondschien opened this issue 7 months ago • 4 comments

related: https://github.com/microsoft/LightGBM/issues/6838

import numpy as np
import lightgbm as lgb

class RegressionObjective:
    higher_is_better = False

    def objective(self, f, data):
        """Objective function for LGBM."""
        return self.grad(f, data), self.hess(f, data)

    def score(self, f, data):
        """Score function for LGBM."""
        return (
            f"{self.name} ({self.gamma})",
            self.loss(f, data).mean(),
            self.higher_is_better,
        )

    def init_score(self, y):
        return np.tile(y.mean(), len(y))

    def grad(self, f, data):
        # Replicate LGBM behaviour
        # https://github.com/microsoft/LightGBM/blob/e9fbd19d7cbaeaea1ca54a091b160868fc\
        # 5c79ec/src/objective/regression_objective.hpp#L130-L131
        return -(data.get_label() - f)

    def hess(self, f, data):
        # Replicate LGBM behaviour
        # https://github.com/microsoft/LightGBM/blob/e9fbd19d7cbaeaea1ca54a091b160868fc\
        # 5c79ec/src/objective/regression_objective.hpp#L130-L131
        return np.ones(len(data.get_label()))

    def loss(self, f, data):
        return 0.5 * (data.get_label() - f) ** 2

n = 100
rng = np.random.RandomState(0)

p = 2

X = rng.normal(size=(n, p))
y = (X[:, 0] <= 0) + (X[:, 0] <= -0.5) * (X[:, 1] <= 1) + 0.25 * rng.normal(size=n)

loss = RegressionObjective()
data0 = lgb.Dataset(X, y)
data1 = lgb.Dataset(X, y, init_score=loss.init_score(y))

model0 = lgb.train(
    params={"learning_rate": 0.1, "objective": "regression"},
    train_set=data0,
    num_boost_round=10,
)

model1 = lgb.train(
    params={"learning_rate": 0.1, "objective": loss.objective},
    train_set=data1,
    num_boost_round=10,
)

pred0 = model0.predict(X)
pred1 = model1.predict(X)

np.testing.assert_allclose(pred0, pred1 + loss.init_score(y), rtol=1e-5)  # this passes

# I would expect that refitting the model on _the same data_ would not change the tree
# leaves. However, it does.
print(model0.trees_to_dataframe().T)
#                           0         1         2         3         4   ...        65       66        67       68        69
# tree_index                 0         0         0         0         0  ...         9        9         9        9         9
# node_depth                 1         2         3         3         2  ...         3        3         2        3         3
# node_index              0-S0      0-S2      0-L0      0-L3      0-S1  ...      9-L0     9-L3      9-S1     9-L1      9-L2
# left_child              0-S2      0-L0      None      None      0-L1  ...      None     None      9-L1     None      None
# right_child             0-S1      0-L3      None      None      0-L2  ...      None     None      9-L2     None      None
# parent_index            None      0-S0      0-S2      0-S2      0-S0  ...      9-S2     9-S2      9-S0     9-S1      9-S1
# split_feature       Column_0  Column_0      None      None  Column_0  ...      None     None  Column_0     None      None
# split_gain         25.214001  0.023952       NaN       NaN   0.39059  ...       NaN      NaN  0.058625      NaN       NaN
# threshold                0.0 -0.757335       NaN       NaN  0.892648  ...       NaN      NaN  0.892648      NaN       NaN
# decision_type             <=        <=      None      None        <=  ...      None     None        <=     None      None
# missing_direction       left      left      None      None      left  ...      None     None      left     None      None
# missing_type            None      None      None      None      None  ...      None     None      None     None      None
# value               0.505192  0.553436  0.555501  0.551205  0.452928  ...  0.016743  0.02001 -0.020248 -0.02333 -0.016285
# weight                     0        52        27        25        48  ...        21       31        48       27        21
# count                    100        52        27        25        48  ...        21       31        48       27        21
#
# [15 rows x 70 columns]

print(model0.refit(X, y, decay_rate=0).trees_to_dataframe().T)
#                           0         1        2         3         4         5         6          7         8   ...    61    62        63        64        65        66        67    68    69
# tree_index                 0         0        0         0         0         0         0          1         1  ...     8     8         9         9         9         9         9     9     9
# node_depth                 1         2        3         3         2         3         3          1         2  ...     3     3         1         2         3         3         2     3     3
# node_index              0-S0      0-S2     0-L0      0-L3      0-S1      0-L1      0-L2       1-S0      1-S2  ...  8-L1  8-L2      9-S0      9-S2      9-L0      9-L3      9-S1  9-L1  9-L2
# left_child              0-S2      0-L0     None      None      0-L1      None      None       1-S2      1-L0  ...  None  None      9-S2      9-L0      None      None      9-L1  None  None
# right_child             0-S1      0-L3     None      None      0-L2      None      None       1-S1      1-L3  ...  None  None      9-S1      9-L3      None      None      9-L2  None  None
# parent_index            None      0-S0     0-S2      0-S2      0-S0      0-S1      0-S1       None      1-S0  ...  8-S1  8-S1      None      9-S0      9-S2      9-S2      9-S0  9-S1  9-S1
# split_feature       Column_0  Column_0     None      None  Column_0      None      None   Column_0  Column_0  ...  None  None  Column_0  Column_1      None      None  Column_0  None  None
# split_gain         25.214001  0.023952      NaN       NaN   0.39059       NaN       NaN  20.423401  0.019401  ...   NaN   NaN   3.78449  0.013367       NaN       NaN  0.058625   NaN   NaN
# threshold                0.0 -0.757335      NaN       NaN  0.892648       NaN       NaN        0.0 -0.757335  ...   NaN   NaN       0.0       0.0       NaN       NaN  0.892648   NaN   NaN
# decision_type             <=        <=     None      None        <=      None      None         <=        <=  ...  None  None        <=        <=      None      None        <=  None  None
# missing_direction       left      left     None      None      left      None      None       left      left  ...  None  None      left      left      None      None      left  None  None
# missing_type            None      None     None      None      None      None      None       None      None  ...  None  None      None      None      None      None      None  None  None
# value               0.505192  0.553436  1.00828  0.965324  0.452928 -0.097003  0.084837        0.0  0.043419  ...  -0.0  -0.0       0.0  0.018691 -0.001813  0.001228 -0.020248  -0.0  -0.0
# weight                     0        52       27        25        48        27        21          0        52  ...    27    21         0        52        21        31        48    27    21
# count                    100        52       27        25        48        27        21        100        52  ...    27    21       100        52        21        31        48    27    21
#
# [15 rows x 70 columns]

# Now, refit model1 with the same data and compare to the previous tree leaves.
print(model1.trees_to_dataframe().T)
#                           0         1         2         3         4         5         6          7   ...        62        63        64        65       66        67       68        69
# tree_index                 0         0         0         0         0         0         0          1  ...         8         9         9         9        9         9        9         9
# node_depth                 1         2         3         3         2         3         3          1  ...         3         1         2         3        3         2        3         3
# node_index              0-S0      0-S2      0-L0      0-L3      0-S1      0-L1      0-L2       1-S0  ...      8-L2      9-S0      9-S2      9-L0     9-L3      9-S1     9-L1      9-L2
# left_child              0-S2      0-L0      None      None      0-L1      None      None       1-S2  ...      None      9-S2      9-L0      None     None      9-L1     None      None
# right_child             0-S1      0-L3      None      None      0-L2      None      None       1-S1  ...      None      9-S1      9-L3      None     None      9-L2     None      None
# parent_index            None      0-S0      0-S2      0-S2      0-S0      0-S1      0-S1       None  ...      8-S1      None      9-S0      9-S2     9-S2      9-S0     9-S1      9-S1
# split_feature       Column_0  Column_0      None      None  Column_0      None      None   Column_0  ...      None  Column_0  Column_1      None     None  Column_0     None      None
# split_gain         25.214001  0.023952       NaN       NaN   0.39059       NaN       NaN  20.423401  ...       NaN   3.78449  0.013367       NaN      NaN  0.058625      NaN       NaN
# threshold                0.0 -0.757335       NaN       NaN  0.892648       NaN       NaN        0.0  ...       NaN       0.0       0.0       NaN      NaN  0.892648      NaN       NaN
# decision_type             <=        <=      None      None        <=      None      None         <=  ...      None        <=        <=      None     None        <=     None      None
# missing_direction       left      left      None      None      left      None      None       left  ...      None      left      left      None     None      left     None      None
# missing_type            None      None      None      None      None      None      None       None  ...      None      None      None      None     None      None     None      None
# value                    0.0  0.048244  0.050309  0.046013 -0.052264 -0.060219 -0.042035        0.0  ... -0.018095       0.0  0.018691  0.016743  0.02001 -0.020248 -0.02333 -0.016285
# weight                     0        52        27        25        48        27        21          0  ...        21         0        52        21       31        48       27        21
# count                    100        52        27        25        48        27        21        100  ...        21       100        52        21       31        48       27        21
#
# [15 rows x 70 columns]

model1._Booster__set_objective_to_none = False  # why do I need to do this?
# https://github.com/microsoft/LightGBM/issues/5609#issuecomment-1342172997
model1.params["objective"] = "regression"
print(model1.refit(X, y, decay_rate=0).trees_to_dataframe().T)
#                           0         1         2         3         4       5         6          7   ...        62        63        64        65        66        67        68        69
# tree_index                 0         0         0         0         0       0         0          1  ...         8         9         9         9         9         9         9         9
# node_depth                 1         2         3         3         2       3         3          1  ...         3         1         2         3         3         2         3         3
# node_index              0-S0      0-S2      0-L0      0-L3      0-S1    0-L1      0-L2       1-S0  ...      8-L2      9-S0      9-S2      9-L0      9-L3      9-S1      9-L1      9-L2
# left_child              0-S2      0-L0      None      None      0-L1    None      None       1-S2  ...      None      9-S2      9-L0      None      None      9-L1      None      None
# right_child             0-S1      0-L3      None      None      0-L2    None      None       1-S1  ...      None      9-S1      9-L3      None      None      9-L2      None      None
# parent_index            None      0-S0      0-S2      0-S2      0-S0    0-S1      0-S1       None  ...      8-S1      None      9-S0      9-S2      9-S2      9-S0      9-S1      9-S1
# split_feature       Column_0  Column_0      None      None  Column_0    None      None   Column_0  ...      None  Column_0  Column_1      None      None  Column_0      None      None
# split_gain         25.214001  0.023952       NaN       NaN   0.39059     NaN       NaN  20.423401  ...       NaN   3.78449  0.013367       NaN       NaN  0.058625       NaN       NaN
# threshold                0.0 -0.757335       NaN       NaN  0.892648     NaN       NaN        0.0  ...       NaN       0.0       0.0       NaN       NaN  0.892648       NaN       NaN
# decision_type             <=        <=      None      None        <=    None      None         <=  ...      None        <=        <=      None      None        <=      None      None
# missing_direction       left      left      None      None      left    None      None       left  ...      None      left      left      None      None      left      None      None
# missing_type            None      None      None      None      None    None      None       None  ...      None      None      None      None      None      None      None      None
# value                    0.0  0.048244  0.100828  0.096532 -0.052264 -0.0097  0.008484        0.0  ...  0.003652       0.0  0.018691  0.036315  0.039582 -0.020248 -0.003758  0.003287
# weight                     0        52        27        25        48      27        21          0  ...        21         0        52        21        31        48        27        21
# count                    100        52        27        25        48      27        21        100  ...        21       100        52        21        31        48        27        21
# 
# [15 rows x 70 columns]

# Note that the first tree of model1 before refitting has leaf values shifted by
# loss.init_score(y)[0] = y.mean() = 0.505192 compared to the first tree of model1. We
# can "fix this" by manually setting the leaves.
leaves = model1.predict(X, pred_leaf=True, num_iteration=1)
preds = model1.predict(X, num_iteration=1)
for idx in np.unique(leaves, return_index=True)[1]:
    model1.set_leaf_output(0, leaves[idx], preds[idx] + loss.init_score(y)[0])

np.testing.assert_allclose(pred0, model1.predict(X), rtol=1e-5)  # this passes

print(model1.trees_to_dataframe().T)
#                           0         1         2         3         4         5         6          7   ...        62        63        64        65       66        67       68        69
# tree_index                 0         0         0         0         0         0         0          1  ...         8         9         9         9        9         9        9         9
# node_depth                 1         2         3         3         2         3         3          1  ...         3         1         2         3        3         2        3         3
# node_index              0-S0      0-S2      0-L0      0-L3      0-S1      0-L1      0-L2       1-S0  ...      8-L2      9-S0      9-S2      9-L0     9-L3      9-S1     9-L1      9-L2
# left_child              0-S2      0-L0      None      None      0-L1      None      None       1-S2  ...      None      9-S2      9-L0      None     None      9-L1     None      None
# right_child             0-S1      0-L3      None      None      0-L2      None      None       1-S1  ...      None      9-S1      9-L3      None     None      9-L2     None      None
# parent_index            None      0-S0      0-S2      0-S2      0-S0      0-S1      0-S1       None  ...      8-S1      None      9-S0      9-S2     9-S2      9-S0     9-S1      9-S1
# split_feature       Column_0  Column_0      None      None  Column_0      None      None   Column_0  ...      None  Column_0  Column_1      None     None  Column_0     None      None
# split_gain         25.214001  0.023952       NaN       NaN   0.39059       NaN       NaN  20.423401  ...       NaN   3.78449  0.013367       NaN      NaN  0.058625      NaN       NaN
# threshold                0.0 -0.757335       NaN       NaN  0.892648       NaN       NaN        0.0  ...       NaN       0.0       0.0       NaN      NaN  0.892648      NaN       NaN
# decision_type             <=        <=      None      None        <=      None      None         <=  ...      None        <=        <=      None     None        <=     None      None
# missing_direction       left      left      None      None      left      None      None       left  ...      None      left      left      None     None      left     None      None
# missing_type            None      None      None      None      None      None      None       None  ...      None      None      None      None     None      None     None      None
# value                    0.0  0.048244  0.555501  0.551205 -0.052264  0.444972  0.463156        0.0  ... -0.018095       0.0  0.018691  0.016743  0.02001 -0.020248 -0.02333 -0.016285
# weight                     0        52        27        25        48        27        21          0  ...        21         0        52        21       31        48       27        21
# count                    100        52        27        25        48        27        21        100  ...        21       100        52        21       31        48       27        21
#
# [15 rows x 70 columns]

# Now, leafs (value for nodes with split_feature=None) between model1 and model0 are the
# same. Still, refitting on the same data the model was trained on yields different
# results.
print(model1.refit(X, y, decay_rate=0).trees_to_dataframe().T)
#                           0         1         2         3         4       5         6          7   ...        62        63        64        65        66        67        68        69
# tree_index                 0         0         0         0         0       0         0          1  ...         8         9         9         9         9         9         9         9
# node_depth                 1         2         3         3         2       3         3          1  ...         3         1         2         3         3         2         3         3
# node_index              0-S0      0-S2      0-L0      0-L3      0-S1    0-L1      0-L2       1-S0  ...      8-L2      9-S0      9-S2      9-L0      9-L3      9-S1      9-L1      9-L2
# left_child              0-S2      0-L0      None      None      0-L1    None      None       1-S2  ...      None      9-S2      9-L0      None      None      9-L1      None      None
# right_child             0-S1      0-L3      None      None      0-L2    None      None       1-S1  ...      None      9-S1      9-L3      None      None      9-L2      None      None
# parent_index            None      0-S0      0-S2      0-S2      0-S0    0-S1      0-S1       None  ...      8-S1      None      9-S0      9-S2      9-S2      9-S0      9-S1      9-S1
# split_feature       Column_0  Column_0      None      None  Column_0    None      None   Column_0  ...      None  Column_0  Column_1      None      None  Column_0      None      None
# split_gain         25.214001  0.023952       NaN       NaN   0.39059     NaN       NaN  20.423401  ...       NaN   3.78449  0.013367       NaN       NaN  0.058625       NaN       NaN
# threshold                0.0 -0.757335       NaN       NaN  0.892648     NaN       NaN        0.0  ...       NaN       0.0       0.0       NaN       NaN  0.892648       NaN       NaN
# decision_type             <=        <=      None      None        <=    None      None         <=  ...      None        <=        <=      None      None        <=      None      None
# missing_direction       left      left      None      None      left    None      None       left  ...      None      left      left      None      None      left      None      None
# missing_type            None      None      None      None      None    None      None       None  ...      None      None      None      None      None      None      None      None
# value                    0.0  0.048244  0.100828  0.096532 -0.052264 -0.0097  0.008484        0.0  ...  0.003652       0.0  0.018691  0.036315  0.039582 -0.020248 -0.003758  0.003287
# weight                     0        52        27        25        48      27        21          0  ...        21         0        52        21        31        48        27        21
# count                    100        52        27        25        48      27        21        100  ...        21       100        52        21        31        48        27        21
#
# [15 rows x 70 columns]

# If I don't use an init score, then I can no longer replicate the output of the 
# "regression" objective model with my custom loss function:
model2 = lgb.train(
    params={"learning_rate": 0.1, "objective": loss.objective},
    train_set=data0,
    num_boost_round=10,
)
# np.testing.assert_allclose(pred0, model2.predict(X), rtol=1e-5)  # this fails

mlondschien avatar Mar 12 '25 17:03 mlondschien