DeepLIIF
DeepLIIF copied to clipboard
Updated Model Training
This PR includes a couple of major updates for model training:
- Implementation of Attention UNet compatible with the current framework: now attention UNet is available to be used as a network architecture (
unet_512_attention) - Validation loss and metrics calculation during training: with flag
--with-val, validation losses (same types of loss as training) and metrics (cell count metrics throughpostprocessfunction) can be calculated as training goes, and the corresponding support in visdom visualizer is also implemented- At the moment cell count metrics are only calculated for DeepLIIF models.
Others
- (cli.py) Allowed specification of generator arch for each individual generator in order (accept comma-separated configuration)
- (cli.py) Debug mode is now available for model training. Use it by passing
--debuginpython cli.py train. Change the approximate number of steps/images to run per epoch for debug mode with--debug-data-size(default to--debug-data-size 10). This helps to quickly check if the training runs as expected. - Allowed to return generated segmentation output from each individaul modality (can be accessed from
infer_modalities()) - Added test cases for training (
--optimizer,--net-g,--net-gs,--with-val) and trainlaunch (gpu test cases only)
Notes:
- Files needed for with-val mode: i) val images, same format as training images ii) ground truth cell count metrics in json: this can be achieved by running get_cell_count_metrics():
from deepliif.stat import get_cell_count_metrics
dir_img = 'Datasets/Sample_Dataset/val'
get_cell_count_metrics(dir, model='DeepLIIF', tile_size=512)
The code generates the metrics.json file for the validation data under the same directory as the images.
- To run multiple tests in parallel (e.g., run latest/ext/sdg at the same time), make sure to use different tmp directory in
--basetemp, so that the pytest processes will not delete or modify a temp folder created or used by another process. For example:
pytest -v -s --basetemp=../tmp/latest --model_type latest 2>&1 | tee ../log/pytest_latest_20240808.log
pytest -v -s --basetemp=../tmp/ext --model_type ext 2>&1 | tee ../log/pytest_ext_20240808.log
pytest -v -s --basetemp=../tmp/sdg --model_type sdg 2>&1 | tee ../log/pytest_sdg_20240808.log
Test environment:
- py 3.9
- pytorch 2.4
All tests passed. Ran ext tests for twice and I did not see GPU OOM failure. Test logs are in onedrive folder DeepLIIF PR#42 attachments.
New commit has been tested.
P.S. need to solve pytest's gpu memory release issue (e.g., https://github.com/pytest-dev/pytest/discussions/10296) which currently is annoying