quaterion icon indicating copy to clipboard operation
quaterion copied to clipboard

Implement cross-batch memory for losses

Open monatis opened this issue 1 year ago • 1 comments

  • Paper: https://arxiv.org/pdf/1912.06798.pdf
  • Reference for implementation: https://github.com/msight-tech/research-xbm/

How it works

  • XBM relies on the observation that the drift of embeddings is slow during training, i.e., embeddings for the same object is changing in a very slow pace.
  • This lets us add embeddings and targets in a ring buffer of a certain size.
  • After a certain number of iterations, start using the buffer. Now the final loss is the weighted sum of the actual mini-batch loss and the ring buffer loss.

Suggested implementation

  1. Introduce an XBMConfig class to hold the configuration values such as buffer_size, start_iteration, xbm_weight.
  2. Add a configure_xbm() hook in TrainableModel and return None by default.
  3. İf it returns an XBMConfig instance instead, create a XBMBuffer instance in the TrainableModel constructor.
  4. Implement the XBM logic in _common_step if stage is training.

Notes

  1. We cannot re-use the existing Accumulator classes because they are not ring buffers.
  2. I don't think we need a mixin because addition to TrainableModel will be only a few lines of code, and we need to update _common_step anyway.

monatis avatar Jul 20 '22 02:07 monatis

Suggested implementation is in the issue. WDYT? @generall and @joein

monatis avatar Aug 02 '22 05:08 monatis

Completed in #175

monatis avatar Aug 30 '22 09:08 monatis