efficient_densenet_pytorch
efficient_densenet_pytorch copied to clipboard
How can I apply this to my own model?
Thank you for your nice work. My model use the densenet connections like:
tensorFeat = torch.cat([self.moduleOne(tensorFeat), tensorFeat], 1) tensorFeat = torch.cat([self.moduleTwo(tensorFeat), tensorFeat], 1) tensorFeat = torch.cat([self.moduleThr(tensorFeat), tensorFeat], 1) tensorFeat = torch.cat([self.moduleFou(tensorFeat), tensorFeat], 1) tensorFeat = torch.cat([self.moduleFiv(tensorFeat), tensorFeat], 1)
What do I need to do to implement efficient technology to save this part of memory consumption.Densenet connections is just a part of my full model.
It really depends on the other aspects of your model. This implementation uses torch's gradient checkpointing feature: https://github.com/gpleiss/efficient_densenet_pytorch/blob/master/models/densenet.py#L38 - which trades off time for memory efficiency.
See these docs for more information.