Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Model Training #42

Open
wants to merge 42 commits into
base: main
Choose a base branch
from

Conversation

wendywangwwt
Copy link
Contributor

@wendywangwwt wendywangwwt commented Aug 14, 2024

This PR includes a couple of major updates for model training:

  1. Implementation of Attention UNet compatible with the current framework: now attention UNet is available to be used as a network architecture (unet_512_attention)
  2. Validation loss and metrics calculation during training: with flag --with-val, validation losses (same types of loss as training) and metrics (cell count metrics through postprocess function) can be calculated as training goes, and the corresponding support in visdom visualizer is also implemented
    • At the moment cell count metrics are only calculated for DeepLIIF models.

Others

  • (cli.py) Allowed specification of generator arch for each individual generator in order (accept comma-separated configuration)
  • (cli.py) Debug mode is now available for model training. Use it by passing --debug in python cli.py train. Change the approximate number of steps/images to run per epoch for debug mode with --debug-data-size (default to --debug-data-size 10). This helps to quickly check if the training runs as expected.
  • Allowed to return generated segmentation output from each individaul modality (can be accessed from infer_modalities())
  • Added test cases for training (--optimizer, --net-g, --net-gs, --with-val) and trainlaunch (gpu test cases only)

Notes:

  • Files needed for with-val mode:
    i) val images, same format as training images
    ii) ground truth cell count metrics in json: this can be achieved by running get_cell_count_metrics():
from deepliif.stat import get_cell_count_metrics

dir_img = 'Datasets/Sample_Dataset/val'
get_cell_count_metrics(dir, model='DeepLIIF', tile_size=512)

The code generates the metrics.json file for the validation data under the same directory as the images.

  • To run multiple tests in parallel (e.g., run latest/ext/sdg at the same time), make sure to use different tmp directory in --basetemp, so that the pytest processes will not delete or modify a temp folder created or used by another process. For example:
pytest -v -s --basetemp=../tmp/latest --model_type latest 2>&1 | tee ../log/pytest_latest_20240808.log
pytest -v -s --basetemp=../tmp/ext --model_type ext 2>&1 | tee ../log/pytest_ext_20240808.log
pytest -v -s --basetemp=../tmp/sdg --model_type sdg 2>&1 | tee ../log/pytest_sdg_20240808.log

… for val during training and the latter for batchsize change during val/test)
…calculation during training (after which the stats needs to be re-enabled); added debug mode for cli.py train; added a flag in cli.py train to enable validation loss calculation; allowed to specify epoch in cli.py serialize
…s commits); moved functions used only for tiff file to the bottom
…ing issue; moved val data loading to with_val condition
…ed calculate_losses method; always use opt.lr in optimizers; set cell count metrics in validation only for deepliif
@wendywangwwt
Copy link
Contributor Author

Test environment:

  • py 3.9
  • pytorch 2.4

All tests passed. Ran ext tests for twice and I did not see GPU OOM failure. Test logs are in onedrive folder DeepLIIF PR#42 attachments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant