coremltools icon indicating copy to clipboard operation
coremltools copied to clipboard

How to reduce the memory usage of DKM algorithm?

Open CaptainDario opened this issue 7 months ago • 4 comments

I am trying to use the DKM to palletize the weights of my Torch model during training. However, on a somewhat large model ~500M parameters, this always leads to OOM. Therefore, I would like to ask if there is a way to use DKM (or any type of training time palletization) with larger models?

CaptainDario avatar Jun 17 '25 16:06 CaptainDario

DKM does consume a lot of memory, as it needs to store the full soft kmeans scores for all weight tensors. For a large model, try to either use the post training palettization APIs, or if you use the DKM, try to stick to less number of bits (<=4).

aseemw avatar Jun 17 '25 16:06 aseemw

I will try to do this, but I think the OOM error happened with n_bits=4.

CaptainDario avatar Jun 17 '25 16:06 CaptainDario

I also tried the TF MOT, and their clustering worked. Are there plans to support a more memory-efficient train time clustering?

CaptainDario avatar Jun 17 '25 16:06 CaptainDario

Post training clustering is more efficient than training time clustering. If you can file a bug report , attaching your script to reproduce the OOM and details about the configuration of the device used, that would be helpful in driving investigations and subsequent perf improvements.

aseemw avatar Jun 17 '25 19:06 aseemw