How to reduce the memory usage of DKM algorithm?
I am trying to use the DKM to palletize the weights of my Torch model during training. However, on a somewhat large model ~500M parameters, this always leads to OOM. Therefore, I would like to ask if there is a way to use DKM (or any type of training time palletization) with larger models?
DKM does consume a lot of memory, as it needs to store the full soft kmeans scores for all weight tensors. For a large model, try to either use the post training palettization APIs, or if you use the DKM, try to stick to less number of bits (<=4).
I will try to do this, but I think the OOM error happened with n_bits=4.
I also tried the TF MOT, and their clustering worked. Are there plans to support a more memory-efficient train time clustering?
Post training clustering is more efficient than training time clustering. If you can file a bug report , attaching your script to reproduce the OOM and details about the configuration of the device used, that would be helpful in driving investigations and subsequent perf improvements.