efficient_densenet_pytorch icon indicating copy to clipboard operation
efficient_densenet_pytorch copied to clipboard

Is this really memory efficient?

Open leonardishere opened this issue 4 years ago • 1 comments

I see the memory consumption chart in the readme, but after looking at the code, I have doubts that this implementation is fully memory efficient. I see the call to cp.checkpoint in _DenseLayer.forward(), but I don't see some of the other modifications that were called for in the paper, specifically post-activation normalization and contiguous concatenation. Am I missing something?

If I understand your approach, you are using a method that still requires quadratic memory and computation, but tossing the memory-hogging intermediate values and recomputing them later?

leonardishere avatar Mar 31 '20 15:03 leonardishere

but I don't see some of the other modifications that were called for in the paper, specifically post-activation normalization and contiguous concatenation. Am I missing something?

This implementation and the default DenseNet implementation use pre-activation normalization and contiguous concatenation. Error increases without pre-activation normalization, and training time suffers significantly without contiguous concatenation. The purpose of the technical report was to design a memory efficient implementation under the constraints that we wanted pre-activation normalization and contiguous concatenation (as in the original DenseNet implementation).

If I understand your approach, you are using a method that still requires quadratic memory and computation, but tossing the memory-hogging intermediate values and recomputing them later?

We are tossing the memory-hogging intermediate values (as described in the technical report), but this makes the memory consumption linear. Figure 3 in the technical report explains. Storing the intermediate activations cause the quadratic memory consumption, whereas the total number of features is linear in depth.

gpleiss avatar Apr 07 '20 11:04 gpleiss