pytorch-widedeep icon indicating copy to clipboard operation
pytorch-widedeep copied to clipboard

Parallel Batch Training on SAINT

Open yanis-falaki opened this issue 1 year ago • 2 comments

Training the SAINT algorithm on a massive dataset, however I'm restricted into using batch sizes of 32-38 as only subsets of the data are highly correlated with each other (share a timestamp of collection as a feature) and would benefit of the intersample attention.

I haven't found anything natively built into the library to allow this, but it feels probable that there could be something built in or that someone has come up with a solution since it seems like it's natural extension of intersample attention.

Advice would be appreciated! Thanks.

yanis-falaki avatar Nov 30 '24 19:11 yanis-falaki

Hey @yanis-falaki Sorry for the delay in replying.

So, I have not implemented parallel training since to be honest I never had consistent access to multiple GPU machines to implement the code and then test it. But I might have now (or if you have a server I can ssh into, that will be great)

Nonetheless, one thing you can do is to define the model using the library and then proceed as you would with any other pytorch model, i.e.

...
model = nn.DataParallel(model)
model.to(device)
...

you would have to implement your own training loop but I think that is easy (?)

let me know if this helps

I will see if I can get my hands in a multi-gpu server and implement this functionality, should be easy

jrzaurin avatar Dec 03 '24 10:12 jrzaurin

Hey @jrzaurin thanks for the reply! Quick note, I'm not intending to do parallel training across multiple GPUs but rather process batches in parallel in a single GPU. Increasing the batch size isn't an option as only subsets of the data is correlated with itself (and this correlation will hold at inference time). So in effect, I need to process "batches of batches". At the moment, with the amount of "sub-batches" I have which need to be loaded iteratively, there's an excessive amount of I/O operations being done, which can be decreased ~128x if done in parallel.

I only took a quick glance at the code implementation for SAINT, but my immediate intuition is that I could add some parameter that accepts a list of "sub-batch" sizes. So that it would allow me to submit one big tensor, which includes every sample, perform the operations on them which are independent of each other, and when it comes to the intersample attention block, break up the tensor into sub-batch tensors first, perform the operations on each sub-batch in a loop, and then concatenate the results back together.

This isn't exactly parallel because I'll still be running a loop on the sub-batches in the intersample attention block, but it would decrease I/O operations a lot since the data will already be on GPU.

I've also seen that torch.cuda.stream() may be able to deal with this without needing to modify the implementation.

I'll keep you updated.

yanis-falaki avatar Dec 03 '24 20:12 yanis-falaki