gbnet icon indicating copy to clipboard operation
gbnet copied to clipboard

Add support for batch training

Open mthorrell opened this issue 1 year ago • 7 comments

True batch training will use new batches for each training round. This is the usual way NNs are trained. Unfortunately, one aspect of GB packages making them fast is prediction caching (ie to fit the next round, you only need the predictions from the previous round). This makes naive batch training with GB packages slow (but not impossible).

I see a couple methods that might make this possible:

  1. For data that fits in memory, use eval datasets to gain prediction caching while still operating on batches.
  2. Re-enable batch predictions and just accept O(number of trees) training rounds. Perhaps this ultimately changes the update cadence between the trees and the neural network parameters so that too much time is not lost.

Related issue https://github.com/mthorrell/gboost_module/issues/9

mthorrell avatar Jul 02 '24 20:07 mthorrell

First possibly-related question: in the context of using XGBModule as, say, the first layer of a network, how does XGB sample weighting jive with network training?

In XGBModule the main mechanism for "training" is via xgboost.boost(dmatrix, iter_count, gradient, hessian). Therefore, the only opportunity to set weights is via XGBModule._input_checking_setting() where we set the training DMatrix which accepts sample weights as a parameter. It seems technically possible but I'm still trying to figure out if it's unsound: that is, is there a problem during back-propagation this layer is treating the input data differently? Not sure.

EDIT: Possibly answering my own question, Pytorch loss functions accept weights so this may be the solution to this problem.

Second question: any leads on a solution? Interested in taking a look, myself. If the above method is sound then one possible solution, albeit hacky and requiring some significant data index management, is to set zero weights outside the batch sample.

cswiercz avatar Dec 02 '24 17:12 cswiercz

Pytorch loss functions accept weights so this may be the solution to this problem.

Yes, I think this would be the "right" way to do this. Weights, as I know them, are usually incorporated as just weighting the different elements of a sum of individual losses, assuming the loss function is a sum of things.

Second question: any leads on a solution? Interested in taking a look, myself. If the above method is sound then one possible solution, albeit hacky and requiring some significant data index management, is to set zero weights outside the batch sample.

Yes, please take a look if you'd like. I think the possibly hacky solution you mention (or something like that) could be the most immediate (and possibly even the best) way to do this. Since predictions on unseen data necessarily scale with the number of trees, to get O(1) updates, we need to get clever with some kind of caching. Maybe lightgbm.Dataset.subset could provide a way also.

As a side note -- I think there are enough compelling cases that extend xgb/lgb, that my plan right now is to focus on those, and this does not require the mini-batching since you'd provide all the data up-front and train in the usual xgb way. Does your use case have a hard requirement on batching?

mthorrell avatar Dec 02 '24 23:12 mthorrell

I think there are enough compelling cases that extend xgb/lgb, that my plan right now is to focus on those, and this does not require the mini-batching since you'd provide all the data up-front and train in the usual xgb way. Does your use case have a hard requirement on batching?

I'm not 100% sure if it's a hard requirement, but there are some tools that accept a torch model as input and then conducts it's own training loop with the model. Here is an example I'm exploring: CACM. Though I just started looking at the source code (link to parent class source) and I'm not yet 100% sure if the batch sizes are user-controllable; there is a dependence on Pytorch Lightning which I'm not yet familiar with.)

If the batch sizing is user-controllable then the point is moot.

FWIW, the network I'm experimenting with is a simple XGBModule -> Linear Layer / MLP with sigmoid activations for classification.

cswiercz avatar Dec 03 '24 16:12 cswiercz

Just trying to read the code there, it looks like potentially you input the batch itself? So you would control the batch size I think? I'm also not familiar with pytorch lightning.

It sounds like you may have this working already but just in case, a basic version of the network you are interested in is here:

import torch
from gbnet.xgbmodule import XGBModule

class Wideboost(torch.nn.Module):
    def __init__(self, n, input_dim, intermediate_dim, output_dim, params={}):
        super(Wideboost, self).__init__()
        self.xgb = XGBModule(n, input_dim, intermediate_dim, params=params)
        self.linear = torch.nn.Linear(intermediate_dim, output_dim)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, X):        
        return self.sigmoid(self.linear(self.xgb(X)))

    def gb_step(self):
        self.xgb.gb_step()

wb = Wideboost(10, 100, 20, 5)
wb(np.random.random([10,100]))

mthorrell avatar Dec 04 '24 03:12 mthorrell

It sounds like you may have this working already but just in case, a basic version of the network you are interested in is here:

I did indeed get it working. The pro-tip is to omit the sigmoid activation at the end and use BCEWithLogitsLoss. (Log-sum-exp trick for numerical stability.)

Just trying to read the code there, it looks like potentially you input the batch itself? So you would control the batch size I think?

That's what it seems but I haven't yet confirmed.

cswiercz avatar Dec 04 '24 16:12 cswiercz

I was playing around with a Word2Vec example to show off categorical inputs and it seems to me that maybe one reasonable solution is to do the subsetting within the forward function.

Posting the Module for posterity:

from gbnet.lgbmodule import LGBModule
from gbnet.xgbmodule import XGBModule
import torch
import xgboost as xgb

class W2V(torch.nn.Module):
    def __init__(self, size, vocab_size, dim, params):
        super(W2V, self).__init__()
        self.size = size
        self.vocab_size = vocab_size
        self.dim = dim
        self.emb_left = XGBModule(vocab_size, 1, dim, params=params)
        self.emb_right = XGBModule(size, 4, dim, params=params)

    def forward(self, df, emb, i):
        left = self.emb_left(xgb.DMatrix(emb[[2]], enable_categorical=True)) + (torch.rand([self.vocab_size, self.dim]) * 2 - 1) / (1 + i)
        right = self.emb_right(xgb.DMatrix(df[[0, 1, 3, 4]], enable_categorical=True)) + (torch.rand([self.size, self.dim]) * 2 - 1) / (1 + i)

        row_indices = torch.tensor(df['label'].to_list(), dtype=torch.long)
        real_logits = (left[row_indices, :] * right).sum(1, keepdim=True)

        batch_size = row_indices.shape[0]
        fake_logits = []
        for _ in range(10):
            # here I'm sampling randomly, but no reason why a mini-batch couldn't contain indices that facilitate sampling
            random_indices = torch.randint(
                low=0,
                high=self.vocab_size,
                size=(batch_size,),
                device=left.device
            )
            fake_logits.append((left[random_indices, :] * right).sum(1, keepdim=True))

        logits = torch.cat([real_logits] + fake_logits, dim=1)
        
        return logits, left, right

    def gb_step(self):
        self.emb_left.gb_step()
        self.emb_right.gb_step()

mthorrell avatar Jan 27 '25 23:01 mthorrell

The more I've thought about it, I think this is the way to go:

  1. Each iteration generates and updates all gbm predictions
  2. The data being fed in specifies the row index being accessed by the mini-batch
  3. PyTorch manipulation pulls those specific rows.

To do this, I think the only gap (other than the PyTorch subsetting logic) might be like a data-loader that explicitly keeps track of indices. And, for very big datasets, allows for periodic large chunks (not every training round) to be swapped in.

mthorrell avatar Feb 16 '25 20:02 mthorrell