efficient_densenet_pytorch
efficient_densenet_pytorch copied to clipboard
Is this really memory efficient?
I see the memory consumption chart in the readme, but after looking at the code, I have doubts that this implementation is fully memory efficient. I see the call to cp.checkpoint in _DenseLayer.forward(), but I don't see some of the other modifications that were called for in the paper, specifically post-activation normalization and contiguous concatenation. Am I missing something?
If I understand your approach, you are using a method that still requires quadratic memory and computation, but tossing the memory-hogging intermediate values and recomputing them later?
but I don't see some of the other modifications that were called for in the paper, specifically post-activation normalization and contiguous concatenation. Am I missing something?
This implementation and the default DenseNet implementation use pre-activation normalization and contiguous concatenation. Error increases without pre-activation normalization, and training time suffers significantly without contiguous concatenation. The purpose of the technical report was to design a memory efficient implementation under the constraints that we wanted pre-activation normalization and contiguous concatenation (as in the original DenseNet implementation).
If I understand your approach, you are using a method that still requires quadratic memory and computation, but tossing the memory-hogging intermediate values and recomputing them later?
We are tossing the memory-hogging intermediate values (as described in the technical report), but this makes the memory consumption linear. Figure 3 in the technical report explains. Storing the intermediate activations cause the quadratic memory consumption, whereas the total number of features is linear in depth.