Add support for batch training
True batch training will use new batches for each training round. This is the usual way NNs are trained. Unfortunately, one aspect of GB packages making them fast is prediction caching (ie to fit the next round, you only need the predictions from the previous round). This makes naive batch training with GB packages slow (but not impossible).
I see a couple methods that might make this possible:
- For data that fits in memory, use eval datasets to gain prediction caching while still operating on batches.
- Re-enable batch predictions and just accept O(number of trees) training rounds. Perhaps this ultimately changes the update cadence between the trees and the neural network parameters so that too much time is not lost.
Related issue https://github.com/mthorrell/gboost_module/issues/9
First possibly-related question: in the context of using XGBModule as, say, the first layer of a network, how does XGB sample weighting jive with network training?
In XGBModule the main mechanism for "training" is via xgboost.boost(dmatrix, iter_count, gradient, hessian). Therefore, the only opportunity to set weights is via XGBModule._input_checking_setting() where we set the training DMatrix which accepts sample weights as a parameter. It seems technically possible but I'm still trying to figure out if it's unsound: that is, is there a problem during back-propagation this layer is treating the input data differently? Not sure.
EDIT: Possibly answering my own question, Pytorch loss functions accept weights so this may be the solution to this problem.
Second question: any leads on a solution? Interested in taking a look, myself. If the above method is sound then one possible solution, albeit hacky and requiring some significant data index management, is to set zero weights outside the batch sample.
Pytorch loss functions accept weights so this may be the solution to this problem.
Yes, I think this would be the "right" way to do this. Weights, as I know them, are usually incorporated as just weighting the different elements of a sum of individual losses, assuming the loss function is a sum of things.
Second question: any leads on a solution? Interested in taking a look, myself. If the above method is sound then one possible solution, albeit hacky and requiring some significant data index management, is to set zero weights outside the batch sample.
Yes, please take a look if you'd like. I think the possibly hacky solution you mention (or something like that) could be the most immediate (and possibly even the best) way to do this. Since predictions on unseen data necessarily scale with the number of trees, to get O(1) updates, we need to get clever with some kind of caching. Maybe lightgbm.Dataset.subset could provide a way also.
As a side note -- I think there are enough compelling cases that extend xgb/lgb, that my plan right now is to focus on those, and this does not require the mini-batching since you'd provide all the data up-front and train in the usual xgb way. Does your use case have a hard requirement on batching?
I think there are enough compelling cases that extend xgb/lgb, that my plan right now is to focus on those, and this does not require the mini-batching since you'd provide all the data up-front and train in the usual xgb way. Does your use case have a hard requirement on batching?
I'm not 100% sure if it's a hard requirement, but there are some tools that accept a torch model as input and then conducts it's own training loop with the model. Here is an example I'm exploring: CACM. Though I just started looking at the source code (link to parent class source) and I'm not yet 100% sure if the batch sizes are user-controllable; there is a dependence on Pytorch Lightning which I'm not yet familiar with.)
If the batch sizing is user-controllable then the point is moot.
FWIW, the network I'm experimenting with is a simple XGBModule -> Linear Layer / MLP with sigmoid activations for classification.
Just trying to read the code there, it looks like potentially you input the batch itself? So you would control the batch size I think? I'm also not familiar with pytorch lightning.
It sounds like you may have this working already but just in case, a basic version of the network you are interested in is here:
import torch
from gbnet.xgbmodule import XGBModule
class Wideboost(torch.nn.Module):
def __init__(self, n, input_dim, intermediate_dim, output_dim, params={}):
super(Wideboost, self).__init__()
self.xgb = XGBModule(n, input_dim, intermediate_dim, params=params)
self.linear = torch.nn.Linear(intermediate_dim, output_dim)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, X):
return self.sigmoid(self.linear(self.xgb(X)))
def gb_step(self):
self.xgb.gb_step()
wb = Wideboost(10, 100, 20, 5)
wb(np.random.random([10,100]))
It sounds like you may have this working already but just in case, a basic version of the network you are interested in is here:
I did indeed get it working. The pro-tip is to omit the sigmoid activation at the end and use BCEWithLogitsLoss. (Log-sum-exp trick for numerical stability.)
Just trying to read the code there, it looks like potentially you input the batch itself? So you would control the batch size I think?
That's what it seems but I haven't yet confirmed.
I was playing around with a Word2Vec example to show off categorical inputs and it seems to me that maybe one reasonable solution is to do the subsetting within the forward function.
Posting the Module for posterity:
from gbnet.lgbmodule import LGBModule
from gbnet.xgbmodule import XGBModule
import torch
import xgboost as xgb
class W2V(torch.nn.Module):
def __init__(self, size, vocab_size, dim, params):
super(W2V, self).__init__()
self.size = size
self.vocab_size = vocab_size
self.dim = dim
self.emb_left = XGBModule(vocab_size, 1, dim, params=params)
self.emb_right = XGBModule(size, 4, dim, params=params)
def forward(self, df, emb, i):
left = self.emb_left(xgb.DMatrix(emb[[2]], enable_categorical=True)) + (torch.rand([self.vocab_size, self.dim]) * 2 - 1) / (1 + i)
right = self.emb_right(xgb.DMatrix(df[[0, 1, 3, 4]], enable_categorical=True)) + (torch.rand([self.size, self.dim]) * 2 - 1) / (1 + i)
row_indices = torch.tensor(df['label'].to_list(), dtype=torch.long)
real_logits = (left[row_indices, :] * right).sum(1, keepdim=True)
batch_size = row_indices.shape[0]
fake_logits = []
for _ in range(10):
# here I'm sampling randomly, but no reason why a mini-batch couldn't contain indices that facilitate sampling
random_indices = torch.randint(
low=0,
high=self.vocab_size,
size=(batch_size,),
device=left.device
)
fake_logits.append((left[random_indices, :] * right).sum(1, keepdim=True))
logits = torch.cat([real_logits] + fake_logits, dim=1)
return logits, left, right
def gb_step(self):
self.emb_left.gb_step()
self.emb_right.gb_step()
The more I've thought about it, I think this is the way to go:
- Each iteration generates and updates all gbm predictions
- The data being fed in specifies the row index being accessed by the mini-batch
- PyTorch manipulation pulls those specific rows.
To do this, I think the only gap (other than the PyTorch subsetting logic) might be like a data-loader that explicitly keeps track of indices. And, for very big datasets, allows for periodic large chunks (not every training round) to be swapped in.