|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Cleanup unused train model checkpoints in the work dir. |
| 4 | +""" |
| 5 | + |
| 6 | +import os |
| 7 | +import sys |
| 8 | +import argparse |
| 9 | +import logging |
| 10 | +from functools import reduce |
| 11 | +from typing import TypeVar |
| 12 | + |
| 13 | +_my_dir = os.path.dirname(__file__) |
| 14 | +_base_dir = reduce(lambda p, _: os.path.dirname(p), range(4), _my_dir) |
| 15 | +_sis_dir = os.path.dirname(_base_dir) + "/tools/sisyphus" |
| 16 | +_returnn_dir = os.path.dirname(_base_dir) + "/tools/returnn" |
| 17 | + |
| 18 | +T = TypeVar("T") |
| 19 | + |
| 20 | + |
| 21 | +def _setup(): |
| 22 | + # In case the user started this script directly. |
| 23 | + if not globals().get("__package__"): |
| 24 | + globals()["__package__"] = "i6_experiments.users.zeyer.sis_tools" |
| 25 | + if _base_dir not in sys.path: |
| 26 | + sys.path.append(_base_dir) |
| 27 | + if _sis_dir not in sys.path: |
| 28 | + sys.path.append(_sis_dir) |
| 29 | + if _returnn_dir not in sys.path: |
| 30 | + sys.path.append(_returnn_dir) |
| 31 | + |
| 32 | + |
| 33 | +_setup() |
| 34 | + |
| 35 | + |
| 36 | +def main(): |
| 37 | + arg_parser = argparse.ArgumentParser() |
| 38 | + arg_parser.add_argument("config") |
| 39 | + arg_parser.add_argument("--log-level", type=int, default=20) |
| 40 | + arg_parser.add_argument("--mode", default="dryrun", help="dryrun (default), remove") |
| 41 | + args = arg_parser.parse_args() |
| 42 | + |
| 43 | + # See Sisyphus __main__ for reference. |
| 44 | + |
| 45 | + import sisyphus.logging_format |
| 46 | + from sisyphus.loader import config_manager |
| 47 | + import sisyphus.toolkit as tk |
| 48 | + from sisyphus import graph |
| 49 | + from sisyphus import gs |
| 50 | + from i6_core.returnn.training import ReturnnTrainingJob |
| 51 | + from i6_experiments.users.zeyer.utils import job_aliases_from_log |
| 52 | + from returnn.util.basic import human_bytes_size |
| 53 | + |
| 54 | + gs.WARNING_ABSPATH = False |
| 55 | + |
| 56 | + sisyphus.logging_format.add_coloring_to_logging() |
| 57 | + logging.basicConfig(format="[%(asctime)s] %(levelname)s: %(message)s", level=args.log_level) |
| 58 | + |
| 59 | + config_manager.load_configs(args.config) |
| 60 | + |
| 61 | + active_train_job_paths = set() |
| 62 | + for job in graph.graph.jobs(): |
| 63 | + if not isinstance(job, ReturnnTrainingJob): |
| 64 | + continue |
| 65 | + # print("active train job:", job._sis_path()) |
| 66 | + active_train_job_paths.add(job._sis_path()) |
| 67 | + print("Num active train jobs:", len(active_train_job_paths)) |
| 68 | + |
| 69 | + total_model_size_to_remove = 0 |
| 70 | + total_train_job_count = 0 |
| 71 | + total_train_job_with_models_to_remove_count = 0 |
| 72 | + model_fns_to_remove = [] |
| 73 | + found_active_count = 0 # as a sanity check |
| 74 | + for basename in os.listdir("../../../../../work/i6_core/returnn/training"): |
| 75 | + if not basename.startswith("ReturnnTrainingJob."): |
| 76 | + continue |
| 77 | + fn = "work/i6_core/returnn/training/" + basename |
| 78 | + if fn in active_train_job_paths: |
| 79 | + found_active_count += 1 |
| 80 | + continue |
| 81 | + |
| 82 | + total_train_job_count += 1 |
| 83 | + aliases = job_aliases_from_log.get_job_aliases(fn) |
| 84 | + alias = aliases[0] |
| 85 | + alias_path = os.path.basename(os.readlink(alias)) |
| 86 | + if alias_path != basename: |
| 87 | + # Can happen, e.g. when cleared by Sisyphus due to error (cleared.0001 etc), |
| 88 | + # or when I changed sth in the config due to some mistake. |
| 89 | + # print("Warning: Alias path mismatch:", alias_path, "actual:", basename) |
| 90 | + # But doesn't matter, clean up anyway, maybe even more so. |
| 91 | + pass |
| 92 | + |
| 93 | + model_dir = fn + "/output/models" |
| 94 | + model_count = 0 |
| 95 | + model_size = 0 |
| 96 | + with os.scandir(model_dir) as it: |
| 97 | + for model_base_fn in it: |
| 98 | + model_base_fn: os.DirEntry |
| 99 | + if not model_base_fn.name.endswith(".pt"): |
| 100 | + print("Unexpected model file:", model_base_fn.name) |
| 101 | + continue |
| 102 | + model_fns_to_remove.append(model_base_fn.path) |
| 103 | + model_size += model_base_fn.stat().st_size |
| 104 | + model_count += 1 |
| 105 | + if model_count == 0: |
| 106 | + continue |
| 107 | + print("Unused train job:", alias, "model size:", human_bytes_size(model_size)) |
| 108 | + total_model_size_to_remove += model_size |
| 109 | + total_train_job_with_models_to_remove_count += 1 |
| 110 | + |
| 111 | + print("Total train job count:", total_train_job_count) |
| 112 | + print("Total train job with models to remove count:", total_train_job_with_models_to_remove_count) |
| 113 | + print("Can remove total model size:", human_bytes_size(total_model_size_to_remove)) |
| 114 | + assert found_active_count == len(active_train_job_paths), (found_active_count, len(active_train_job_paths)) |
| 115 | + |
| 116 | + if args.mode == "remove": |
| 117 | + for fn in model_fns_to_remove: |
| 118 | + print("Remove model:", fn) |
| 119 | + os.remove(fn) |
| 120 | + elif args.mode == "dryrun": |
| 121 | + print("Dry-run mode, not removing.") |
| 122 | + else: |
| 123 | + raise ValueError("invalid mode: %r" % args.mode) |
| 124 | + |
| 125 | + |
| 126 | +if __name__ == "__main__": |
| 127 | + main() |
0 commit comments