Changed nnunet LR Schedule and fixed memory leak
PR Type
[Feature | Fix | Documentation | Other ]
Short Description
-
Updated the nnunet client LR Scheduler
- Now updates every step
- Takes into account the total number of rounds when determining max step
-
Fixed an issue where RAM usage would steadily increase throughout training and cause OOM
- Not sure if this issue was specific to the nnunet client or affects all clients
- RAM usage would steadily increase after each round eventually leading to an OOM.
- Resolved issue by running garbage collection at the beginning of each round. Implemented this only for the nnunet client using a hook.
-
Added Gradient Clipping to nnunet client
- By default, nnunet 2.5.1 does gradient some gradient clipping during training
- Added this to the nnunet client so that we could be consistent
- In order to add this functionality, had to add another hook function to the basic client
- This repository definitely needs a cleanup/refactor. Particularly Basic Client
-
Changed how plans files are named
- Previously, if the client was asked to generate a plans file, it was named the nnunet default which is nnUNetPlans
- The final plans file had a prefix to specify it was for FL, however if in a previous experiment another client was asked to generate the default plans, then those plans would have had the same name and been overwritten.
- source plans are now named after the dataset from which they were generated
I should note that I don't think I've fixed a memory leak so much as I have addressed an issue with memory not being cleaned up enough frequently causing OOM's. Also worth noting that the gen2 objects are supposed to be cleaned up automatically after every 10 collections/cleanups of gen1 (https://docs.python.org/3/library/gc.html). For some reason this is not occurring, python refuses to clean up gen 3 objects after the first round (I would presume this could be further narrowed down to the first pass) unless its done manually
This issue was helpful in fixing the OOM's: https://github.com/pytorch/pytorch/issues/95462
Also the nnunet client would randomly just not work sometimes. It might have been OOM issues, but it also might have been something to do with torch.compile. The nnunet model is compiled with torch.compile which for some reason spawns like 30 processes. If one of these processes were terminated or killed (say perhaps due to an OOM issue?), it would automatically generate an error due to the change that flwr made causing all process terminations or inturruptions to raise an exception (https://github.com/adap/flower/issues/3837). I was getting around this issue with the dataloaders which also spawn threads by temporarily changing the signal handlers back to the python defaults before creating the dataloaders and spawning new processes and then changing them back. I've now moved the entire setup within the scope of the default signal handlers so hopefully if one of the pytorch dynamo processes is terminated, it won't generate an exception and we won't get random exceptions and the client and smoke tests will be more consistent
All in all, things look good! A few small comments but mostly clarifying questions for my understanding. Perhaps we can meet tomorrow after the meeting with Masoom and discuss some of the issues that popped up so you can get me on the same page and we can discuss how to look into them going forward. Won't delay getting this merged in but perhaps we put some other tickets on the backlog to do some further investigating
Pretty much good to go! I just want to get your take on the one comment I left before getting this merged in.
I also tried to detail some of the discussions today in a clickup ticket to monitor and follow up on. If there is anything I didn't capture that you think should be added, please feel free to add it or let me know!
https://app.clickup.com/t/8689e8urk
Pretty much good to go! I just want to get your take on the one comment I left before getting this merged in.
I also tried to detail some of the discussions today in a clickup ticket to monitor and follow up on. If there is anything I didn't capture that you think should be added, please feel free to add it or let me know!
https://app.clickup.com/t/8689e8urk
I made a few edits, I'm freezing the objects left in memory after the first round not the second. The reason I check for current_round=2 is because the hook function occurs before training