From 4f62c98d93d1b0899218f200d74246c2b98945ad Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Sun, 21 Jul 2024 17:44:47 +0800 Subject: [PATCH 01/35] Fix torch FutureWarning FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead. --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..dcdb552943 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from functools import partial from typing import Union import torch @@ -17,7 +18,9 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - from torch.cuda.amp import GradScaler + # from torch.cuda.amp import GradScaler + from torch.amp import GradScaler as amp_GradScaler + GradScaler = partial(amp_GradScaler, device='cuda') @OPTIM_WRAPPERS.register_module() From b6b42241e77c19745fc0bb0d20afceff9debcc54 Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Sun, 21 Jul 2024 17:45:28 +0800 Subject: [PATCH 02/35] Fix torch FutureWarning FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead. --- mmengine/optim/optimizer/amp_optimizer_wrapper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index dcdb552943..60200924b5 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -18,7 +18,6 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - # from torch.cuda.amp import GradScaler from torch.amp import GradScaler as amp_GradScaler GradScaler = partial(amp_GradScaler, device='cuda') From 4c7a5d499ff232eaf97f2d4f15fc37088c407bfa Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Fri, 26 Jul 2024 10:08:17 +0800 Subject: [PATCH 03/35] Optimize the prompt for compile --- mmengine/_strategy/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index 5df3a79c92..05d070f6f2 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -322,7 +322,8 @@ def compile_model( Returns: nn.Module: Compiled model. """ - if isinstance(compile, bool) and not compile: + if isinstance(compile, bool) and not compile or \ + isinstance(compile, dict) and not compile.get('disable', False): return model assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( From 28d47f849d3ee2fa6041500baf058df1325e047e Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Wed, 21 Aug 2024 14:15:53 +0800 Subject: [PATCH 04/35] Fix Incorrect Optim Param Resume Method FSDP.optim_state_dict_to_load requires the following parameters: model: Module, optim: Optimizer, optim_state_dict: Dict[str, Any] --- mmengine/_strategy/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 0788fafdab..124dfd7c57 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - state_dict, self.model, self.optim_wrapper.optimizer) + self.model, self.optim_wrapper.optimizer, state_dict) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: From 91d945f29205c16c3dd8e9655f80daa73bd6ab63 Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Wed, 28 Aug 2024 23:59:38 +0800 Subject: [PATCH 05/35] Update runner.py to support pure-python style model wrapper configurations The current runner implementation has not yet supported for pure-python style configurations on model wrapper class. I follow the mainstream implementation to support this feature. --- mmengine/runner/runner.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 68716ab253..7160ac84d7 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import inspect import logging import os import os.path as osp @@ -902,8 +903,18 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - model_wrapper_type = MODEL_WRAPPERS.get( - model_wrapper_cfg.get('type')) # type: ignore + model_wrapper_type = model_wrapper_cfg.get('type') + if isinstance(model_wrapper_type, str): + model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore + elif inspect.isclass(model_wrapper_type): + pass + else: + raise KeyError( + f'{model_wrapper_type} is not in the ' + 'registry. Please check whether the value of ' + f'`{model_wrapper_type}` is correct or it was registered ' + 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 + ) default_args: dict = dict() if issubclass( model_wrapper_type, # type: ignore From 7103c3e629a189336cac3308add2b080319025d6 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Mon, 23 Sep 2024 03:00:04 +0000 Subject: [PATCH 06/35] reconstruct --- mmengine/_strategy/fsdp.py | 2 +- mmengine/model/wrappers/distributed.py | 3 +++ mmengine/optim/optimizer/builder.py | 1 + mmengine/runner/loops.py | 4 +++- mmengine/runner/runner.py | 4 +++- 5 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 124dfd7c57..0788fafdab 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,7 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - self.model, self.optim_wrapper.optimizer, state_dict) + state_dict, self.model, self.optim_wrapper.optimizer) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 4113aebf9e..b88bc7c2b0 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -95,6 +95,7 @@ def __init__(self, def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: + return self.module.train_step(data, optim_wrapper) """Interface for model forward, backward and parameters updating during training process. @@ -126,6 +127,7 @@ def train_step(self, data: Union[dict, tuple, list], return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: + return self.module.val_step(data) """Gets the prediction of module during validation process. Args: @@ -137,6 +139,7 @@ def val_step(self, data: Union[dict, tuple, list]) -> list: return self.module.val_step(data) def test_step(self, data: Union[dict, tuple, list]) -> list: + return self.module.test_step(data) """Gets the predictions of module during testing process. Args: diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 8557f4d34c..b57ebc315a 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -207,5 +207,6 @@ def build_optim_wrapper(model: nn.Module, type=constructor_type, optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) + optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 5a678db7b9..25ff690f0b 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -12,6 +12,7 @@ from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of +from mmengine.dataset.sampler import InfiniteSampler from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -274,13 +275,14 @@ def run(self) -> None: # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - if self._iter > 0: + if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler): print_log( f'Advance dataloader {self._iter} steps to skip data ' 'that has already been trained', logger='current', level=logging.WARNING) for _ in range(self._iter): + break next(self.dataloader_iterator) while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 7160ac84d7..435bd55ac0 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -903,9 +903,11 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') + model_wrapper_type = model_wrapper_cfg.get('type') if isinstance(model_wrapper_type, str): - model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore + model_wrapper_type = MODEL_WRAPPERS.get( + model_wrapper_type) # type: ignore elif inspect.isclass(model_wrapper_type): pass else: From eecaa92179bb275f650931f7597b0ead0420f6b5 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sun, 3 Nov 2024 05:18:48 +0000 Subject: [PATCH 07/35] PyTorch Profiler within IterBasedTrainLoop --- mmengine/runner/loops.py | 19 +++++++++++++++++-- mmengine/runner/runner.py | 2 +- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 25ff690f0b..7be8995781 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -282,8 +282,21 @@ def run(self) -> None: logger='current', level=logging.WARNING) for _ in range(self._iter): - break + break # NOTE MGAM: override all preprocessing steps during resume. next(self.dataloader_iterator) + + # with torch.profiler.profile( + # activities=[torch.profiler.ProfilerActivity.CPU, + # torch.profiler.ProfilerActivity.CUDA], + # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3), + # on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_log'), + # record_shapes=True, + # profile_memory=True, + # with_stack=False, + # with_flops=True, + # with_modules=True, + # ) as p: + while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() @@ -294,8 +307,10 @@ def run(self) -> None: if (self.runner.val_loop is not None and self._iter >= self.val_begin and (self._iter % self.val_interval == 0 - or self._iter == self._max_iters)): + or self._iter == self._max_iters)): self.runner.val_loop.run() + + # p.step() self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train') diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 435bd55ac0..6b8dd60e2b 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1851,7 +1851,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None: try: getattr(hook, fn_name)(self, **kwargs) except TypeError as e: - raise TypeError(f'{e} in {hook}') from None + raise TypeError(f'{e} in {hook}') from e def register_hook( self, From 698ad5ebaed47965fb3c999da2ee82228a4b0600 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sun, 3 Nov 2024 13:53:29 +0800 Subject: [PATCH 08/35] enable hook error exception traceback --- mmengine/runner/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 435bd55ac0..6b8dd60e2b 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1851,7 +1851,7 @@ def call_hook(self, fn_name: str, **kwargs) -> None: try: getattr(hook, fn_name)(self, **kwargs) except TypeError as e: - raise TypeError(f'{e} in {hook}') from None + raise TypeError(f'{e} in {hook}') from e def register_hook( self, From 1e4c2ed17e6bb01af74ccd45923e844a2764bc32 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Fri, 15 Nov 2024 01:18:32 +0000 Subject: [PATCH 09/35] improve codes --- mmengine/runner/checkpoint.py | 2 +- mmengine/runner/loops.py | 4 ++-- mmengine/visualization/vis_backend.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 2bf5f50f7c..fa0a1eb520 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -344,7 +344,7 @@ def load_from_local(filename, map_location): filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load(filename, map_location=map_location, weights_only=False) return checkpoint diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 7be8995781..f511c14e68 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -288,11 +288,11 @@ def run(self) -> None: # with torch.profiler.profile( # activities=[torch.profiler.ProfilerActivity.CPU, # torch.profiler.ProfilerActivity.CUDA], - # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3), + # schedule=torch.profiler.schedule(wait=1, warmup=2, active=3), # on_trace_ready=torch.profiler.tensorboard_trace_handler('./profiler_log'), # record_shapes=True, # profile_memory=True, - # with_stack=False, + # with_stack=True, # with_flops=True, # with_modules=True, # ) as p: diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index b752ec85a7..a5bf7d88e7 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -604,7 +604,8 @@ def add_scalar(self, (int, float, torch.Tensor, np.ndarray, np.number)): self._tensorboard.add_scalar(name, value, step) else: - warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, ' + warnings.warn(f'Got type {type(value)} with name {name}, ' + 'but numpy array, torch tensor, ' f'int or float are expected. skip it!') @force_init_env From 29e3a0882cace17b6d5391224a8c4abcf35419c7 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Fri, 3 Jan 2025 08:56:10 +0000 Subject: [PATCH 10/35] KeyError: 'Adafactor is already registered in optimizer at torch.optim'. This may be due to the version confliction. Newer PyTorch may have introduced this optimizer. --- mmengine/optim/optimizer/builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index b57ebc315a..e778a3d5bc 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -166,7 +166,8 @@ def register_transformers_optimizers(): except ImportError: pass else: - OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) + # KeyError: 'Adafactor is already registered in optimizer at torch.optim' + # OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) transformer_optimizers.append('Adafactor') return transformer_optimizers From be86710ecb96d5682b7d7e7f1e9e72bccc1bd6a2 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sat, 11 Jan 2025 17:57:03 +0800 Subject: [PATCH 11/35] Update support for deep speed and multiple improvements. --- mmengine/_strategy/deepspeed.py | 10 +- mmengine/config/config.py | 218 ++++++++++++++++--------------- mmengine/logging/message_hub.py | 6 +- mmengine/model/averaged_model.py | 1 + 4 files changed, 123 insertions(+), 112 deletions(-) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 3f89ff760d..1fff461bf3 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -63,9 +63,11 @@ def register_deepspeed_optimizers() -> List[str]: @OPTIM_WRAPPERS.register_module() class DeepSpeedOptimWrapper(BaseOptimWrapper): - def __init__(self, optimizer): + def __init__(self, optimizer, accumulative_counts): super().__init__(optimizer) self._model = None + self._inner_count = 0 + self._accumulative_counts = accumulative_counts @property def model(self): @@ -80,11 +82,13 @@ def model(self, value): def update_params(self, loss) -> None: # type: ignore """Update parameters in :attr:`optimizer`.""" self.backward(loss) - self.step() + if self.should_update(): + self.step() def backward(self, loss: torch.Tensor, **kwargs) -> None: """"Perform gradient back propagation.""" self.model.backward(loss) + self._inner_count += 1 def zero_grad(self, **kwargs) -> None: raise NotImplementedError( @@ -107,6 +111,8 @@ def load_state_dict(self, state_dict: dict) -> None: if base_param_settings is not None: self.base_param_settings = base_param_settings + def should_update(self) -> bool: + return (self._inner_count % self._accumulative_counts == 0) @MODEL_WRAPPERS.register_module() class MMDeepSpeedEngineWrapper: diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 36f92f0b3a..5ca06954ed 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -1375,120 +1375,122 @@ def env_variables(self) -> dict: @property def pretty_text(self) -> str: """Get formatted python config text.""" + try: + indent = 4 + + def _indent(s_, num_spaces): + s = s_.split('\n') + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + def _format_basic_types(k, v, use_mapping=False): + if isinstance(v, str): + v_str = repr(v) + else: + v_str = str(v) - indent = 4 - - def _indent(s_, num_spaces): - s = s_.split('\n') - if len(s) == 1: - return s_ - first = s.pop(0) - s = [(num_spaces * ' ') + line for line in s] - s = '\n'.join(s) - s = first + '\n' + s - return s - - def _format_basic_types(k, v, use_mapping=False): - if isinstance(v, str): - v_str = repr(v) - else: - v_str = str(v) - - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' + else: + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) - return attr_str + return attr_str - def _format_list_tuple(k, v, use_mapping=False): - if isinstance(v, list): - left = '[' - right = ']' - else: - left = '(' - right = ')' - - v_str = f'{left}\n' - # check if all items in the list are dict - for item in v: - if isinstance(item, dict): - v_str += f'dict({_indent(_format_dict(item), indent)}),\n' - elif isinstance(item, tuple): - v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 - elif isinstance(item, list): - v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 - elif isinstance(item, str): - v_str += f'{_indent(repr(item), indent)},\n' + def _format_list_tuple(k, v, use_mapping=False): + if isinstance(v, list): + left = '[' + right = ']' else: - v_str += str(item) + ',\n' - if k is None: - return _indent(v_str, indent) + right - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' - else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + right - return attr_str - - def _contain_invalid_identifier(dict_str): - contain_invalid_identifier = False - for key_name in dict_str: - contain_invalid_identifier |= \ - (not str(key_name).isidentifier()) - return contain_invalid_identifier - - def _format_dict(input_dict, outest_level=False): - r = '' - s = [] - - use_mapping = _contain_invalid_identifier(input_dict) - if use_mapping: - r += '{' - for idx, (k, v) in enumerate( - sorted(input_dict.items(), key=lambda x: str(x[0]))): - is_last = idx >= len(input_dict) - 1 - end = '' if outest_level or is_last else ',' - if isinstance(v, dict): - v_str = '\n' + _format_dict(v) - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: dict({v_str}' + left = '(' + right = ')' + + v_str = f'{left}\n' + # check if all items in the list are dict + for item in v: + if isinstance(item, dict): + v_str += f'dict({_indent(_format_dict(item), indent)}),\n' + elif isinstance(item, tuple): + v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 + elif isinstance(item, list): + v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 + elif isinstance(item, str): + v_str += f'{_indent(repr(item), indent)},\n' else: - attr_str = f'{str(k)}=dict({v_str}' - attr_str = _indent(attr_str, indent) + ')' + end - elif isinstance(v, (list, tuple)): - attr_str = _format_list_tuple(k, v, use_mapping) + end + v_str += str(item) + ',\n' + if k is None: + return _indent(v_str, indent) + right + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' else: - attr_str = _format_basic_types(k, v, use_mapping) + end - - s.append(attr_str) - r += '\n'.join(s) - if use_mapping: - r += '}' - return r - - cfg_dict = self.to_dict() - text = _format_dict(cfg_dict, outest_level=True) - if self._format_python_code: - # copied from setup.cfg - yapf_style = dict( - based_on_style='pep8', - blank_line_before_nested_class_or_def=True, - split_before_expression_after_opening_paren=True) - try: - if digit_version(yapf.__version__) >= digit_version('0.40.2'): - text, _ = FormatCode(text, style_config=yapf_style) - else: - text, _ = FormatCode( - text, style_config=yapf_style, verify=True) - except: # noqa: E722 - raise SyntaxError('Failed to format the config file, please ' - f'check the syntax of: \n{text}') - return text + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + right + return attr_str + + def _contain_invalid_identifier(dict_str): + contain_invalid_identifier = False + for key_name in dict_str: + contain_invalid_identifier |= \ + (not str(key_name).isidentifier()) + return contain_invalid_identifier + + def _format_dict(input_dict, outest_level=False): + r = '' + s = [] + + use_mapping = _contain_invalid_identifier(input_dict) + if use_mapping: + r += '{' + for idx, (k, v) in enumerate( + sorted(input_dict.items(), key=lambda x: str(x[0]))): + is_last = idx >= len(input_dict) - 1 + end = '' if outest_level or is_last else ',' + if isinstance(v, dict): + v_str = '\n' + _format_dict(v) + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: dict({v_str}' + else: + attr_str = f'{str(k)}=dict({v_str}' + attr_str = _indent(attr_str, indent) + ')' + end + elif isinstance(v, (list, tuple)): + attr_str = _format_list_tuple(k, v, use_mapping) + end + else: + attr_str = _format_basic_types(k, v, use_mapping) + end + + s.append(attr_str) + r += '\n'.join(s) + if use_mapping: + r += '}' + return r + + cfg_dict = self.to_dict() + text = _format_dict(cfg_dict, outest_level=True) + if self._format_python_code: + # copied from setup.cfg + yapf_style = dict( + based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) + try: + if digit_version(yapf.__version__) >= digit_version('0.40.2'): + text, _ = FormatCode(text, style_config=yapf_style) + else: + text, _ = FormatCode( + text, style_config=yapf_style, verify=True) + except: # noqa: E722 + raise SyntaxError('Failed to format the config file, please ' + f'check the syntax of: \n{text}') + return text + except Exception as e: + return f'Error occurs when formatting config: {e}' def __repr__(self): return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 82565d8832..e4edc3466e 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -342,8 +342,10 @@ def _get_valid_value( else: # check whether value is torch.Tensor but don't want # to import torch in this file - assert hasattr(value, 'numel') and value.numel() == 1 - value = value.item() + if hasattr(value, 'numel') and value.numel() == 1: + value = value.item() + else: + print_log(f"MessageHub got unexpceted log: {value}", level=logging.WARN) return value # type: ignore def state_dict(self) -> dict: diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index 58457c2a6e..cc83a5976d 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -103,6 +103,7 @@ def update_parameters(self, model: nn.Module) -> None: for k, p_avg in self.avg_parameters.items(): p_avg.data.copy_(src_parameters[k].data) elif self.steps % self.interval == 0: + print(self.avg_parameters) for k, p_avg in self.avg_parameters.items(): if p_avg.dtype.is_floating_point: device = p_avg.device From 861fc1b91b111ceecf9c05895de4156f3725a6f3 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sun, 12 Jan 2025 14:40:11 +0800 Subject: [PATCH 12/35] improve multiple mmengine undeveloped issues. --- mmengine/model/base_module.py | 1 - mmengine/optim/optimizer/builder.py | 18 ++++++++++-------- mmengine/runner/loops.py | 1 - 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 3cfe0b14a8..276e6fe218 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -65,7 +65,6 @@ def is_init(self, value): def init_weights(self): """Initialize the weights.""" - is_top_level_module = False # check if it is top-level module if not hasattr(self, '_params_init_info'): diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 09467a192f..ebba603dbf 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -9,6 +9,7 @@ from mmengine.config import Config, ConfigDict from mmengine.device import is_npu_available, is_npu_support_full_precision from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS +from .default_constructor import DefaultOptimWrapperConstructor from .optimizer_wrapper import OptimWrapper @@ -197,8 +198,9 @@ def build_optim_wrapper(model: nn.Module, OptimWrapper: The built optimizer wrapper. """ optim_wrapper_cfg = copy.deepcopy(cfg) - constructor_type = optim_wrapper_cfg.pop('constructor', - 'DefaultOptimWrapperConstructor') + constructor_cfg = optim_wrapper_cfg.pop('constructor', None) + if constructor_cfg is None: + constructor_cfg = dict(type=DefaultOptimWrapperConstructor) paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) # Since the current generation of NPU(Ascend 910) only supports @@ -206,12 +208,12 @@ def build_optim_wrapper(model: nn.Module, # to make the training normal if is_npu_available() and not is_npu_support_full_precision(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' - - optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( - dict( - type=constructor_type, - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg)) + constructor_cfg.update(dict( + optim_wrapper_cfg=optim_wrapper_cfg, + paramwise_cfg=paramwise_cfg + )) + + optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(constructor_cfg) optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index f511c14e68..ba9ec9d9dd 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -394,7 +394,6 @@ def run(self) -> dict: self.val_loss.clear() for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) - # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) From 8f37dd2d16f8ee4ea46cb8a9603f0f42eb96280d Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Fri, 17 Jan 2025 09:56:42 +0000 Subject: [PATCH 13/35] Multiple improvements --- mmengine/model/averaged_model.py | 1 - mmengine/runner/loops.py | 4 ++-- mmengine/visualization/vis_backend.py | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index cc83a5976d..58457c2a6e 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -103,7 +103,6 @@ def update_parameters(self, model: nn.Module) -> None: for k, p_avg in self.avg_parameters.items(): p_avg.data.copy_(src_parameters[k].data) elif self.steps % self.interval == 0: - print(self.avg_parameters) for k, p_avg in self.avg_parameters.items(): if p_avg.dtype.is_floating_point: device = p_avg.device diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index ba9ec9d9dd..4411edb412 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -418,9 +418,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]): # outputs should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) - + outputs, self.val_loss = _update_losses(outputs, self.val_loss) - + self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_val_iter', diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index a5bf7d88e7..fcb7d23b05 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -578,7 +578,7 @@ def add_image(self, step: int = 0, **kwargs) -> None: """Record the image to tensorboard. - + Args: name (str): The image identifier. image (np.ndarray): The image to be saved. The format From d45205c3c48ed1a64364c0867d769ecd718ba1c4 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Sat, 22 Feb 2025 21:13:42 +0800 Subject: [PATCH 14/35] update dependency and bump versions --- pyproject.toml | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..81294f4a68 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["setuptools>=72", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mmengine" +version = "0.10.5" +description = "Engine of OpenMMLab projects" +readme = "README.md" +license = { text = "Apache License 2.0" } +authors = [ + { name = "MMEngine Authors", email = "openmmlab@gmail.com" }, + { name = "MGAM", email = "312065559@qq.com" } +] +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Utilities", +] +keywords = ["OpenMMLab", "Engine"] +dependencies = [ + "addict", + "matplotlib", + "numpy", + "pyyaml", + "regex;sys_platform=='win32'", + "rich", + "termcolor", + "yapf", +] \ No newline at end of file From c472f2b6df7c4549911415ec85d09cae0bf18963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E8=B4=BB=E9=92=A6?= <312065559@qq.com> Date: Mon, 3 Mar 2025 09:22:22 +0800 Subject: [PATCH 15/35] fix wrong pyproject config. --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 81294f4a68..053f4f3bfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,4 +32,8 @@ dependencies = [ "rich", "termcolor", "yapf", -] \ No newline at end of file +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["mmengine"] \ No newline at end of file From 4b3627ac23173276833609962aa9b0bf524c1aea Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Sun, 16 Mar 2025 19:09:22 +0800 Subject: [PATCH 16/35] sync version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 053f4f3bfe..e4710b2766 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mmengine" -version = "0.10.5" +version = "0.10.7" description = "Engine of OpenMMLab projects" readme = "README.md" license = { text = "Apache License 2.0" } From c5f5ca7563738492c74204030fc0b88d7f64e624 Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Sun, 16 Mar 2025 20:09:44 +0800 Subject: [PATCH 17/35] disable HistoryBuffer's torch compile --- mmengine/logging/history_buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmengine/logging/history_buffer.py b/mmengine/logging/history_buffer.py index a50de22c65..3c69e86aca 100644 --- a/mmengine/logging/history_buffer.py +++ b/mmengine/logging/history_buffer.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union import numpy as np +import torch class HistoryBuffer: @@ -57,6 +58,7 @@ def _set_default_statistics(self) -> None: self._statistics_methods.setdefault('current', HistoryBuffer.current) self._statistics_methods.setdefault('mean', HistoryBuffer.mean) + @torch._dynamo.disable() def update(self, log_val: Union[int, float], count: int = 1) -> None: """Update the log history. From 438eb6463619f892486966ca8203662dfdf7af03 Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Mon, 17 Mar 2025 01:55:03 +0000 Subject: [PATCH 18/35] Fix histort buffer bug when using torch.compile --- mmengine/logging/history_buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmengine/logging/history_buffer.py b/mmengine/logging/history_buffer.py index a50de22c65..3c69e86aca 100644 --- a/mmengine/logging/history_buffer.py +++ b/mmengine/logging/history_buffer.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union import numpy as np +import torch class HistoryBuffer: @@ -57,6 +58,7 @@ def _set_default_statistics(self) -> None: self._statistics_methods.setdefault('current', HistoryBuffer.current) self._statistics_methods.setdefault('mean', HistoryBuffer.mean) + @torch._dynamo.disable() def update(self, log_val: Union[int, float], count: int = 1) -> None: """Update the log history. From 6d618bc837582f55dce3a15b7c31679203a3e71f Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Sun, 30 Mar 2025 10:51:23 +0800 Subject: [PATCH 19/35] 1. Undo changes made to history buffer. 2. Add torch compiler disable flag to message hub class. 3. The compile-time fault override has been moved from history buffer to message hub. 4. The MMDistributedDataParallel module has now been recovered to original MMEngine implementation. The reason for the modification at that time may be related to the train_step function modification at earlier projects. Such modification will be achived by inheriting a new class in the future. --- mmengine/logging/history_buffer.py | 1 - mmengine/logging/message_hub.py | 4 ++-- mmengine/model/wrappers/distributed.py | 3 --- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/mmengine/logging/history_buffer.py b/mmengine/logging/history_buffer.py index 3c69e86aca..be8e3e242c 100644 --- a/mmengine/logging/history_buffer.py +++ b/mmengine/logging/history_buffer.py @@ -58,7 +58,6 @@ def _set_default_statistics(self) -> None: self._statistics_methods.setdefault('current', HistoryBuffer.current) self._statistics_methods.setdefault('mean', HistoryBuffer.mean) - @torch._dynamo.disable() def update(self, log_val: Union[int, float], count: int = 1) -> None: """Update the log history. diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index e4edc3466e..68a0ee7f38 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -10,8 +10,7 @@ from .history_buffer import HistoryBuffer from .logger import print_log -if TYPE_CHECKING: - import torch +import torch class MessageHub(ManagerMixin): @@ -92,6 +91,7 @@ def get_current_instance(cls) -> 'MessageHub': cls.get_instance('mmengine') return super().get_current_instance() + @torch.compiler.disable def update_scalar(self, key: str, value: Union[int, float, np.ndarray, 'torch.Tensor'], diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index b88bc7c2b0..4113aebf9e 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -95,7 +95,6 @@ def __init__(self, def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: - return self.module.train_step(data, optim_wrapper) """Interface for model forward, backward and parameters updating during training process. @@ -127,7 +126,6 @@ def train_step(self, data: Union[dict, tuple, list], return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: - return self.module.val_step(data) """Gets the prediction of module during validation process. Args: @@ -139,7 +137,6 @@ def val_step(self, data: Union[dict, tuple, list]) -> list: return self.module.val_step(data) def test_step(self, data: Union[dict, tuple, list]) -> list: - return self.module.test_step(data) """Gets the predictions of module during testing process. Args: From 61493167787dec0a7fd80a34a989f792061e22bb Mon Sep 17 00:00:00 2001 From: MGAM <312065559@qq.com> Date: Fri, 25 Apr 2025 13:42:58 +0000 Subject: [PATCH 20/35] 1. remove unnecessary distributed warp. --- mmengine/model/wrappers/distributed.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index b88bc7c2b0..c842cb3067 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -127,7 +127,6 @@ def train_step(self, data: Union[dict, tuple, list], return log_vars def val_step(self, data: Union[dict, tuple, list]) -> list: - return self.module.val_step(data) """Gets the prediction of module during validation process. Args: @@ -139,7 +138,6 @@ def val_step(self, data: Union[dict, tuple, list]) -> list: return self.module.val_step(data) def test_step(self, data: Union[dict, tuple, list]) -> list: - return self.module.test_step(data) """Gets the predictions of module during testing process. Args: From 385d0294b0eaa5f08ed125abf806eeadce1adbc6 Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Mon, 21 Jul 2025 22:07:59 +0800 Subject: [PATCH 21/35] Remove setup.py to fix installation bug using python>=3.13 --- setup.py => setup.py.backup | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename setup.py => setup.py.backup (100%) diff --git a/setup.py b/setup.py.backup similarity index 100% rename from setup.py rename to setup.py.backup From c62a408b859e96c1d2a555561cdeca4368326a72 Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Fri, 17 Oct 2025 12:27:48 +0800 Subject: [PATCH 22/35] Change the pip install method --- pyproject.toml | 39 ++++++++++++++++++++++++++++++++++++++ setup.cfg => setup.cfg.bak | 0 2 files changed, 39 insertions(+) create mode 100644 pyproject.toml rename setup.cfg => setup.cfg.bak (100%) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..e4710b2766 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools>=72", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mmengine" +version = "0.10.7" +description = "Engine of OpenMMLab projects" +readme = "README.md" +license = { text = "Apache License 2.0" } +authors = [ + { name = "MMEngine Authors", email = "openmmlab@gmail.com" }, + { name = "MGAM", email = "312065559@qq.com" } +] +requires-python = ">=3.10" +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Utilities", +] +keywords = ["OpenMMLab", "Engine"] +dependencies = [ + "addict", + "matplotlib", + "numpy", + "pyyaml", + "regex;sys_platform=='win32'", + "rich", + "termcolor", + "yapf", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["mmengine"] \ No newline at end of file diff --git a/setup.cfg b/setup.cfg.bak similarity index 100% rename from setup.cfg rename to setup.cfg.bak From ebcee9c5dc843a9b60488379bc21f10f20654062 Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Fri, 17 Oct 2025 12:50:48 +0800 Subject: [PATCH 23/35] Fix code style and type annotations in builder.py --- mmengine/logging/message_hub.py | 10 +++++----- mmengine/optim/optimizer/builder.py | 25 ++++++++++++++----------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index dff2f9c8b0..6e4faaee6e 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -2,17 +2,16 @@ import copy import logging from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import Any, Optional, Union import numpy as np +import torch from mmengine.utils import ManagerMixin + from .history_buffer import HistoryBuffer from .logger import print_log -if TYPE_CHECKING: - import torch - class MessageHub(ManagerMixin): """Message hub for component interaction. MessageHub is created and @@ -346,7 +345,8 @@ def _get_valid_value( if hasattr(value, 'numel') and value.numel() == 1: value = value.item() else: - print_log(f"MessageHub got unexpceted log: {value}", level=logging.WARN) + print_log(f"MessageHub got unexpceted log: {value}", + level=logging.WARN) return value # type: ignore def state_dict(self) -> dict: diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 54c546ca54..a76fd9730c 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import inspect +import warnings from typing import List, Union import torch @@ -9,6 +10,7 @@ from mmengine.config import Config, ConfigDict from mmengine.device import is_npu_available, is_npu_support_full_precision from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS + from .default_constructor import DefaultOptimWrapperConstructor from .optimizer_wrapper import OptimWrapper @@ -116,7 +118,7 @@ def register_sophia_optimizers() -> List[str]: Returns: List[str]: A list of registered optimizers' name. """ - optimizers = [] + optimizers: List[str] = [] try: import Sophia except ImportError: @@ -129,7 +131,8 @@ def register_sophia_optimizers() -> List[str]: try: OPTIMIZERS.register_module(module=_optim) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn( + f"Failed to import {_optim.__name__} for {e}") return optimizers @@ -170,8 +173,8 @@ def register_bitsandbytes_optimizers() -> List[str]: BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers() -def register_transformers_optimizers(): - transformer_optimizers = [] +def register_transformers_optimizers() -> List[str]: + transformer_optimizers: List[str] = [] try: from transformers import Adafactor except ImportError: @@ -180,7 +183,7 @@ def register_transformers_optimizers(): try: OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn(f"Failed to import {Adafactor.__name__} for {e}") transformer_optimizers.append('Adafactor') return transformer_optimizers @@ -216,12 +219,12 @@ def build_optim_wrapper(model: nn.Module, # to make the training normal if is_npu_available() and not is_npu_support_full_precision(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' - - constructor_cfg.update(dict( - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg - )) - optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build(constructor_cfg) + constructor_cfg.update( + dict( + optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) + + optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( + constructor_cfg) optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper From 428bd84407eef4247515638754f47d8fb3124d7f Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Fri, 17 Oct 2025 14:42:20 +0800 Subject: [PATCH 24/35] Packing according to PyPA. Make pytest happy on new numpy. Using `importlib` instead of `pkg_resources` as the latter one is deprecated. Make mypy, flake, yapf and isort happy in python 3.13. --- .pre-commit-config.yaml | 4 +- mmengine/utils/package_utils.py | 66 +++++--- pyproject.toml | 66 +++++++- setup.py | 144 ------------------ .../lazy_module_config/test_ast_transform.py | 2 +- tests/test_config/test_lazy.py | 12 +- 6 files changed, 115 insertions(+), 179 deletions(-) delete mode 100644 setup.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c8edd013c6..81ddf48216 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,8 +12,8 @@ repos: rev: 5.11.5 hooks: - id: isort - - repo: https://github.com/pre-commit/mirrors-yapf - rev: v0.32.0 + - repo: https://github.com/google/yapf + rev: v0.43.0 hooks: - id: yapf - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 1816f47f07..65eef94bea 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -1,6 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import subprocess +from typing import Any + +# Import distribution function with fallback for older Python versions +try: + from importlib.metadata import PackageNotFoundError, distribution +except ImportError: + from importlib_metadata import ( # type: ignore[import-untyped, no-redef, import-not-found] # noqa: E501 + PackageNotFoundError, distribution) def is_installed(package: str) -> bool: @@ -9,21 +17,16 @@ def is_installed(package: str) -> bool: Args: package (str): Name of package to be checked. """ - # When executing `import mmengine.runner`, - # pkg_resources will be imported and it takes too much time. - # Therefore, import it in function scope to save time. + # Use importlib.metadata instead of deprecated pkg_resources + # importlib.metadata is available in Python 3.8+ + # For Python 3.7, importlib_metadata backport can be used import importlib.util - import pkg_resources - from pkg_resources import get_distribution - - # refresh the pkg_resources - # more datails at https://github.com/pypa/setuptools/issues/373 - importlib.reload(pkg_resources) try: - get_distribution(package) + distribution(package) return True - except pkg_resources.DistributionNotFound: + except Exception: + # If distribution not found, check if module can be imported spec = importlib.util.find_spec(package) if spec is None: return False @@ -45,15 +48,31 @@ def get_installed_path(package: str) -> str: """ import importlib.util - from pkg_resources import DistributionNotFound, get_distribution - # if the package name is not the same as module name, module name should be # inferred. For example, mmcv-full is the package name, but mmcv is module # name. If we want to get the installed path of mmcv-full, we should concat # the pkg.location and module name try: - pkg = get_distribution(package) - except DistributionNotFound as e: + dist = distribution(package) + # In importlib.metadata, we use dist.locate_file() or files + if hasattr(dist, 'locate_file'): + # Python 3.9+ + # locate_file returns PathLike, need to access parent + locate_result: Any = dist.locate_file('') + location = str(locate_result.parent) + elif hasattr(dist, '_path'): + # Python 3.8 - _path is a pathlib.Path object + # We know _path exists because we checked with hasattr + dist_any: Any = dist + location = str(dist_any._path.parent) # type: ignore[attr-defined] + else: + # Fallback: try to find via importlib + spec = importlib.util.find_spec(package) + if spec is not None and spec.origin is not None: + return osp.dirname(spec.origin) + raise RuntimeError( + f'Cannot determine installation path for {package}') + except PackageNotFoundError as e: # if the package is not installed, package path set in PYTHONPATH # can be detected by `find_spec` spec = importlib.util.find_spec(package) @@ -69,23 +88,26 @@ def get_installed_path(package: str) -> str: else: raise e - possible_path = osp.join(pkg.location, package) # type: ignore + possible_path = osp.join(location, package) if osp.exists(possible_path): return possible_path else: - return osp.join(pkg.location, package2module(package)) # type: ignore + return osp.join(location, package2module(package)) -def package2module(package: str): +def package2module(package: str) -> str: """Infer module name from package. Args: package (str): Package to infer module name. """ - from pkg_resources import get_distribution - pkg = get_distribution(package) - if pkg.has_metadata('top_level.txt'): - module_name = pkg.get_metadata('top_level.txt').split('\n')[0] + dist = distribution(package) + + # In importlib.metadata, + # top-level modules are in dist.read_text('top_level.txt') + top_level_text = dist.read_text('top_level.txt') + if top_level_text: + module_name = top_level_text.split('\n')[0] return module_name else: raise ValueError(f'can not infer the module name of {package}') diff --git a/pyproject.toml b/pyproject.toml index e4710b2766..5895ecbc3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,8 @@ build-backend = "setuptools.build_meta" [project] name = "mmengine" -version = "0.10.7" +# Version is dynamically set from mmengine/version.py +dynamic = ["version"] description = "Engine of OpenMMLab projects" readme = "README.md" license = { text = "Apache License 2.0" } @@ -12,28 +13,85 @@ authors = [ { name = "MMEngine Authors", email = "openmmlab@gmail.com" }, { name = "MGAM", email = "312065559@qq.com" } ] -requires-python = ">=3.10" +requires-python = ">=3.7" classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Utilities", ] keywords = ["OpenMMLab", "Engine"] +# Core dependencies from requirements/runtime.txt dependencies = [ "addict", "matplotlib", "numpy", "pyyaml", - "regex;sys_platform=='win32'", + "regex; sys_platform=='win32'", "rich", "termcolor", "yapf", ] +# Optional dependency groups +[project.optional-dependencies] +# All dependencies (runtime + tests) +all = [ + # Runtime dependencies are already included in base + # Test dependencies from requirements/tests.txt + # Note: aim excluded due to dependency issues (aimrocks not available) + # "aim<=3.17.5; sys_platform!='win32'", + "bitsandbytes", + "clearml", + "coverage", + "dadaptation", + "dvclive", + "lion-pytorch", + "lmdb", + "mlflow", + "parameterized", + "pydantic==1.10.9", + "pytest", + "transformers", +] +# Test dependencies only +tests = [ + "bitsandbytes", + "clearml", + "coverage", + "dadaptation", + "dvclive", + "lion-pytorch", + "lmdb", + "mlflow", + "parameterized", + "pydantic==1.10.9", + "pytest", + "transformers", +] + +[project.urls] +Homepage = "https://github.com/open-mmlab/mmengine" +Repository = "https://github.com/open-mmlab/mmengine" +Documentation = "https://mmengine.readthedocs.io" + +# Setuptools configuration +[tool.setuptools] +# Include package data files (similar to include_package_data=True) +include-package-data = true + [tool.setuptools.packages.find] where = ["."] -include = ["mmengine"] \ No newline at end of file +include = ["mmengine*"] +exclude = ["tests*", "docs*", "examples*"] + +# Dynamic version from mmengine/version.py +[tool.setuptools.dynamic] +version = {attr = "mmengine.version.__version__"} diff --git a/setup.py b/setup.py deleted file mode 100644 index 5b1f7fc803..0000000000 --- a/setup.py +++ /dev/null @@ -1,144 +0,0 @@ -import os -import re -from setuptools import find_packages, setup # type: ignore - -from pkg_resources import DistributionNotFound, get_distribution - - -def readme(): - with open('README.md', encoding='utf-8') as f: - content = f.read() - return content - - -version_file = 'mmengine/version.py' - - -def choose_requirement(primary, secondary): - """If some version of primary requirement installed, return primary, else - return secondary.""" - try: - name = re.split(r'[!<>=]', primary)[0] - get_distribution(name) - except DistributionNotFound: - return secondary - - return str(primary) - - -def get_version(): - with open(version_file) as f: - exec(compile(f.read(), version_file, 'exec')) - return locals()['__version__'] - - -def parse_requirements(fname='requirements/runtime.txt', with_version=True): - """Parse the package dependencies listed in a requirements file but strips - specific versioning information. - - Args: - fname (str): path to requirements file - with_version (bool, default=False): if True include version specs - - Returns: - List[str]: list of requirements items - - CommandLine: - python -c "import setup; print(setup.parse_requirements())" - """ - import re - import sys - from os.path import exists - require_fpath = fname - - def parse_line(line): - """Parse information from a line in a requirements text file.""" - if line.startswith('-r '): - # Allow specifying requirements in other files - target = line.split(' ')[1] - for info in parse_require_file(target): - yield info - else: - info = {'line': line} - if line.startswith('-e '): - info['package'] = line.split('#egg=')[1] - else: - # Remove versioning from the package - pat = '(' + '|'.join(['>=', '==', '>']) + ')' - parts = re.split(pat, line, maxsplit=1) - parts = [p.strip() for p in parts] - - info['package'] = parts[0] - if len(parts) > 1: - op, rest = parts[1:] - if ';' in rest: - # Handle platform specific dependencies - # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies - version, platform_deps = map(str.strip, - rest.split(';')) - info['platform_deps'] = platform_deps - else: - version = rest # NOQA - info['version'] = (op, version) - yield info - - def parse_require_file(fpath): - with open(fpath) as f: - for line in f.readlines(): - line = line.strip() - if line and not line.startswith('#'): - yield from parse_line(line) - - def gen_packages_items(): - if exists(require_fpath): - for info in parse_require_file(require_fpath): - parts = [info['package']] - if with_version and 'version' in info: - parts.extend(info['version']) - if not sys.version.startswith('3.4'): - # apparently package_deps are broken in 3.4 - platform_deps = info.get('platform_deps') - if platform_deps is not None: - parts.append(';' + platform_deps) - item = ''.join(parts) - yield item - - packages = list(gen_packages_items()) - return packages - - -if int(os.getenv('MMENGINE_LITE', '0')) == 1: - install_requires = parse_requirements('requirements/runtime_lite.txt') -else: - install_requires = parse_requirements() - -setup( - name='mmengine' - if os.getenv('MMENGINE_LITE', '0') == '0' else 'mmengine-lite', - version=get_version(), - description='Engine of OpenMMLab projects', - long_description=readme(), - long_description_content_type='text/markdown', - url='https://github.com/open-mmlab/mmengine', - author='MMEngine Authors', - author_email='openmmlab@gmail.com', - packages=find_packages(), - include_package_data=True, - classifiers=[ - 'Development Status :: 4 - Beta', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Topic :: Utilities', - ], - python_requires='>=3.7', - install_requires=install_requires, - extras_require={ - 'all': parse_requirements('requirements.txt'), - 'tests': parse_requirements('requirements/tests.txt'), - }, -) diff --git a/tests/data/config/lazy_module_config/test_ast_transform.py b/tests/data/config/lazy_module_config/test_ast_transform.py index a8803dde24..6f0ada1736 100644 --- a/tests/data/config/lazy_module_config/test_ast_transform.py +++ b/tests/data/config/lazy_module_config/test_ast_transform.py @@ -3,7 +3,7 @@ from importlib.util import find_spec as find_module import numpy -import numpy.compat +import numpy.fft import numpy.linalg as linalg from mmengine.config import Config diff --git a/tests/test_config/test_lazy.py b/tests/test_config/test_lazy.py index d69822814b..1dda04fdaa 100644 --- a/tests/test_config/test_lazy.py +++ b/tests/test_config/test_lazy.py @@ -8,7 +8,7 @@ from unittest import TestCase import numpy -import numpy.compat +import numpy.fft import numpy.linalg as linalg from rich.progress import Progress @@ -56,17 +56,17 @@ def test_lazy_module(self): # 1.2 getattr as LazyAttr self.assertIsInstance(lazy_numpy.linalg, LazyAttr) - self.assertIsInstance(lazy_numpy.compat, LazyAttr) + self.assertIsInstance(lazy_numpy.fft, LazyAttr) - # 1.3 Build module from LazyObject. amp and functional can be accessed + # 1.3 Build module from LazyObject. linalg and fft can be accessed imported_numpy = lazy_numpy.build() self.assertIs(imported_numpy.linalg, linalg) - self.assertIs(imported_numpy.compat, numpy.compat) + self.assertIs(imported_numpy.fft, numpy.fft) # 1.4.1 Build module from LazyAttr imported_linalg = lazy_numpy.linalg.build() - imported_compat = lazy_numpy.compat.build() - self.assertIs(imported_compat, numpy.compat) + imported_fft = lazy_numpy.fft.build() + self.assertIs(imported_fft, numpy.fft) self.assertIs(imported_linalg, linalg) # 1.4.2 build class method from LazyAttr From c19a00c02270d69fa875cc490d33fb527560132f Mon Sep 17 00:00:00 2001 From: mgam <312065559@qq.com> Date: Fri, 17 Oct 2025 16:01:51 +0800 Subject: [PATCH 25/35] Make Lint Happy. --- .pre-commit-config.yaml | 8 +- examples/distributed_training.py | 27 +- ...stributed_training_with_flexible_runner.py | 77 ++- examples/llama2/fsdp_finetune.py | 40 +- examples/llama2/generate.py | 7 +- examples/segmentation/train.py | 76 ++- examples/test_time_augmentation.py | 17 +- examples/text_classification/train.py | 65 +-- examples/text_translation/train.py | 32 +- mmengine/_strategy/__init__.py | 1 + mmengine/_strategy/base.py | 25 +- mmengine/_strategy/colossalai.py | 44 +- mmengine/_strategy/deepspeed.py | 24 +- mmengine/_strategy/distributed.py | 29 +- mmengine/_strategy/fsdp.py | 20 +- mmengine/_strategy/single_device.py | 11 +- mmengine/analysis/complexity_analysis.py | 4 +- mmengine/analysis/jit_analysis.py | 8 +- mmengine/analysis/print_helper.py | 1 + mmengine/config/config.py | 92 +-- mmengine/dataset/dataset_wrapper.py | 1 + mmengine/dataset/utils.py | 3 +- mmengine/dist/dist.py | 26 +- mmengine/dist/utils.py | 27 +- mmengine/evaluator/evaluator.py | 1 + mmengine/evaluator/metric.py | 18 +- mmengine/fileio/backends/base.py | 7 +- mmengine/fileio/backends/lmdb_backend.py | 11 +- mmengine/fileio/backends/local_backend.py | 5 +- mmengine/fileio/backends/petrel_backend.py | 6 +- mmengine/fileio/file_client.py | 17 +- mmengine/fileio/handlers/registry_utils.py | 1 + mmengine/fileio/io.py | 101 ++-- mmengine/hooks/checkpoint_hook.py | 45 +- mmengine/hooks/early_stopping_hook.py | 1 + mmengine/hooks/ema_hook.py | 33 +- mmengine/hooks/empty_cache_hook.py | 1 + mmengine/hooks/hook.py | 51 +- mmengine/hooks/iter_timer_hook.py | 5 +- mmengine/hooks/logger_hook.py | 24 +- mmengine/hooks/param_scheduler_hook.py | 1 + mmengine/hooks/runtime_info_hook.py | 10 +- mmengine/hooks/sampler_seed_hook.py | 1 + mmengine/hooks/sync_buffer_hook.py | 1 + mmengine/infer/infer.py | 14 +- mmengine/logging/logger.py | 6 +- mmengine/logging/message_hub.py | 7 +- mmengine/model/__init__.py | 1 + mmengine/model/averaged_model.py | 21 +- mmengine/model/base_model/base_model.py | 1 + .../model/base_model/data_preprocessor.py | 1 + mmengine/model/base_module.py | 10 +- mmengine/model/efficient_conv_bn_eval.py | 10 +- mmengine/model/test_time_aug.py | 8 +- mmengine/model/utils.py | 7 +- mmengine/model/weight_init.py | 52 +- mmengine/model/wrappers/__init__.py | 1 + mmengine/model/wrappers/distributed.py | 1 + .../wrappers/fully_sharded_distributed.py | 21 +- .../model/wrappers/seperate_distributed.py | 1 + .../optim/optimizer/amp_optimizer_wrapper.py | 1 + .../optim/optimizer/apex_optimizer_wrapper.py | 1 + mmengine/optim/optimizer/builder.py | 7 +- .../optim/optimizer/default_constructor.py | 24 +- mmengine/optim/optimizer/optimizer_wrapper.py | 1 + mmengine/optim/scheduler/lr_scheduler.py | 1 + .../optim/scheduler/momentum_scheduler.py | 1 + mmengine/optim/scheduler/param_scheduler.py | 205 ++++--- mmengine/registry/build_functions.py | 1 + mmengine/registry/registry.py | 1 + mmengine/registry/root.py | 4 +- mmengine/registry/utils.py | 5 +- mmengine/runner/_flexible_runner.py | 117 ++-- mmengine/runner/amp.py | 14 +- mmengine/runner/checkpoint.py | 51 +- mmengine/runner/log_processor.py | 7 +- mmengine/runner/loops.py | 63 ++- mmengine/runner/runner.py | 156 +++-- mmengine/structures/base_data_element.py | 6 +- mmengine/structures/instance_data.py | 1 + mmengine/testing/compare.py | 24 +- mmengine/testing/runner_test_case.py | 51 +- mmengine/utils/dl_utils/collect_env.py | 10 +- mmengine/utils/dl_utils/hub.py | 6 +- mmengine/utils/dl_utils/torch_ops.py | 6 +- mmengine/utils/dl_utils/visualize.py | 10 +- mmengine/utils/progressbar_rich.py | 5 +- mmengine/utils/version_utils.py | 7 +- mmengine/visualization/vis_backend.py | 29 +- mmengine/visualization/visualizer.py | 127 ++--- tests/test_analysis/test_flop_count.py | 14 +- tests/test_analysis/test_jit_analysis.py | 80 ++- tests/test_analysis/test_print_helper.py | 26 +- tests/test_config/test_config.py | 168 +++--- tests/test_data/test_data_utils.py | 44 +- tests/test_dataset/test_base_dataset.py | 498 ++++++++-------- tests/test_dataset/test_sampler.py | 5 +- tests/test_dist/test_dist.py | 57 +- tests/test_dist/test_utils.py | 34 +- tests/test_evaluator/test_evaluator.py | 38 +- tests/test_evaluator/test_metric.py | 7 +- .../test_backends/test_backend_utils.py | 15 +- .../test_backends/test_local_backend.py | 156 +++-- .../test_backends/test_petrel_backend.py | 212 ++++--- tests/test_fileio/test_fileclient.py | 152 ++--- tests/test_fileio/test_fileio.py | 57 +- tests/test_fileio/test_io.py | 70 ++- tests/test_hooks/test_checkpoint_hook.py | 144 ++--- tests/test_hooks/test_early_stopping_hook.py | 73 +-- tests/test_hooks/test_ema_hook.py | 18 +- tests/test_hooks/test_empty_cache_hook.py | 7 +- tests/test_hooks/test_logger_hook.py | 37 +- .../test_naive_visualization_hook.py | 43 +- tests/test_hooks/test_prepare_tta_hook.py | 10 +- tests/test_hooks/test_profiler_hook.py | 76 ++- tests/test_hooks/test_runtime_info_hook.py | 38 +- tests/test_hooks/test_sync_buffers_hook.py | 5 +- tests/test_hub/test_hub.py | 13 +- tests/test_infer/test_infer.py | 21 +- tests/test_logging/test_logger.py | 54 +- tests/test_logging/test_message_hub.py | 28 +- tests/test_model/test_averaged_model.py | 65 +-- .../test_base_model/test_base_model.py | 7 +- .../test_base_model/test_data_preprocessor.py | 67 ++- tests/test_model/test_base_module.py | 124 ++-- .../test_model/test_efficient_conv_bn_eval.py | 5 +- tests/test_model/test_model_utils.py | 8 +- tests/test_model/test_test_aug_time.py | 14 +- .../test_wrappers/test_model_wrapper.py | 52 +- .../test_optimizer/test_optimizer.py | 404 ++++++------- .../test_optimizer/test_optimizer_wrapper.py | 85 +-- .../test_optimizer_wrapper_dict.py | 23 +- .../test_scheduler/test_lr_scheduler.py | 285 +++++----- .../test_scheduler/test_momentum_scheduler.py | 304 +++++----- .../test_scheduler/test_param_scheduler.py | 535 +++++++++--------- tests/test_registry/test_build_functions.py | 41 +- tests/test_registry/test_registry.py | 64 ++- tests/test_runner/test_checkpoint.py | 31 +- tests/test_runner/test_log_processor.py | 76 ++- tests/test_runner/test_runner.py | 460 +++++++-------- tests/test_strategies/test_fsdp.py | 84 ++- tests/test_structures/test_data_element.py | 41 +- tests/test_structures/test_instance_data.py | 36 +- tests/test_structures/test_label_data.py | 13 +- tests/test_structures/test_pixel_data.py | 6 +- tests/test_testing/test_runner_test_case.py | 4 +- .../test_dl_utils/test_setup_env.py | 5 +- tests/test_utils/test_misc.py | 4 +- tests/test_utils/test_package_utils.py | 8 +- tests/test_utils/test_progressbar.py | 55 +- tests/test_utils/test_timer.py | 12 +- tests/test_visualizer/test_vis_backend.py | 28 +- tests/test_visualizer/test_visualizer.py | 424 +++++++------- 153 files changed, 3891 insertions(+), 3788 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81ddf48216..73dbbb8b71 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,10 +12,14 @@ repos: rev: 5.11.5 hooks: - id: isort - - repo: https://github.com/google/yapf - rev: v0.43.0 + - repo: local hooks: - id: yapf + name: yapf + entry: yapf + language: system + types: [python] + args: ["-i"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v5.0.0 hooks: diff --git a/examples/distributed_training.py b/examples/distributed_training.py index 236bee234c..c9af4929fa 100644 --- a/examples/distributed_training.py +++ b/examples/distributed_training.py @@ -42,11 +42,10 @@ def compute_metrics(self, results): def parse_args(): parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -73,16 +72,14 @@ def main(): transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize(**norm_cfg)])) - train_dataloader = dict( - batch_size=32, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) - val_dataloader = dict( - batch_size=32, - dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + train_dataloader = dict(batch_size=32, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate')) + val_dataloader = dict(batch_size=32, + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate')) runner = Runner( model=MMResNet50(), work_dir='./work_dirs', diff --git a/examples/distributed_training_with_flexible_runner.py b/examples/distributed_training_with_flexible_runner.py index 99d2cf257d..43772fbf81 100644 --- a/examples/distributed_training_with_flexible_runner.py +++ b/examples/distributed_training_with_flexible_runner.py @@ -70,16 +70,14 @@ def main(): transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize(**norm_cfg)])) - train_dataloader = dict( - batch_size=128, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) - val_dataloader = dict( - batch_size=128, - dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + train_dataloader = dict(batch_size=128, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate')) + val_dataloader = dict(batch_size=128, + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate')) if args.use_deepspeed: strategy = dict( @@ -97,30 +95,28 @@ def main(): # bf16=dict( # enabled=True, # ), - zero_optimization=dict( - stage=3, - allgather_partitions=True, - reduce_scatter=True, - allgather_bucket_size=50000000, - reduce_bucket_size=50000000, - overlap_comm=True, - contiguous_gradients=True, - cpu_offload=False), + zero_optimization=dict(stage=3, + allgather_partitions=True, + reduce_scatter=True, + allgather_bucket_size=50000000, + reduce_bucket_size=50000000, + overlap_comm=True, + contiguous_gradients=True, + cpu_offload=False), ) - optim_wrapper = dict( - type='DeepSpeedOptimWrapper', - optimizer=dict(type='AdamW', lr=1e-3)) + optim_wrapper = dict(type='DeepSpeedOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3)) elif args.use_fsdp: from functools import partial from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy - size_based_auto_wrap_policy = partial( - size_based_auto_wrap_policy, min_num_params=1e7) + size_based_auto_wrap_policy = partial(size_based_auto_wrap_policy, + min_num_params=1e7) strategy = dict( type='FSDPStrategy', model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy)) - optim_wrapper = dict( - type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3)) + optim_wrapper = dict(type='AmpOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3)) elif args.use_colossalai: from colossalai.tensor.op_wrapper import colo_op_impl @@ -142,20 +138,21 @@ def main(): optim_wrapper = dict(optimizer=dict(type='HybridAdam', lr=1e-3)) else: strategy = None - optim_wrapper = dict( - type='AmpOptimWrapper', optimizer=dict(type='AdamW', lr=1e-3)) - - runner = FlexibleRunner( - model=MMResNet50(), - work_dir='./work_dirs', - strategy=strategy, - train_dataloader=train_dataloader, - optim_wrapper=optim_wrapper, - param_scheduler=dict(type='LinearLR'), - train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=1), - val_dataloader=val_dataloader, - val_cfg=dict(), - val_evaluator=dict(type=Accuracy)) + optim_wrapper = dict(type='AmpOptimWrapper', + optimizer=dict(type='AdamW', lr=1e-3)) + + runner = FlexibleRunner(model=MMResNet50(), + work_dir='./work_dirs', + strategy=strategy, + train_dataloader=train_dataloader, + optim_wrapper=optim_wrapper, + param_scheduler=dict(type='LinearLR'), + train_cfg=dict(by_epoch=True, + max_epochs=10, + val_interval=1), + val_dataloader=val_dataloader, + val_cfg=dict(), + val_evaluator=dict(type=Accuracy)) runner.train() diff --git a/examples/llama2/fsdp_finetune.py b/examples/llama2/fsdp_finetune.py index 0d7e2751b7..d1879c9e1c 100644 --- a/examples/llama2/fsdp_finetune.py +++ b/examples/llama2/fsdp_finetune.py @@ -92,17 +92,14 @@ def parse_args(): def train(): args = parse_args() # Setup distributed related component in Strategy. - strategy = FSDPStrategy( - model_wrapper=dict( - auto_wrap_policy=partial( - transformer_auto_wrap_policy, - transformer_layer_cls={LlamaDecoderLayer})), - state_dict_cfg='full', - env_kwargs=dict(randomness=dict(seed=42))) - visualizer = Visualizer( - name='mmengine', - save_dir=args.output_dir, - vis_backends=[dict(type=WandbVisBackend)]) + strategy = FSDPStrategy(model_wrapper=dict( + auto_wrap_policy=partial(transformer_auto_wrap_policy, + transformer_layer_cls={LlamaDecoderLayer})), + state_dict_cfg='full', + env_kwargs=dict(randomness=dict(seed=42))) + visualizer = Visualizer(name='mmengine', + save_dir=args.output_dir, + vis_backends=[dict(type=WandbVisBackend)]) # Prepare model tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint) @@ -112,21 +109,20 @@ def train(): model.train() # Prepare dataset - train_dataset = AlpacaDataset( - tokenizer=tokenizer, data_path=args.data_root) - train_dataloader = DataLoader( - train_dataset, - batch_size=args.batch_size, - sampler=DefaultSampler(train_dataset, seed=0), - collate_fn=default_data_collator, - drop_last=True) + train_dataset = AlpacaDataset(tokenizer=tokenizer, + data_path=args.data_root) + train_dataloader = DataLoader(train_dataset, + batch_size=args.batch_size, + sampler=DefaultSampler(train_dataset, + seed=0), + collate_fn=default_data_collator, + drop_last=True) # Get the prepared model, scheduler and optimizer from strategy epoch_length = len(train_dataloader) max_iters = epoch_length * args.max_epoch - optim_cfg = dict( - optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0), - accumulative_counts=ORI_BATCH_SIZE / args.batch_size) + optim_cfg = dict(optimizer=dict(type=AdamW, lr=1e-4, weight_decay=0.0), + accumulative_counts=ORI_BATCH_SIZE / args.batch_size) scheduler_cfgs = [dict(type=StepLR, step_size=1, gamma=0.85)] model, optimizer, schedulers = strategy.prepare( model, diff --git a/examples/llama2/generate.py b/examples/llama2/generate.py index 85635c37ae..83f80ccaa5 100644 --- a/examples/llama2/generate.py +++ b/examples/llama2/generate.py @@ -30,7 +30,6 @@ def parse_args(): with torch.no_grad(): generate_ids = model.generate(inputs.input_ids.cuda(), max_length=300) print( - tokenizer.batch_decode( - generate_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False)[0]) + tokenizer.batch_decode(generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False)[0]) diff --git a/examples/segmentation/train.py b/examples/segmentation/train.py index dc045a18b9..a26654952f 100644 --- a/examples/segmentation/train.py +++ b/examples/segmentation/train.py @@ -40,8 +40,9 @@ def __init__(self, mask_folder, transform=None, target_transform=None): - super().__init__( - root, transform=transform, target_transform=target_transform) + super().__init__(root, + transform=transform, + target_transform=target_transform) self.img_folder = img_folder self.mask_folder = mask_folder self.images = list( @@ -72,8 +73,9 @@ def __getitem__(self, index): if self.target_transform is not None: labels = self.target_transform(labels) - data_samples = dict( - labels=labels, img_path=img_path, mask_path=mask_path) + data_samples = dict(labels=labels, + img_path=img_path, + mask_path=mask_path) return img, data_samples def __len__(self): @@ -102,8 +104,8 @@ def process(self, data_batch, data_samples): intersect = (labels == preds).sum() union = (torch.logical_or(preds, labels)).sum() iou = (intersect / union).cpu() - self.results.append( - dict(batch_size=len(labels), iou=iou * len(labels))) + self.results.append(dict(batch_size=len(labels), + iou=iou * len(labels))) def compute_metrics(self, results): total_iou = sum(result['iou'] for result in self.results) @@ -151,18 +153,16 @@ def after_val_iter(self, osp.join(saved_dir, osp.basename(img_path))) shutil.copyfile(mask_path, osp.join(saved_dir, osp.basename(mask_path))) - cv2.imwrite( - osp.join(saved_dir, f'pred_{osp.basename(img_path)}'), - pred_mask) + cv2.imwrite(osp.join(saved_dir, f'pred_{osp.basename(img_path)}'), + pred_mask) def parse_args(): parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -181,37 +181,33 @@ def main(): target_transform = transforms.Lambda( lambda x: torch.tensor(np.array(x), dtype=torch.long)) - train_set = CamVid( - 'data/CamVid', - img_folder='train', - mask_folder='train_labels', - transform=transform, - target_transform=target_transform) - - valid_set = CamVid( - 'data/CamVid', - img_folder='val', - mask_folder='val_labels', - transform=transform, - target_transform=target_transform) - - train_dataloader = dict( - batch_size=3, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate')) - val_dataloader = dict( - batch_size=3, - dataset=valid_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=dict(type='default_collate')) + train_set = CamVid('data/CamVid', + img_folder='train', + mask_folder='train_labels', + transform=transform, + target_transform=target_transform) + + valid_set = CamVid('data/CamVid', + img_folder='val', + mask_folder='val_labels', + transform=transform, + target_transform=target_transform) + + train_dataloader = dict(batch_size=3, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate')) + val_dataloader = dict(batch_size=3, + dataset=valid_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate')) runner = Runner( model=MMDeeplabV3(num_classes), work_dir='./work_dir', train_dataloader=train_dataloader, - optim_wrapper=dict( - type=AmpOptimWrapper, optimizer=dict(type=AdamW, lr=2e-4)), + optim_wrapper=dict(type=AmpOptimWrapper, + optimizer=dict(type=AdamW, lr=2e-4)), train_cfg=dict(by_epoch=True, max_epochs=10, val_interval=10), val_dataloader=val_dataloader, val_cfg=dict(), diff --git a/examples/test_time_augmentation.py b/examples/test_time_augmentation.py index 0a896a05a2..f2ed739c22 100644 --- a/examples/test_time_augmentation.py +++ b/examples/test_time_augmentation.py @@ -28,15 +28,14 @@ def _merge_single_sample(self, data_samples): cfg.work_dir = 'work_dirs/resnet50_8xb16_cifar10' cfg.model = dict(type='ClsTTAModel', module=cfg.model) test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline) - flip_tta = dict( - type='TestTimeAug', - transforms=[ - [ - dict(type='RandomFlip', prob=1.), - dict(type='RandomFlip', prob=0.) - ], - [test_pipeline[-1]], - ]) + flip_tta = dict(type='TestTimeAug', + transforms=[ + [ + dict(type='RandomFlip', prob=1.), + dict(type='RandomFlip', prob=0.) + ], + [test_pipeline[-1]], + ]) # Replace the last transform with `TestTimeAug` cfg.test_dataloader.dataset.pipeline[-1] = flip_tta cfg.load_from = 'https://download.openmmlab.com/mmclassification/v0' \ diff --git a/examples/text_classification/train.py b/examples/text_classification/train.py index 84a2841729..81e1d17ba3 100644 --- a/examples/text_classification/train.py +++ b/examples/text_classification/train.py @@ -17,11 +17,10 @@ def __init__(self, model): self.model = model def forward(self, label, input_ids, token_type_ids, attention_mask, mode): - output = self.model( - input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask, - labels=label) + output = self.model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + labels=label) if mode == 'loss': return {'loss': output.loss} elif mode == 'predict': @@ -45,11 +44,10 @@ def compute_metrics(self, results): def parse_args(): parser = argparse.ArgumentParser(description='Distributed Training') - parser.add_argument( - '--launcher', - choices=['none', 'pytorch', 'slurm', 'mpi'], - default='none', - help='job launcher') + parser.add_argument('--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -71,41 +69,36 @@ def collate_fn(data): token_type_ids = torch.stack(token_type_ids) attention_mask = torch.stack(attention_mask) label = torch.tensor(labels) - return dict( - label=label, - input_ids=input_ids, - token_type_ids=token_type_ids, - attention_mask=attention_mask) + return dict(label=label, + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask) def main(): args = parse_args() - model = BertForSequenceClassification.from_pretrained( - 'bert-base-uncased', num_labels=2) + model = BertForSequenceClassification.from_pretrained('bert-base-uncased', + num_labels=2) tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') train_set = load_dataset('imdb', split='train') test_set = load_dataset('imdb', split='test') - train_set = train_set.map( - lambda x: tokenizer( - x['text'], truncation=True, padding=True, max_length=128), - batched=True) - test_set = test_set.map( - lambda x: tokenizer( - x['text'], truncation=True, padding=True, max_length=128), - batched=True) - - train_loader = dict( - batch_size=32, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=collate_fn) - test_loader = dict( - batch_size=32, - dataset=test_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=collate_fn) + train_set = train_set.map(lambda x: tokenizer( + x['text'], truncation=True, padding=True, max_length=128), + batched=True) + test_set = test_set.map(lambda x: tokenizer( + x['text'], truncation=True, padding=True, max_length=128), + batched=True) + + train_loader = dict(batch_size=32, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=collate_fn) + test_loader = dict(batch_size=32, + dataset=test_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=collate_fn) runner = Runner( model=MMBertForClassify(model), train_dataloader=train_loader, diff --git a/examples/text_translation/train.py b/examples/text_translation/train.py index 61f43bafef..12cc11455a 100644 --- a/examples/text_translation/train.py +++ b/examples/text_translation/train.py @@ -19,10 +19,9 @@ def __init__(self, model): def forward(self, label, input_ids, attention_mask, mode): if mode == 'loss': - output = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - labels=label) + output = self.model(input_ids=input_ids, + attention_mask=attention_mask, + labels=label) return {'loss': output.loss} elif mode == 'predict': output = self.model.generate(input_ids) @@ -80,10 +79,9 @@ def collate_fn(data): ).input_ids label[label == tokenizer.pad_token_id] = -100 # ignore contribution to loss - return dict( - label=label, - input_ids=input_dict.input_ids, - attention_mask=input_dict.attention_mask) + return dict(label=label, + input_ids=input_dict.input_ids, + attention_mask=input_dict.attention_mask) def main(): @@ -93,16 +91,14 @@ def main(): books = books['train'].train_test_split(test_size=0.2) train_set, test_set = books['train'], books['test'] - train_loader = dict( - batch_size=16, - dataset=train_set, - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=collate_fn) - test_loader = dict( - batch_size=32, - dataset=test_set, - sampler=dict(type='DefaultSampler', shuffle=False), - collate_fn=collate_fn) + train_loader = dict(batch_size=16, + dataset=train_set, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=collate_fn) + test_loader = dict(batch_size=32, + dataset=test_set, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=collate_fn) runner = Runner( model=MMT5ForTranslation(model), train_dataloader=train_loader, diff --git a/mmengine/_strategy/__init__.py b/mmengine/_strategy/__init__.py index 764abcf868..2e1a3b2c19 100644 --- a/mmengine/_strategy/__init__.py +++ b/mmengine/_strategy/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION + from .base import BaseStrategy from .colossalai import ColossalAIStrategy from .deepspeed import DeepSpeedStrategy diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index b555df9e94..af444a0d99 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -270,10 +270,9 @@ def _set_randomness( more details. """ from mmengine.runner import set_random_seed - self._seed = set_random_seed( - seed=seed, - deterministic=deterministic, - diff_rank_seed=diff_rank_seed) + self._seed = set_random_seed(seed=seed, + deterministic=deterministic, + diff_rank_seed=diff_rank_seed) def build_model(self, model: Union[nn.Module, dict]) -> nn.Module: """Build model. @@ -322,8 +321,8 @@ def compile_model( Returns: nn.Module: Compiled model. """ - if isinstance(compile, bool) and not compile or \ - isinstance(compile, dict) and not compile.get('disable', False): + if isinstance(compile, bool) and not compile or \ + isinstance(compile, dict) and not compile.get('disable', False): return model assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( @@ -561,10 +560,10 @@ def _build_param_scheduler( 'Use the max epochs/iters of train loop as default.') param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, **default_args))) + PARAM_SCHEDULERS.build(_scheduler, + default_args=dict( + optimizer=optim_wrapper, + **default_args))) else: raise TypeError( 'scheduler should be a _ParamScheduler object or dict, ' @@ -800,8 +799,10 @@ def load_model_state_dict( else: model = self.model - _load_checkpoint_to_model( - model, state_dict, strict=strict, revise_keys=revise_keys) + _load_checkpoint_to_model(model, + state_dict, + strict=strict, + revise_keys=revise_keys) def load_optim_state_dict(self, state_dict: dict) -> None: """Load optimizer state from dict.""" diff --git a/mmengine/_strategy/colossalai.py b/mmengine/_strategy/colossalai.py index 13d9f38fc3..1a2eac6143 100644 --- a/mmengine/_strategy/colossalai.py +++ b/mmengine/_strategy/colossalai.py @@ -365,8 +365,9 @@ def resume( directly.""" self.logger.info(f'Resume checkpoint from {filename}') - extra_ckpt = self.load_checkpoint( - filename, map_location=map_location, callback=callback) + extra_ckpt = self.load_checkpoint(filename, + map_location=map_location, + callback=callback) if resume_optimizer: self.booster.load_optimizer( @@ -438,10 +439,11 @@ def save_checkpoint( extra_ckpt = dict() if 'meta' not in extra_ckpt: extra_ckpt['meta'] = dict() - extra_ckpt['meta'].update( - seed=self.seed, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine=mmengine.__version__ + get_git_hash()) + extra_ckpt['meta'].update(seed=self.seed, + time=time.strftime('%Y%m%d_%H%M%S', + time.localtime()), + mmengine=mmengine.__version__ + + get_git_hash()) model_dir = join_path(filename, self.MODEL_DIR) optimizer_dir = join_path(filename, self.OPTIMIZER_DIR) @@ -450,14 +452,14 @@ def save_checkpoint( mkdir_or_exist(optimizer_dir) mkdir_or_exist(schedulers_dir) - self.booster.save_model( - self.model.model_wrapper, checkpoint=model_dir, shard=True) + self.booster.save_model(self.model.model_wrapper, + checkpoint=model_dir, + shard=True) if save_optimizer: - self.booster.save_optimizer( - self.optim_wrapper.optimizer, - checkpoint=optimizer_dir, - shard=True) + self.booster.save_optimizer(self.optim_wrapper.optimizer, + checkpoint=optimizer_dir, + shard=True) if is_main_process() and save_param_scheduler: for i, scheduler in enumerate(self.param_schedulers): @@ -470,8 +472,8 @@ def _build_plugin(self, plugin: Union[str, dict]): if isinstance(plugin, str): if plugin == 'gemini': try: - plugin = colo_plugin.GeminiPlugin( - precision='bf16', placement_policy='auto') + plugin = colo_plugin.GeminiPlugin(precision='bf16', + placement_policy='auto') except AssertionError: from colossalai.zero.gemini.placement_policy import \ PlacementPolicyFactory as colo_placement @@ -545,14 +547,14 @@ def _wrap( model_wrapper, optimizer, *_ = self.booster.boost(model, optimizer) optim_wrapper.optimizer = optimizer default_args = {'model_wrapper': model_wrapper, 'model': model} - model_wrapper = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) + model_wrapper = MODEL_WRAPPERS.build(self.model_wrapper, + default_args=default_args) return model_wrapper, optim_wrapper # type: ignore else: model_wrapper, *_ = self.booster.boost(model) default_args = {'model_wrapper': model_wrapper, 'model': model} - model_wrapper = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) + model_wrapper = MODEL_WRAPPERS.build(self.model_wrapper, + default_args=default_args) return model_wrapper def _setup_distributed( # type: ignore @@ -561,5 +563,7 @@ def _setup_distributed( # type: ignore backend: str = 'nccl', **kwargs, ): - init_dist( - launcher, backend, init_backend='colossalai', config=self.config) + init_dist(launcher, + backend, + init_backend='colossalai', + config=self.config) diff --git a/mmengine/_strategy/deepspeed.py b/mmengine/_strategy/deepspeed.py index 3f89ff760d..7ff827e9d8 100644 --- a/mmengine/_strategy/deepspeed.py +++ b/mmengine/_strategy/deepspeed.py @@ -24,6 +24,7 @@ STRATEGIES) from mmengine.runner.checkpoint import save_checkpoint, weights_to_cpu from mmengine.utils import apply_to, digit_version, get_git_hash + from .base import BaseStrategy @@ -310,10 +311,10 @@ def __init__( self.config.setdefault('gradient_accumulation_steps', 1) self.config['steps_per_print'] = steps_per_print self._inputs_to_half = inputs_to_half - assert (exclude_frozen_parameters is None or - digit_version(deepspeed.__version__) >= digit_version('0.13.2') - ), ('DeepSpeed >= 0.13.2 is required to enable ' - 'exclude_frozen_parameters') + assert (exclude_frozen_parameters is None or digit_version( + deepspeed.__version__) >= digit_version('0.13.2')), ( + 'DeepSpeed >= 0.13.2 is required to enable ' + 'exclude_frozen_parameters') self.exclude_frozen_parameters = exclude_frozen_parameters register_deepspeed_optimizers() @@ -405,8 +406,8 @@ def _wrap_model(self, model: nn.Module) -> nn.Module: else: engine, *_ = deepspeed.initialize(model=model, config=self.config) - wrapper = MMDeepSpeedEngineWrapper( - model=engine, inputs_to_half=self._inputs_to_half) + wrapper = MMDeepSpeedEngineWrapper(model=engine, + inputs_to_half=self._inputs_to_half) return wrapper def load_checkpoint( @@ -563,12 +564,11 @@ def save_checkpoint( extra_ckpt['optim_wrapper'] = self.optim_state_dict() dirname, basename = osp.split(filename) - self.model.save_checkpoint( - dirname, - tag=basename, - client_state=extra_ckpt, - save_latest=False, - **state_dict_kwargs) + self.model.save_checkpoint(dirname, + tag=basename, + client_state=extra_ckpt, + save_latest=False, + **state_dict_kwargs) else: if self.model.zero_optimization_partition_weights(): state_dict = self.model._zero3_consolidated_16bit_state_dict( diff --git a/mmengine/_strategy/distributed.py b/mmengine/_strategy/distributed.py index dbe17d5aeb..057f8de38d 100644 --- a/mmengine/_strategy/distributed.py +++ b/mmengine/_strategy/distributed.py @@ -9,6 +9,7 @@ from mmengine.dist import init_dist, is_distributed, master_only from mmengine.model import convert_sync_batchnorm, is_model_wrapper from mmengine.registry import MODEL_WRAPPERS, STRATEGIES + from .single_device import SingleDeviceStrategy @@ -93,15 +94,14 @@ def _wrap_model(self, model: nn.Module) -> DistributedDataParallel: if self.model_wrapper is None: # set broadcast_buffers as False to keep compatibility with # OpenMMLab repos - self.model_wrapper = dict( - type='MMDistributedDataParallel', broadcast_buffers=False) - - default_args = dict( - type='MMDistributedDataParallel', - module=model, - device_ids=[int(os.environ['LOCAL_RANK'])]) - model = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) + self.model_wrapper = dict(type='MMDistributedDataParallel', + broadcast_buffers=False) + + default_args = dict(type='MMDistributedDataParallel', + module=model, + device_ids=[int(os.environ['LOCAL_RANK'])]) + model = MODEL_WRAPPERS.build(self.model_wrapper, + default_args=default_args) return model @master_only @@ -114,9 +114,8 @@ def save_checkpoint( extra_ckpt: Optional[dict] = None, callback: Optional[Callable] = None, ) -> None: - super().save_checkpoint( - filename=filename, - save_optimizer=save_optimizer, - save_param_scheduler=save_param_scheduler, - extra_ckpt=extra_ckpt, - callback=callback) + super().save_checkpoint(filename=filename, + save_optimizer=save_optimizer, + save_param_scheduler=save_param_scheduler, + extra_ckpt=extra_ckpt, + callback=callback) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index 0788fafdab..b3fe48a6c0 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -29,6 +29,7 @@ from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS, PARAM_SCHEDULERS, STRATEGIES, Registry) from mmengine.utils import get_git_hash, mkdir_or_exist + from .distributed import DDPStrategy from .utils import MetaTensorContext @@ -151,12 +152,11 @@ def _wrap_model(self, model: nn.Module) -> None: if self.model_wrapper is None: self.model_wrapper = dict(type='MMFullyShardedDataParallel') - default_args = dict( - module=model, - device_id=int(os.environ['LOCAL_RANK']), - type='MMFullyShardedDataParallel') - model = MODEL_WRAPPERS.build( - self.model_wrapper, default_args=default_args) + default_args = dict(module=model, + device_id=int(os.environ['LOCAL_RANK']), + type='MMFullyShardedDataParallel') + model = MODEL_WRAPPERS.build(self.model_wrapper, + default_args=default_args) model.set_state_dict_type(model, self.state_dict_type, self.state_dict_config, self.optim_state_dict_config) @@ -632,10 +632,10 @@ def _build_param_scheduler( 'Use the max epochs/iters of train loop as default.') param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, **default_args))) + PARAM_SCHEDULERS.build(_scheduler, + default_args=dict( + optimizer=optim_wrapper, + **default_args))) else: raise TypeError( 'scheduler should be a _ParamScheduler object or dict, ' diff --git a/mmengine/_strategy/single_device.py b/mmengine/_strategy/single_device.py index c7d8accd5a..ddcdce8966 100644 --- a/mmengine/_strategy/single_device.py +++ b/mmengine/_strategy/single_device.py @@ -10,6 +10,7 @@ from mmengine.optim import BaseOptimWrapper, _ParamScheduler from mmengine.registry import STRATEGIES from mmengine.utils import get_git_hash + from .base import BaseStrategy @@ -150,8 +151,9 @@ def load_checkpoint( callback(checkpoint) state_dict = checkpoint.pop('state_dict') - self.load_model_state_dict( - state_dict, strict=strict, revise_keys=revise_keys) + self.load_model_state_dict(state_dict, + strict=strict, + revise_keys=revise_keys) return checkpoint @@ -191,8 +193,9 @@ def resume( """ self.logger.info(f'Resume checkpoint from {filename}') - checkpoint = self.load_checkpoint( - filename, map_location=map_location, callback=callback) + checkpoint = self.load_checkpoint(filename, + map_location=map_location, + callback=callback) if resume_optimizer: self.load_optim_state_dict(checkpoint.pop('optimizer')) diff --git a/mmengine/analysis/complexity_analysis.py b/mmengine/analysis/complexity_analysis.py index 435e5fe5d3..6daeb925b6 100644 --- a/mmengine/analysis/complexity_analysis.py +++ b/mmengine/analysis/complexity_analysis.py @@ -342,8 +342,8 @@ def fill(lvl: int, prefix: str) -> None: rows.append(('model', format_size(count.pop('')))) fill(0, '') - table = Table( - title=f'parameter count of {model.__class__.__name__}', box=box.ASCII2) + table = Table(title=f'parameter count of {model.__class__.__name__}', + box=box.ASCII2) table.add_column('name') table.add_column('#elements or shape') diff --git a/mmengine/analysis/jit_analysis.py b/mmengine/analysis/jit_analysis.py index 17b294863a..4c4a628291 100644 --- a/mmengine/analysis/jit_analysis.py +++ b/mmengine/analysis/jit_analysis.py @@ -20,6 +20,7 @@ from torch.jit import TracerWarning, _get_trace_graph from mmengine.logging import print_log + from .jit_handles import Handle T = TypeVar('T', bound='JitModelAnalysis') @@ -628,10 +629,9 @@ def _analyze(self) -> 'Statistics': counts[name] += op_counts uncalled_mods = set(self._aliases.values()) - all_seen - stats = Statistics( - counts=counts, - unsupported_ops=unsupported_ops, - uncalled_mods=uncalled_mods) + stats = Statistics(counts=counts, + unsupported_ops=unsupported_ops, + uncalled_mods=uncalled_mods) self._stats = stats self._warn_unsupported_ops(unsupported_ops['']) self._warn_uncalled_mods(uncalled_mods) diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py index 3b87d42373..9c4fadc2e5 100644 --- a/mmengine/analysis/print_helper.py +++ b/mmengine/analysis/print_helper.py @@ -13,6 +13,7 @@ from torch import nn from mmengine.utils import is_tuple_of + from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer, parameter_count) diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 801243c82d..3ca4a13066 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -26,6 +26,7 @@ from mmengine.utils import (check_file_exist, digit_version, get_installed_path, import_modules_from_strings, is_installed) + from .lazy import LazyAttr, LazyObject from .utils import (ConfigParsingError, ImportTransformer, RemoveAssignFromAST, _gather_abs_import_lazyobj, _get_external_cfg_base_path, @@ -46,9 +47,10 @@ def _lazy2string(cfg_dict, dict_type=None): if isinstance(cfg_dict, dict): dict_type = dict_type or type(cfg_dict) - return dict_type( - {k: _lazy2string(v, dict_type) - for k, v in dict.items(cfg_dict)}) + return dict_type({ + k: _lazy2string(v, dict_type) + for k, v in dict.items(cfg_dict) + }) elif isinstance(cfg_dict, (tuple, list)): return type(cfg_dict)(_lazy2string(v, dict_type) for v in cfg_dict) elif isinstance(cfg_dict, (LazyAttr, LazyObject)): @@ -254,8 +256,8 @@ def _merge_a_into_b(a, b): b.clear() all_keys = list(b.keys()) + list(a.keys()) return { - key: - _merge_a_into_b(a.get(key, default), b.get(key, default)) + key: _merge_a_into_b(a.get(key, default), + b.get(key, default)) for key in all_keys if key != DELETE_KEY } else: @@ -271,13 +273,15 @@ def __reduce_ex__(self, proto): # called by CPython interpreter during pickling. See more details in # https://github.com/python/cpython/blob/8d61a71f9c81619e34d4a30b625922ebc83c561b/Objects/typeobject.c#L6196 # noqa: E501 if digit_version(platform.python_version()) < digit_version('3.8'): - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None) + return (self.__class__, ({ + k: v + for k, v in super().items() + }, ), None, None, None) else: - return (self.__class__, ({k: v - for k, v in super().items()}, ), None, - None, None, None) + return (self.__class__, ({ + k: v + for k, v in super().items() + }, ), None, None, None, None) def __eq__(self, other): if isinstance(other, ConfigDict): @@ -338,12 +342,12 @@ def add_args(parser: ArgumentParser, elif isinstance(v, dict): add_args(parser, v, prefix + k + '.') elif isinstance(v, abc.Iterable): - parser.add_argument( - '--' + prefix + k, type=type(next(iter(v))), nargs='+') + parser.add_argument('--' + prefix + k, + type=type(next(iter(v))), + nargs='+') else: - print_log( - f'cannot parse key {prefix + k} of type {type(v)}', - logger='current') + print_log(f'cannot parse key {prefix + k} of type {type(v)}', + logger='current') return parser @@ -495,10 +499,9 @@ def fromfile(filename: Union[str, Path], # about lazy in the docstring of ConfigDict ConfigDict.lazy = False - cfg = Config( - cfg_dict, - filename=filename, - format_python_code=format_python_code) + cfg = Config(cfg_dict, + filename=filename, + format_python_code=format_python_code) object.__setattr__(cfg, '_imported_names', imported_names) return cfg @@ -529,9 +532,10 @@ def fromstring(cfg_str: str, file_format: str) -> 'Config': # As a workaround we set `delete=False` and close the temporary file # before opening again. - with tempfile.NamedTemporaryFile( - 'w', encoding='utf-8', suffix=file_format, - delete=False) as temp_file: + with tempfile.NamedTemporaryFile('w', + encoding='utf-8', + suffix=file_format, + delete=False) as temp_file: temp_file.write(cfg_str) cfg = Config.fromfile(temp_file.name) @@ -1094,19 +1098,17 @@ def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]: # the global dict. After the ast transformation, most of import # syntax will be removed (except for the builtin import) and # replaced with the `LazyObject` - transform = ImportTransformer( - global_dict=global_dict, - base_dict=base_dict, - filename=filename) + transform = ImportTransformer(global_dict=global_dict, + base_dict=base_dict, + filename=filename) modified_code = transform.visit(parsed_codes) modified_code, abs_imported = _gather_abs_import_lazyobj( modified_code, filename=filename) imported_names = transform.imported_obj | abs_imported imported_names |= base_imported_names modified_code = ast.fix_missing_locations(modified_code) - exec( - compile(modified_code, filename, mode='exec'), global_dict, - global_dict) + exec(compile(modified_code, filename, mode='exec'), global_dict, + global_dict) ret: dict = {} for key, value in global_dict.items(): @@ -1138,8 +1140,8 @@ def _dict_to_config_dict_lazy(cfg: dict): cfg_dict[key] = Config._dict_to_config_dict_lazy(value) return cfg_dict if isinstance(cfg, (tuple, list)): - return type(cfg)( - Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg) + return type(cfg)(Config._dict_to_config_dict_lazy(_cfg) + for _cfg in cfg) return cfg @staticmethod @@ -1165,8 +1167,9 @@ def _dict_to_config_dict(cfg: dict, cfg = ConfigDict(cfg) dict.__setattr__(cfg, 'scope', scope) for key, value in cfg.items(): - cfg[key] = Config._dict_to_config_dict( - value, scope=scope, has_scope=has_scope) + cfg[key] = Config._dict_to_config_dict(value, + scope=scope, + has_scope=has_scope) elif isinstance(cfg, tuple): cfg = tuple( Config._dict_to_config_dict(_cfg, scope, has_scope=has_scope) @@ -1475,16 +1478,16 @@ def _format_dict(input_dict, outest_level=False): text = _format_dict(cfg_dict, outest_level=True) if self._format_python_code: # copied from setup.cfg - yapf_style = dict( - based_on_style='pep8', - blank_line_before_nested_class_or_def=True, - split_before_expression_after_opening_paren=True) + yapf_style = dict(based_on_style='pep8', + blank_line_before_nested_class_or_def=True, + split_before_expression_after_opening_paren=True) try: if digit_version(yapf.__version__) >= digit_version('0.40.2'): text, _ = FormatCode(text, style_config=yapf_style) else: - text, _ = FormatCode( - text, style_config=yapf_style, verify=True) + text, _ = FormatCode(text, + style_config=yapf_style, + verify=True) except: # noqa: E722 raise SyntaxError('Failed to format the config file, please ' f'check the syntax of: \n{text}') @@ -1622,8 +1625,9 @@ def merge_from_dict(self, cfg_dict = super().__getattribute__('_cfg_dict') super().__setattr__( '_cfg_dict', - Config._merge_a_into_b( - option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) + Config._merge_a_into_b(option_cfg_dict, + cfg_dict, + allow_list_keys=allow_list_keys)) @staticmethod def diff(cfg1: Union[str, 'Config'], cfg2: Union[str, 'Config']) -> str: @@ -1633,8 +1637,8 @@ def diff(cfg1: Union[str, 'Config'], cfg2: Union[str, 'Config']) -> str: if isinstance(cfg2, str): cfg2 = Config.fromfile(cfg2) - res = difflib.unified_diff( - cfg1.pretty_text.split('\n'), cfg2.pretty_text.split('\n')) + res = difflib.unified_diff(cfg1.pretty_text.split('\n'), + cfg2.pretty_text.split('\n')) # Convert into rich format for better visualization console = Console() diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index e63860bee0..8e167ba650 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -11,6 +11,7 @@ from mmengine.logging import print_log from mmengine.registry import DATASETS + from .base_dataset import BaseDataset, force_full_init diff --git a/mmengine/dataset/utils.py b/mmengine/dataset/utils.py index 2c9cf96497..d140cc8dc4 100644 --- a/mmengine/dataset/utils.py +++ b/mmengine/dataset/utils.py @@ -158,7 +158,8 @@ def default_collate(data_batch: Sequence) -> Any: return [default_collate(samples) for samples in transposed] elif isinstance(data_item, Mapping): return data_item_type({ - key: default_collate([d[key] for d in data_batch]) + key: + default_collate([d[key] for d in data_batch]) for key in data_item }) else: diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index f70cc3ef46..88e9b4559d 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -646,8 +646,9 @@ def _all_gather_object(object_list: List[Any], # Gather all local sizes. This is so that we can find the max size, and # index until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device) + object_sizes_tensor = torch.zeros(group_size, + dtype=torch.long, + device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] @@ -656,8 +657,9 @@ def _all_gather_object(object_list: List[Any], max_object_size = int(max(object_size_list).item()) # Resize tensor to max size across all ranks. input_tensor.resize_(max_object_size) - coalesced_output_tensor = torch.empty( - max_object_size * group_size, dtype=torch.uint8, device=current_device) + coalesced_output_tensor = torch.empty(max_object_size * group_size, + dtype=torch.uint8, + device=current_device) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i:max_object_size * (i + 1)] @@ -800,8 +802,9 @@ def _gather_object(obj: Any, # Gather all local sizes. This is so that we can find the max size, and # index until the correct size when deserializing the tensors. group_size = get_world_size(group=group) - object_sizes_tensor = torch.zeros( - group_size, dtype=torch.long, device=current_device) + object_sizes_tensor = torch.zeros(group_size, + dtype=torch.long, + device=current_device) object_size_list = [ object_sizes_tensor[i].unsqueeze(dim=0) for i in range(group_size) ] @@ -815,10 +818,9 @@ def _gather_object(obj: Any, # Avoid populating output tensors if the result won't be gathered on this # rank. if my_rank == dst: - coalesced_output_tensor = torch.empty( - max_object_size * group_size, - dtype=torch.uint8, - device=current_device) + coalesced_output_tensor = torch.empty(max_object_size * group_size, + dtype=torch.uint8, + device=current_device) # Output tensors are nonoverlapping views of coalesced_output_tensor output_tensors = [ coalesced_output_tensor[max_object_size * i:max_object_size * @@ -996,8 +998,8 @@ def collect_results_cpu(result_part: list, if rank == 0: mmengine.mkdir_or_exist('.dist_test') tmpdir = tempfile.mkdtemp(dir='.dist_test') - tmpdir = torch.tensor( - bytearray(tmpdir.encode()), dtype=torch.uint8) + tmpdir = torch.tensor(bytearray(tmpdir.encode()), + dtype=torch.uint8) dir_tensor[:len(tmpdir)] = tmpdir broadcast(dir_tensor, 0) tmpdir = dir_tensor.numpy().tobytes().decode().rstrip() diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index 5d32cec36b..f41b938155 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -105,27 +105,24 @@ def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: if is_mlu_available(): import torch_mlu # noqa: F401 torch.mlu.set_device(local_rank) - torch_dist.init_process_group( - backend='cncl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) + torch_dist.init_process_group(backend='cncl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) elif is_npu_available(): import torch_npu # noqa: F401 torch.npu.set_device(local_rank) - torch_dist.init_process_group( - backend='hccl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) + torch_dist.init_process_group(backend='hccl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) elif is_musa_available(): import torch_musa # noqa: F401 torch.musa.set_device(rank) - torch_dist.init_process_group( - backend='mccl', - rank=rank, - world_size=int(os.environ['WORLD_SIZE']), - **kwargs) + torch_dist.init_process_group(backend='mccl', + rank=rank, + world_size=int(os.environ['WORLD_SIZE']), + **kwargs) else: torch.cuda.set_device(local_rank) diff --git a/mmengine/evaluator/evaluator.py b/mmengine/evaluator/evaluator.py index 930ce93028..065e057aa8 100644 --- a/mmengine/evaluator/evaluator.py +++ b/mmengine/evaluator/evaluator.py @@ -4,6 +4,7 @@ from mmengine.dataset import pseudo_collate from mmengine.registry import EVALUATOR, METRICS from mmengine.structures import BaseDataElement + from .metric import BaseMetric diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index 1292ce61ec..06396e103f 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -119,11 +119,10 @@ def evaluate(self, size: int) -> dict: level=logging.WARNING) if self.collect_device == 'cpu': - results = collect_results( - self.results, - size, - self.collect_device, - tmpdir=self.collect_dir) + results = collect_results(self.results, + size, + self.collect_device, + tmpdir=self.collect_dir) else: results = collect_results(self.results, size, self.collect_device) @@ -168,8 +167,8 @@ def __init__(self, out_file_path: str, collect_device: str = 'cpu', collect_dir: Optional[str] = None) -> None: - super().__init__( - collect_device=collect_device, collect_dir=collect_dir) + super().__init__(collect_device=collect_device, + collect_dir=collect_dir) if not out_file_path.endswith(('.pkl', '.pickle')): raise ValueError('The output file must be a pkl file.') self.out_file_path = out_file_path @@ -181,9 +180,8 @@ def process(self, data_batch: Any, predictions: Sequence[dict]) -> None: def compute_metrics(self, results: list) -> dict: """Dump the prediction results to a pickle file.""" dump(results, self.out_file_path) - print_log( - f'Results has been saved to {self.out_file_path}.', - logger='current') + print_log(f'Results has been saved to {self.out_file_path}.', + logger='current') return {} diff --git a/mmengine/fileio/backends/base.py b/mmengine/fileio/backends/base.py index 9331edf598..6759d8b2a8 100644 --- a/mmengine/fileio/backends/base.py +++ b/mmengine/fileio/backends/base.py @@ -21,10 +21,9 @@ class BaseStorageBackend(metaclass=ABCMeta): @property def allow_symlink(self): - print_log( - 'allow_symlink will be deprecated in future', - logger='current', - level=logging.WARNING) + print_log('allow_symlink will be deprecated in future', + logger='current', + level=logging.WARNING) return self._allow_symlink @property diff --git a/mmengine/fileio/backends/lmdb_backend.py b/mmengine/fileio/backends/lmdb_backend.py index eb47923e56..60cce2145a 100644 --- a/mmengine/fileio/backends/lmdb_backend.py +++ b/mmengine/fileio/backends/lmdb_backend.py @@ -70,12 +70,11 @@ def get_text(self, filepath, encoding=None): def _get_client(self): import lmdb - return lmdb.open( - self.db_path, - readonly=self.readonly, - lock=self.lock, - readahead=self.readahead, - **self.kwargs) + return lmdb.open(self.db_path, + readonly=self.readonly, + lock=self.lock, + readahead=self.readahead, + **self.kwargs) def __del__(self): if self._client is not None: diff --git a/mmengine/fileio/backends/local_backend.py b/mmengine/fileio/backends/local_backend.py index c7d5f04621..ea7bd9fdc3 100644 --- a/mmengine/fileio/backends/local_backend.py +++ b/mmengine/fileio/backends/local_backend.py @@ -7,6 +7,7 @@ from typing import Generator, Iterator, Optional, Tuple, Union import mmengine + from .base import BaseStorageBackend @@ -156,8 +157,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return osp.isfile(filepath) - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: + Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/fileio/backends/petrel_backend.py b/mmengine/fileio/backends/petrel_backend.py index 3994372f66..21deaf3839 100644 --- a/mmengine/fileio/backends/petrel_backend.py +++ b/mmengine/fileio/backends/petrel_backend.py @@ -10,6 +10,7 @@ import mmengine from mmengine.utils import has_method + from .base import BaseStorageBackend @@ -605,8 +606,9 @@ def rmtree(self, dir_path: Union[str, Path]) -> None: >>> dir_path = 'petrel://path/of/dir' >>> backend.rmtree(dir_path) """ - for path in self.list_dir_or_file( - dir_path, list_dir=False, recursive=True): + for path in self.list_dir_or_file(dir_path, + list_dir=False, + recursive=True): filepath = self.join_path(dir_path, path) self.remove(filepath) diff --git a/mmengine/fileio/file_client.py b/mmengine/fileio/file_client.py index 61551d3d1d..6393f56163 100644 --- a/mmengine/fileio/file_client.py +++ b/mmengine/fileio/file_client.py @@ -7,6 +7,7 @@ from mmengine.logging import print_log from mmengine.utils import is_filepath + from .backends import (BaseStorageBackend, HTTPBackend, LmdbBackend, LocalBackend, MemcachedBackend, PetrelBackend) @@ -271,13 +272,17 @@ def get_text(self, filepath): `New in version 1.3.15.` """ if backend is not None: - cls._register_backend( - name, backend, force=force, prefixes=prefixes) + cls._register_backend(name, + backend, + force=force, + prefixes=prefixes) return def _register(backend_cls): - cls._register_backend( - name, backend_cls, force=force, prefixes=prefixes) + cls._register_backend(name, + backend_cls, + force=force, + prefixes=prefixes) return backend_cls return _register @@ -385,8 +390,8 @@ def isfile(self, filepath: Union[str, Path]) -> bool: """ return self.client.isfile(filepath) - def join_path(self, filepath: Union[str, Path], - *filepaths: Union[str, Path]) -> str: + def join_path(self, filepath: Union[str, Path], *filepaths: + Union[str, Path]) -> str: r"""Concatenate all file paths. Join one or more filepath components intelligently. The return value diff --git a/mmengine/fileio/handlers/registry_utils.py b/mmengine/fileio/handlers/registry_utils.py index 106fc881f2..49a50d35fc 100644 --- a/mmengine/fileio/handlers/registry_utils.py +++ b/mmengine/fileio/handlers/registry_utils.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils import is_list_of + from .base import BaseFileHandler from .json_handler import JsonHandler from .pickle_handler import PickleHandler diff --git a/mmengine/fileio/io.py b/mmengine/fileio/io.py index fdeb4dc6df..00b7b52f6d 100644 --- a/mmengine/fileio/io.py +++ b/mmengine/fileio/io.py @@ -38,6 +38,7 @@ from typing import Generator, Iterator, Optional, Tuple, Union from mmengine.utils import is_filepath, is_str + from .backends import backends, prefix_to_backends from .file_client import FileClient # file_handlers and register_handler had been moved to @@ -176,8 +177,9 @@ def get( >>> get(filepath) b'hello world' """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.get(filepath) @@ -203,8 +205,9 @@ def get_text( >>> get_text(filepath) 'hello world' """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.get_text(filepath, encoding) @@ -229,8 +232,9 @@ def put( >>> filepath = '/path/of/file' >>> put(b'hello world', filepath) """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) backend.put(obj, filepath) @@ -257,8 +261,9 @@ def put_text( >>> filepath = '/path/of/file' >>> put_text('hello world', filepath) """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) backend.put_text(obj, filepath) @@ -281,8 +286,9 @@ def exists( >>> exists(filepath) True """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.exists(filepath) @@ -307,8 +313,9 @@ def isdir( >>> isdir(filepath) True """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.isdir(filepath) @@ -332,8 +339,9 @@ def isfile( >>> isfile(filepath) True """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.isfile(filepath) @@ -363,8 +371,9 @@ def join_path( >>> join_path(filepath1, filepath2, filepath3) '/path/of/dir/dir2/path/of/file' """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) return backend.join_path(filepath, *filepaths) @@ -395,8 +404,9 @@ def get_local_path( >>> with get_local_path('s3://bucket/abc.jpg') as path: ... # do something here """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) with backend.get_local_path(str(filepath)) as local_path: yield local_path @@ -439,8 +449,9 @@ def copyfile( >>> copyfile(src, dst) '/path1/of/dir/file' """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(src, + backend_args=backend_args, + enable_singleton=True) return backend.copyfile(src, dst) @@ -473,8 +484,9 @@ def copytree( >>> copytree(src, dst) '/path/of/dir2' """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(src, + backend_args=backend_args, + enable_singleton=True) return backend.copytree(src, dst) @@ -513,8 +525,9 @@ def copyfile_from_local( >>> copyfile_from_local(src, dst) 's3://openmmlab/mmengine/file' """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dst, + backend_args=backend_args, + enable_singleton=True) return backend.copyfile_from_local(src, dst) @@ -545,8 +558,9 @@ def copytree_from_local( >>> copyfile_from_local(src, dst) 's3://openmmlab/mmengine/dir' """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dst, + backend_args=backend_args, + enable_singleton=True) return backend.copytree_from_local(src, dst) @@ -589,8 +603,9 @@ def copyfile_to_local( >>> copyfile_to_local(src, dst) '/path/of/dir/file' """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dst, + backend_args=backend_args, + enable_singleton=True) return backend.copyfile_to_local(src, dst) @@ -621,8 +636,9 @@ def copytree_to_local( >>> copytree_to_local(src, dst) '/path/of/dir' """ - backend = get_file_backend( - dst, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dst, + backend_args=backend_args, + enable_singleton=True) return backend.copytree_to_local(src, dst) @@ -647,8 +663,9 @@ def remove( >>> filepath = '/path/of/file' >>> remove(filepath) """ - backend = get_file_backend( - filepath, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(filepath, + backend_args=backend_args, + enable_singleton=True) backend.remove(filepath) @@ -667,8 +684,9 @@ def rmtree( >>> dir_path = '/path/of/dir' >>> rmtree(dir_path) """ - backend = get_file_backend( - dir_path, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dir_path, + backend_args=backend_args, + enable_singleton=True) backend.rmtree(dir_path) @@ -702,8 +720,9 @@ def copy_if_symlink_fails( >>> copy_if_symlink_fails(src, dst) True """ - backend = get_file_backend( - src, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(src, + backend_args=backend_args, + enable_singleton=True) return backend.copy_if_symlink_fails(src, dst) @@ -755,8 +774,9 @@ def list_dir_or_file( >>> for file_path in list_dir_or_file(dir_path, recursive=True): ... print(file_path) """ - backend = get_file_backend( - dir_path, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(dir_path, + backend_args=backend_args, + enable_singleton=True) yield from backend.list_dir_or_file(dir_path, list_dir, list_file, suffix, recursive) @@ -784,8 +804,9 @@ def generate_presigned_url( Returns: str: Generated presigned url. """ - backend = get_file_backend( - url, backend_args=backend_args, enable_singleton=True) + backend = get_file_backend(url, + backend_args=backend_args, + enable_singleton=True) return backend.generate_presigned_url(url, client_method, expires_in) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 92a4867bb9..c3e62914b3 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -13,6 +13,7 @@ from mmengine.logging import print_log from mmengine.registry import HOOKS from mmengine.utils import is_list_of, is_seq_of + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -196,10 +197,10 @@ def __init__(self, self.save_best = save_best # rule logic - assert (isinstance(rule, str) or is_list_of(rule, str) - or (rule is None)), ( - '"rule" should be a str or list of str or None, ' - f'but got {type(rule)}') + assert (isinstance(rule, str) or is_list_of(rule, str) or + (rule + is None)), ('"rule" should be a str or list of str or None, ' + f'but got {type(rule)}') if isinstance(rule, list): # check the length of rule list assert len(rule) in [ @@ -440,16 +441,15 @@ def _save_checkpoint_with_step(self, runner, step, meta): ckpt_filename) runner.message_hub.update_info('last_ckpt', self.last_ckpt) - runner.save_checkpoint( - self.out_dir, - ckpt_filename, - self.file_client_args, - save_optimizer=self.save_optimizer, - save_param_scheduler=self.save_param_scheduler, - meta=meta, - by_epoch=self.by_epoch, - backend_args=self.backend_args, - **self.args) + runner.save_checkpoint(self.out_dir, + ckpt_filename, + self.file_client_args, + save_optimizer=self.save_optimizer, + save_param_scheduler=self.save_param_scheduler, + meta=meta, + by_epoch=self.by_epoch, + backend_args=self.backend_args, + **self.args) # Model parallel-like training should involve pulling sharded states # from all ranks, but skip the following procedure. @@ -557,15 +557,14 @@ def _save_best_checkpoint(self, runner, metrics) -> None: runner.message_hub.update_info( runtime_best_ckpt_key, self.best_ckpt_path_dict[key_indicator]) - runner.save_checkpoint( - self.out_dir, - filename=best_ckpt_name, - file_client_args=self.file_client_args, - save_optimizer=False, - save_param_scheduler=False, - meta=meta, - by_epoch=False, - backend_args=self.backend_args) + runner.save_checkpoint(self.out_dir, + filename=best_ckpt_name, + file_client_args=self.file_client_args, + save_optimizer=False, + save_param_scheduler=False, + meta=meta, + by_epoch=False, + backend_args=self.backend_args) runner.logger.info( f'The best checkpoint with {best_score:0.4f} {key_indicator} ' f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') diff --git a/mmengine/hooks/early_stopping_hook.py b/mmengine/hooks/early_stopping_hook.py index 5533ebc84c..517265f93e 100644 --- a/mmengine/hooks/early_stopping_hook.py +++ b/mmengine/hooks/early_stopping_hook.py @@ -4,6 +4,7 @@ from typing import Optional, Tuple, Union from mmengine.registry import HOOKS + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] diff --git a/mmengine/hooks/ema_hook.py b/mmengine/hooks/ema_hook.py index 5bc1051d0b..504fcd30ab 100644 --- a/mmengine/hooks/ema_hook.py +++ b/mmengine/hooks/ema_hook.py @@ -7,6 +7,7 @@ from mmengine.logging import print_log from mmengine.model import is_model_wrapper from mmengine.registry import HOOKS, MODELS + from .hook import DATA_BATCH, Hook @@ -71,8 +72,8 @@ def before_run(self, runner) -> None: if is_model_wrapper(model): model = model.module self.src_model = model - self.ema_model = MODELS.build( - self.ema_cfg, default_args=dict(model=self.src_model)) + self.ema_model = MODELS.build(self.ema_cfg, + default_args=dict(model=self.src_model)) def before_train(self, runner) -> None: """Check the begin_epoch/iter is smaller than max_epochs/iters. @@ -181,8 +182,8 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: # The original model parameters are actually saved in ema # field swap the weights back to resume ema state. self._swap_ema_state_dict(checkpoint) - self.ema_model.load_state_dict( - checkpoint['ema_state_dict'], strict=self.strict_load) + self.ema_model.load_state_dict(checkpoint['ema_state_dict'], + strict=self.strict_load) # Support load checkpoint without ema state dict. else: @@ -191,22 +192,20 @@ def after_load_checkpoint(self, runner, checkpoint: dict) -> None: 'There is no `ema_state_dict` in checkpoint. ' '`EMAHook` will make a copy of `state_dict` as the ' 'initial `ema_state_dict`', 'current', logging.WARNING) - load_state_dict( - self.ema_model.module, - copy.deepcopy(checkpoint['state_dict']), - strict=self.strict_load) + load_state_dict(self.ema_model.module, + copy.deepcopy(checkpoint['state_dict']), + strict=self.strict_load) def _swap_ema_parameters(self) -> None: """Swap the parameter of model with ema_model.""" - avg_param = ( - itertools.chain(self.ema_model.module.parameters(), - self.ema_model.module.buffers()) - if self.ema_model.update_buffers else - self.ema_model.module.parameters()) - src_param = ( - itertools.chain(self.src_model.parameters(), - self.src_model.buffers()) - if self.ema_model.update_buffers else self.src_model.parameters()) + avg_param = (itertools.chain(self.ema_model.module.parameters(), + self.ema_model.module.buffers()) + if self.ema_model.update_buffers else + self.ema_model.module.parameters()) + src_param = (itertools.chain(self.src_model.parameters(), + self.src_model.buffers()) + if self.ema_model.update_buffers else + self.src_model.parameters()) for p_avg, p_src in zip(avg_param, src_param): tmp = p_avg.data.clone() p_avg.data.copy_(p_src.data) diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index 9a92cdebfe..7b691d107d 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -4,6 +4,7 @@ import torch from mmengine.registry import HOOKS + from ..device import is_cuda_available, is_musa_available from .hook import Hook diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 4e1c4ce8bd..27230c6fad 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -183,8 +183,10 @@ def before_train_iter(self, batch_idx (int): The index of the current batch in the train loop. data_batch (dict or tuple or list, optional): Data from dataloader. """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') + self._before_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + mode='train') def before_val_iter(self, runner, @@ -199,8 +201,10 @@ def before_val_iter(self, data_batch (dict, optional): Data from dataloader. Defaults to None. """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') + self._before_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + mode='val') def before_test_iter(self, runner, @@ -215,8 +219,10 @@ def before_test_iter(self, data_batch (dict or tuple or list, optional): Data from dataloader. Defaults to None. """ - self._before_iter( - runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') + self._before_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + mode='test') def after_train_iter(self, runner, @@ -232,12 +238,11 @@ def after_train_iter(self, data_batch (dict tuple or list, optional): Data from dataloader. outputs (dict, optional): Outputs from model. """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='train') + self._after_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='train') def after_val_iter(self, runner, @@ -253,12 +258,11 @@ def after_val_iter(self, data_batch (dict or tuple or list, optional): Data from dataloader. outputs (Sequence, optional): Outputs from model. """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='val') + self._after_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='val') def after_test_iter(self, runner, @@ -274,12 +278,11 @@ def after_test_iter(self, data_batch (dict or tuple or list, optional): Data from dataloader. outputs (Sequence, optional): Outputs from model. """ - self._after_iter( - runner, - batch_idx=batch_idx, - data_batch=data_batch, - outputs=outputs, - mode='test') + self._after_iter(runner, + batch_idx=batch_idx, + data_batch=data_batch, + outputs=outputs, + mode='test') def _before_epoch(self, runner, mode: str = 'train') -> None: """All subclasses should override this method, if they need any diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index 5632c2b25e..edbf209d3a 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -3,6 +3,7 @@ from typing import Optional, Sequence, Union from mmengine.registry import HOOKS + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -90,8 +91,8 @@ def _after_iter(self, if mode == 'train': self.time_sec_tot += iter_time.current() # Calculate average iterative time. - time_sec_avg = self.time_sec_tot / ( - runner.iter - self.start_iter + 1) + time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + + 1) # Calculate eta. eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) runner.message_hub.update_info('eta', eta_sec) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index fa0b79dcf9..cfd0a2dd36 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -137,8 +137,8 @@ def __init__(self, self.file_client = FileClient.infer_client(file_client_args, self.out_dir) if file_client_args is None: - self.file_backend = get_file_backend( - self.out_dir, backend_args=backend_args) + self.file_backend = get_file_backend(self.out_dir, + backend_args=backend_args) else: self.file_backend = self.file_client @@ -196,8 +196,9 @@ def after_train_iter(self, else: return runner.logger.info(log_str) - runner.visualizer.add_scalars( - tag, step=runner.iter + 1, file_path=self.json_log_path) + runner.visualizer.add_scalars(tag, + step=runner.iter + 1, + file_path=self.json_log_path) def after_val_iter(self, runner, @@ -262,16 +263,18 @@ def after_val_epoch(self, epoch = 0 else: epoch = runner.epoch - runner.visualizer.add_scalars( - tag, step=epoch, file_path=self.json_log_path) + runner.visualizer.add_scalars(tag, + step=epoch, + file_path=self.json_log_path) else: if (isinstance(runner._train_loop, dict) or runner._train_loop is None): iter = 0 else: iter = runner.iter - runner.visualizer.add_scalars( - tag, step=iter, file_path=self.json_log_path) + runner.visualizer.add_scalars(tag, + step=iter, + file_path=self.json_log_path) def after_test_epoch(self, runner, @@ -288,9 +291,8 @@ def after_test_epoch(self, tag, log_str = runner.log_processor.get_log_after_epoch( runner, len(runner.test_dataloader), 'test', with_non_scalar=True) runner.logger.info(log_str) - dump( - self._process_tags(tag), - osp.join(runner.log_dir, self.json_log_path)) # type: ignore + dump(self._process_tags(tag), + osp.join(runner.log_dir, self.json_log_path)) # type: ignore @staticmethod def _process_tags(tags: dict): diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index 3b2f1e610a..60cb0270fd 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -4,6 +4,7 @@ from mmengine.optim import _ParamScheduler from mmengine.registry import HOOKS from mmengine.utils import is_list_of + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 49407e4563..34487caa78 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -7,6 +7,7 @@ from mmengine.registry import HOOKS from mmengine.utils import get_git_hash from mmengine.version import __version__ + from .hook import Hook DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -47,11 +48,10 @@ def before_run(self, runner) -> None: Args: runner (Runner): The runner of the training process. """ - metainfo = dict( - cfg=runner.cfg.pretty_text, - seed=runner.seed, - experiment_name=runner.experiment_name, - mmengine_version=__version__ + get_git_hash()) + metainfo = dict(cfg=runner.cfg.pretty_text, + seed=runner.seed, + experiment_name=runner.experiment_name, + mmengine_version=__version__ + get_git_hash()) runner.message_hub.update_info_dict(metainfo) self.last_loop_stage = None diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index 9aed9dbcf5..6317fb5a3a 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import HOOKS + from .hook import Hook diff --git a/mmengine/hooks/sync_buffer_hook.py b/mmengine/hooks/sync_buffer_hook.py index 7cc75757fe..5e85bc24bd 100644 --- a/mmengine/hooks/sync_buffer_hook.py +++ b/mmengine/hooks/sync_buffer_hook.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.dist import all_reduce_params, is_distributed from mmengine.registry import HOOKS + from .hook import Hook diff --git a/mmengine/infer/infer.py b/mmengine/infer/infer.py index 322d885224..c46b7b5a7b 100644 --- a/mmengine/infer/infer.py +++ b/mmengine/infer/infer.py @@ -51,9 +51,8 @@ def __init__(self, *args, **kwargs): assert isinstance(self.visualize_kwargs, set) assert isinstance(self.postprocess_kwargs, set) - all_kwargs = ( - self.preprocess_kwargs | self.forward_kwargs - | self.visualize_kwargs | self.postprocess_kwargs) + all_kwargs = (self.preprocess_kwargs | self.forward_kwargs + | self.visualize_kwargs | self.postprocess_kwargs) assert len(all_kwargs) == ( len(self.preprocess_kwargs) + len(self.forward_kwargs) + @@ -215,8 +214,9 @@ def __call__( ) = self._dispatch_kwargs(**kwargs) ori_inputs = self._inputs_to_list(inputs) - inputs = self.preprocess( - ori_inputs, batch_size=batch_size, **preprocess_kwargs) + inputs = self.preprocess(ori_inputs, + batch_size=batch_size, + **preprocess_kwargs) preds = [] for data in (track(inputs, description='Inference') if self.show_progress else inputs): @@ -286,8 +286,8 @@ def __call__(self, inputs, batch_size=1, **kwargs): Yields: Any: Data processed by the ``pipeline`` and ``collate_fn``. """ - chunked_data = self._get_chunk_data( - map(self.pipeline, inputs), batch_size) + chunked_data = self._get_chunk_data(map(self.pipeline, inputs), + batch_size) yield from map(self.collate_fn, chunked_data) @torch.no_grad() diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index e6cf9fe37d..b35024f5e5 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -57,8 +57,10 @@ class MMFormatter(logging.Formatter): **kwargs: Keyword arguments passed to :meth:`logging.Formatter.__init__`. """ - _color_mapping: dict = dict( - ERROR='red', WARNING='yellow', INFO='white', DEBUG='green') + _color_mapping: dict = dict(ERROR='red', + WARNING='yellow', + INFO='white', + DEBUG='green') def __init__(self, color: bool = True, blink: bool = False, **kwargs): super().__init__(**kwargs) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 6e4faaee6e..f056b5eabb 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -377,10 +377,9 @@ def state_dict(self) -> dict: logger='current', level=logging.WARNING) saved_info[key] = value - return dict( - log_scalars=saved_scalars, - runtime_info=saved_info, - resumed_keys=self._resumed_keys) + return dict(log_scalars=saved_scalars, + runtime_info=saved_info, + resumed_keys=self._resumed_keys) def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None: """Loads log scalars, runtime information and resumed keys from diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index 033512a985..41f65f41fd 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.version_utils import digit_version + from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage, MomentumAnnealingEMA, StochasticWeightAverage) from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index 58457c2a6e..eb14294cf4 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -96,9 +96,8 @@ def update_parameters(self, model: nn.Module) -> None: Args: model (nn.Module): The model whose parameters will be averaged. """ - src_parameters = ( - model.state_dict() - if self.update_buffers else dict(model.named_parameters())) + src_parameters = (model.state_dict() if self.update_buffers else dict( + model.named_parameters())) if self.steps == 0: for k, p_avg in self.avg_parameters.items(): p_avg.data.copy_(src_parameters[k].data) @@ -138,9 +137,8 @@ def avg_func(self, averaged_param: Tensor, source_param: Tensor, steps (int): The number of times the parameters have been updated. """ - averaged_param.add_( - source_param - averaged_param, - alpha=1 / float(steps // self.interval + 1)) + averaged_param.add_(source_param - averaged_param, + alpha=1 / float(steps // self.interval + 1)) @MODELS.register_module() @@ -238,12 +236,11 @@ def __init__(self, interval: int = 1, device: Optional[torch.device] = None, update_buffers: bool = False) -> None: - super().__init__( - model=model, - momentum=momentum, - interval=interval, - device=device, - update_buffers=update_buffers) + super().__init__(model=model, + momentum=momentum, + interval=interval, + device=device, + update_buffers=update_buffers) assert gamma > 0, f'gamma must be greater than 0, but got {gamma}' self.gamma = gamma diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 299cd67557..660054dc6c 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -9,6 +9,7 @@ from mmengine.optim import OptimWrapper from mmengine.registry import MODELS from mmengine.utils import is_list_of + from ..base_module import BaseModule from .data_preprocessor import BaseDataPreprocessor diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 4d621851b0..3cd38b4286 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -9,6 +9,7 @@ from mmengine.registry import MODELS from mmengine.structures import BaseDataElement from mmengine.utils import is_seq_of + from ..utils import stack_batch CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 3cfe0b14a8..6eee81b4c0 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -10,6 +10,7 @@ from mmengine.dist import master_only from mmengine.logging import MMLogger, print_log + from .weight_init import PretrainedInit, initialize, update_init_info from .wrappers.utils import is_model_wrapper @@ -135,11 +136,10 @@ def init_weights(self): m, 'is_init', False): m.init_weights() # users may overload the `init_weights` - update_init_info( - m, - init_info=f'Initialized by ' - f'user-defined `init_weights`' - f' in {m.__class__.__name__} ') + update_init_info(m, + init_info=f'Initialized by ' + f'user-defined `init_weights`' + f' in {m.__class__.__name__} ') if self.init_cfg and pretrained_cfg: initialize(self, pretrained_cfg) self._is_init = True diff --git a/mmengine/model/efficient_conv_bn_eval.py b/mmengine/model/efficient_conv_bn_eval.py index 9cb2ad6199..ef12ffa818 100644 --- a/mmengine/model/efficient_conv_bn_eval.py +++ b/mmengine/model/efficient_conv_bn_eval.py @@ -111,10 +111,12 @@ def efficient_conv_bn_eval_graph_transform(fx_model): # note that we directly call `create_node` to fill the `name` # argument. `fx_model.graph.get_attr` and # `fx_model.graph.call_function` does not allow the `name` argument. - conv_get_node = fx_model.graph.create_node( - op='get_attr', target=conv_node.target, name='get_conv') - bn_get_node = fx_model.graph.create_node( - op='get_attr', target=bn_node.target, name='get_bn') + conv_get_node = fx_model.graph.create_node(op='get_attr', + target=conv_node.target, + name='get_conv') + bn_get_node = fx_model.graph.create_node(op='get_attr', + target=bn_node.target, + name='get_bn') # prepare args for the fused function args = (bn_get_node, conv_get_node, conv_node.args[0]) # create a new node diff --git a/mmengine/model/test_time_aug.py b/mmengine/model/test_time_aug.py index c623eec8bc..65fcab5405 100644 --- a/mmengine/model/test_time_aug.py +++ b/mmengine/model/test_time_aug.py @@ -7,6 +7,7 @@ from mmengine.registry import MODELS from mmengine.structures import BaseDataElement + from .base_model import BaseModel # multi-batch inputs processed by different augmentations from the same batch. @@ -124,9 +125,10 @@ def test_step(self, data): data_list: Union[List[dict], List[list]] if isinstance(data, dict): num_augs = len(data[next(iter(data))]) - data_list = [{key: value[idx] - for key, value in data.items()} - for idx in range(num_augs)] + data_list = [{ + key: value[idx] + for key, value in data.items() + } for idx in range(num_augs)] elif isinstance(data, (tuple, list)): num_augs = len(data[0]) data_list = [[_data[idx] for _data in data] diff --git a/mmengine/model/utils.py b/mmengine/model/utils.py index c78ea3134d..0d30aa44ca 100644 --- a/mmengine/model/utils.py +++ b/mmengine/model/utils.py @@ -199,10 +199,9 @@ def revert_sync_batchnorm(module: nn.Module) -> nn.Module: try: module_output.add_module(name, revert_sync_batchnorm(child)) except Exception: - print_log( - F'Failed to convert {child} from SyncBN to BN!', - logger='current', - level=logging.WARNING) + print_log(F'Failed to convert {child} from SyncBN to BN!', + logger='current', + level=logging.WARNING) del module return module_output diff --git a/mmengine/model/weight_init.py b/mmengine/model/weight_init.py index b6d0186ed7..c1d5a07d08 100644 --- a/mmengine/model/weight_init.py +++ b/mmengine/model/weight_init.py @@ -97,11 +97,15 @@ def kaiming_init(module, assert distribution in ['uniform', 'normal'] if hasattr(module, 'weight') and module.weight is not None: if distribution == 'uniform': - nn.init.kaiming_uniform_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + nn.init.kaiming_uniform_(module.weight, + a=a, + mode=mode, + nonlinearity=nonlinearity) else: - nn.init.kaiming_normal_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) + nn.init.kaiming_normal_(module.weight, + a=a, + mode=mode, + nonlinearity=nonlinearity) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) @@ -109,13 +113,12 @@ def kaiming_init(module, def caffe2_xavier_init(module, bias=0): # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch # Acknowledgment to FAIR's internal code - kaiming_init( - module, - a=1, - mode='fan_in', - nonlinearity='leaky_relu', - bias=bias, - distribution='uniform') + kaiming_init(module, + a=1, + mode='fan_in', + nonlinearity='leaky_relu', + bias=bias, + distribution='uniform') def bias_init_with_prob(prior_prob): @@ -450,12 +453,11 @@ class Caffe2XavierInit(KaimingInit): # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch # Acknowledgment to FAIR's internal code def __init__(self, **kwargs): - super().__init__( - a=1, - mode='fan_in', - nonlinearity='leaky_relu', - distribution='uniform', - **kwargs) + super().__init__(a=1, + mode='fan_in', + nonlinearity='leaky_relu', + distribution='uniform', + **kwargs) def __call__(self, module): super().__call__(module) @@ -487,16 +489,14 @@ def __call__(self, module): load_state_dict) if self.prefix is None: print_log(f'load model from: {self.checkpoint}', logger='current') - load_checkpoint( - module, - self.checkpoint, - map_location=self.map_location, - strict=False, - logger='current') + load_checkpoint(module, + self.checkpoint, + map_location=self.map_location, + strict=False, + logger='current') else: - print_log( - f'load {self.prefix} in model from: {self.checkpoint}', - logger='current') + print_log(f'load {self.prefix} in model from: {self.checkpoint}', + logger='current') state_dict = _load_checkpoint_with_prefix( self.prefix, self.checkpoint, map_location=self.map_location) load_state_dict(module, state_dict, strict=False, logger='current') diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py index 90eddabbe1..35480c8df3 100644 --- a/mmengine/model/wrappers/__init__.py +++ b/mmengine/model/wrappers/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.utils.version_utils import digit_version + from .distributed import MMDistributedDataParallel from .seperate_distributed import MMSeparateDistributedDataParallel from .utils import is_model_wrapper diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 4113aebf9e..dda05ff685 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -6,6 +6,7 @@ from mmengine.optim import OptimWrapper from mmengine.registry import MODEL_WRAPPERS + from ..utils import detect_anomalous_params MODEL_WRAPPERS.register_module(module=DistributedDataParallel) diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py index df128597b1..d991b7d703 100644 --- a/mmengine/model/wrappers/fully_sharded_distributed.py +++ b/mmengine/model/wrappers/fully_sharded_distributed.py @@ -233,17 +233,16 @@ def parse_dtype(dtype): kwargs['ignored_modules'] = self._get_ignored_modules( module, kwargs['ignored_modules']) - super().__init__( - module=module, - process_group=process_group, - sharding_strategy=sharding_strategy, - auto_wrap_policy=auto_wrap_policy, - cpu_offload=cpu_offload, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - param_init_fn=param_init_fn, - use_orig_params=use_orig_params, - **kwargs) + super().__init__(module=module, + process_group=process_group, + sharding_strategy=sharding_strategy, + auto_wrap_policy=auto_wrap_policy, + cpu_offload=cpu_offload, + backward_prefetch=backward_prefetch, + mixed_precision=mixed_precision, + param_init_fn=param_init_fn, + use_orig_params=use_orig_params, + **kwargs) def train_step(self, data: dict, optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py index ac9c2383c3..43e860c124 100644 --- a/mmengine/model/wrappers/seperate_distributed.py +++ b/mmengine/model/wrappers/seperate_distributed.py @@ -9,6 +9,7 @@ from mmengine.device import get_device from mmengine.optim import OptimWrapperDict from mmengine.registry import MODEL_WRAPPERS + from .distributed import MMDistributedDataParallel diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 60200924b5..b3beb4ef2e 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -11,6 +11,7 @@ from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION + from .optimizer_wrapper import OptimWrapper if is_npu_available(): diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index a2e6190460..ad38dad21a 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -9,6 +9,7 @@ # from mmengine.model.wrappers import is_model_wrapper import mmengine from mmengine.registry import OPTIM_WRAPPERS + from .optimizer_wrapper import OptimWrapper try: diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index a76fd9730c..65ac3f378d 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -29,8 +29,8 @@ def register_torch_optimizers() -> List[str]: if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): if module_name == 'Adafactor': - OPTIMIZERS.register_module( - name='TorchAdafactor', module=_optim) + OPTIMIZERS.register_module(name='TorchAdafactor', + module=_optim) else: OPTIMIZERS.register_module(module=_optim) torch_optimizers.append(module_name) @@ -221,8 +221,7 @@ def build_optim_wrapper(model: nn.Module, optim_wrapper_cfg['type'] = 'AmpOptimWrapper' constructor_cfg.update( - dict( - optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) + dict(optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( constructor_cfg) diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index b623a3e70e..344a57d0cd 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -13,6 +13,7 @@ from mmengine.utils import is_list_of from mmengine.utils.dl_utils import mmcv_full_available from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm + from .optimizer_wrapper import OptimWrapper @@ -199,9 +200,8 @@ def add_params(self, # special rules for norm layers and depth-wise conv layers is_norm = isinstance(module, (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) - is_dwconv = ( - isinstance(module, torch.nn.Conv2d) - and module.in_channels == module.groups) + is_dwconv = (isinstance(module, torch.nn.Conv2d) + and module.in_channels == module.groups) for name, param in module.named_parameters(recurse=False): param_group = {'params': [param]} @@ -272,9 +272,8 @@ def add_params(self, if key == 'params': continue full_name = f'{prefix}.{name}' if prefix else name - print_log( - f'paramwise_options -- {full_name}:{key}={value}', - logger='current') + print_log(f'paramwise_options -- {full_name}:{key}={value}', + logger='current') if mmcv_full_available(): from mmcv.ops import DeformConv2d, ModulatedDeformConv2d @@ -284,11 +283,10 @@ def add_params(self, is_dcn_module = False for child_name, child_mod in module.named_children(): child_prefix = f'{prefix}.{child_name}' if prefix else child_name - self.add_params( - params, - child_mod, - prefix=child_prefix, - is_dcn_module=is_dcn_module) + self.add_params(params, + child_mod, + prefix=child_prefix, + is_dcn_module=is_dcn_module) def __call__(self, model: nn.Module) -> OptimWrapper: if hasattr(model, 'module'): @@ -304,8 +302,8 @@ def __call__(self, model: nn.Module) -> OptimWrapper: if isinstance(optimizer_cls, str): with OPTIMIZERS.switch_scope_and_registry(None) as registry: optimizer_cls = registry.get(self.optimizer_cfg['type']) - fisrt_arg_name = next( - iter(inspect.signature(optimizer_cls).parameters)) + fisrt_arg_name = next(iter( + inspect.signature(optimizer_cls).parameters)) # if no paramwise option is specified, just use the global setting if not self.paramwise_cfg: optimizer_cfg[fisrt_arg_name] = model.parameters() diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py index 41218ef768..75aa4d08b9 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper.py +++ b/mmengine/optim/optimizer/optimizer_wrapper.py @@ -10,6 +10,7 @@ from mmengine.logging import MessageHub, print_log from mmengine.registry import OPTIM_WRAPPERS from mmengine.utils.dl_utils import has_batch_norm + from .base import BaseOptimWrapper diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index 13bc61d542..48405f2770 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import PARAM_SCHEDULERS + # yapf: disable from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py index e356e70f7b..50df22347a 100644 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import PARAM_SCHEDULERS + # yapf: disable from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index 2dcb1af072..9f46034ac7 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -258,14 +258,13 @@ def __init__(self, verbose: bool = False): self.step_size = step_size self.gamma = gamma - super().__init__( - optimizer=optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer=optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -288,13 +287,12 @@ def build_iter_from_epoch(cls, begin = int(begin * epoch_length) if end != INF: end = int(end * epoch_length) - return cls( - *args, - step_size=step_size, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + step_size=step_size, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) def _get_value(self): """Compute value using chainable form of the scheduler.""" @@ -346,14 +344,13 @@ def __init__(self, verbose: bool = False): self.milestones = Counter(milestones) self.gamma = gamma - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -376,13 +373,12 @@ def build_iter_from_epoch(cls, begin = int(begin * epoch_length) if end != INF: end = int(end * epoch_length) - return cls( - *args, - milestones=milestones, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + milestones=milestones, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) def _get_value(self): """Compute value using chainable form of the scheduler.""" @@ -438,14 +434,13 @@ def __init__(self, self.factor = factor self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -521,14 +516,13 @@ def __init__(self, by_epoch: bool = True, verbose: bool = False): self.gamma = gamma - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -638,14 +632,13 @@ def __init__(self, self.T_max = T_max or (end - begin) self.eta_min = eta_min self.eta_min_ratio = eta_min_ratio - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -669,13 +662,12 @@ def build_iter_from_epoch(cls, begin = int(begin * epoch_length) if end != INF: end = int(end * epoch_length) - return cls( - *args, - T_max=T_max, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + T_max=T_max, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) def _get_value(self) -> list: """Compute value using chainable form of the scheduler.""" @@ -756,14 +748,13 @@ def __init__(self, self.start_factor = start_factor self.end_factor = end_factor self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -846,14 +837,13 @@ def __init__(self, self.power = power self.total_iters = end - begin - 1 - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -1043,14 +1033,13 @@ def __init__(self, group[f'min_{param_name}'] = \ group[f'initial_{param_name}'] / final_div_factor - super().__init__( - optimizer=optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer=optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) def _format_param(self, name, optimizer, param): """Return correctly formatted lr/momentum for each param group.""" @@ -1098,13 +1087,12 @@ def build_iter_from_epoch(cls, end = int(end * epoch_length) if total_steps is not None: total_steps = total_steps * epoch_length - return cls( - *args, - begin=begin, - end=end, - total_steps=total_steps, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + begin=begin, + end=end, + total_steps=total_steps, + by_epoch=by_epoch, + **kwargs) def _get_value(self): """Compute value using chainable form of the scheduler.""" @@ -1190,14 +1178,13 @@ def __init__(self, sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) ] - super().__init__( - optimizer, - param_name=param_name, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) + super().__init__(optimizer, + param_name=param_name, + begin=begin, + end=end, + last_step=last_step, + by_epoch=by_epoch, + verbose=verbose) @classmethod def build_iter_from_epoch(cls, @@ -1220,13 +1207,12 @@ def build_iter_from_epoch(cls, begin = int(begin * epoch_length) if end != INF: end = int(end * epoch_length) - return cls( - *args, - periods=periods, - begin=begin, - end=end, - by_epoch=by_epoch, - **kwargs) + return cls(*args, + periods=periods, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) def _get_value(self): """Compute value using chainable form of the scheduler.""" @@ -1444,8 +1430,9 @@ def __init__(self, self.eps = eps self.monitor = monitor - self._init_is_better( - rule=rule, threshold=threshold, threshold_rule=threshold_rule) + self._init_is_better(rule=rule, + threshold=threshold, + threshold_rule=threshold_rule) self._reset() # remove call self.step() and init self._global_step = 0 diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index 3de6798514..2856bacdf6 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -5,6 +5,7 @@ from mmengine.config import Config, ConfigDict from mmengine.utils import ManagerMixin, digit_version + from .registry import Registry if TYPE_CHECKING: diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index e7d8962be4..387b3e3d43 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -12,6 +12,7 @@ from mmengine.config.utils import MODULE2PACKAGE from mmengine.utils import get_object_from_string, is_seq_of + from .default_scope import DefaultScope diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index eb9a225a91..06a4817ea0 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -41,8 +41,8 @@ # manage constructors that customize the optimization hyperparameters. OPTIM_WRAPPER_CONSTRUCTORS = Registry('optimizer wrapper constructor') # mangage all kinds of parameter schedulers like `MultiStepLR` -PARAM_SCHEDULERS = Registry( - 'parameter scheduler', build_func=build_scheduler_from_cfg) +PARAM_SCHEDULERS = Registry('parameter scheduler', + build_func=build_scheduler_from_cfg) # manage all kinds of metrics METRICS = Registry('metric') diff --git a/mmengine/registry/utils.py b/mmengine/registry/utils.py index 2737e879a7..66b1ac4cfa 100644 --- a/mmengine/registry/utils.py +++ b/mmengine/registry/utils.py @@ -6,6 +6,7 @@ from mmengine.fileio import dump from mmengine.logging import print_log + from . import root from .default_scope import DefaultScope from .registry import Registry @@ -85,8 +86,8 @@ def count_registered_modules(save_path: Optional[str] = None, scan_date=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), registries=registries_info) if verbose: - print_log( - f'Finish registry analysis, got: {scan_data}', logger='current') + print_log(f'Finish registry analysis, got: {scan_data}', + logger='current') if save_path is not None: json_path = osp.join(save_path, 'modules_statistic_results.json') dump(scan_data, json_path, indent=2) diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py index 5160a5cfb0..3ad936c3be 100644 --- a/mmengine/runner/_flexible_runner.py +++ b/mmengine/runner/_flexible_runner.py @@ -26,6 +26,7 @@ from mmengine.utils import digit_version from mmengine.utils.dl_utils import TORCH_VERSION from mmengine.visualization import Visualizer + from .base_loop import BaseLoop from .checkpoint import find_latest_checkpoint from .log_processor import LogProcessor @@ -708,10 +709,9 @@ def build_visualizer( Visualizer: A Visualizer object build from ``visualizer``. """ if visualizer is None: - visualizer = dict( - name=self.experiment_name, - vis_backends=[dict(type='LocalVisBackend')], - save_dir=self.log_dir) + visualizer = dict(name=self.experiment_name, + vis_backends=[dict(type='LocalVisBackend')], + save_dir=self.log_dir) return Visualizer.get_instance(**visualizer) if isinstance(visualizer, Visualizer): @@ -833,9 +833,9 @@ def build_dataloader( sampler_cfg = dataloader_cfg.pop('sampler') if isinstance(sampler_cfg, dict): sampler_seed = None if diff_rank_seed else seed - sampler = DATA_SAMPLERS.build( - sampler_cfg, - default_args=dict(dataset=dataset, seed=sampler_seed)) + sampler = DATA_SAMPLERS.build(sampler_cfg, + default_args=dict(dataset=dataset, + seed=sampler_seed)) else: # fallback to raise error in dataloader # if `sampler_cfg` is not a valid type @@ -848,9 +848,8 @@ def build_dataloader( elif isinstance(batch_sampler_cfg, dict): batch_sampler = DATA_SAMPLERS.build( batch_sampler_cfg, - default_args=dict( - sampler=sampler, - batch_size=dataloader_cfg.pop('batch_size'))) + default_args=dict(sampler=sampler, + batch_size=dataloader_cfg.pop('batch_size'))) else: # fallback to raise error in dataloader # if `batch_sampler_cfg` is not a valid type @@ -955,18 +954,20 @@ def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: 'Only one of `type` or `by_epoch` can exist in `loop_cfg`.') if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, dataloader=self._train_dataloader)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._train_dataloader)) else: by_epoch = loop_cfg.pop('by_epoch') if by_epoch: - loop = EpochBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) + loop = EpochBasedTrainLoop(**loop_cfg, + runner=self, + dataloader=self._train_dataloader) else: - loop = IterBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) + loop = IterBasedTrainLoop(**loop_cfg, + runner=self, + dataloader=self._train_dataloader) return loop # type: ignore def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: @@ -997,18 +998,16 @@ def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: loop_cfg = copy.deepcopy(loop) if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._val_dataloader, + evaluator=self._val_evaluator)) else: - loop = ValLoop( - **loop_cfg, - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator) # type: ignore + loop = ValLoop(**loop_cfg, + runner=self, + dataloader=self._val_dataloader, + evaluator=self._val_evaluator) # type: ignore return loop # type: ignore @@ -1039,18 +1038,16 @@ def build_test_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: loop_cfg = copy.deepcopy(loop) # type: ignore if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._test_dataloader, + evaluator=self._test_evaluator)) else: - loop = TestLoop( - **loop_cfg, - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator) # type: ignore + loop = TestLoop(**loop_cfg, + runner=self, + dataloader=self._test_dataloader, + evaluator=self._test_evaluator) # type: ignore return loop # type: ignore @@ -1172,12 +1169,11 @@ def train(self) -> nn.Module: compile = copy.copy(self._compile) compile.setdefault('target', 'train_step') - dispatch_kwargs = dict( - epoch_length=len(self.train_dataloader), - max_epochs=self.max_epochs, - max_iters=self.max_iters, - train_micro_batch_size_per_gpu=_get_batch_size( - self.train_dataloader)) # type: ignore + dispatch_kwargs = dict(epoch_length=len(self.train_dataloader), + max_epochs=self.max_epochs, + max_iters=self.max_iters, + train_micro_batch_size_per_gpu=_get_batch_size( + self.train_dataloader)) # type: ignore self.strategy.prepare( self.model, @@ -1215,9 +1211,8 @@ def val(self) -> dict: self._val_loop = self.build_val_loop(self._val_loop) # type: ignore - dispatch_kwargs = dict( - init_weights_for_test_or_val=self.cfg.get( - 'init_weights_for_test_or_val', True)) + dispatch_kwargs = dict(init_weights_for_test_or_val=self.cfg.get( + 'init_weights_for_test_or_val', True)) self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) self.model = self.strategy.model @@ -1242,9 +1237,8 @@ def test(self) -> dict: '`test_evaluator` arguments when initializing runner.') self._test_loop = self.build_test_loop(self._test_loop) # type: ignore - dispatch_kwargs = dict( - init_weights_for_test_or_val=self.cfg.get( - 'init_weights_for_test_or_val', True)) + dispatch_kwargs = dict(init_weights_for_test_or_val=self.cfg.get( + 'init_weights_for_test_or_val', True)) self.strategy.prepare(self.model, dispatch_kwargs=dispatch_kwargs) self.model = self.strategy.model @@ -1467,8 +1461,8 @@ def callback(checkpoint): # check whether the number of GPU used for current experiment # is consistent with resuming from checkpoint if 'config' in checkpoint['meta']: - config = mmengine.Config.fromstring( - checkpoint['meta']['config'], file_format='.py') + config = mmengine.Config.fromstring(checkpoint['meta']['config'], + file_format='.py') previous_gpu_ids = config.get('gpu_ids', None) if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0 and len(previous_gpu_ids) != self.world_size): @@ -1525,12 +1519,11 @@ def load_checkpoint(self, def callback(checkpoint): self.call_hook('after_load_checkpoint', checkpoint=checkpoint) - self.strategy.load_checkpoint( - filename, - map_location=map_location, - strict=strict, - revise_keys=revise_keys, - callback=callback) + self.strategy.load_checkpoint(filename, + map_location=map_location, + strict=strict, + revise_keys=revise_keys, + callback=callback) def save_checkpoint( self, @@ -1596,8 +1589,8 @@ def save_checkpoint( filepath = join_path( # type: ignore out_dir, filename, backend_args=backend_args) - meta.update( - cfg=self.cfg.pretty_text, experiment_name=self.experiment_name) + meta.update(cfg=self.cfg.pretty_text, + experiment_name=self.experiment_name) if hasattr(self.train_dataloader.dataset, 'metainfo'): meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) diff --git a/mmengine/runner/amp.py b/mmengine/runner/amp.py index 198babc582..0bd23a1f84 100644 --- a/mmengine/runner/amp.py +++ b/mmengine/runner/amp.py @@ -138,8 +138,9 @@ def autocast(device_type: Optional[str] = None, elif device_type == 'musa': if dtype is None: dtype = torch.get_autocast_gpu_dtype() - with torch.musa.amp.autocast( - enabled=enabled, dtype=dtype, cache_enabled=cache_enabled): + with torch.musa.amp.autocast(enabled=enabled, + dtype=dtype, + cache_enabled=cache_enabled): yield return else: @@ -153,9 +154,8 @@ def autocast(device_type: Optional[str] = None, raise ValueError('User specified autocast device_type must be ' f'cuda or cpu, but got {device_type}') - with torch.autocast( - device_type=device_type, - enabled=enabled, - dtype=dtype, - cache_enabled=cache_enabled): + with torch.autocast(device_type=device_type, + enabled=enabled, + dtype=dtype, + cache_enabled=cache_enabled): yield diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index fa0a1eb520..7cd323092f 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -48,8 +48,8 @@ def _get_mmengine_home(): mmengine_home = os.path.expanduser( os.getenv( ENV_MMENGINE_HOME, - os.path.join( - os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmengine'))) + os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), + 'mmengine'))) mkdir_or_exist(mmengine_home) return mmengine_home @@ -344,7 +344,9 @@ def load_from_local(filename, map_location): filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location, weights_only=False) + checkpoint = torch.load(filename, + map_location=map_location, + weights_only=False) return checkpoint @@ -368,19 +370,17 @@ def load_from_http(filename, """ rank, world_size = get_dist_info() if rank == 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) + checkpoint = load_url(filename, + model_dir=model_dir, + map_location=map_location, + progress=progress) if world_size > 1: torch.distributed.barrier() if rank > 0: - checkpoint = load_url( - filename, - model_dir=model_dir, - map_location=map_location, - progress=progress) + checkpoint = load_url(filename, + model_dir=model_dir, + map_location=map_location, + progress=progress) return checkpoint @@ -432,8 +432,8 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): Returns: dict or OrderedDict: The loaded checkpoint. """ - file_backend = get_file_backend( - filename, backend_args={'backend': backend}) + file_backend = get_file_backend(filename, + backend_args={'backend': backend}) with io.BytesIO(file_backend.get(filename)) as buffer: checkpoint = torch.load(buffer, map_location=map_location) return checkpoint @@ -522,8 +522,8 @@ def load_from_mmcls(filename, map_location=None): model_urls = get_mmcls_models() model_name = filename[8:] - checkpoint = load_from_http( - model_urls[model_name], map_location=map_location) + checkpoint = load_from_http(model_urls[model_name], + map_location=map_location) checkpoint = _process_mmcls_checkpoint(checkpoint) return checkpoint @@ -597,9 +597,10 @@ def _load_checkpoint_to_model(model, # strip prefix of state_dict metadata = getattr(state_dict, '_metadata', OrderedDict()) for p, r in revise_keys: - state_dict = OrderedDict( - {re.sub(p, r, k): v - for k, v in state_dict.items()}) + state_dict = OrderedDict({ + re.sub(p, r, k): v + for k, v in state_dict.items() + }) # Keep metadata in state_dict state_dict._metadata = metadata @@ -720,8 +721,10 @@ def get_state_dict(module, destination=None, prefix='', keep_vars=False): module._save_to_state_dict(destination, prefix, keep_vars) for name, child in module._modules.items(): if child is not None: - get_state_dict( - child, destination, prefix + name + '.', keep_vars=keep_vars) + get_state_dict(child, + destination, + prefix + name + '.', + keep_vars=keep_vars) for hook in module._state_dict_hooks.values(): hook_result = hook(module, destination, prefix, local_metadata) if hook_result is not None: @@ -783,8 +786,8 @@ def save_checkpoint(checkpoint, else: file_client = FileClient.infer_client(file_client_args, filename) if file_client_args is None: - file_backend = get_file_backend( - filename, backend_args=backend_args) + file_backend = get_file_backend(filename, + backend_args=backend_args) else: file_backend = file_client diff --git a/mmengine/runner/log_processor.py b/mmengine/runner/log_processor.py index 98183ae317..404000f510 100644 --- a/mmengine/runner/log_processor.py +++ b/mmengine/runner/log_processor.py @@ -301,10 +301,9 @@ def get_log_after_epoch(self, dict(data_src='time', window_size='epoch', method_name='mean')) if 'data_time' not in custom_keys: custom_cfg_copy.append( - dict( - data_src='data_time', - window_size='epoch', - method_name='mean')) + dict(data_src='data_time', + window_size='epoch', + method_name='mean')) parsed_cfg = self._parse_windows_size(runner, batch_idx, custom_cfg_copy) # tag is used to write log information to different backends. diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 4dc0d04ef5..b7ae43b7b5 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -7,12 +7,13 @@ import torch from torch.utils.data import DataLoader +from mmengine.dataset.sampler import InfiniteSampler from mmengine.evaluator import Evaluator from mmengine.logging import HistoryBuffer, print_log from mmengine.registry import LOOPS from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of -from mmengine.dataset.sampler import InfiniteSampler + from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -124,19 +125,19 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: Args: data_batch (Sequence[dict]): Batch of data from dataloader. """ - self.runner.call_hook( - 'before_train_iter', batch_idx=idx, data_batch=data_batch) + self.runner.call_hook('before_train_iter', + batch_idx=idx, + data_batch=data_batch) # Enable gradient accumulation mode and avoid unnecessary gradient # synchronization during gradient accumulation process. # outputs should be a dict of loss. outputs = self.runner.model.train_step( data_batch, optim_wrapper=self.runner.optim_wrapper) - self.runner.call_hook( - 'after_train_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) + self.runner.call_hook('after_train_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) self._iter += 1 def _decide_current_val_interval(self) -> None: @@ -275,14 +276,14 @@ def run(self) -> None: # In iteration-based training loop, we treat the whole training process # as a big epoch and execute the corresponding hook. self.runner.call_hook('before_train_epoch') - if self._iter > 0 and not isinstance(self.dataloader.sampler, InfiniteSampler): + if self._iter > 0 and not isinstance(self.dataloader.sampler, + InfiniteSampler): print_log( f'Advance dataloader {self._iter} steps to skip data ' 'that has already been trained', logger='current', level=logging.WARNING) for _ in range(self._iter): - break # NOTE MGAM: override all preprocessing steps during resume. next(self.dataloader_iterator) while self._iter < self._max_iters and not self.stop_training: self.runner.model.train() @@ -307,19 +308,19 @@ def run_iter(self, data_batch: Sequence[dict]) -> None: Args: data_batch (Sequence[dict]): Batch of data from dataloader. """ - self.runner.call_hook( - 'before_train_iter', batch_idx=self._iter, data_batch=data_batch) + self.runner.call_hook('before_train_iter', + batch_idx=self._iter, + data_batch=data_batch) # Enable gradient accumulation mode and avoid unnecessary gradient # synchronization during gradient accumulation process. # outputs should be a dict of loss. outputs = self.runner.model.train_step( data_batch, optim_wrapper=self.runner.optim_wrapper) - self.runner.call_hook( - 'after_train_iter', - batch_idx=self._iter, - data_batch=data_batch, - outputs=outputs) + self.runner.call_hook('after_train_iter', + batch_idx=self._iter, + data_batch=data_batch, + outputs=outputs) self._iter += 1 def _decide_current_val_interval(self) -> None: @@ -399,8 +400,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]): data_batch (Sequence[dict]): Batch of data from dataloader. """ - self.runner.call_hook( - 'before_val_iter', batch_idx=idx, data_batch=data_batch) + self.runner.call_hook('before_val_iter', + batch_idx=idx, + data_batch=data_batch) # outputs should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) @@ -408,11 +410,10 @@ def run_iter(self, idx, data_batch: Sequence[dict]): outputs, self.val_loss = _update_losses(outputs, self.val_loss) self.evaluator.process(data_samples=outputs, data_batch=data_batch) - self.runner.call_hook( - 'after_val_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) + self.runner.call_hook('after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) @LOOPS.register_module() @@ -482,8 +483,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: Args: data_batch (Sequence[dict]): Batch of data from dataloader. """ - self.runner.call_hook( - 'before_test_iter', batch_idx=idx, data_batch=data_batch) + self.runner.call_hook('before_test_iter', + batch_idx=idx, + data_batch=data_batch) # predictions should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.test_step(data_batch) @@ -491,11 +493,10 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: outputs, self.test_loss = _update_losses(outputs, self.test_loss) self.evaluator.process(data_samples=outputs, data_batch=data_batch) - self.runner.call_hook( - 'after_test_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) + self.runner.call_hook('after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) def _parse_losses(losses: Dict[str, HistoryBuffer], diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index c0c48f2947..764d6e7d46 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -42,6 +42,7 @@ from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, set_multi_processing) from mmengine.visualization import Visualizer + from .activation_checkpointing import turn_on_activation_checkpointing from .base_loop import BaseLoop from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, @@ -429,8 +430,8 @@ def __init__( model.setdefault('data_preprocessor', data_preprocessor) self.model = self.build_model(model) # wrap model - self.model = self.wrap_model( - self.cfg.get('model_wrapper_cfg'), self.model) + self.model = self.wrap_model(self.cfg.get('model_wrapper_cfg'), + self.model) # get model name from the model class if hasattr(self.model, 'module'): @@ -714,10 +715,9 @@ def set_randomness(self, more details. """ self._deterministic = deterministic - self._seed = set_random_seed( - seed=seed, - deterministic=deterministic, - diff_rank_seed=diff_rank_seed) + self._seed = set_random_seed(seed=seed, + deterministic=deterministic, + diff_rank_seed=diff_rank_seed) def build_logger(self, log_level: Union[int, str] = 'INFO', @@ -788,10 +788,9 @@ def build_visualizer( Visualizer: A Visualizer object build from ``visualizer``. """ if visualizer is None: - visualizer = dict( - name=self._experiment_name, - vis_backends=[dict(type='LocalVisBackend')], - save_dir=self._log_dir) + visualizer = dict(name=self._experiment_name, + vis_backends=[dict(type='LocalVisBackend')], + save_dir=self._log_dir) return Visualizer.get_instance(**visualizer) if isinstance(visualizer, Visualizer): @@ -903,27 +902,28 @@ def wrap_model( find_unused_parameters=find_unused_parameters) else: model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') - + model_wrapper_type = model_wrapper_cfg.get('type') if isinstance(model_wrapper_type, str): - model_wrapper_type = MODEL_WRAPPERS.get(model_wrapper_type) # type: ignore + model_wrapper_type = MODEL_WRAPPERS.get( + model_wrapper_type) # type: ignore elif inspect.isclass(model_wrapper_type): pass else: raise KeyError( - f'{model_wrapper_type} is not in the ' - 'registry. Please check whether the value of ' - f'`{model_wrapper_type}` is correct or it was registered ' - 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 - ) + f'{model_wrapper_type} is not in the ' + 'registry. Please check whether the value of ' + f'`{model_wrapper_type}` is correct or it was registered ' + 'as expected. More details can be found at https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#import-the-custom-module' # noqa: E501 + ) default_args: dict = dict() if issubclass( model_wrapper_type, # type: ignore DistributedDataParallel): default_args['device_ids'] = [int(os.environ['LOCAL_RANK'])] default_args['module'] = model - model = MODEL_WRAPPERS.build( - model_wrapper_cfg, default_args=default_args) + model = MODEL_WRAPPERS.build(model_wrapper_cfg, + default_args=default_args) return model def _init_model_weights(self) -> None: @@ -1188,11 +1188,11 @@ def _build_param_scheduler( 'Use the max epochs/iters of train loop as default.') param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict( - optimizer=optim_wrapper, - epoch_length=len(self.train_dataloader)))) + PARAM_SCHEDULERS.build(_scheduler, + default_args=dict( + optimizer=optim_wrapper, + epoch_length=len( + self.train_dataloader)))) else: raise TypeError( 'scheduler should be a _ParamScheduler object or dict, ' @@ -1390,18 +1390,17 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], num_batch_per_epoch = dataloader_cfg.pop('num_batch_per_epoch', None) if num_batch_per_epoch is not None: world_size = get_world_size() - num_samples = ( - num_batch_per_epoch * _get_batch_size(dataloader_cfg) * - world_size) + num_samples = (num_batch_per_epoch * + _get_batch_size(dataloader_cfg) * world_size) dataset = _SlicedDataset(dataset, num_samples) # build sampler sampler_cfg = dataloader_cfg.pop('sampler') if isinstance(sampler_cfg, dict): sampler_seed = None if diff_rank_seed else seed - sampler = DATA_SAMPLERS.build( - sampler_cfg, - default_args=dict(dataset=dataset, seed=sampler_seed)) + sampler = DATA_SAMPLERS.build(sampler_cfg, + default_args=dict(dataset=dataset, + seed=sampler_seed)) else: # fallback to raise error in dataloader # if `sampler_cfg` is not a valid type @@ -1414,9 +1413,8 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], elif isinstance(batch_sampler_cfg, dict): batch_sampler = DATA_SAMPLERS.build( batch_sampler_cfg, - default_args=dict( - sampler=sampler, - batch_size=dataloader_cfg.pop('batch_size'))) + default_args=dict(sampler=sampler, + batch_size=dataloader_cfg.pop('batch_size'))) else: # fallback to raise error in dataloader # if `batch_sampler_cfg` is not a valid type @@ -1529,18 +1527,20 @@ def build_train_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: 'Only one of `type` or `by_epoch` can exist in `loop_cfg`.') if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, dataloader=self._train_dataloader)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._train_dataloader)) else: by_epoch = loop_cfg.pop('by_epoch') if by_epoch: - loop = EpochBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) + loop = EpochBasedTrainLoop(**loop_cfg, + runner=self, + dataloader=self._train_dataloader) else: - loop = IterBasedTrainLoop( - **loop_cfg, runner=self, dataloader=self._train_dataloader) + loop = IterBasedTrainLoop(**loop_cfg, + runner=self, + dataloader=self._train_dataloader) return loop # type: ignore def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: @@ -1571,18 +1571,16 @@ def build_val_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: loop_cfg = copy.deepcopy(loop) if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._val_dataloader, + evaluator=self._val_evaluator)) else: - loop = ValLoop( - **loop_cfg, - runner=self, - dataloader=self._val_dataloader, - evaluator=self._val_evaluator) # type: ignore + loop = ValLoop(**loop_cfg, + runner=self, + dataloader=self._val_dataloader, + evaluator=self._val_evaluator) # type: ignore return loop # type: ignore @@ -1613,18 +1611,16 @@ def build_test_loop(self, loop: Union[BaseLoop, Dict]) -> BaseLoop: loop_cfg = copy.deepcopy(loop) # type: ignore if 'type' in loop_cfg: - loop = LOOPS.build( - loop_cfg, - default_args=dict( - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator)) + loop = LOOPS.build(loop_cfg, + default_args=dict( + runner=self, + dataloader=self._test_dataloader, + evaluator=self._test_evaluator)) else: - loop = TestLoop( - **loop_cfg, - runner=self, - dataloader=self._test_dataloader, - evaluator=self._test_evaluator) # type: ignore + loop = TestLoop(**loop_cfg, + runner=self, + dataloader=self._test_dataloader, + evaluator=self._test_evaluator) # type: ignore return loop # type: ignore @@ -2028,8 +2024,8 @@ def resume(self, device = get_device() checkpoint = self.load_checkpoint(filename, map_location=device) else: - checkpoint = self.load_checkpoint( - filename, map_location=map_location) + checkpoint = self.load_checkpoint(filename, + map_location=map_location) self.train_loop._epoch = checkpoint['meta']['epoch'] self.train_loop._iter = checkpoint['meta']['iter'] @@ -2037,8 +2033,8 @@ def resume(self, # check whether the number of GPU used for current experiment # is consistent with resuming from checkpoint if 'config' in checkpoint['meta']: - config = mmengine.Config.fromstring( - checkpoint['meta']['config'], file_format='.py') + config = mmengine.Config.fromstring(checkpoint['meta']['config'], + file_format='.py') previous_gpu_ids = config.get('gpu_ids', None) if (previous_gpu_ids is not None and len(previous_gpu_ids) > 0 and len(previous_gpu_ids) != self._world_size): @@ -2146,8 +2142,10 @@ def load_checkpoint(self, else: model = self.model - checkpoint = _load_checkpoint_to_model( - model, checkpoint, strict, revise_keys=revise_keys) + checkpoint = _load_checkpoint_to_model(model, + checkpoint, + strict, + revise_keys=revise_keys) self._has_loaded = True @@ -2223,12 +2221,11 @@ def save_checkpoint( filepath = join_path( # type: ignore out_dir, filename, backend_args=backend_args) - meta.update( - cfg=self.cfg.pretty_text, - seed=self.seed, - experiment_name=self.experiment_name, - time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), - mmengine_version=mmengine.__version__ + get_git_hash()) + meta.update(cfg=self.cfg.pretty_text, + seed=self.seed, + experiment_name=self.experiment_name, + time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), + mmengine_version=mmengine.__version__ + get_git_hash()) if hasattr(self.train_dataloader.dataset, 'metainfo'): meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) @@ -2280,11 +2277,10 @@ def save_checkpoint( checkpoint['param_schedulers'].append(state_dict) self.call_hook('before_save_checkpoint', checkpoint=checkpoint) - save_checkpoint( - checkpoint, - filepath, - file_client_args=file_client_args, - backend_args=backend_args) + save_checkpoint(checkpoint, + filepath, + file_client_args=file_client_args, + backend_args=backend_args) @master_only def dump_config(self) -> None: diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py index 8ac5a3d27d..da27a4b16e 100644 --- a/mmengine/structures/base_data_element.py +++ b/mmengine/structures/base_data_element.py @@ -395,8 +395,10 @@ def __setattr__(self, name: str, value: Any): raise AttributeError(f'{name} has been used as a ' 'private attribute, which is immutable.') else: - self.set_field( - name=name, value=value, field_type='data', dtype=None) + self.set_field(name=name, + value=value, + field_type='data', + dtype=None) def __delattr__(self, item: str): """Delete the item in dataelement. diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py index 8633b86037..e841a4d73a 100644 --- a/mmengine/structures/instance_data.py +++ b/mmengine/structures/instance_data.py @@ -7,6 +7,7 @@ import torch from mmengine.device import get_device + from .base_data_element import BaseDataElement BoolTypeTensor: Union[Any] diff --git a/mmengine/testing/compare.py b/mmengine/testing/compare.py index 14c7a97ba7..549fbe64ef 100644 --- a/mmengine/testing/compare.py +++ b/mmengine/testing/compare.py @@ -42,18 +42,20 @@ def assert_allclose( """ if 'parrots' not in TORCH_VERSION and \ digit_version(TORCH_VERSION) >= digit_version('1.6'): - _assert_allclose( - actual, - expected, - rtol=rtol, - atol=atol, - equal_nan=equal_nan, - msg=msg) + _assert_allclose(actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan, + msg=msg) else: # torch.testing.assert_allclose has no ``msg`` argument # when PyTorch < 1.6 - _assert_allclose( - actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan) + _assert_allclose(actual, + expected, + rtol=rtol, + atol=atol, + equal_nan=equal_nan) def check_python_script(cmd): @@ -180,8 +182,8 @@ def assert_params_all_zeros(module) -> bool: if hasattr(module, 'bias') and module.bias is not None: bias_data = module.bias.data - is_bias_zero = bias_data.allclose( - bias_data.new_zeros(bias_data.size())) + is_bias_zero = bias_data.allclose(bias_data.new_zeros( + bias_data.size())) else: is_bias_zero = True diff --git a/mmengine/testing/runner_test_case.py b/mmengine/testing/runner_test_case.py index f64594acef..c1dea6bdb4 100644 --- a/mmengine/testing/runner_test_case.py +++ b/mmengine/testing/runner_test_case.py @@ -91,12 +91,11 @@ class RunnerTestCase(TestCase): 3. Provides `build_runner` method to build runner easily. 4. Clean the global variable used by the runner. """ - dist_cfg = dict( - MASTER_ADDR='127.0.0.1', - MASTER_PORT=29600, - RANK='0', - WORLD_SIZE='1', - LOCAL_RANK='0') + dist_cfg = dict(MASTER_ADDR='127.0.0.1', + MASTER_PORT=29600, + RANK='0', + WORLD_SIZE='1', + LOCAL_RANK='0') def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() @@ -108,22 +107,22 @@ def setUp(self) -> None: epoch_based_cfg = dict( work_dir=self.temp_dir.name, model=dict(type='ToyModel'), - train_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), + train_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), val_evaluator=[dict(type='ToyMetric')], - test_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), + test_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), test_evaluator=[dict(type='ToyMetric')], optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.1)), train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1), @@ -145,10 +144,12 @@ def setUp(self) -> None: self.iter_based_cfg.log_processor = dict(by_epoch=False) self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12) - self.iter_based_cfg.default_hooks = dict( - logger=dict(type='LoggerHook', interval=1), - checkpoint=dict( - type='CheckpointHook', interval=12, by_epoch=False)) + self.iter_based_cfg.default_hooks = dict(logger=dict(type='LoggerHook', + interval=1), + checkpoint=dict( + type='CheckpointHook', + interval=12, + by_epoch=False)) def tearDown(self): # `FileHandler` should be closed in Windows, otherwise we cannot diff --git a/mmengine/utils/dl_utils/collect_env.py b/mmengine/utils/dl_utils/collect_env.py index 0ee99abad2..83882425c8 100644 --- a/mmengine/utils/dl_utils/collect_env.py +++ b/mmengine/utils/dl_utils/collect_env.py @@ -11,6 +11,7 @@ import mmengine from mmengine.device import is_cuda_available, is_musa_available + from .parrots_wrapper import TORCH_VERSION, get_build_config, is_rocm_pytorch @@ -77,8 +78,8 @@ def collect_env(): if CUDA_HOME == '/opt/rocm': try: nvcc = osp.join(CUDA_HOME, 'hip/bin/hipcc') - nvcc = subprocess.check_output( - f'"{nvcc}" --version', shell=True) + nvcc = subprocess.check_output(f'"{nvcc}" --version', + shell=True) nvcc = nvcc.decode('utf-8').strip() release = nvcc.rfind('HIP version:') build = nvcc.rfind('') @@ -134,8 +135,9 @@ def collect_env(): from distutils.ccompiler import new_compiler ccompiler = new_compiler() ccompiler.initialize() - cc = subprocess.check_output( - f'{ccompiler.cc}', stderr=subprocess.STDOUT, shell=True) + cc = subprocess.check_output(f'{ccompiler.cc}', + stderr=subprocess.STDOUT, + shell=True) encoding = os.device_encoding( sys.stdout.fileno()) or locale.getpreferredencoding() env_info['MSVC'] = cc.decode(encoding).partition('\n')[0].strip() diff --git a/mmengine/utils/dl_utils/hub.py b/mmengine/utils/dl_utils/hub.py index 7f7f1a087d..41deaa0b1a 100644 --- a/mmengine/utils/dl_utils/hub.py +++ b/mmengine/utils/dl_utils/hub.py @@ -107,8 +107,10 @@ def load_url(url, if check_hash: r = HASH_REGEX.search(filename) # r is Optional[Match[str]] hash_prefix = r.group(1) if r else None - download_url_to_file( - url, cached_file, hash_prefix, progress=progress) + download_url_to_file(url, + cached_file, + hash_prefix, + progress=progress) if _is_legacy_zip_format(cached_file): return _legacy_zip_load(cached_file, model_dir, map_location) diff --git a/mmengine/utils/dl_utils/torch_ops.py b/mmengine/utils/dl_utils/torch_ops.py index 2550ae6986..85dc3100d2 100644 --- a/mmengine/utils/dl_utils/torch_ops.py +++ b/mmengine/utils/dl_utils/torch_ops.py @@ -4,9 +4,9 @@ from ..version_utils import digit_version from .parrots_wrapper import TORCH_VERSION -_torch_version_meshgrid_indexing = ( - 'parrots' not in TORCH_VERSION - and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) +_torch_version_meshgrid_indexing = ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) + >= digit_version('1.10.0a0')) def torch_meshgrid(*tensors): diff --git a/mmengine/utils/dl_utils/visualize.py b/mmengine/utils/dl_utils/visualize.py index f3361e1d50..6f7b05e095 100644 --- a/mmengine/utils/dl_utils/visualize.py +++ b/mmengine/utils/dl_utils/visualize.py @@ -49,11 +49,11 @@ def fake_run(cfg): cfg.pop('test_cfg') extra_cfg = dict( model=dict(type='ToyModel'), - visualizer=dict( - type='Visualizer', - vis_backends=[ - dict(type='TensorboardVisBackend', save_dir='temp_dir') - ]), + visualizer=dict(type='Visualizer', + vis_backends=[ + dict(type='TensorboardVisBackend', + save_dir='temp_dir') + ]), ) cfg.merge_from_dict(extra_cfg) # build the runner from config diff --git a/mmengine/utils/progressbar_rich.py b/mmengine/utils/progressbar_rich.py index f8e04d8041..44162e2160 100644 --- a/mmengine/utils/progressbar_rich.py +++ b/mmengine/utils/progressbar_rich.py @@ -121,8 +121,9 @@ def track_progress_rich(func: Callable, ) worker = _Worker(func) - task_id = prog_bar.add_task( - total=task_num, color=color, description=description) + task_id = prog_bar.add_task(total=task_num, + color=color, + description=description) tasks = _tasks_with_index(tasks) # Use single process when nproc is 1, else use multiprocess. diff --git a/mmengine/utils/version_utils.py b/mmengine/utils/version_utils.py index 620180547a..2e02ecddd4 100644 --- a/mmengine/utils/version_utils.py +++ b/mmengine/utils/version_utils.py @@ -58,9 +58,10 @@ def _minimal_ext_cmd(cmd): env['LANGUAGE'] = 'C' env['LANG'] = 'C' env['LC_ALL'] = 'C' - out, err = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - env=env).communicate() + out, err = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env).communicate() return out diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index a5bf7d88e7..3a350daf07 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -437,8 +437,8 @@ def add_config(self, config: Config, **kwargs) -> None: """ assert isinstance(self._init_kwargs, dict) allow_val_change = self._init_kwargs.get('allow_val_change', False) - self._wandb.config.update( - config.to_dict(), allow_val_change=allow_val_change) + self._wandb.config.update(config.to_dict(), + allow_val_change=allow_val_change) self._wandb.run.log_code(name=self._log_code_name) @force_init_env @@ -605,7 +605,7 @@ def add_scalar(self, self._tensorboard.add_scalar(name, value, step) else: warnings.warn(f'Got type {type(value)} with name {name}, ' - 'but numpy array, torch tensor, ' + 'but numpy array, torch tensor, ' f'int or float are expected. skip it!') @force_init_env @@ -939,8 +939,10 @@ def add_image(self, should be RGB. step (int): Global step value to record. Defaults to 0. """ - self._logger.report_image( - title=name, series=name, iteration=step, image=image) + self._logger.report_image(title=name, + series=name, + iteration=step, + image=image) @force_init_env def add_scalar(self, @@ -955,8 +957,10 @@ def add_scalar(self, value (int, float, torch.Tensor, np.ndarray): Value to save. step (int): Global step value to record. Defaults to 0. """ - self._logger.report_scalar( - title=name, series=name, value=value, iteration=step) + self._logger.report_scalar(title=name, + series=name, + value=value, + iteration=step) @force_init_env def add_scalars(self, @@ -976,8 +980,10 @@ def add_scalars(self, assert 'step' not in scalar_dict, 'Please set it directly ' \ 'through the step parameter' for key, value in scalar_dict.items(): - self._logger.report_scalar( - title=key, series=key, value=value, iteration=step) + self._logger.report_scalar(title=key, + series=key, + value=value, + iteration=step) def close(self) -> None: """Close the clearml.""" @@ -1093,8 +1099,9 @@ def add_image(self, # values in the array need to be in the [0, 1] range img = image.astype(np.float32) / 255.0 - self._neptune['images'].append( - File.as_image(img), name=name, step=step) + self._neptune['images'].append(File.as_image(img), + name=name, + step=step) @force_init_env def add_scalar(self, diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index 6979395aca..e1525a86e3 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -271,9 +271,8 @@ def show(self, # will be updated with `win_name`. cv2.namedWindow(winname=f'{id(self)}') cv2.setWindowTitle(f'{id(self)}', win_name) - cv2.imshow( - str(id(self)), - self.get_image() if drawn_img is None else drawn_img) + cv2.imshow(str(id(self)), + self.get_image() if drawn_img is None else drawn_img) cv2.waitKey(int(np.ceil(wait_time * 1000))) else: raise ValueError('backend should be "matplotlib" or "cv2", ' @@ -300,10 +299,9 @@ def set_image(self, image: np.ndarray) -> None: # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) self.ax_save.cla() self.ax_save.axis(False) - self.ax_save.imshow( - image, - extent=(0, self.width, self.height, 0), - interpolation='none') + self.ax_save.imshow(image, + extent=(0, self.width, self.height, 0), + interpolation='none') @master_only def get_image(self) -> np.ndarray: @@ -344,14 +342,16 @@ def _init_manager(self, win_name: str) -> None: from matplotlib.figure import Figure from matplotlib.pyplot import new_figure_manager if getattr(self, 'manager', None) is None: - self.manager = new_figure_manager( - num=1, FigureClass=Figure, **self.fig_show_cfg) + self.manager = new_figure_manager(num=1, + FigureClass=Figure, + **self.fig_show_cfg) try: self.manager.set_window_title(win_name) except Exception: - self.manager = new_figure_manager( - num=1, FigureClass=Figure, **self.fig_show_cfg) + self.manager = new_figure_manager(num=1, + FigureClass=Figure, + **self.fig_show_cfg) self.manager.set_window_title(win_name) @master_only @@ -413,8 +413,11 @@ def draw_points(self, 'The shape of `positions` should be (N, 2), ' f'but got {positions.shape}') colors = color_val_matplotlib(colors) # type: ignore - self.ax_save.scatter( - positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker) + self.ax_save.scatter(positions[:, 0], + positions[:, 1], + c=colors, + s=sizes, + marker=marker) return self @master_only @@ -616,11 +619,10 @@ def draw_lines( warnings.warn( 'Warning: The line is out of bounds,' ' the drawn line may not be in the image', UserWarning) - line_collect = LineCollection( - lines.tolist(), - colors=colors, - linestyles=line_styles, - linewidths=line_widths) + line_collect = LineCollection(lines.tolist(), + colors=colors, + linestyles=line_styles, + linewidths=line_widths) self.ax_save.add_collection(line_collect) return self @@ -676,10 +678,9 @@ def draw_circles( assert center.shape == (radius.shape[0], 2), ( 'The shape of `center` should be (radius.shape, 2), ' f'but got {center.shape}') - if not (self._is_posion_valid(center - - np.tile(radius.reshape((-1, 1)), (1, 2))) - and self._is_posion_valid( - center + np.tile(radius.reshape((-1, 1)), (1, 2)))): + if not (self._is_posion_valid(center - np.tile(radius.reshape( + (-1, 1)), (1, 2))) and self._is_posion_valid( + center + np.tile(radius.reshape((-1, 1)), (1, 2)))): warnings.warn( 'Warning: The circle is out of bounds,' ' the drawn circle may not be in the image', UserWarning) @@ -698,13 +699,12 @@ def draw_circles( min(max(linewidth, 1), self._default_font_size / 4) for linewidth in line_widths ] - p = PatchCollection( - circles, - alpha=alpha, - facecolors=face_colors, - edgecolors=edge_colors, - linewidths=line_widths, - linestyles=line_styles) + p = PatchCollection(circles, + alpha=alpha, + facecolors=face_colors, + edgecolors=edge_colors, + linewidths=line_widths, + linestyles=line_styles) self.ax_save.add_collection(p) return self @@ -754,8 +754,9 @@ def draw_bboxes( assert bboxes.shape[-1] == 4, ( f'The shape of `bboxes` should be (N, 4), but got {bboxes.shape}') - assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= - bboxes[:, 3]).all() + assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] + <= bboxes[:, + 3]).all() if not self._is_posion_valid(bboxes.reshape((-1, 2, 2))): warnings.warn( 'Warning: The bbox is out of bounds,' @@ -765,13 +766,12 @@ def draw_bboxes( bboxes[:, 2], bboxes[:, 3], bboxes[:, 0], bboxes[:, 3]), axis=-1).reshape(-1, 4, 2) poly = [p for p in poly] - return self.draw_polygons( - poly, - alpha=alpha, - edge_colors=edge_colors, - line_styles=line_styles, - line_widths=line_widths, - face_colors=face_colors) + return self.draw_polygons(poly, + alpha=alpha, + edge_colors=edge_colors, + line_styles=line_styles, + line_widths=line_widths, + face_colors=face_colors) @master_only def draw_polygons( @@ -837,13 +837,12 @@ def draw_polygons( min(max(linewidth, 1), self._default_font_size / 4) for linewidth in line_widths ] - polygon_collection = PolyCollection( - polygons, - alpha=alpha, - facecolor=face_colors, - linestyles=line_styles, - edgecolors=edge_colors, - linewidths=line_widths) + polygon_collection = PolyCollection(polygons, + alpha=alpha, + facecolor=face_colors, + linestyles=line_styles, + edgecolors=edge_colors, + linewidths=line_widths) self.ax_save.add_collection(polygon_collection) return self @@ -903,14 +902,14 @@ def draw_binary_masks( rgb = np.zeros_like(img) rgb[...] = color rgb = cv2.bitwise_and(rgb, rgb, mask=binary_mask) - img_complement = cv2.bitwise_and( - img, img, mask=binary_mask_complement) + img_complement = cv2.bitwise_and(img, + img, + mask=binary_mask_complement) rgb = rgb + img_complement img = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0) - self.ax_save.imshow( - img, - extent=(0, self.width, self.height, 0), - interpolation='nearest') + self.ax_save.imshow(img, + extent=(0, self.width, self.height, 0), + interpolation='nearest') return self @staticmethod @@ -991,18 +990,16 @@ def draw_featmap(featmap: torch.Tensor, f'the feature map will be interpolated. ' f'This may cause mismatch problems !') if resize_shape is None: - featmap = F.interpolate( - featmap[None], - overlaid_image.shape[:2], - mode='bilinear', - align_corners=False)[0] + featmap = F.interpolate(featmap[None], + overlaid_image.shape[:2], + mode='bilinear', + align_corners=False)[0] if resize_shape is not None: - featmap = F.interpolate( - featmap[None], - resize_shape, - mode='bilinear', - align_corners=False)[0] + featmap = F.interpolate(featmap[None], + resize_shape, + mode='bilinear', + align_corners=False)[0] if overlaid_image is not None: overlaid_image = cv2.resize(overlaid_image, resize_shape[::-1]) @@ -1044,8 +1041,12 @@ def draw_featmap(featmap: torch.Tensor, fig = plt.figure(frameon=False) # Set the window layout - fig.subplots_adjust( - left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) + fig.subplots_adjust(left=0, + right=1, + bottom=0, + top=1, + wspace=0, + hspace=0) dpi = fig.get_dpi() fig.set_size_inches((width * col + 1e-2) / dpi, (height * row + 1e-2) / dpi) diff --git a/tests/test_analysis/test_flop_count.py b/tests/test_analysis/test_flop_count.py index 20749a0bab..99d096cbf8 100644 --- a/tests/test_analysis/test_flop_count.py +++ b/tests/test_analysis/test_flop_count.py @@ -243,8 +243,8 @@ def addmm_dummy_flop_jit( custom_ops2: Dict[str, Handle] = { f'aten::{self.lin_op}': addmm_dummy_flop_jit } - flop_dict2, _ = flop_count( - custom_net, (x, ), supported_ops=custom_ops2) + flop_dict2, _ = flop_count(custom_net, (x, ), + supported_ops=custom_ops2) flop = 400000 / 1e9 self.assertEqual( flop_dict2[self.lin_op], @@ -365,9 +365,9 @@ def _test_conv( else: spatial_size = ( (spatial_dim + 2 * padding) - kernel_size) // stride + 1 - gt_flop = ( - batch_size * input_dim * output_dim * (kernel_size**conv_dim) * - (spatial_size**conv_dim) / group_size / 1e9) + gt_flop = (batch_size * input_dim * output_dim * + (kernel_size**conv_dim) * (spatial_size**conv_dim) / + group_size / 1e9) gt_dict = defaultdict(float) gt_dict['conv'] = gt_flop self.assertDictEqual( @@ -849,8 +849,8 @@ def _count_function(self, func, inputs, name) -> Tuple[Any, Any]: def f(*args): return func(*inputs) - graph = torch.jit.trace( - f, tuple(tensor_inputs), check_trace=False).graph + graph = torch.jit.trace(f, tuple(tensor_inputs), + check_trace=False).graph nodes = [k for k in graph.nodes() if k.kind() == name] self.assertEqual(len(nodes), 1) node = nodes[0] diff --git a/tests/test_analysis/test_jit_analysis.py b/tests/test_analysis/test_jit_analysis.py index be10309d0f..b66dcb7f85 100644 --- a/tests/test_analysis/test_jit_analysis.py +++ b/tests/test_analysis/test_jit_analysis.py @@ -44,8 +44,8 @@ def __init__(self, lin_op: str = 'addmm') -> None: fc_flops_ = fc_in * fc_out fc_flops = Counter({lin_op: fc_flops_}) - spatial_pos = (conv_input_size[1] + 2 * padding) - 2 * ( - kernel_size // 2) + spatial_pos = (conv_input_size[1] + + 2 * padding) - 2 * (kernel_size // 2) conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out conv_flops = Counter({'conv': conv_flops_}) model_flops = conv_flops + fc_flops @@ -95,8 +95,8 @@ def __init__(self, lin_op: str = 'addmm') -> None: fc_flops_ = fc_in * fc_out fc_flops = Counter({lin_op: fc_flops_}) - spatial_pos = (self.input_size[1] + 2 * padding) - 2 * ( - kernel_size // 2) + spatial_pos = (self.input_size[1] + + 2 * padding) - 2 * (kernel_size // 2) conv_flops_ = spatial_pos * kernel_size * conv_in * conv_out conv_flops = Counter({'conv': conv_flops_}) @@ -428,8 +428,8 @@ def test_non_forward_func_call(self) -> None: model = NonForwardNet() inputs = (torch.randn((1, 10)), ) - analyzer = FlopAnalyzer( - model=model, inputs=inputs).ancestor_mode('caller') + analyzer = FlopAnalyzer(model=model, + inputs=inputs).ancestor_mode('caller') inner_fc_count = model.submod.fc_flops total_count = model.fc_flops + inner_fc_count @@ -441,8 +441,8 @@ def test_non_forward_func_call(self) -> None: # The mod not directly called is registered as such self.assertEqual(analyzer.uncalled_modules(), {'submod'}) - analyzer = FlopAnalyzer( - model=model, inputs=inputs).ancestor_mode('owner') + analyzer = FlopAnalyzer(model=model, + inputs=inputs).ancestor_mode('owner') self.assertEqual(analyzer.total('submod'), inner_fc_count) self.assertEqual(analyzer.total('submod.fc'), inner_fc_count) self.assertEqual(analyzer.total(''), total_count) @@ -455,9 +455,9 @@ def test_shared_module(self) -> None: model = SharedModuleNet() inputs = (torch.randn((1, *model.input_size)), ) - analyzer = ( - FlopAnalyzer(model=model, inputs=inputs).unsupported_ops_warnings( - enabled=False).ancestor_mode('caller')) + analyzer = (FlopAnalyzer(model=model, + inputs=inputs).unsupported_ops_warnings( + enabled=False).ancestor_mode('caller')) # The names `submod2.submod` and `multiname2` are not included, # since only the first name of a module is made the canonical one. @@ -487,14 +487,14 @@ def test_shared_module(self) -> None: ) # Test getting canonical name - self.assertEqual( - analyzer.canonical_module_name('multiname2'), 'multiname1') - self.assertEqual( - analyzer.canonical_module_name('multiname1'), 'multiname1') - self.assertEqual( - analyzer.canonical_module_name('submod2.submod'), 'submod1.submod') - self.assertEqual( - analyzer.canonical_module_name('submod1.submod'), 'submod1.submod') + self.assertEqual(analyzer.canonical_module_name('multiname2'), + 'multiname1') + self.assertEqual(analyzer.canonical_module_name('multiname1'), + 'multiname1') + self.assertEqual(analyzer.canonical_module_name('submod2.submod'), + 'submod1.submod') + self.assertEqual(analyzer.canonical_module_name('submod1.submod'), + 'submod1.submod') # Tests no uncalled modules self.assertEqual(analyzer.uncalled_modules(), set()) @@ -561,13 +561,12 @@ def test_unsupported_ops(self) -> None: model = NestedNet(lin_op=self.lin_op) inputs = (torch.randn((1, *model.input_size)), ) - analyzer = JitModelAnalysis( - model=model, inputs=inputs).set_op_handle( - 'aten::addmm', - addmm_flop_jit, - 'aten::linear', - linear_flop_jit, - ) + analyzer = JitModelAnalysis(model=model, inputs=inputs).set_op_handle( + 'aten::addmm', + addmm_flop_jit, + 'aten::linear', + linear_flop_jit, + ) analyzer.total() skipped_inner_conv = Counter({'aten::_convolution': 1}) @@ -606,8 +605,8 @@ def test_changing_handles(self) -> None: 'aten::linear': linear_flop_jit, } - analyzer = JitModelAnalysis( - model=model, inputs=inputs).set_op_handle(**op_handles) + analyzer = JitModelAnalysis(model=model, + inputs=inputs).set_op_handle(**op_handles) analyzer.unsupported_ops_warnings(enabled=False) # Request a result once to cache flop counts @@ -634,9 +633,10 @@ def dummy_ops_handle(inputs: List[Any], dummy_flops = {} for name, counts in model.flops.items(): - dummy_flops[name] = Counter( - {op: flop - for op, flop in counts.items() if op != self.lin_op}) + dummy_flops[name] = Counter({ + op: flop + for op, flop in counts.items() if op != self.lin_op + }) dummy_flops[''][dummy_name] = 2 * dummy_out dummy_flops['fc'][dummy_name] = dummy_out dummy_flops['submod'][dummy_name] = dummy_out @@ -657,14 +657,12 @@ def test_copy(self) -> None: model = RepeatedNet() inputs = (torch.randn((1, *model.input_size)), ) - analyzer = ( - JitModelAnalysis(model=model, inputs=inputs).set_op_handle( - 'aten::addmm', - addmm_flop_jit, - 'aten::linear', - linear_flop_jit, - ).unsupported_ops_warnings(enabled=False).tracer_warnings( - mode='none')) + analyzer = (JitModelAnalysis(model=model, inputs=inputs).set_op_handle( + 'aten::addmm', + addmm_flop_jit, + 'aten::linear', + linear_flop_jit, + ).unsupported_ops_warnings(enabled=False).tracer_warnings(mode='none')) repeated_net_flops = model.fc1_num * model.fc1_flops repeated_net_flops += model.fc2_num * model.fc2_flops @@ -699,8 +697,8 @@ def test_copy(self) -> None: new_model = NonForwardNet() bs = 5 new_inputs = (torch.randn((bs, *new_model.input_size)), ) - analyzer_new = analyzer.copy( - new_model=new_model, new_inputs=new_inputs) + analyzer_new = analyzer.copy(new_model=new_model, + new_inputs=new_inputs) non_forward_flops = new_model.fc_flops + new_model.submod.fc_flops diff --git a/tests/test_analysis/test_print_helper.py b/tests/test_analysis/test_print_helper.py index 14366583d5..3abd0a0bd9 100644 --- a/tests/test_analysis/test_print_helper.py +++ b/tests/test_analysis/test_print_helper.py @@ -60,23 +60,24 @@ def test_get_model_complexity_info(): assert complexity_info['flops'] == flops assert complexity_info['params'] == params - complexity_info = get_model_complexity_info( - model=model, input_shape=input_shape1) - flops = FlopAnalyzer( - model=model, inputs=(torch.randn(1, *input_shape1), )).total() + complexity_info = get_model_complexity_info(model=model, + input_shape=input_shape1) + flops = FlopAnalyzer(model=model, + inputs=(torch.randn(1, *input_shape1), )).total() assert complexity_info['flops'] == flops # test a network that accepts two tensors as input model = NetAcceptTwoTensors() - complexity_info = get_model_complexity_info( - model=model, inputs=(input1, input2)) + complexity_info = get_model_complexity_info(model=model, + inputs=(input1, input2)) flops = FlopAnalyzer(model=model, inputs=(input1, input2)).total() params = parameter_count(model=model)[''] assert complexity_info['flops'] == flops assert complexity_info['params'] == params - complexity_info = get_model_complexity_info( - model=model, input_shape=(input_shape1, input_shape2)) + complexity_info = get_model_complexity_info(model=model, + input_shape=(input_shape1, + input_shape2)) inputs = (torch.randn(1, *input_shape1), torch.randn(1, *input_shape2)) flops = FlopAnalyzer(model=model, inputs=inputs).total() assert complexity_info['flops'] == flops @@ -88,8 +89,8 @@ def test_get_model_complexity_info(): scalar = torch.tensor([ scalar ]) if digit_version(TORCH_VERSION) < digit_version('1.9.0') else scalar - complexity_info = get_model_complexity_info( - model=model, inputs=(input1, scalar)) + complexity_info = get_model_complexity_info(model=model, + inputs=(input1, scalar)) flops = FlopAnalyzer(model=model, inputs=(input1, scalar)).total() params = parameter_count(model=model)[''] assert complexity_info['flops'] == flops @@ -104,5 +105,6 @@ def test_get_model_complexity_info(): # when both `inputs` and `input_shape` are specified model = NetAcceptOneTensor() with pytest.raises(ValueError, match='cannot be both set'): - get_model_complexity_info( - model, inputs=input1, input_shape=input_shape1) + get_model_complexity_info(model, + inputs=input1, + input_shape=input_shape1) diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index e783431441..8d6c2bef5f 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -40,8 +40,10 @@ def test_init(self, file_format): Config([0, 1]) # test `filename` parameter - cfg_dict = dict( - item1=[1, 2], item2=dict(a=0), item3=True, item4='test') + cfg_dict = dict(item1=[1, 2], + item2=dict(a=0), + item3=True, + item4='test') cfg_file = osp.join( self.data_path, f'config/{file_format}_config/simple_config.{file_format}') @@ -54,9 +56,9 @@ def test_init(self, file_format): self.data_path, f'config/{file_format}_config/test_reserved_key.{file_format}') # reserved keys cannot be set in config - with pytest.raises( - KeyError, match='filename is reserved for config ' - 'file'): + with pytest.raises(KeyError, + match='filename is reserved for config ' + 'file'): Config.fromfile(cfg_file) def test_fromfile(self): @@ -74,8 +76,8 @@ def test_fromfile(self): Config.fromfile(cfg_file, import_custom_modules=False) assert 'TEST_VALUE' not in os.environ sys.modules.pop('test_custom_import_module') - with pytest.raises( - ImportError, match='Failed to import custom modules from'): + with pytest.raises(ImportError, + match='Failed to import custom modules from'): Config.fromfile(cfg_file, import_custom_modules=True) @pytest.mark.parametrize('file_format', ['py', 'json', 'yaml']) @@ -100,8 +102,10 @@ def test_fromstring(self, file_format): Config.fromstring(cfg_str, '.xml') def test_magic_methods(self): - cfg_dict = dict( - item1=[1, 2], item2=dict(a=0), item3=True, item4='test') + cfg_dict = dict(item1=[1, 2], + item2=dict(a=0), + item3=True, + item4='test') filename = 'py_config/simple_config.py' cfg_file = osp.join(self.data_path, 'config', filename) cfg = Config.fromfile(cfg_file) @@ -218,8 +222,9 @@ def test_auto_argparser(self): sys.argv.extend(tmp) def test_dict_to_config_dict(self): - cfg_dict = dict( - a=1, b=dict(c=dict()), d=[dict(e=dict(f=(dict(g=1), [])))]) + cfg_dict = dict(a=1, + b=dict(c=dict()), + d=[dict(e=dict(f=(dict(g=1), [])))]) cfg_dict = Config._dict_to_config_dict(cfg_dict) assert isinstance(cfg_dict, ConfigDict) assert isinstance(cfg_dict.a, int) @@ -316,8 +321,10 @@ def test_repr(self, tmp_path): def test_dict_action(self): parser = argparse.ArgumentParser(description='Train a detector') - parser.add_argument( - '--options', nargs='+', action=DictAction, help='custom options') + parser.add_argument('--options', + nargs='+', + action=DictAction, + help='custom options') # Nested brackets args = parser.parse_args( ['--options', 'item2.a=a,b', 'item2.b=[(a,b), [1,2], false]']) @@ -471,10 +478,9 @@ def test_pre_substitute_base_vars(self, tmp_path): assert cfg_module_dict['item10'].startswith('_item7') def test_substitute_base_vars(self): - cfg = dict( - item4='_item1.12345', - item5=dict(item3='1', item2='_item2_.fswf'), - item0=('_item0_.12ed21wq', 1)) + cfg = dict(item4='_item1.12345', + item5=dict(item3='1', item2='_item2_.fswf'), + item0=('_item0_.12ed21wq', 1)) cfg_base = dict(item1=0, item2=[1, 2, 3], item0=(1, 2, 3)) base_var_dict = { '_item1.12345': 'item1', @@ -517,9 +523,8 @@ def test_get_cfg_path_local(self): assert scope is None osp.isfile(cfg_path) - @pytest.mark.skipif( - not is_installed('mmdet') or not is_installed('mmcls'), - reason='mmdet and mmcls should be installed') + @pytest.mark.skipif(not is_installed('mmdet') or not is_installed('mmcls'), + reason='mmdet and mmcls should be installed') def test_get_cfg_path_external(self): filename = 'py_config/simple_config.py' filename = osp.join(self.data_path, 'config', filename) @@ -559,20 +564,18 @@ def _predefined_vars(self): path = osp.join(self.data_path, 'config/py_config') path = Path(path).as_posix() - cfg_dict_dst = dict( - item1='test_predefined_var.py', - item2=path, - item3='abc_test_predefined_var') + cfg_dict_dst = dict(item1='test_predefined_var.py', + item2=path, + item3='abc_test_predefined_var') assert Config._file2dict(cfg_file)[0]['item1'] == cfg_dict_dst['item1'] assert Config._file2dict(cfg_file)[0]['item2'] == cfg_dict_dst['item2'] assert Config._file2dict(cfg_file)[0]['item3'] == cfg_dict_dst['item3'] # test `use_predefined_variable=False` - cfg_dict_ori = dict( - item1='{{fileBasename}}', - item2='{{ fileDirname}}', - item3='abc_{{ fileBasenameNoExtension }}') + cfg_dict_ori = dict(item1='{{fileBasename}}', + item2='{{ fileDirname}}', + item3='abc_{{ fileBasenameNoExtension }}') assert Config._file2dict(cfg_file, False)[0]['item1'] == cfg_dict_ori['item1'] @@ -652,8 +655,8 @@ def _merge_from_multiple_bases(self): assert cfg_dict['item4'] == 'test' assert cfg_dict['item5'] == dict(a=0, b=1) assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict['item7'] == dict(a=[0, 1, 2], + b=dict(c=[3.1, 4.2, 5.3])) # Redefine key with pytest.raises(KeyError): Config.fromfile( @@ -674,8 +677,8 @@ def _base_variables(self): assert cfg_dict['item4'] == 'test' assert cfg_dict['item5'] == dict(a=0, b=1) assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict['item7'] == dict(a=[0, 1, 2], + b=dict(c=[3.1, 4.2, 5.3])) assert cfg_dict['item8'] == file.split('/')[-1] assert cfg_dict['item9'] == dict(a=0) assert cfg_dict['item10'] == [3.1, 4.2, 5.3] @@ -696,8 +699,8 @@ def _base_variables(self): assert cfg_dict['item4'] == 'test' assert cfg_dict['item5'] == dict(a=0, b=1) assert cfg_dict['item6'] == [dict(a=0), dict(b=1)] - assert cfg_dict['item7'] == dict( - a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3])) + assert cfg_dict['item7'] == dict(a=[0, 1, 2], + b=dict(c=[3.1, 4.2, 5.3])) assert cfg_dict['item8'] == 'test_base_variables.py' assert cfg_dict['item9'] == dict(a=0) assert cfg_dict['item10'] == [3.1, 4.2, 5.3] @@ -705,18 +708,17 @@ def _base_variables(self): assert cfg_dict['item12'] == dict(a=0) assert cfg_dict['item13'] == [3.1, 4.2, 5.3] assert cfg_dict['item14'] == [1, 2] - assert cfg_dict['item15'] == dict( - a=dict(b=dict(a=0)), - b=[False], - c=['test'], - d=[[{ - 'e': 0 - }], [{ - 'a': 0 - }, { - 'b': 1 - }]], - e=[1, 2]) + assert cfg_dict['item15'] == dict(a=dict(b=dict(a=0)), + b=[False], + c=['test'], + d=[[{ + 'e': 0 + }], [{ + 'a': 0 + }, { + 'b': 1 + }]], + e=[1, 2]) # test reference assignment for py cfg_file = osp.join( @@ -728,17 +730,16 @@ def _base_variables(self): assert cfg_dict['item22'] == 'test_base_variables.py' assert cfg_dict['item23'] == [3.1, 4.2, 5.3] assert cfg_dict['item24'] == [3.1, 4.2, 5.3] - assert cfg_dict['item25'] == dict( - a=dict(b=[3.1, 4.2, 5.3]), - b=[[3.1, 4.2, 5.3]], - c=[[{ - 'e': 'test_base_variables.py' - }], [{ - 'a': 0 - }, { - 'b': 1 - }]], - e='test_base_variables.py') + assert cfg_dict['item25'] == dict(a=dict(b=[3.1, 4.2, 5.3]), + b=[[3.1, 4.2, 5.3]], + c=[[{ + 'e': 'test_base_variables.py' + }], [{ + 'a': 0 + }, { + 'b': 1 + }]], + e='test_base_variables.py') cfg_file = osp.join(self.data_path, 'config/py_config/test_py_base.py') cfg = Config.fromfile(cfg_file) @@ -780,18 +781,17 @@ def _base_variables(self): assert cfg.item12 == 'test_py_base.py' assert cfg.item13 == 3.1 assert cfg.item14 == [1, 2] - assert cfg.item15 == dict( - a=dict(b=dict(a=0, b=[5, 6])), - b=[False], - c=['test'], - d=[[{ - 'e': 0 - }], [{ - 'c': 0 - }, { - 'b': 1 - }]], - e=[1, 2]) + assert cfg.item15 == dict(a=dict(b=dict(a=0, b=[5, 6])), + b=[False], + c=['test'], + d=[[{ + 'e': 0 + }], [{ + 'c': 0 + }, { + 'b': 1 + }]], + e=[1, 2]) # Test use global variable in config function cfg_file = osp.join(self.data_path, @@ -913,8 +913,8 @@ def test_copy(self): assert new_cfg._filename == cfg._filename assert new_cfg._text == cfg._text - @pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet should be installed') + @pytest.mark.skipif(not is_installed('mmdet'), + reason='mmdet should be installed') def test_get_external_cfg(self): ext_cfg_path = osp.join(self.data_path, 'config/py_config/test_get_external_cfg.py') @@ -927,8 +927,8 @@ def test_get_external_cfg(self): ) assert '_scope_' in ext_cfg._cfg_dict.model - @pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet should be installed') + @pytest.mark.skipif(not is_installed('mmdet'), + reason='mmdet should be installed') def test_build_external_package(self): # Test load base config. ext_cfg_path = osp.join(self.data_path, @@ -1062,10 +1062,9 @@ def _compare_dict(a, b): 'config/lazy_module_config/error_mix_using1.py')) # Force to import in non-lazy-import mode - Config.fromfile( - osp.join(self.data_path, - 'config/lazy_module_config/error_mix_using1.py'), - lazy_import=False) + Config.fromfile(osp.join( + self.data_path, 'config/lazy_module_config/error_mix_using1.py'), + lazy_import=False) # current lazy-import config, base text config with pytest.raises(RuntimeError, match='_base_ ='): @@ -1131,15 +1130,12 @@ def test_build_lazy(self): self.assertDictEqual(cfg_dict, raw) # Check `items` and `values` will only return the build object - raw = dict( - a=LazyObject('mmengine'), - b=dict( - c=2, - e=[ - dict( - f=dict(h=LazyObject('mmengine')), - g=LazyObject('mmengine')) - ])) + raw = dict(a=LazyObject('mmengine'), + b=dict(c=2, + e=[ + dict(f=dict(h=LazyObject('mmengine')), + g=LazyObject('mmengine')) + ])) cfg_dict = ConfigDict(raw) # check `items` and values self.assertDictEqual(cfg_dict._to_lazy_dict(), raw) diff --git a/tests/test_data/test_data_utils.py b/tests/test_data/test_data_utils.py index 76e30e8642..255a849bd2 100644 --- a/tests/test_data/test_data_utils.py +++ b/tests/test_data/test_data_utils.py @@ -49,30 +49,26 @@ def test_pseudo_collate(self): self.assertIs(batch_data_sample[1], data_sample2) # Test with list of tuple, each tuple is a nested dict instance - data_batch = [(dict( - inputs=input1, - data_sample=data_sample1, - value=1, - name='1', - nested=dict(data_sample=data_sample1)), - dict( - inputs=input2, - data_sample=data_sample2, - value=2, - name='2', - nested=dict(data_sample=data_sample2))), - (dict( - inputs=input1, - data_sample=data_sample1, - value=1, - name='1', - nested=dict(data_sample=data_sample1)), - dict( - inputs=input2, - data_sample=data_sample2, - value=2, - name='2', - nested=dict(data_sample=data_sample2)))] + data_batch = [(dict(inputs=input1, + data_sample=data_sample1, + value=1, + name='1', + nested=dict(data_sample=data_sample1)), + dict(inputs=input2, + data_sample=data_sample2, + value=2, + name='2', + nested=dict(data_sample=data_sample2))), + (dict(inputs=input1, + data_sample=data_sample1, + value=1, + name='1', + nested=dict(data_sample=data_sample1)), + dict(inputs=input2, + data_sample=data_sample2, + value=2, + name='2', + nested=dict(data_sample=data_sample2)))] data_batch = pseudo_collate(data_batch) batch_inputs_0 = data_batch[0]['inputs'] batch_inputs_1 = data_batch[1]['inputs'] diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index f4ec815ec2..24a1091a85 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -36,8 +36,10 @@ class CustomDataset(BaseDataset): class TestBaseDataset: def setup_method(self): - self.data_info = dict( - filename='test_img.jpg', height=604, width=640, sample_idx=0) + self.data_info = dict(filename='test_img.jpg', + height=604, + width=640, + sample_idx=0) self.imgs = torch.rand((2, 3, 32, 32)) self.ori_meta = BaseDataset.METAINFO self.ori_parse_data_info = BaseDataset.parse_data_info @@ -50,28 +52,28 @@ def teardown_method(self): def test_init(self): # test the instantiation of self.base_dataset - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') # test the instantiation of self.base_dataset with # `serialize_data=False` - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - serialize_data=False) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + serialize_data=False) assert dataset._fully_initialized assert hasattr(dataset, 'data_list') assert not hasattr(dataset, 'data_address') @@ -79,54 +81,49 @@ def test_init(self): assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset with lazy init - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=True) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=True) assert not dataset._fully_initialized assert not dataset.data_list # test the instantiation of self.base_dataset if ann_file is not # existed. with pytest.raises(FileNotFoundError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/not_existed_annotation.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/not_existed_annotation.json') # Use the default value of ann_file, i.e., '' with pytest.raises(TypeError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs')) + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs')) # test the instantiation of self.base_dataset when the ann_file is # wrong with pytest.raises(ValueError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/annotation_wrong_keys.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/annotation_wrong_keys.json') with pytest.raises(TypeError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/annotation_wrong_format.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/annotation_wrong_format.json') with pytest.raises(TypeError): - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=['img']), - ann_file='annotations/annotation_wrong_format.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path=['img']), + ann_file='annotations/annotation_wrong_format.json') # test the instantiation of self.base_dataset when `parse_data_info` # return `list[dict]` BaseDataset.parse_data_info = MagicMock( return_value=[self.data_info, self.data_info.copy()]) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') dataset.pipeline = self.pipeline assert dataset._fully_initialized assert hasattr(dataset, 'data_list') @@ -139,25 +136,24 @@ def test_init(self): # return unsupported data. with pytest.raises(TypeError): BaseDataset.parse_data_info = MagicMock(return_value='xxx') - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') with pytest.raises(TypeError): BaseDataset.parse_data_info = MagicMock( return_value=[self.data_info, 'xxx']) - BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + BaseDataset(data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') # test the instantiation of self.base_dataset without `ann_file` BaseDataset.parse_data_info = self.ori_parse_data_info - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='', - serialize_data=False, - lazy_init=True) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='', + serialize_data=False, + lazy_init=True) assert not dataset.ann_file # Test `ann_file` and `data_root` could be None. @@ -166,125 +162,119 @@ def test_init(self): def test_meta(self): # test dataset.metainfo with setting the metainfo from annotation file # as the metainfo of self.base_dataset. - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') - assert dataset.metainfo == dict( - dataset_type='test_dataset', task_name='test_task', empty_list=[]) + assert dataset.metainfo == dict(dataset_type='test_dataset', + task_name='test_task', + empty_list=[]) # test dataset.metainfo with setting METAINFO in self.base_dataset dataset_type = 'new_dataset' - BaseDataset.METAINFO = dict( - dataset_type=dataset_type, classes=('dog', 'cat')) - - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='test_task', - classes=('dog', 'cat'), - empty_list=[]) + BaseDataset.METAINFO = dict(dataset_type=dataset_type, + classes=('dog', 'cat')) + + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='test_task', + classes=('dog', 'cat'), + empty_list=[]) # test dataset.metainfo with passing metainfo into self.base_dataset metainfo = dict(classes=('dog', ), task_name='new_task') - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat')) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) # test dataset.metainfo with passing metainfo as Config into # self.base_dataset metainfo = Config(dict(classes=('dog', ), task_name='new_task')) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat')) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) # test dataset.metainfo with passing metainfo as ConfigDict (Mapping) # into self.base_dataset metainfo = ConfigDict(dict(classes=('dog', ), task_name='new_task')) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat')) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat')) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) # reset `base_dataset.METAINFO`, the `dataset.metainfo` should not # change BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish') - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='new_task', - classes=('dog', ), - empty_list=[]) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat', 'fish')) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) # test dataset.metainfo with passing metainfo containing a file into # self.base_dataset - metainfo = dict( - classes=osp.join( - osp.dirname(__file__), '../data/meta/classes.txt')) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) - assert dataset.metainfo == dict( - dataset_type=dataset_type, - task_name='test_task', - classes=['dog'], - empty_list=[]) + metainfo = dict(classes=osp.join(osp.dirname(__file__), + '../data/meta/classes.txt')) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert dataset.metainfo == dict(dataset_type=dataset_type, + task_name='test_task', + classes=['dog'], + empty_list=[]) # test dataset.metainfo with passing unsupported metainfo into # self.base_dataset with pytest.raises(TypeError): metainfo = 'dog' - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) # test dataset.metainfo with passing metainfo into self.base_dataset # and lazy_init is True metainfo = dict(classes=('dog', )) - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=metainfo, - lazy_init=True) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo, + lazy_init=True) # 'task_name' and 'empty_list' not in dataset.metainfo - assert dataset.metainfo == dict( - dataset_type=dataset_type, classes=('dog', )) + assert dataset.metainfo == dict(dataset_type=dataset_type, + classes=('dog', )) # test whether self.base_dataset.METAINFO is changed when a customize # dataset inherit self.base_dataset @@ -293,26 +283,26 @@ class ToyDataset(BaseDataset): METAINFO = dict(xxx='xxx') assert ToyDataset.METAINFO == dict(xxx='xxx') - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat', 'fish')) # test update METAINFO in ToyDataset. class ToyDataset(BaseDataset): METAINFO = copy.deepcopy(BaseDataset.METAINFO) METAINFO['classes'] = ('bird', ) - assert ToyDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('bird', )) - assert BaseDataset.METAINFO == dict( - dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) + assert ToyDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('bird', )) + assert BaseDataset.METAINFO == dict(dataset_type=dataset_type, + classes=('dog', 'cat', 'fish')) @pytest.mark.parametrize('lazy_init', [True, False]) def test_length(self, lazy_init): - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init) if not lazy_init: assert dataset._fully_initialized assert hasattr(dataset, 'data_list') @@ -364,11 +354,11 @@ def test_compose(self): @pytest.mark.parametrize('lazy_init', [True, False]) def test_getitem(self, lazy_init): - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init) dataset.pipeline = self.pipeline if not lazy_init: assert dataset._fully_initialized @@ -406,11 +396,11 @@ def fake_prepare_data(idx): @pytest.mark.parametrize('lazy_init', [True, False]) def test_get_data_info(self, lazy_init): - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init) if not lazy_init: assert dataset._fully_initialized @@ -427,10 +417,10 @@ def test_get_data_info(self, lazy_init): # Test parse_data_info with `data_prefix` BaseDataset.parse_data_info = self.ori_parse_data_info data_root = osp.join(osp.dirname(__file__), '../data/') - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') data_info = dataset.get_data_info(0) assert data_info['img_path'] == osp.join(data_root, 'imgs', 'test_img.jpg') @@ -448,11 +438,11 @@ def foo(self): class_without_full_init.foo() def test_full_init(self): - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=True) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=True) dataset.pipeline = self.pipeline # test `full_init()` when lazy_init is True assert not dataset._fully_initialized @@ -465,11 +455,11 @@ def test_full_init(self): assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - lazy_init=False) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=False) dataset.pipeline = self.pipeline assert dataset._fully_initialized @@ -479,10 +469,10 @@ def test_full_init(self): assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset when passing indices - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json') + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json') dataset_sliced = BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img_path=''), @@ -497,12 +487,12 @@ def test_full_init(self): def test_get_subset_(self, lazy_init, serialize_data): # Test positive int indices. indices = 2 - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init, - serialize_data=serialize_data) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init, + serialize_data=serialize_data) dataset_copy = copy.deepcopy(dataset) dataset_copy.get_subset_(indices) @@ -575,12 +565,12 @@ def test_get_subset_(self, lazy_init, serialize_data): def test_get_subset(self, lazy_init, serialize_data): # Test positive indices. indices = 2 - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init, - serialize_data=serialize_data) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init, + serialize_data=serialize_data) dataset_sliced = dataset.get_subset(indices) assert len(dataset_sliced) == 2 assert dataset_sliced[0] == dataset[0] @@ -621,11 +611,11 @@ def test_get_subset(self, lazy_init, serialize_data): def test_rand_another(self): # test the instantiation of self.base_dataset when passing num_samples - dataset = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path=''), - ann_file='annotations/dummy_annotation.json', - indices=1) + dataset = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path=''), + ann_file='annotations/dummy_annotation.json', + indices=1) assert dataset._rand_another() >= 0 assert dataset._rand_another() < len(dataset) @@ -640,20 +630,20 @@ def setup_method(self): dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - self.dataset_a = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + self.dataset_a = dataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') self.dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs)) # create dataset_b data_info = dict(filename='gray.jpg', height=288, width=512) dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - self.dataset_b = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + self.dataset_b = dataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) # test init self.cat_datasets = ConcatDataset( @@ -661,11 +651,11 @@ def setup_method(self): def test_init(self): # Test build dataset from cfg. - dataset_cfg_b = dict( - type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + dataset_cfg_b = dict(type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b]) cat_datasets.datasets[1].pipeline = self.dataset_b.pipeline assert len(cat_datasets) == len(self.cat_datasets) @@ -678,8 +668,8 @@ def test_init(self): ConcatDataset(datasets=[0]) with pytest.raises(TypeError): - ConcatDataset( - datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1) + ConcatDataset(datasets=[self.dataset_a, dataset_cfg_b], + ignore_keys=1) def test_full_init(self): # test init with lazy_init=True @@ -696,11 +686,11 @@ def test_full_init(self): with pytest.raises(NotImplementedError): self.cat_datasets.get_subset(1) - dataset_b = BaseDataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json', - metainfo=dict(classes=('cat'))) + dataset_b = BaseDataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=dict(classes=('cat'))) # Regardless of order, different meta information without # `ignore_keys` will raise error. with pytest.raises(ValueError): @@ -710,11 +700,11 @@ def test_full_init(self): # `ignore_keys` does not contain different meta information keys will # raise error. with pytest.raises(ValueError): - ConcatDataset( - datasets=[self.dataset_a, dataset_b], ignore_keys=['a']) + ConcatDataset(datasets=[self.dataset_a, dataset_b], + ignore_keys=['a']) # Different meta information with `ignore_keys` will not raise error. - cat_datasets = ConcatDataset( - datasets=[self.dataset_a, dataset_b], ignore_keys='classes') + cat_datasets = ConcatDataset(datasets=[self.dataset_a, dataset_b], + ignore_keys='classes') cat_datasets.full_init() assert len(cat_datasets) == 6 cat_datasets.full_init() @@ -727,19 +717,19 @@ def test_metainfo(self): assert self.cat_datasets.metainfo == self.dataset_a.metainfo def test_length(self): - assert len(self.cat_datasets) == ( - len(self.dataset_a) + len(self.dataset_b)) + assert len(self.cat_datasets) == (len(self.dataset_a) + + len(self.dataset_b)) def test_getitem(self): assert ( self.cat_datasets[0]['imgs'] == self.dataset_a[0]['imgs']).all() - assert (self.cat_datasets[0]['imgs'] != - self.dataset_b[0]['imgs']).all() + assert (self.cat_datasets[0]['imgs'] + != self.dataset_b[0]['imgs']).all() assert ( self.cat_datasets[-1]['imgs'] == self.dataset_b[-1]['imgs']).all() - assert (self.cat_datasets[-1]['imgs'] != - self.dataset_a[-1]['imgs']).all() + assert (self.cat_datasets[-1]['imgs'] + != self.dataset_a[-1]['imgs']).all() def test_get_data_info(self): assert self.cat_datasets.get_data_info( @@ -768,26 +758,26 @@ def setup_method(self): data_info = dict(filename='test_img.jpg', height=604, width=640) dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) - self.dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + self.dataset = dataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.repeat_times = 5 # test init - self.repeat_datasets = RepeatDataset( - dataset=self.dataset, times=self.repeat_times) + self.repeat_datasets = RepeatDataset(dataset=self.dataset, + times=self.repeat_times) def test_init(self): # Test build dataset from cfg. - dataset_cfg = dict( - type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') - repeat_dataset = RepeatDataset( - dataset=dataset_cfg, times=self.repeat_times) + dataset_cfg = dict(type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') + repeat_dataset = RepeatDataset(dataset=dataset_cfg, + times=self.repeat_times) repeat_dataset.dataset.pipeline = self.dataset.pipeline assert len(repeat_dataset) == len(self.repeat_datasets) for i in range(len(repeat_dataset)): @@ -840,10 +830,10 @@ def setup_method(self): dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) dataset.get_cat_ids = MagicMock(return_value=[0]) - self.dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') + self.dataset = dataset(data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') self.dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) self.repeat_indices = [0, 0, 1, 1, 1] @@ -854,13 +844,13 @@ def setup_method(self): def test_init(self): # Test build dataset from cfg. - dataset_cfg = dict( - type=CustomDataset, - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img_path='imgs'), - ann_file='annotations/dummy_annotation.json') - cls_banlanced_datasets = ClassBalancedDataset( - dataset=dataset_cfg, oversample_thr=1e-3) + dataset_cfg = dict(type=CustomDataset, + data_root=osp.join(osp.dirname(__file__), + '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json') + cls_banlanced_datasets = ClassBalancedDataset(dataset=dataset_cfg, + oversample_thr=1e-3) cls_banlanced_datasets.repeat_indices = self.repeat_indices cls_banlanced_datasets.dataset.pipeline = self.dataset.pipeline assert len(cls_banlanced_datasets) == len(self.cls_banlanced_datasets) diff --git a/tests/test_dataset/test_sampler.py b/tests/test_dataset/test_sampler.py index 31582a8679..70d510159c 100644 --- a/tests/test_dataset/test_sampler.py +++ b/tests/test_dataset/test_sampler.py @@ -44,9 +44,8 @@ def test_dist(self, mock): self.assertEqual(sampler.num_samples, np.ceil(self.data_length / 3)) self.assertEqual(sampler.total_size, sampler.num_samples * 3) self.assertEqual(len(sampler), sampler.num_samples) - self.assertEqual( - list(sampler), - list(range(self.data_length))[2::3] + [1]) + self.assertEqual(list(sampler), + list(range(self.data_length))[2::3] + [1]) # test round_up=False sampler = DefaultSampler(self.dataset, round_up=False, shuffle=False) diff --git a/tests/test_dist/test_dist.py b/tests/test_dist/test_dist.py index a2ef07b713..95db0f8bd7 100644 --- a/tests/test_dist/test_dist.py +++ b/tests/test_dist/test_dist.py @@ -126,8 +126,9 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29505' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) def setUp(self): super().setUp() @@ -193,9 +194,8 @@ def test_broadcast_dist(self): def test_sync_random_seed(self): self._init_dist_env(self.rank, self.world_size) - with patch.object( - torch, 'tensor', - return_value=torch.tensor(1024)) as mock_tensor: + with patch.object(torch, 'tensor', + return_value=torch.tensor(1024)) as mock_tensor: output = dist.sync_random_seed() assert output == 1024 mock_tensor.assert_called() @@ -333,20 +333,17 @@ def test_all_reduce_params(self): torch.tensor([0, 1], dtype=tensor_type) for _ in range(100) ] else: - data = ( - torch.tensor([2, 3], dtype=tensor_type) - for _ in range(100)) + data = (torch.tensor([2, 3], dtype=tensor_type) + for _ in range(100)) data_gen = (item for item in data) if reduce_op == 'sum': - expected = ( - torch.tensor([2, 4], dtype=tensor_type) - for _ in range(100)) + expected = (torch.tensor([2, 4], dtype=tensor_type) + for _ in range(100)) else: - expected = ( - torch.tensor([1, 2], dtype=tensor_type) - for _ in range(100)) + expected = (torch.tensor([1, 2], dtype=tensor_type) + for _ in range(100)) dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op) @@ -354,8 +351,8 @@ def test_all_reduce_params(self): self.assertTrue(torch.allclose(item1, item2)) -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +@unittest.skipIf(torch.cuda.device_count() < 2, + reason='need 2 gpu to test nccl') class TestDistWithNCCLBackend(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): @@ -366,8 +363,9 @@ def _init_dist_env(self, rank, world_size): num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='nccl', + rank=rank, + world_size=world_size) def setUp(self): super().setUp() @@ -431,9 +429,8 @@ def test_broadcast_dist(self): def test_sync_random_seed(self): self._init_dist_env(self.rank, self.world_size) - with patch.object( - torch, 'tensor', - return_value=torch.tensor(1024)) as mock_tensor: + with patch.object(torch, 'tensor', + return_value=torch.tensor(1024)) as mock_tensor: output = dist.sync_random_seed() assert output == 1024 mock_tensor.assert_called() @@ -580,8 +577,10 @@ def test_collect_results(self): # broadcast tmpdir to all ranks to make it consistent object_list = [tmpdir] dist.broadcast_object_list(object_list) - output = dist.collect_results( - data, size, device='cpu', tmpdir=object_list[0]) + output = dist.collect_results(data, + size, + device='cpu', + tmpdir=object_list[0]) if dist.get_rank() == 0: self.assertEqual(output, expected) else: @@ -646,13 +645,13 @@ def test_all_reduce_params(self): dist.all_reduce_params(data_gen, coalesce=coalesce, op=reduce_op) if reduce_op == 'sum': - expected = ( - torch.tensor([2, 4], dtype=tensor_type).to(device_type) - for _ in range(100)) + expected = (torch.tensor([2, 4], + dtype=tensor_type).to(device_type) + for _ in range(100)) else: - expected = ( - torch.tensor([1, 2], dtype=tensor_type).to(device_type) - for _ in range(100)) + expected = (torch.tensor([1, 2], + dtype=tensor_type).to(device_type) + for _ in range(100)) for item1, item2 in zip(data_gen, expected): self.assertTrue(torch.allclose(item1, item2)) diff --git a/tests/test_dist/test_utils.py b/tests/test_dist/test_utils.py index d9af72f964..4cffb385fa 100644 --- a/tests/test_dist/test_utils.py +++ b/tests/test_dist/test_utils.py @@ -101,8 +101,8 @@ def test_get_data_device(self): 'data should be a Tensor, sequence of tensor or dict'): dist.get_data_device('123') - @unittest.skipIf( - torch.cuda.device_count() == 0, reason='at lest need 1 gpu to test') + @unittest.skipIf(torch.cuda.device_count() == 0, + reason='at lest need 1 gpu to test') def test_cast_data_device(self): expected_device = torch.device('cuda', torch.cuda.current_device()) # data is a Tensor @@ -181,8 +181,8 @@ def test_cast_data_device(self): self.assertEqual(output['key1'].device, expected_device) self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1'])) self.assertEqual(output['key2'][0].device, expected_device) - self.assertTrue( - torch.allclose(output['key2'][0].cpu(), out['key2'][0])) + self.assertTrue(torch.allclose(output['key2'][0].cpu(), + out['key2'][0])) # data is not a valid type with self.assertRaisesRegex( @@ -218,8 +218,9 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_PORT'] = '29505' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) dist.init_local_group(0, world_size) def setUp(self): @@ -247,8 +248,8 @@ def test_local_size(self): def test_local_rank(self): self._init_dist_env(self.rank, self.world_size) - self.assertEqual( - torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) + self.assertEqual(torch_dist.get_rank(dist.get_local_group()), + dist.get_local_rank()) def test_get_dist_info(self): self._init_dist_env(self.rank, self.world_size) @@ -337,8 +338,8 @@ def test_get_comm_device(self): assert dist.get_comm_device(group) == torch.device('cpu') -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test nccl') +@unittest.skipIf(torch.cuda.device_count() < 2, + reason='need 2 gpu to test nccl') class TestUtilsWithNCCLBackend(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): @@ -349,8 +350,9 @@ def _init_dist_env(self, rank, world_size): num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='nccl', + rank=rank, + world_size=world_size) dist.init_local_group(0, world_size) def setUp(self): @@ -378,8 +380,8 @@ def test_local_size(self): def test_local_rank(self): self._init_dist_env(self.rank, self.world_size) - self.assertEqual( - torch_dist.get_rank(dist.get_local_group()), dist.get_local_rank()) + self.assertEqual(torch_dist.get_rank(dist.get_local_group()), + dist.get_local_rank()) def test_get_dist_info(self): self._init_dist_env(self.rank, self.world_size) @@ -579,8 +581,8 @@ def test_cast_data_device(self): self.assertEqual(output['key1'].device, expected_device) self.assertTrue(torch.allclose(output['key1'].cpu(), out['key1'])) self.assertEqual(output['key2'][0].device, expected_device) - self.assertTrue( - torch.allclose(output['key2'][0].cpu(), out['key2'][0])) + self.assertTrue(torch.allclose(output['key2'][0].cpu(), + out['key2'][0])) # data is not a valid type with self.assertRaisesRegex( diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py index 58b7e1e6fe..c9f4100b40 100644 --- a/tests/test_evaluator/test_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -100,8 +100,10 @@ def test_single_metric(self): size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, + batch_size, + pred=1, + label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) metrics = evaluator.evaluate(size=size) @@ -126,8 +128,10 @@ def test_composed_metrics(self): size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, + batch_size, + pred=1, + label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) metrics = evaluator.evaluate(size=size) @@ -147,8 +151,10 @@ def test_ambiguous_metric(self): size = 10 batch_size = 4 - for data_samples, outputs in generate_test_results( - size, batch_size, pred=1, label=1): + for data_samples, outputs in generate_test_results(size, + batch_size, + pred=1, + label=1): evaluator.process(data_samples=outputs, data_batch=data_samples) with self.assertRaisesRegex( @@ -175,10 +181,9 @@ def test_dataset_meta(self): def test_collect_device(self): cfg = [ dict(type='ToyMetric', collect_device='cpu'), - dict( - type='ToyMetric', - collect_device='gpu', - dummy_metrics=dict(mAP=0.0)) + dict(type='ToyMetric', + collect_device='gpu', + dummy_metrics=dict(mAP=0.0)) ] evaluator = Evaluator(cfg) @@ -262,16 +267,15 @@ def test_evaluate_cast_cpu(self): size = 10 all_data = [ - dict( - inputs=torch.zeros((3, 10, 10), device='cuda'), - data_sample=BaseDataElement( - label=torch.ones((1, ), device='cuda'))) + dict(inputs=torch.zeros((3, 10, 10), device='cuda'), + data_sample=BaseDataElement( + label=torch.ones((1, ), device='cuda'))) for _ in range(size) ] all_predictions = [ - BaseDataElement( - pred=torch.zeros((1, ), device='cuda'), - label=torch.ones((1, ), device='cuda')) for _ in range(size) + BaseDataElement(pred=torch.zeros((1, ), device='cuda'), + label=torch.ones((1, ), device='cuda')) + for _ in range(size) ] for data, pred in zip(all_data, all_predictions): evaluator.process([pred], [data]) diff --git a/tests/test_evaluator/test_metric.py b/tests/test_evaluator/test_metric.py index 055bd73ca1..d1a5608ef4 100644 --- a/tests/test_evaluator/test_metric.py +++ b/tests/test_evaluator/test_metric.py @@ -19,10 +19,9 @@ def test_init(self): # collect_dir could only be configured when collect_device='cpu' with self.assertRaises(ValueError): - DumpResults( - out_file_path='./results.json', - collect_device='gpu', - collect_dir='./tmp') + DumpResults(out_file_path='./results.json', + collect_device='gpu', + collect_dir='./tmp') def test_process(self): metric = DumpResults(out_file_path='./results.pkl') diff --git a/tests/test_fileio/test_backends/test_backend_utils.py b/tests/test_fileio/test_backends/test_backend_utils.py index 7903f5574e..9ed38ff701 100644 --- a/tests/test_fileio/test_backends/test_backend_utils.py +++ b/tests/test_fileio/test_backends/test_backend_utils.py @@ -57,8 +57,8 @@ def get(self, filepath): def get_text(self, filepath): return filepath - with pytest.raises( - TypeError, match='not a subclass of BaseStorageBackend'): + with pytest.raises(TypeError, + match='not a subclass of BaseStorageBackend'): register_backend('example3', ExampleBackend2) # 4. test `force` parameter @@ -85,8 +85,9 @@ def get_text(self, filepath): assert 'prefix1' in prefix_to_backends # 5.2 prefixes is a list (tuple) of strings - register_backend( - 'example4', ExampleBackend3, prefixes=['prefix2', 'prefix3']) + register_backend('example4', + ExampleBackend3, + prefixes=['prefix2', 'prefix3']) assert 'example4' in backends assert 'prefix2' in prefix_to_backends assert 'prefix3' in prefix_to_backends @@ -108,7 +109,9 @@ def get(self, filepath): def get_text(self, filepath): return filepath - register_backend( - 'example6', ExampleBackend4, prefixes='prefix2', force=True) + register_backend('example6', + ExampleBackend4, + prefixes='prefix2', + force=True) assert 'example6' in backends assert 'prefix2' in prefix_to_backends diff --git a/tests/test_fileio/test_backends/test_local_backend.py b/tests/test_fileio/test_backends/test_local_backend.py index 427ebf789a..71b2423504 100644 --- a/tests/test_fileio/test_backends/test_local_backend.py +++ b/tests/test_fileio/test_backends/test_local_backend.py @@ -146,15 +146,15 @@ def test_isfile(self, path_type): @parameterized.expand([[Path], [str]]) def test_join_path(self, path_type): backend = LocalBackend() - filepath = backend.join_path( - path_type(self.test_data_dir), path_type('file')) + filepath = backend.join_path(path_type(self.test_data_dir), + path_type('file')) expected = osp.join(path_type(self.test_data_dir), path_type('file')) self.assertEqual(filepath, expected) - filepath = backend.join_path( - path_type(self.test_data_dir), path_type('dir'), path_type('file')) - expected = osp.join( - path_type(self.test_data_dir), path_type('dir'), path_type('file')) + filepath = backend.join_path(path_type(self.test_data_dir), + path_type('dir'), path_type('file')) + expected = osp.join(path_type(self.test_data_dir), path_type('dir'), + path_type('file')) self.assertEqual(filepath, expected) @parameterized.expand([[Path], [str]]) @@ -170,17 +170,15 @@ def test_copyfile(self, path_type): src = Path(tmp_dir) / 'test.txt' backend.put_text('disk', src) dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) self.assertEqual(backend.get_text(dst), 'disk') # dst is a directory dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) self.assertEqual( backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') @@ -195,17 +193,16 @@ def test_copytree(self, path_type): # src and dst are Path objects src = Path(tmp_dir) / 'dir1' dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) self.assertTrue(backend.isdir(dst)) self.assertTrue(backend.isfile(dst / 'text3.txt')) self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), + path_type(Path(tmp_dir) / 'dir2')) @parameterized.expand([[Path], [str]]) def test_copyfile_from_local(self, path_type): @@ -214,16 +211,14 @@ def test_copyfile_from_local(self, path_type): src = Path(tmp_dir) / 'test.txt' backend.put_text('disk', src) dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) self.assertEqual(backend.get_text(dst), 'disk') dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) self.assertEqual( backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') @@ -238,17 +233,16 @@ def test_copytree_from_local(self, path_type): # src and dst are Path objects src = Path(tmp_dir) / 'dir1' dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) self.assertTrue(backend.isdir(dst)) self.assertTrue(backend.isfile(dst / 'text3.txt')) self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), + path_type(Path(tmp_dir) / 'dir2')) @parameterized.expand([[Path], [str]]) def test_copyfile_to_local(self, path_type): @@ -257,16 +251,14 @@ def test_copyfile_to_local(self, path_type): src = Path(tmp_dir) / 'test.txt' backend.put_text('disk', src) dst = Path(tmp_dir) / 'test.txt.bak' - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + path_type(dst)) self.assertEqual(backend.get_text(dst), 'disk') dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile(path_type(src), path_type(dst)), - backend.join_path(path_type(dst), 'test.txt')) + self.assertEqual(backend.copyfile(path_type(src), path_type(dst)), + backend.join_path(path_type(dst), 'test.txt')) self.assertEqual( backend.get_text(backend.join_path(dst, 'test.txt')), 'disk') @@ -281,17 +273,16 @@ def test_copytree_to_local(self, path_type): # src and dst are Path objects src = Path(tmp_dir) / 'dir1' dst = Path(tmp_dir) / 'dir100' - self.assertEqual( - backend.copytree(path_type(src), path_type(dst)), - path_type(dst)) + self.assertEqual(backend.copytree(path_type(src), path_type(dst)), + path_type(dst)) self.assertTrue(backend.isdir(dst)) self.assertTrue(backend.isfile(dst / 'text3.txt')) self.assertEqual(backend.get_text(dst / 'text3.txt'), 'text3') # dst should not exist with self.assertRaises(FileExistsError): - backend.copytree( - path_type(src), path_type(Path(tmp_dir) / 'dir2')) + backend.copytree(path_type(src), + path_type(Path(tmp_dir) / 'dir2')) @parameterized.expand([[Path], [str]]) def test_remove(self, path_type): @@ -361,8 +352,8 @@ def symlink(src, dst): with patch.object(os, 'symlink', side_effect=symlink): src = Path(tmp_dir) / 'test.txt' dst = Path(tmp_dir) / 'test_link1.txt' - res = backend.copy_if_symlink_fails( - path_type(src), path_type(dst)) + res = backend.copy_if_symlink_fails(path_type(src), + path_type(dst)) self.assertFalse(res) self.assertFalse(osp.islink(dst)) self.assertTrue(backend.exists(dst)) @@ -371,8 +362,8 @@ def symlink(src, dst): with patch.object(os, 'symlink', side_effect=symlink): src = Path(tmp_dir) / 'dir' dst = Path(tmp_dir) / 'dir_link1' - res = backend.copy_if_symlink_fails( - path_type(src), path_type(dst)) + res = backend.copy_if_symlink_fails(path_type(src), + path_type(dst)) self.assertFalse(res) self.assertFalse(osp.islink(dst)) self.assertTrue(backend.exists(dst)) @@ -382,15 +373,14 @@ def test_list_dir_or_file(self, path_type): backend = LocalBackend() with build_temporary_directory() as tmp_dir: # list directories and files - self.assertEqual( - set(backend.list_dir_or_file(path_type(tmp_dir))), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(path_type(tmp_dir))), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) # list directories and files recursively self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), recursive=True)), + backend.list_dir_or_file(path_type(tmp_dir), + recursive=True)), { 'dir1', osp.join('dir1', 'text3.txt'), 'dir2', @@ -402,35 +392,38 @@ def test_list_dir_or_file(self, path_type): # only list directories self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False)), + backend.list_dir_or_file(path_type(tmp_dir), + list_file=False)), {'dir1', 'dir2'}) with self.assertRaisesRegex( TypeError, '`suffix` should be None when `list_dir` is True'): - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False, suffix='.txt') + backend.list_dir_or_file(path_type(tmp_dir), + list_file=False, + suffix='.txt') # only list directories recursively self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_file=False, recursive=True)), + backend.list_dir_or_file(path_type(tmp_dir), + list_file=False, + recursive=True)), {'dir1', 'dir2', osp.join('dir2', 'dir3')}) # only list files self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False)), + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False)), {'text1.txt', 'text2.txt'}) # only list files recursively self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False, recursive=True)), + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + recursive=True)), { osp.join('dir1', 'text3.txt'), osp.join('dir2', 'dir3', 'text4.txt'), @@ -440,45 +433,44 @@ def test_list_dir_or_file(self, path_type): # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), list_dir=False, suffix='.txt')), + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix='.txt')), {'text1.txt', 'text2.txt'}) self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix=('.txt', '.jpg'))), + {'text1.txt', 'text2.txt'}) with self.assertRaisesRegex( TypeError, '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=['.txt', '.jpg']) + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix=['.txt', '.jpg']) # only list files ending with suffix recursively self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix='.txt', - recursive=True)), { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', - 'text2.txt' - }) + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix='.txt', + recursive=True)), + { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', + 'text2.txt' + }) # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - path_type(tmp_dir), - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + backend.list_dir_or_file(path_type(tmp_dir), + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), { osp.join('dir1', 'text3.txt'), osp.join('dir2', 'dir3', 'text4.txt'), diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py index 6f379c3f23..3af60276b5 100644 --- a/tests/test_fileio/test_backends/test_petrel_backend.py +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -124,13 +124,13 @@ def test_name(self): def test_map_path(self): backend = PetrelBackend(path_mapping=None) - self.assertEqual( - backend._map_path(self.petrel_path), self.petrel_path) + self.assertEqual(backend._map_path(self.petrel_path), + self.petrel_path) backend = PetrelBackend( path_mapping={'data/': 'petrel://user/data/'}) - self.assertEqual( - backend._map_path('data/test.jpg'), self.petrel_path) + self.assertEqual(backend._map_path('data/test.jpg'), + self.petrel_path) def test_format_path(self): backend = PetrelBackend() @@ -140,37 +140,31 @@ def test_format_path(self): def test_replace_prefix(self): backend = PetrelBackend() - self.assertEqual( - backend._replace_prefix(self.petrel_path), self.expected_path) + self.assertEqual(backend._replace_prefix(self.petrel_path), + self.expected_path) def test_join_path(self): backend = PetrelBackend() - self.assertEqual( - backend.join_path(self.petrel_dir, 'file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(f'{self.petrel_dir}/', 'file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(f'{self.petrel_dir}/', '/file'), - f'{self.petrel_dir}/file') - self.assertEqual( - backend.join_path(self.petrel_dir, 'dir', 'file'), - f'{self.petrel_dir}/dir/file') + self.assertEqual(backend.join_path(self.petrel_dir, 'file'), + f'{self.petrel_dir}/file') + self.assertEqual(backend.join_path(f'{self.petrel_dir}/', 'file'), + f'{self.petrel_dir}/file') + self.assertEqual(backend.join_path(f'{self.petrel_dir}/', '/file'), + f'{self.petrel_dir}/file') + self.assertEqual(backend.join_path(self.petrel_dir, 'dir', 'file'), + f'{self.petrel_dir}/dir/file') def test_get(self): backend = PetrelBackend() - with patch.object( - backend._client, 'Get', - return_value=b'petrel') as patched_get: + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get: self.assertEqual(backend.get(self.petrel_path), b'petrel') patched_get.assert_called_once_with(self.expected_path) def test_get_text(self): backend = PetrelBackend() - with patch.object( - backend._client, 'Get', - return_value=b'petrel') as patched_get: + with patch.object(backend._client, 'Get', + return_value=b'petrel') as patched_get: self.assertEqual(backend.get_text(self.petrel_path), 'petrel') patched_get.assert_called_once_with(self.expected_path) @@ -201,9 +195,8 @@ def test_exists(self): with self.assertRaises(NotImplementedError): backend.exists(self.petrel_path) - with patch.object( - backend._client, 'contains', - return_value=True) as patched_contains: + with patch.object(backend._client, 'contains', + return_value=True) as patched_contains: self.assertTrue(backend.exists(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) @@ -216,9 +209,8 @@ def test_isdir(self): with self.assertRaises(NotImplementedError): backend.isdir(self.petrel_path) - with patch.object( - backend._client, 'isdir', - return_value=True) as patched_contains: + with patch.object(backend._client, 'isdir', + return_value=True) as patched_contains: self.assertTrue(backend.isdir(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) @@ -231,9 +223,8 @@ def test_isfile(self): with self.assertRaises(NotImplementedError): backend.isfile(self.petrel_path) - with patch.object( - backend._client, 'contains', - return_value=True) as patched_contains: + with patch.object(backend._client, 'contains', + return_value=True) as patched_contains: self.assertTrue(backend.isfile(self.petrel_path)) patched_contains.assert_called_once_with(self.expected_path) @@ -335,8 +326,8 @@ def test_copyfile_from_local(self): src = self.img_path dst = f'{self.petrel_dir}/dir' expected_dst = f'{self.expected_dir}/dir/color.jpg' - self.assertEqual( - backend.copyfile_from_local(src, dst), f'{dst}/color.jpg') + self.assertEqual(backend.copyfile_from_local(src, dst), + f'{dst}/color.jpg') patched_put.assert_called_once_with(expected_dst, src.open('rb').read()) patched_isdir.assert_called_once_with( @@ -380,8 +371,8 @@ def test_copyfile_to_local(self): src = self.petrel_path dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile_to_local(src, dst), dst / 'test.jpg') + self.assertEqual(backend.copyfile_to_local(src, dst), + dst / 'test.jpg') patched_get.assert_called_once_with(self.expected_path) self.assertEqual((dst / 'test.jpg').open('rb').read(), b'petrel') @@ -468,9 +459,8 @@ def test_list_dir_or_file(self): with build_temporary_directory() as tmp_dir: # list directories and files - self.assertEqual( - set(backend.list_dir_or_file(tmp_dir)), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(tmp_dir)), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) # list directories and files recursively self.assertEqual( @@ -489,14 +479,16 @@ def test_list_dir_or_file(self): TypeError, '`list_dir` should be False when `suffix` is not None' ): - backend.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + backend.list_dir_or_file(tmp_dir, + list_file=False, + suffix='.txt') # only list directories recursively self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)), + backend.list_dir_or_file(tmp_dir, + list_file=False, + recursive=True)), {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) # only list files @@ -507,8 +499,9 @@ def test_list_dir_or_file(self): # only list files recursively self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, recursive=True)), + backend.list_dir_or_file(tmp_dir, + list_dir=False, + recursive=True)), { '/'.join(('dir1', 'text3.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( @@ -518,41 +511,43 @@ def test_list_dir_or_file(self): # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt')), + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt')), {'text1.txt', 'text2.txt'}) self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix=('.txt', '.jpg'))), + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) with self.assertRaisesRegex( TypeError, '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=['.txt', '.jpg']) # only list files ending with suffix recursively self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix='.txt', - recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), - 'text1.txt', 'text2.txt' - }) + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt', + recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), 'text1.txt', + 'text2.txt' + }) # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), { '/'.join(('dir1', 'text3.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( @@ -673,9 +668,8 @@ def test_copytree(self): dst = f'{self.petrel_dir}/dir3' self.assertFalse(backend.exists(dst)) self.assertEqual(backend.copytree(src, dst), dst) - self.assertEqual( - list(backend.list_dir_or_file(src)), - list(backend.list_dir_or_file(dst))) + self.assertEqual(list(backend.list_dir_or_file(src)), + list(backend.list_dir_or_file(dst))) # dst should not exist with self.assertRaises(FileExistsError): @@ -696,8 +690,8 @@ def test_copyfile_from_local(self): dst = f'{self.petrel_dir}/dir1' expected_dst = f'{self.petrel_dir}/dir1/color.jpg' self.assertFalse(backend.exists(expected_dst)) - self.assertEqual( - backend.copyfile_from_local(src, dst), expected_dst) + self.assertEqual(backend.copyfile_from_local(src, dst), + expected_dst) self.assertTrue(backend.isfile(expected_dst)) def test_copytree_from_local(self): @@ -705,8 +699,8 @@ def test_copytree_from_local(self): backend.rmtree(self.petrel_dir) with build_temporary_directory() as tmp_dir: backend.copytree_from_local(tmp_dir, self.petrel_dir) - files = backend.list_dir_or_file( - self.petrel_dir, recursive=True) + files = backend.list_dir_or_file(self.petrel_dir, + recursive=True) self.assertEqual(len(list(files)), 8) def test_copyfile_to_local(self): @@ -721,8 +715,8 @@ def test_copyfile_to_local(self): # dst is a directory dst = Path(tmp_dir) / 'dir' dst.mkdir() - self.assertEqual( - backend.copyfile_to_local(src, dst), dst / 'img.jpg') + self.assertEqual(backend.copyfile_to_local(src, dst), + dst / 'img.jpg') self.assertEqual((dst / 'img.jpg').open('rb').read(), b'img') def test_copytree_to_local(self): @@ -767,9 +761,8 @@ def test_list_dir_or_file(self): backend = PetrelBackend() # list directories and files - self.assertEqual( - set(backend.list_dir_or_file(self.petrel_dir)), - {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) + self.assertEqual(set(backend.list_dir_or_file(self.petrel_dir)), + {'dir1', 'dir2', 'text1.txt', 'text2.txt'}) # list directories and files recursively self.assertEqual( @@ -783,21 +776,22 @@ def test_list_dir_or_file(self): # only list directories self.assertEqual( - set( - backend.list_dir_or_file(self.petrel_dir, + set(backend.list_dir_or_file(self.petrel_dir, list_file=False)), {'dir1', 'dir2'}) with self.assertRaisesRegex( TypeError, '`list_dir` should be False when `suffix` is not None'): - backend.list_dir_or_file( - self.petrel_dir, list_file=False, suffix='.txt') + backend.list_dir_or_file(self.petrel_dir, + list_file=False, + suffix='.txt') # only list directories recursively self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, list_file=False, recursive=True)), + backend.list_dir_or_file(self.petrel_dir, + list_file=False, + recursive=True)), {'dir1', 'dir2', '/'.join(('dir2', 'dir3'))}) # only list files @@ -808,8 +802,9 @@ def test_list_dir_or_file(self): # only list files recursively self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, recursive=True)), + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + recursive=True)), { '/'.join(('dir1', 'text3.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( @@ -819,42 +814,43 @@ def test_list_dir_or_file(self): # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, suffix='.txt')), + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix='.txt')), {'text1.txt', 'text2.txt'}) self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix=('.txt', '.jpg'))), {'text1.txt', 'text2.txt'}) + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix=('.txt', '.jpg'))), + {'text1.txt', 'text2.txt'}) with self.assertRaisesRegex( TypeError, '`suffix` must be a string or tuple of strings'): - backend.list_dir_or_file( - self.petrel_dir, list_dir=False, suffix=['.txt', '.jpg']) + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix=['.txt', '.jpg']) # only list files ending with suffix recursively self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix='.txt', - recursive=True)), { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), 'text1.txt', - 'text2.txt' - }) + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix='.txt', + recursive=True)), + { + '/'.join(('dir1', 'text3.txt')), '/'.join( + ('dir2', 'dir3', 'text4.txt')), 'text1.txt', + 'text2.txt' + }) # only list files ending with suffix self.assertEqual( set( - backend.list_dir_or_file( - self.petrel_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)), + backend.list_dir_or_file(self.petrel_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)), { '/'.join(('dir1', 'text3.txt')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( diff --git a/tests/test_fileio/test_fileclient.py b/tests/test_fileio/test_fileclient.py index 345832a026..0cc87f3167 100644 --- a/tests/test_fileio/test_fileclient.py +++ b/tests/test_fileio/test_fileclient.py @@ -226,23 +226,24 @@ def test_disk_backend(self): osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' } # 3. only list directories - assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_file=False)) == {'dir1', 'dir2'} + assert set(disk_backend.list_dir_or_file( + tmp_dir, list_file=False)) == {'dir1', 'dir2'} with pytest.raises( TypeError, match='`suffix` should be None when `list_dir` is True'): # Exception is raised among the `list_dir_or_file` of client, # so we need to invode the client to trigger the exception - disk_backend.client.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + disk_backend.client.list_dir_or_file(tmp_dir, + list_file=False, + suffix='.txt') # 4. only list directories recursively assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)) == { - 'dir1', 'dir2', - osp.join('dir2', 'dir3') - } + disk_backend.list_dir_or_file(tmp_dir, + list_file=False, + recursive=True)) == { + 'dir1', 'dir2', + osp.join('dir2', 'dir3') + } # 5. only list files assert set(disk_backend.list_dir_or_file( tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} @@ -256,18 +257,23 @@ def test_disk_backend(self): } # 7. only list files ending with suffix assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix='.txt')) == {'text1.txt', 'text2.txt'} + disk_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt')) == { + 'text1.txt', 'text2.txt' + } assert set( - disk_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} + disk_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'))) == { + 'text1.txt', 'text2.txt' + } with pytest.raises( TypeError, match='`suffix` must be a string or tuple of strings'): - disk_backend.client.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + disk_backend.client.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=['.txt', '.jpg']) # 8. only list files ending with suffix recursively assert set( disk_backend.list_dir_or_file( @@ -326,16 +332,16 @@ def test_petrel_backend(self, backend, prefix): == petrel_path # test `get` - with patch.object( - petrel_backend.client._client, 'Get', - return_value=b'petrel') as mock_get: + with patch.object(petrel_backend.client._client, + 'Get', + return_value=b'petrel') as mock_get: assert petrel_backend.get(petrel_path) == b'petrel' mock_get.assert_called_once_with(petrel_path) # test `get_text` - with patch.object( - petrel_backend.client._client, 'Get', - return_value=b'petrel') as mock_get: + with patch.object(petrel_backend.client._client, + 'Get', + return_value=b'petrel') as mock_get: assert petrel_backend.get_text(petrel_path) == 'petrel' mock_get.assert_called_once_with(petrel_path) @@ -381,9 +387,9 @@ def test_petrel_backend(self, backend, prefix): with pytest.raises(NotImplementedError): petrel_backend.exists(petrel_path) - with patch.object( - petrel_backend.client._client, 'contains', - return_value=True) as mock_contains: + with patch.object(petrel_backend.client._client, + 'contains', + return_value=True) as mock_contains: assert petrel_backend.exists(petrel_path) mock_contains.assert_called_once_with(petrel_path) @@ -394,9 +400,9 @@ def test_petrel_backend(self, backend, prefix): with pytest.raises(NotImplementedError): petrel_backend.isdir(petrel_path) - with patch.object( - petrel_backend.client._client, 'isdir', - return_value=True) as mock_isdir: + with patch.object(petrel_backend.client._client, + 'isdir', + return_value=True) as mock_isdir: assert petrel_backend.isdir(petrel_dir) mock_isdir.assert_called_once_with(petrel_dir) @@ -408,9 +414,9 @@ def test_petrel_backend(self, backend, prefix): with pytest.raises(NotImplementedError): petrel_backend.isfile(petrel_path) - with patch.object( - petrel_backend.client._client, 'contains', - return_value=True) as mock_contains: + with patch.object(petrel_backend.client._client, + 'contains', + return_value=True) as mock_contains: assert petrel_backend.isfile(petrel_path) mock_contains.assert_called_once_with(petrel_path) @@ -447,8 +453,8 @@ def test_petrel_backend(self, backend, prefix): 'dir1', 'dir2', 'text1.txt', 'text2.txt' } # 2. list directories and files recursively - assert set( - petrel_backend.list_dir_or_file(tmp_dir, recursive=True)) == { + assert set(petrel_backend.list_dir_or_file( + tmp_dir, recursive=True)) == { 'dir1', '/'.join(('dir1', 'text3.txt')), 'dir2', '/'.join( ('dir2', 'dir3')), '/'.join( ('dir2', 'dir3', 'text4.txt')), '/'.join( @@ -464,18 +470,20 @@ def test_petrel_backend(self, backend, prefix): 'None')): # Exception is raised among the `list_dir_or_file` of client, # so we need to invode the client to trigger the exception - petrel_backend.client.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt') + petrel_backend.client.list_dir_or_file(tmp_dir, + list_file=False, + suffix='.txt') # 4. only list directories recursively assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_file=False, recursive=True)) == { - 'dir1', 'dir2', '/'.join(('dir2', 'dir3')) - } + petrel_backend.list_dir_or_file(tmp_dir, + list_file=False, + recursive=True)) == { + 'dir1', 'dir2', '/'.join( + ('dir2', 'dir3')) + } # 5. only list files - assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} + assert set(petrel_backend.list_dir_or_file( + tmp_dir, list_dir=False)) == {'text1.txt', 'text2.txt'} # 6. only list files recursively assert set( petrel_backend.list_dir_or_file( @@ -486,27 +494,35 @@ def test_petrel_backend(self, backend, prefix): } # 7. only list files ending with suffix assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix='.txt')) == {'text1.txt', 'text2.txt'} + petrel_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt')) == { + 'text1.txt', 'text2.txt' + } assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} + petrel_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'))) == { + 'text1.txt', 'text2.txt' + } with pytest.raises( TypeError, match='`suffix` must be a string or tuple of strings'): - petrel_backend.client.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg']) + petrel_backend.client.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=['.txt', '.jpg']) # 8. only list files ending with suffix recursively assert set( - petrel_backend.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt', - recursive=True)) == { - '/'.join(('dir1', 'text3.txt')), '/'.join( - ('dir2', 'dir3', 'text4.txt')), 'text1.txt', - 'text2.txt' - } + petrel_backend.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt', + recursive=True)) == { + '/'.join( + ('dir1', 'text3.txt')), + '/'.join(('dir2', 'dir3', + 'text4.txt')), + 'text1.txt', 'text2.txt' + } # 7. only list files ending with suffix assert set( petrel_backend.list_dir_or_file( @@ -782,11 +798,10 @@ def get(self, filepath): def get_text(self, filepath, encoding='utf-8'): return 'text6' - FileClient.register_backend( - 'example4', - Example6Backend, - force=True, - prefixes='example4_prefix') + FileClient.register_backend('example4', + Example6Backend, + force=True, + prefixes='example4_prefix') example_backend = FileClient('example4') assert example_backend.get(self.img_path) == b'bytes6' assert example_backend.get_text(self.text_path) == 'text6' @@ -830,11 +845,10 @@ def get(self, filepath): def get_text(self, filepath, encoding='utf-8'): return 'text8' - FileClient.register_backend( - 'example6', - Example8Backend, - force=True, - prefixes='example6_prefix') + FileClient.register_backend('example6', + Example8Backend, + force=True, + prefixes='example6_prefix') example_backend = FileClient('example6') assert example_backend.get(self.img_path) == b'bytes8' assert example_backend.get_text(self.text_path) == 'text8' diff --git a/tests/test_fileio/test_fileio.py b/tests/test_fileio/test_fileio.py index 33a0956fed..13fba651dc 100644 --- a/tests/test_fileio/test_fileio.py +++ b/tests/test_fileio/test_fileio.py @@ -152,35 +152,37 @@ def test_list_from_file(): # get list from http filename = 'http://path/of/your/file' - with patch.object( - HTTPBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + with patch.object(HTTPBackend, + 'get_text', + return_value='1.jpg\n2.jpg\n3.jpg'): filelist = mmengine.list_from_file( filename, file_client_args={'backend': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, file_client_args={'prefix': 'http'}) + filelist = mmengine.list_from_file(filename, + file_client_args={'prefix': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, backend_args={'backend': 'http'}) + filelist = mmengine.list_from_file(filename, + backend_args={'backend': 'http'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] # get list from petrel filename = 's3://path/of/your/file' - with patch.object( - PetrelBackend, 'get_text', return_value='1.jpg\n2.jpg\n3.jpg'): + with patch.object(PetrelBackend, + 'get_text', + return_value='1.jpg\n2.jpg\n3.jpg'): filelist = mmengine.list_from_file( filename, file_client_args={'backend': 'petrel'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, file_client_args={'prefix': 's3'}) + filelist = mmengine.list_from_file(filename, + file_client_args={'prefix': 's3'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] filelist = mmengine.list_from_file(filename) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] - filelist = mmengine.list_from_file( - filename, backend_args={'backend': 'petrel'}) + filelist = mmengine.list_from_file(filename, + backend_args={'backend': 'petrel'}) assert filelist == ['1.jpg', '2.jpg', '3.jpg'] @@ -194,35 +196,36 @@ def test_dict_from_file(): # get dict from http filename = 'http://path/of/your/file' - with patch.object( - HTTPBackend, 'get_text', return_value='1 cat\n2 dog cow\n3 panda'): - mapping = mmengine.dict_from_file( - filename, file_client_args={'backend': 'http'}) + with patch.object(HTTPBackend, + 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): + mapping = mmengine.dict_from_file(filename, + file_client_args={'backend': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, file_client_args={'prefix': 'http'}) + mapping = mmengine.dict_from_file(filename, + file_client_args={'prefix': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, backend_args={'backend': 'http'}) + mapping = mmengine.dict_from_file(filename, + backend_args={'backend': 'http'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} # get dict from petrel filename = 's3://path/of/your/file' - with patch.object( - PetrelBackend, 'get_text', - return_value='1 cat\n2 dog cow\n3 panda'): + with patch.object(PetrelBackend, + 'get_text', + return_value='1 cat\n2 dog cow\n3 panda'): mapping = mmengine.dict_from_file( filename, file_client_args={'backend': 'petrel'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, file_client_args={'prefix': 's3'}) + mapping = mmengine.dict_from_file(filename, + file_client_args={'prefix': 's3'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} mapping = mmengine.dict_from_file(filename) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} - mapping = mmengine.dict_from_file( - filename, backend_args={'backend': 'petrel'}) + mapping = mmengine.dict_from_file(filename, + backend_args={'backend': 'petrel'}) assert mapping == {'1': 'cat', '2': ['dog', 'cow'], '3': 'panda'} diff --git a/tests/test_fileio/test_io.py b/tests/test_fileio/test_io.py index c34af47e0b..5fb4c9b596 100644 --- a/tests/test_fileio/test_io.py +++ b/tests/test_fileio/test_io.py @@ -139,8 +139,9 @@ def test_get_file_backend(): backend_args = {'path_mapping': {'src': 'dst'}, 'enable_mc': True} uri = 'petrel://your_bucket/img.png' - backend4 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + backend4 = fileio.get_file_backend(uri=uri, + backend_args=backend_args, + enable_singleton=True) assert isinstance(backend4, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 2 unique_key = 'petrel:{"path_mapping": {"src": "dst"}, "enable_mc": true}' @@ -148,16 +149,18 @@ def test_get_file_backend(): assert backend4 is not backend2 uri = 'petrel://your_bucket/img1.png' - backend5 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + backend5 = fileio.get_file_backend(uri=uri, + backend_args=backend_args, + enable_singleton=True) assert isinstance(backend5, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 2 assert backend5 is backend4 assert backend5 is not backend2 backend_args = {'path_mapping': {'src1': 'dst1'}, 'enable_mc': True} - backend6 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=True) + backend6 = fileio.get_file_backend(uri=uri, + backend_args=backend_args, + enable_singleton=True) assert isinstance(backend6, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 3 unique_key = 'petrel:{"path_mapping": {"src1": "dst1"}, "enable_mc": true}' @@ -165,8 +168,9 @@ def test_get_file_backend(): assert backend6 is not backend4 assert backend6 is not backend5 - backend7 = fileio.get_file_backend( - uri=uri, backend_args=backend_args, enable_singleton=False) + backend7 = fileio.get_file_backend(uri=uri, + backend_args=backend_args, + enable_singleton=False) assert isinstance(backend7, fileio.backends.PetrelBackend) assert len(fileio.io.backend_instances) == 3 assert backend7 is not backend6 @@ -472,8 +476,9 @@ def test_list_dir_or_file(): TypeError, match='`suffix` should be None when `list_dir` is True'): list( - fileio.list_dir_or_file( - tmp_dir, list_file=False, suffix='.txt')) + fileio.list_dir_or_file(tmp_dir, + list_file=False, + suffix='.txt')) # only list directories recursively assert set( @@ -502,34 +507,39 @@ def test_list_dir_or_file(): tmp_dir, list_dir=False, suffix='.txt')) == {'text1.txt', 'text2.txt'} assert set( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, - suffix=('.txt', '.jpg'))) == {'text1.txt', 'text2.txt'} + fileio.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'))) == { + 'text1.txt', 'text2.txt' + } with pytest.raises( TypeError, match='`suffix` must be a string or tuple of strings'): list( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, suffix=['.txt', '.jpg'])) + fileio.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=['.txt', '.jpg'])) # only list files ending with suffix recursively assert set( - fileio.list_dir_or_file( - tmp_dir, list_dir=False, suffix='.txt', recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), 'text1.txt', - 'text2.txt' - } + fileio.list_dir_or_file(tmp_dir, + list_dir=False, + suffix='.txt', + recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + 'text1.txt', 'text2.txt' + } # only list files ending with suffix assert set( - fileio.list_dir_or_file( - tmp_dir, - list_dir=False, - suffix=('.txt', '.jpg'), - recursive=True)) == { - osp.join('dir1', 'text3.txt'), - osp.join('dir2', 'dir3', 'text4.txt'), - osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' - } + fileio.list_dir_or_file(tmp_dir, + list_dir=False, + suffix=('.txt', '.jpg'), + recursive=True)) == { + osp.join('dir1', 'text3.txt'), + osp.join('dir2', 'dir3', 'text4.txt'), + osp.join('dir2', 'img.jpg'), + 'text1.txt', 'text2.txt' + } diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index d731a42b76..eb7ac967cb 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -57,9 +57,8 @@ def test_init(self): ValueError, '"file_client_args" and "backend_args" cannot be set ' 'at the same time'): - CheckpointHook( - file_client_args={'backend': 'disk'}, - backend_args={'backend': 'local'}) + CheckpointHook(file_client_args={'backend': 'disk'}, + backend_args={'backend': 'local'}) # Test save best CheckpointHook(save_best='acc') @@ -88,8 +87,9 @@ def test_init(self): hook = CheckpointHook(greater_keys=['acc']) self.assertEqual(hook.greater_keys, ['acc']) - hook = CheckpointHook( - interval=2, by_epoch=False, save_best=['acc', 'mIoU']) + hook = CheckpointHook(interval=2, + by_epoch=False, + save_best=['acc', 'mIoU']) self.assertEqual(hook.key_indicators, ['acc', 'mIoU']) self.assertEqual(hook.rules, ['greater', 'greater']) @@ -123,8 +123,9 @@ def test_before_train(self): self.assertEqual(checkpoint_hook.out_dir, runner.work_dir) # the out_dir of the checkpoint hook is not None - checkpoint_hook = CheckpointHook( - interval=1, by_epoch=True, out_dir='test_dir') + checkpoint_hook = CheckpointHook(interval=1, + by_epoch=True, + out_dir='test_dir') checkpoint_hook.before_train(runner) self.assertEqual(checkpoint_hook.out_dir, osp.join('test_dir', osp.basename(cfg.work_dir))) @@ -162,13 +163,15 @@ def test_after_val_epoch(self): # if metrics is an empty dict, print a warning information with self.assertLogs(runner.logger, level='WARNING'): - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='auto') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='auto') checkpoint_hook.after_val_epoch(runner, {}) # if save_best is None,no best_ckpt meta should be stored - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best=None) + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best=None) checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, {}) self.assertNotIn('best_score', runner.message_hub.runtime_info) @@ -176,8 +179,9 @@ def test_after_val_epoch(self): # when `save_best` is set to `auto`, first metric will be used. metrics = {'acc': 0.5, 'map': 0.3} - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='auto') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='auto') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) best_ckpt_name = 'best_acc_epoch_9.pth' @@ -186,20 +190,22 @@ def test_after_val_epoch(self): self.assertEqual(checkpoint_hook.key_indicators, ['acc']) self.assertEqual(checkpoint_hook.rules, ['greater']) self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt'), + best_ckpt_path) # # when `save_best` is set to `acc`, it should update greater value - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='acc') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='acc') checkpoint_hook.before_train(runner) metrics['acc'] = 0.8 checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(runner.message_hub.get_info('best_score'), 0.8) # # when `save_best` is set to `loss`, it should update less value - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='loss') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='loss') checkpoint_hook.before_train(runner) metrics['loss'] = 0.8 checkpoint_hook.after_val_epoch(runner, metrics) @@ -209,8 +215,10 @@ def test_after_val_epoch(self): # when `rule` is set to `less`,then it should update less value # no matter what `save_best` is - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='acc', rule='less') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='acc', + rule='less') checkpoint_hook.before_train(runner) metrics['acc'] = 0.3 checkpoint_hook.after_val_epoch(runner, metrics) @@ -218,22 +226,26 @@ def test_after_val_epoch(self): # # when `rule` is set to `greater`,then it should update greater value # # no matter what `save_best` is - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=True, save_best='loss', rule='greater') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=True, + save_best='loss', + rule='greater') checkpoint_hook.before_train(runner) metrics['loss'] = 1.0 checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(runner.message_hub.get_info('best_score'), 1.0) # test multi `save_best` with one rule - checkpoint_hook = CheckpointHook( - interval=2, save_best=['acc', 'mIoU'], rule='greater') + checkpoint_hook = CheckpointHook(interval=2, + save_best=['acc', 'mIoU'], + rule='greater') self.assertEqual(checkpoint_hook.key_indicators, ['acc', 'mIoU']) self.assertEqual(checkpoint_hook.rules, ['greater', 'greater']) # test multi `save_best` with multi rules - checkpoint_hook = CheckpointHook( - interval=2, save_best=['FID', 'IS'], rule=['less', 'greater']) + checkpoint_hook = CheckpointHook(interval=2, + save_best=['FID', 'IS'], + rule=['less', 'greater']) self.assertEqual(checkpoint_hook.key_indicators, ['FID', 'IS']) self.assertEqual(checkpoint_hook.rules, ['less', 'greater']) @@ -254,10 +266,10 @@ def test_after_val_epoch(self): checkpoint_hook.out_dir, best_mIoU_name) self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt_acc'), + best_acc_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt_mIoU'), + best_mIoU_path) # test behavior when by_epoch is False cfg = copy.deepcopy(self.iter_based_cfg) @@ -266,8 +278,10 @@ def test_after_val_epoch(self): # check best ckpt name and best score metrics = {'acc': 0.5, 'map': 0.3} - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=False, save_best='acc', rule='greater') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=False, + save_best='acc', + rule='greater') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) self.assertEqual(checkpoint_hook.key_indicators, ['acc']) @@ -276,8 +290,8 @@ def test_after_val_epoch(self): best_ckpt_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_ckpt_name) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt'), + best_ckpt_path) self.assertEqual(runner.message_hub.get_info('best_score'), 0.5) # check best score updating @@ -286,13 +300,14 @@ def test_after_val_epoch(self): best_ckpt_name = 'best_acc_iter_9.pth' best_ckpt_path = checkpoint_hook.file_client.join_path( checkpoint_hook.out_dir, best_ckpt_name) - self.assertEqual( - runner.message_hub.get_info('best_ckpt'), best_ckpt_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt'), + best_ckpt_path) self.assertEqual(runner.message_hub.get_info('best_score'), 0.666) # check best checkpoint name with `by_epoch` is False - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=False, save_best=['acc', 'mIoU']) + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=False, + save_best=['acc', 'mIoU']) checkpoint_hook.before_train(runner) metrics = dict(acc=0.5, mIoU=0.6) checkpoint_hook.after_val_epoch(runner, metrics) @@ -305,10 +320,10 @@ def test_after_val_epoch(self): self.assertEqual(runner.message_hub.get_info('best_score_acc'), 0.5) self.assertEqual(runner.message_hub.get_info('best_score_mIoU'), 0.6) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_acc'), best_acc_path) - self.assertEqual( - runner.message_hub.get_info('best_ckpt_mIoU'), best_mIoU_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt_acc'), + best_acc_path) + self.assertEqual(runner.message_hub.get_info('best_ckpt_mIoU'), + best_mIoU_path) # after_val_epoch should not save last_checkpoint self.assertFalse( @@ -321,8 +336,9 @@ def test_after_val_epoch(self): self.clear_work_dir() cfg = copy.deepcopy(cfg) runner = self.build_runner(cfg) - checkpoint_hook = CheckpointHook( - interval=2, by_epoch=by_epoch, save_best='acc') + checkpoint_hook = CheckpointHook(interval=2, + by_epoch=by_epoch, + save_best='acc') checkpoint_hook.before_train(runner) checkpoint_hook.after_val_epoch(runner, metrics) all_files = os.listdir(runner.work_dir) @@ -373,9 +389,8 @@ def test_after_train_epoch(self): checkpoint_hook.before_train(runner) checkpoint_hook.after_train_epoch(runner) self.assertEqual((runner.epoch + 1) % 2, 0) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'epoch_10.pth')) + self.assertEqual(runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'epoch_10.pth')) last_ckpt_path = osp.join(cfg.work_dir, 'last_checkpoint') self.assertTrue(osp.isfile(last_ckpt_path)) @@ -387,9 +402,8 @@ def test_after_train_epoch(self): # epoch can not be evenly divided by 2 runner.train_loop._epoch = 10 checkpoint_hook.after_train_epoch(runner) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'epoch_10.pth')) + self.assertEqual(runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'epoch_10.pth')) runner.message_hub.runtime_info.clear() # by epoch is False @@ -416,25 +430,22 @@ def test_after_train_iter(self): checkpoint_hook.before_train(runner) checkpoint_hook.after_train_iter(runner, batch_idx=9) self.assertIn('last_ckpt', runner.message_hub.runtime_info) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'iter_10.pth')) + self.assertEqual(runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'iter_10.pth')) # epoch can not be evenly divided by 2 runner.train_loop._iter = 10 checkpoint_hook.after_train_epoch(runner) - self.assertEqual( - runner.message_hub.get_info('last_ckpt'), - osp.join(cfg.work_dir, 'iter_10.pth')) + self.assertEqual(runner.message_hub.get_info('last_ckpt'), + osp.join(cfg.work_dir, 'iter_10.pth')) @parameterized.expand([['iter'], ['epoch']]) def test_with_runner(self, training_type): common_cfg = getattr(self, f'{training_type}_based_cfg') setattr(common_cfg.train_cfg, f'max_{training_type}s', 11) - checkpoint_cfg = dict( - type='CheckpointHook', - interval=1, - by_epoch=training_type == 'epoch') + checkpoint_cfg = dict(type='CheckpointHook', + interval=1, + by_epoch=training_type == 'epoch') common_cfg.default_hooks = dict(checkpoint=checkpoint_cfg) # Test interval in epoch based training @@ -470,12 +481,11 @@ def test_with_runner(self, training_type): # Test save_param_scheduler=False cfg = copy.deepcopy(common_cfg) cfg.param_scheduler = [ - dict( - type='LinearLR', - start_factor=0.1, - begin=0, - end=500, - by_epoch=training_type == 'epoch') + dict(type='LinearLR', + start_factor=0.1, + begin=0, + end=500, + by_epoch=training_type == 'epoch') ] runner = self.build_runner(cfg) runner.train() diff --git a/tests/test_hooks/test_early_stopping_hook.py b/tests/test_hooks/test_early_stopping_hook.py index 16f8fd981c..08fe4cbac5 100644 --- a/tests/test_hooks/test_early_stopping_hook.py +++ b/tests/test_hooks/test_early_stopping_hook.py @@ -149,8 +149,9 @@ def test_after_val_epoch(self): # if `monitor` does not match and strict=True, crash the training. with self.assertRaises(RuntimeError): metrics = {'accuracy/top1': 0.5, 'loss': 0.23} - hook = EarlyStoppingHook( - monitor='acc', rule='greater', strict=True) + hook = EarlyStoppingHook(monitor='acc', + rule='greater', + strict=True) hook.after_val_epoch(runner, metrics) # Check largest value @@ -176,8 +177,9 @@ def test_after_val_epoch(self): # Check stop training runner = get_mock_runner() metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1) + hook = EarlyStoppingHook(monitor='accuracy/top1', + rule='greater', + min_delta=1) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -187,8 +189,9 @@ def test_after_val_epoch(self): # Check finite runner = get_mock_runner() metrics = [{'accuracy/top1': math.inf} for i in range(5)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1) + hook = EarlyStoppingHook(monitor='accuracy/top1', + rule='greater', + min_delta=1) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -198,8 +201,10 @@ def test_after_val_epoch(self): # Check patience runner = get_mock_runner() metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', rule='greater', min_delta=1, patience=10) + hook = EarlyStoppingHook(monitor='accuracy/top1', + rule='greater', + min_delta=1, + patience=10) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -209,11 +214,10 @@ def test_after_val_epoch(self): # Check stopping_threshold runner = get_mock_runner() metrics = [{'accuracy/top1': i} for i in torch.linspace(98, 99, 8)] - hook = EarlyStoppingHook( - monitor='accuracy/top1', - rule='greater', - stopping_threshold=98.5, - patience=0) + hook = EarlyStoppingHook(monitor='accuracy/top1', + rule='greater', + stopping_threshold=98.5, + patience=0) for metric in metrics: hook.after_val_epoch(runner, metric) if runner.train_loop.stop_training: @@ -230,26 +234,27 @@ def test_with_runner(self): min_delta=1, patience=3, ) - runner = Runner( - model=ToyModel(), - work_dir=work_dir, - train_dataloader=dict( - dataset=DummyDataset(), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=DummyDataset(), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - val_evaluator=dict(type=DummyMetric, length=max_epoch), - optim_wrapper=OptimWrapper( - torch.optim.Adam(ToyModel().parameters())), - train_cfg=dict( - by_epoch=True, max_epochs=max_epoch, val_interval=1), - val_cfg=dict(), - custom_hooks=[early_stop_cfg], - experiment_name='earlystop_test') + runner = Runner(model=ToyModel(), + work_dir=work_dir, + train_dataloader=dict(dataset=DummyDataset(), + sampler=dict( + type='DefaultSampler', + shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict(dataset=DummyDataset(), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), + val_evaluator=dict(type=DummyMetric, length=max_epoch), + optim_wrapper=OptimWrapper( + torch.optim.Adam(ToyModel().parameters())), + train_cfg=dict(by_epoch=True, + max_epochs=max_epoch, + val_interval=1), + val_cfg=dict(), + custom_hooks=[early_stop_cfg], + experiment_name='earlystop_test') runner.train() self.assertEqual(runner.epoch, 6) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 6dad7ba4f0..398c7f1672 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -208,9 +208,9 @@ def test_after_load_checkpoint(self): # Check the weight of state_dict and ema_state_dict have been swapped. # when runner._resume is True runner._resume = True - checkpoint = dict( - state_dict=ToyModel().state_dict(), - ema_state_dict=ExponentialMovingAverage(ToyModel()).state_dict()) + checkpoint = dict(state_dict=ToyModel().state_dict(), + ema_state_dict=ExponentialMovingAverage( + ToyModel()).state_dict()) ori_checkpoint = copy.deepcopy(checkpoint) ema_hook.after_load_checkpoint(runner, checkpoint) for key in ori_checkpoint['state_dict'].keys(): @@ -273,8 +273,8 @@ def test_with_runner(self): cfg.custom_hooks = [ConfigDict(type='EMAHook', begin_epoch=5)] runner = self.build_runner(cfg) runner.train() - state_dict = torch.load( - osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu') + state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'), + map_location='cpu') self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) @@ -286,13 +286,13 @@ def test_with_runner(self): cfg.default_hooks.checkpoint.interval = 1 runner = self.build_runner(cfg) runner.train() - state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu') + state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'), + map_location='cpu') self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) - state_dict = torch.load( - osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu') + state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'), + map_location='cpu') self.assertIn('ema_state_dict', state_dict) def _test_swap_parameters(self, func_name, *args, **kwargs): diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py index d30972d360..7e722a0e77 100644 --- a/tests/test_hooks/test_empty_cache_hook.py +++ b/tests/test_hooks/test_empty_cache_hook.py @@ -9,8 +9,8 @@ class TestEmptyCacheHook(RunnerTestCase): - @pytest.mark.skipif( - not is_cuda_available(), reason='cuda should be available') + @pytest.mark.skipif(not is_cuda_available(), + reason='cuda should be available') def test_with_runner(self): with patch('torch.cuda.empty_cache') as mock_empty_cache: cfg = self.epoch_based_cfg @@ -47,8 +47,7 @@ def test_with_runner(self): with patch('torch.cuda.empty_cache') as mock_empty_cache: cfg.custom_hooks = [ - dict( - type='EmptyCacheHook', after_iter=True, before_epoch=True) + dict(type='EmptyCacheHook', after_iter=True, before_epoch=True) ] runner = self.build_runner(cfg) diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index 52b8bc1fa3..925226b98a 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -49,17 +49,15 @@ def test_init(self): # test deprecated warning raised by `file_client_args` logger = MMLogger.get_current_instance() with self.assertLogs(logger, level='WARNING'): - LoggerHook( - out_dir=self.temp_dir.name, - file_client_args=dict(backend='disk')) + LoggerHook(out_dir=self.temp_dir.name, + file_client_args=dict(backend='disk')) with self.assertRaisesRegex( ValueError, '"file_client_args" and "backend_args" cannot be '): - LoggerHook( - out_dir=self.temp_dir.name, - file_client_args=dict(enable_mc=True), - backend_args=dict(enable_mc=True)) + LoggerHook(out_dir=self.temp_dir.name, + file_client_args=dict(enable_mc=True), + backend_args=dict(enable_mc=True)) def test_after_train_iter(self): # Test LoggerHook by iter. @@ -138,8 +136,8 @@ def test_after_val_epoch(self): 'acc': 0.8 }, **args), ] - self.assertEqual( - len(calls), len(runner.visualizer.add_scalars.mock_calls)) + self.assertEqual(len(calls), + len(runner.visualizer.add_scalars.mock_calls)) runner.visualizer.add_scalars.assert_has_calls(calls) # Test when `log_metric_by_epoch` is False @@ -165,8 +163,8 @@ def test_after_val_epoch(self): 'acc': 0.5 }, **args), ] - self.assertEqual( - len(calls), len(runner.visualizer.add_scalars.mock_calls)) + self.assertEqual(len(calls), + len(runner.visualizer.add_scalars.mock_calls)) runner.visualizer.add_scalars.assert_has_calls(calls) def test_after_test_epoch(self): @@ -174,10 +172,9 @@ def test_after_test_epoch(self): runner = MagicMock() runner.log_dir = self.temp_dir.name runner.timestamp = 'test_after_test_epoch' - runner.log_processor.get_log_after_epoch = MagicMock( - return_value=( - dict(a=1, b=2, c={'list': [1, 2]}, d=torch.tensor([1, 2, 3])), - 'log_str')) + runner.log_processor.get_log_after_epoch = MagicMock(return_value=( + dict(a=1, b=2, c={'list': [1, 2]}, d=torch.tensor([1, 2, 3])), + 'log_str')) logger_hook.before_run(runner) logger_hook.after_test_epoch(runner) runner.log_processor.get_log_after_epoch.assert_called() @@ -232,8 +229,9 @@ def test_with_runner(self): shutil.rmtree(osp.join(out_dir, filename)) # Test out_suffix - cfg.default_hooks.logger = dict( - type='LoggerHook', out_dir=out_dir, out_suffix='.log') + cfg.default_hooks.logger = dict(type='LoggerHook', + out_dir=out_dir, + out_suffix='.log') runner = self.build_runner(cfg) runner.train() filenames = scandir(out_dir, recursive=True) @@ -241,8 +239,9 @@ def test_with_runner(self): all(filename.endswith('.log') for filename in filenames)) # Test keep_local=False - cfg.default_hooks.logger = dict( - type='LoggerHook', out_dir=out_dir, keep_local=False) + cfg.default_hooks.logger = dict(type='LoggerHook', + out_dir=out_dir, + keep_local=False) runner = self.build_runner(cfg) runner.train() filenames = scandir(runner._log_dir, recursive=True) diff --git a/tests/test_hooks/test_naive_visualization_hook.py b/tests/test_hooks/test_naive_visualization_hook.py index 2e39e94527..2345dcbc0d 100644 --- a/tests/test_hooks/test_naive_visualization_hook.py +++ b/tests/test_hooks/test_naive_visualization_hook.py @@ -16,47 +16,40 @@ def test_after_train_iter(self): inputs = torch.randn(1, 3, 15, 15) batch_idx = 10 # test with normalize, resize, pad - gt_datasamples = BaseDataElement( - metainfo=dict( - img_norm_cfg=dict( - mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict(img_norm_cfg=dict( + mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), + scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with resize, pad - gt_datasamples = BaseDataElement( - metainfo=dict( - scale=(10, 10), - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict(scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only resize - gt_datasamples = BaseDataElement( - metainfo=dict( - scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict( + scale=(15, 15), ori_height=5, ori_width=5, img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, pred_datasamples) # test with only pad - gt_datasamples = BaseDataElement( - metainfo=dict( - pad_shape=(15, 15, 3), - ori_height=5, - ori_width=5, - img_path='tmp.jpg')) + gt_datasamples = BaseDataElement(metainfo=dict(pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) pred_datasamples = [BaseDataElement()] data_batch = [dict(inputs=inputs, data_sample=gt_datasamples)] naive_visualization_hook.after_test_iter(runner, batch_idx, data_batch, diff --git a/tests/test_hooks/test_prepare_tta_hook.py b/tests/test_hooks/test_prepare_tta_hook.py index a356164ef6..0de30d788a 100644 --- a/tests/test_hooks/test_prepare_tta_hook.py +++ b/tests/test_hooks/test_prepare_tta_hook.py @@ -82,9 +82,8 @@ def test_before_test(self): # Test with epoch based runner. cfg = copy.deepcopy(self.epoch_based_cfg) cfg.custom_hooks.append( - dict( - type='PrepareTTAHook', - tta_cfg=dict(type='ToyTestTimeAugModel'))) + dict(type='PrepareTTAHook', + tta_cfg=dict(type='ToyTestTimeAugModel'))) cfg.model = dict(type='ToyModel') cfg.test_dataloader.dataset = dict( type='ToyDatasetTTA', pipeline=dict(type='ToyTTAPipeline')) @@ -96,9 +95,8 @@ def test_before_test(self): # Test with iteration based runner cfg = copy.deepcopy(self.iter_based_cfg) cfg.custom_hooks.append( - dict( - type='PrepareTTAHook', - tta_cfg=dict(type='ToyTestTimeAugModel'))) + dict(type='PrepareTTAHook', + tta_cfg=dict(type='ToyTestTimeAugModel'))) cfg.model = dict(type='ToyModel') cfg.test_dataloader.dataset = dict( type='ToyDatasetTTA', pipeline=dict(type='ToyTTAPipeline')) diff --git a/tests/test_hooks/test_profiler_hook.py b/tests/test_hooks/test_profiler_hook.py index 2db6df01b6..8021664bbd 100644 --- a/tests/test_hooks/test_profiler_hook.py +++ b/tests/test_hooks/test_profiler_hook.py @@ -52,13 +52,13 @@ def deal_profile(_profile): hook.on_trace_ready = dict(type='unknown') hook._parse_trace_config(runner) - hook.on_trace_ready = dict( - type='log_trace', sort_by='self_cpu_time_total', row_limit=10) + hook.on_trace_ready = dict(type='log_trace', + sort_by='self_cpu_time_total', + row_limit=10) hook._parse_trace_config(runner) - @unittest.skipIf( - not is_installed('torch-tb-profiler'), - reason='required torch-tb-profiler') + @unittest.skipIf(not is_installed('torch-tb-profiler'), + reason='required torch-tb-profiler') def test_parse_trace_config_tensorboard(self): # Test on_trace_ready_args runner = MagicMock() @@ -76,16 +76,15 @@ def test_parse_trace_config_tensorboard(self): hook._parse_trace_config(runner) # with self.assertWarns(DeprecationWarning): - hook = ProfilerHook( - on_trace_ready=dict(type='tb_trace'), - json_trace_path=ops.join(self.temp_dir.name, 'demo.json')) + hook = ProfilerHook(on_trace_ready=dict(type='tb_trace'), + json_trace_path=ops.join(self.temp_dir.name, + 'demo.json')) hook._parse_trace_config(runner) self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='ProfilerHook', - on_trace_ready=dict( - type='tb_trace', dir_name=self.temp_dir.name)) + dict(type='ProfilerHook', + on_trace_ready=dict(type='tb_trace', + dir_name=self.temp_dir.name)) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() @@ -148,19 +147,18 @@ def test_after_train_iter(self): hook.profiler.__exit__.assert_called_once() hook.profiler.step.assert_called_once() - hook = ProfilerHook( - by_epoch=False, - schedule=dict(wait=1, warmup=1, active=3, repeat=1)) + hook = ProfilerHook(by_epoch=False, + schedule=dict(wait=1, warmup=1, active=3, + repeat=1)) hook.profiler = MagicMock() hook.after_train_iter(runner, 1, 1, 1) hook.profiler.step.assert_called_once() def test_with_runner(self): self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='ProfilerHook', - activity_with_cpu=False, - activity_with_cuda=False) + dict(type='ProfilerHook', + activity_with_cpu=False, + activity_with_cuda=False) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() @@ -171,16 +169,14 @@ def test_with_runner(self): ] runner = self.build_runner(self.epoch_based_cfg) runner.train() - self.assertTrue( - ops.exists(json_path), 'ERROR::json file is not generated!') + self.assertTrue(ops.exists(json_path), + 'ERROR::json file is not generated!') self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='ProfilerHook', - on_trace_ready=dict( - type='log_trace', - sort_by='self_cpu_time_total', - row_limit=10)) + dict(type='ProfilerHook', + on_trace_ready=dict(type='log_trace', + sort_by='self_cpu_time_total', + row_limit=10)) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() @@ -200,8 +196,8 @@ def test_with_runner(self): runner.train() -@unittest.skipIf( - not is_npu_available(), reason='Ascend PyTorch and npu devices not exist') +@unittest.skipIf(not is_npu_available(), + reason='Ascend PyTorch and npu devices not exist') class TestNPUProfilerHook(RunnerTestCase): def test_init(self): @@ -243,27 +239,25 @@ def test_after_train_iter(self): def test_with_runner(self): result_path = ops.join(self.temp_dir.name, 'test/cann_profiling') self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='NPUProfilerHook', - begin=0, - result_path=result_path, - exit_after_profiling=False) + dict(type='NPUProfilerHook', + begin=0, + result_path=result_path, + exit_after_profiling=False) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() self.epoch_based_cfg['custom_hooks'] = [ - dict( - type='NPUProfilerHook', - result_path=result_path, - ge_profiling_to_std_out=True, - exit_after_profiling=False) + dict(type='NPUProfilerHook', + result_path=result_path, + ge_profiling_to_std_out=True, + exit_after_profiling=False) ] runner = self.build_runner(self.epoch_based_cfg) runner.train() - self.assertTrue( - ops.exists(result_path), 'profiler result path is not generated!') + self.assertTrue(ops.exists(result_path), + 'profiler result path is not generated!') self.assertTrue( os.getenv('GE_PROFILING_TO_STD_OUT', '0') == '1', diff --git a/tests/test_hooks/test_runtime_info_hook.py b/tests/test_hooks/test_runtime_info_hook.py index c7e7a3c339..5f15f7ddd8 100644 --- a/tests/test_hooks/test_runtime_info_hook.py +++ b/tests/test_hooks/test_runtime_info_hook.py @@ -95,8 +95,8 @@ def test_before_train_iter(self): optim2 = SGD(model.layer2.parameters(), lr=0.02) optim_wrapper1 = OptimWrapper(optim1) optim_wrapper2 = OptimWrapper(optim2) - optim_wrapper_dict = OptimWrapperDict( - key1=optim_wrapper1, key2=optim_wrapper2) + optim_wrapper_dict = OptimWrapperDict(key1=optim_wrapper1, + key2=optim_wrapper2) runner.optim_wrapper = optim_wrapper_dict hook.before_train_iter(runner, batch_idx=2, data_batch=None) self.assertEqual( @@ -108,8 +108,10 @@ def test_after_train_iter(self): cfg = copy.deepcopy(self.epoch_based_cfg) runner = self.build_runner(cfg) hook = self._get_runtime_info_hook(runner) - hook.after_train_iter( - runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111}) + hook.after_train_iter(runner, + batch_idx=2, + data_batch=None, + outputs={'loss_cls': 1.111}) self.assertEqual( runner.message_hub.get_scalar('train/loss_cls').current(), 1.111) @@ -167,14 +169,13 @@ def test_scalar_check(self): # check other scalar dtypes val = np.mean([5]) # this is not ndarray but dtype is np.float64. - hook.after_val_epoch( - runner, - metrics={ - 'acc_f32': val.astype(np.float32), - 'acc_i32': val.astype(np.int32), - 'acc_u8': val.astype(np.uint8), - 'acc_ndarray': np.array([5]), - }) + hook.after_val_epoch(runner, + metrics={ + 'acc_f32': val.astype(np.float32), + 'acc_i32': val.astype(np.int32), + 'acc_u8': val.astype(np.uint8), + 'acc_ndarray': np.array([5]), + }) self.assertEqual( runner.message_hub.get_scalar('val/acc_f32').current(), 5) self.assertEqual( @@ -185,13 +186,12 @@ def test_scalar_check(self): runner.message_hub.get_scalar('val/acc_ndarray').current(), 5) val = torch.tensor([5.0]).mean() - hook.after_val_epoch( - runner, - metrics={ - 'acc_f32': val.float(), - 'acc_i64': val.long(), - 'acc_tensor': torch.tensor([5]), - }) + hook.after_val_epoch(runner, + metrics={ + 'acc_f32': val.float(), + 'acc_i64': val.long(), + 'acc_tensor': torch.tensor([5]), + }) self.assertEqual( runner.message_hub.get_scalar('val/acc_f32').current(), 5) self.assertEqual( diff --git a/tests/test_hooks/test_sync_buffers_hook.py b/tests/test_hooks/test_sync_buffers_hook.py index 6d4019dc58..8558f53985 100644 --- a/tests/test_hooks/test_sync_buffers_hook.py +++ b/tests/test_hooks/test_sync_buffers_hook.py @@ -70,5 +70,6 @@ def test_with_runner(self): def setup_dist_env(self): super().setup_dist_env() os.environ['RANK'] = str(self.rank) - torch_dist.init_process_group( - backend='gloo', rank=self.rank, world_size=self.world_size) + torch_dist.init_process_group(backend='gloo', + rank=self.rank, + world_size=self.world_size) diff --git a/tests/test_hub/test_hub.py b/tests/test_hub/test_hub.py index ae21d3dab4..5dd951e478 100644 --- a/tests/test_hub/test_hub.py +++ b/tests/test_hub/test_hub.py @@ -12,9 +12,8 @@ # mmdet has a more typical config structure, while mmpose has a complex # config structure -@pytest.mark.skipif( - not (is_installed('mmdet') and is_installed('mmpose')), - reason='mmdet and mmpose should be installed') +@pytest.mark.skipif(not (is_installed('mmdet') and is_installed('mmpose')), + reason='mmdet and mmpose should be installed') def test_get_config(): # Test load base config. base_cfg = get_config('mmdet::_base_/models/faster-rcnn_r50_fpn.py') @@ -32,8 +31,8 @@ def test_get_config(): assert cfg._cfg_dict == test_cfg._cfg_dict # Test pretrained - cfg = get_config( - 'mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', pretrained=True) + cfg = get_config('mmdet::faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py', + pretrained=True) assert cfg.model_path == 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # noqa E301 # Test load mmpose @@ -42,8 +41,8 @@ def test_get_config(): ) -@pytest.mark.skipif( - not is_installed('mmdet'), reason='mmdet and mmpose should be installed') +@pytest.mark.skipif(not is_installed('mmdet'), + reason='mmdet and mmpose should be installed') def test_get_model(): # TODO compatible with downstream codebase. DefaultScope.get_instance('test_get_model', scope_name='test_scope') diff --git a/tests/test_infer/test_infer.py b/tests/test_infer/test_infer.py index 2d020b6300..c0142c98a5 100644 --- a/tests/test_infer/test_infer.py +++ b/tests/test_infer/test_infer.py @@ -133,8 +133,8 @@ def test_call(self): inferencer(imgs) inferencer(img_paths) - @pytest.mark.skipif( - not is_imported('mmdet'), reason='mmdet is not installed') + @pytest.mark.skipif(not is_imported('mmdet'), + reason='mmdet is not installed') def test_load_model_from_meta(self): from mmdet.utils import register_all_modules @@ -154,8 +154,8 @@ def test_get_chunk_data(self): inferencer = ToyInferencer(self.cfg_path, self.ckpt_path) data = list(range(1, 11)) chunk_data = inferencer._get_chunk_data(data, 3) - self.assertEqual( - list(chunk_data), [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]) + self.assertEqual(list(chunk_data), + [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]) def test_init_visualizer(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -173,11 +173,10 @@ def test_init_visualizer(self): def test_dispatch_kwargs(self): inferencer = ToyInferencer(self.cfg_path, self.ckpt_path) - kwargs = dict( - pre_arg=dict(a=1), - for_arg=dict(c=2), - vis_arg=dict(b=3), - pos_arg=dict(d=4)) + kwargs = dict(pre_arg=dict(a=1), + for_arg=dict(c=2), + vis_arg=dict(b=3), + pos_arg=dict(d=4)) pre_arg, for_arg, vis_arg, pos_arg = inferencer._dispatch_kwargs( **kwargs) self.assertEqual(pre_arg, dict(pre_arg=dict(a=1))) @@ -217,8 +216,8 @@ def test_preprocess(self): for data in dataloader: self.assertTrue(is_list_of(data, torch.Tensor)) - @pytest.mark.skipif( - not is_imported('mmdet'), reason='mmdet is not installed') + @pytest.mark.skipif(not is_imported('mmdet'), + reason='mmdet is not installed') def test_list_models(self): model_list = BaseInferencer.list_models('mmdet') self.assertTrue(len(model_list) > 0) diff --git a/tests/test_logging/test_logger.py b/tests/test_logging/test_logger.py index 2ac2b3548e..2826c349e1 100644 --- a/tests/test_logging/test_logger.py +++ b/tests/test_logging/test_logger.py @@ -34,16 +34,18 @@ def test_init_rank0(self, tmp_path): # If `rank=0`, the `log_level` of stream_handler and file_handler # depends on the given arguments. tmp_file = tmp_path / 'tmp_file.log' - logger = MMLogger.get_instance( - 'rank0.pkg2', log_level='INFO', log_file=str(tmp_file)) + logger = MMLogger.get_instance('rank0.pkg2', + log_level='INFO', + log_file=str(tmp_file)) assert isinstance(logger, logging.Logger) assert len(logger.handlers) == 2 assert isinstance(logger.handlers[0], logging.StreamHandler) assert isinstance(logger.handlers[1], logging.FileHandler) logger_pkg3 = MMLogger.get_instance('rank0.pkg2') assert id(logger_pkg3) == id(logger) - logger = MMLogger.get_instance( - 'rank0.pkg3', logger_name='logger_test', log_level='INFO') + logger = MMLogger.get_instance('rank0.pkg3', + logger_name='logger_test', + log_level='INFO') assert logger.name == 'logger_test' assert logger.instance_name == 'rank0.pkg3' # `FileHandler` should be closed in Windows, otherwise we cannot @@ -59,14 +61,14 @@ def test_init_rank1(self, tmp_path): # If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`. tmp_file = tmp_path / 'tmp_file.log' log_path = tmp_path / 'tmp_file_test_device1_rank1.log' - logger = MMLogger.get_instance( - 'rank1.pkg2', log_level='INFO', log_file=str(tmp_file)) + logger = MMLogger.get_instance('rank1.pkg2', + log_level='INFO', + log_file=str(tmp_file)) assert len(logger.handlers) == 1 - logger = MMLogger.get_instance( - 'rank1.pkg3', - log_level='INFO', - log_file=str(tmp_file), - distributed=True) + logger = MMLogger.get_instance('rank1.pkg3', + log_level='INFO', + log_file=str(tmp_file), + distributed=True) assert logger.handlers[0].level == logging.ERROR assert logger.handlers[1].level == logging.INFO assert len(logger.handlers) == 2 @@ -94,8 +96,9 @@ def test_handler(self, capsys, tmp_path, log_level): # test file_handler output plain text without color. tmp_file = tmp_path / 'tmp_file.log' instance_name = f'test_file_{log_level}' - logger = MMLogger.get_instance( - instance_name, log_level=log_level, log_file=tmp_file) + logger = MMLogger.get_instance(instance_name, + log_level=log_level, + log_file=tmp_file) logger.log(level=log_level, msg='welcome') with open(tmp_file) as f: @@ -209,27 +212,32 @@ def test_filter(self, capsys): def test_file_handlers(self, tmp_path): tmp_file = tmp_path / 'tmp_file.log' fh = None - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.FileHandler) fh = dict(type='BaseRotatingHandler', mode='a') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.handlers.BaseRotatingHandler) fh = dict(type='RotatingFileHandler', maxBytes=1024) - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.handlers.RotatingFileHandler) fh = dict(type='TimedRotatingFileHandler', when='MIDNIGHT') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.handlers.TimedRotatingFileHandler) fh = dict(type='WatchedFileHandler') - logger = MMLogger( - name='test_file_handlers', log_file=tmp_file, file_handler_cfg=fh) + logger = MMLogger(name='test_file_handlers', + log_file=tmp_file, + file_handler_cfg=fh) assert isinstance(logger.handlers[-1], logging.handlers.WatchedFileHandler) # `FileHandler` should be closed in Windows, otherwise we cannot diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index 3dc5cef748..b82211ea2d 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -27,10 +27,9 @@ def test_init(self): MessageHub('hello', log_scalars=OrderedDict(a=1)) # `Resumed_keys` with pytest.raises(AssertionError): - MessageHub( - 'hello', - runtime_info=OrderedDict(iter=1), - resumed_keys=OrderedDict(iters=False)) + MessageHub('hello', + runtime_info=OrderedDict(iter=1), + resumed_keys=OrderedDict(iters=False)) def test_update_scalar(self): message_hub = MessageHub.get_instance('mmengine') @@ -99,11 +98,10 @@ def test_get_runtime(self): def test_get_scalars(self): import torch message_hub = MessageHub.get_instance('mmengine') - log_dict = dict( - loss=1, - loss_cls=torch.tensor(2), - loss_bbox=np.array(3), - loss_iou=dict(value=1, count=2)) + log_dict = dict(loss=1, + loss_cls=torch.tensor(2), + loss_bbox=np.array(3), + loss_iou=dict(value=1, count=2)) message_hub.update_scalars(log_dict) loss = message_hub.get_scalar('loss') loss_cls = message_hub.get_scalar('loss_cls') @@ -169,8 +167,11 @@ def test_load_state_dict(self, capsys): state_dict = OrderedDict() state_dict['log_scalars'] = dict(a=1, b=HistoryBuffer()) state_dict['runtime_info'] = dict(c=1, d=NoDeepCopy(), e=1) - state_dict['resumed_keys'] = dict( - a=True, b=True, c=True, e=False, f=True) + state_dict['resumed_keys'] = dict(a=True, + b=True, + c=True, + e=False, + f=True) message_hub4 = MessageHub.get_instance('test_load_state_dict4') message_hub4.load_state_dict(state_dict) @@ -179,8 +180,9 @@ def test_load_state_dict(self, capsys): assert 'c' in message_hub4.runtime_info and \ state_dict['runtime_info']['d'] is \ message_hub4.runtime_info['d'] - assert message_hub4._resumed_keys == OrderedDict( - b=True, c=True, e=False) + assert message_hub4._resumed_keys == OrderedDict(b=True, + c=True, + e=False) def test_getstate(self): message_hub = MessageHub.get_instance('name') diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py index 6438b8bde5..f9d3d38ca0 100644 --- a/tests/test_model/test_averaged_model.py +++ b/tests/test_model/test_averaged_model.py @@ -18,9 +18,8 @@ class TestAveragedModel(TestCase): """ # noqa: E501 def _test_swa_model(self, net_device, avg_device): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.Linear(5, 10)).to(net_device) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)).to(net_device) averaged_model = StochasticWeightAverage(model, device=avg_device) averaged_params = [ @@ -52,8 +51,8 @@ def test_averaged_model_all_devices(self): def test_swa_mixed_device(self): if not torch.cuda.is_available(): return - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) model[0].cuda() model[1].cpu() averaged_model = StochasticWeightAverage(model) @@ -73,8 +72,8 @@ def test_swa_mixed_device(self): self.assertTrue(p_avg.device == p_swa.device) def test_swa_state_dict(self): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) averaged_model = StochasticWeightAverage(model) averaged_model2 = StochasticWeightAverage(model) n_updates = 10 @@ -92,19 +91,19 @@ def test_ema(self): # test invalid momentum with self.assertRaisesRegex(AssertionError, 'momentum must be in range'): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) ExponentialMovingAverage(model, momentum=3) # Warning should be raised if the value of momentum in EMA is # a large number with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) ExponentialMovingAverage(model, momentum=0.9) # test EMA - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.Linear(5, 10)) momentum = 0.1 ema_model = ExponentialMovingAverage(model, momentum=momentum) @@ -129,13 +128,14 @@ def test_ema(self): def test_ema_update_buffers(self): # Test EMA and update_buffers as True. - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), + torch.nn.Linear(5, 10)) momentum = 0.1 - ema_model = ExponentialMovingAverage( - model, momentum=momentum, update_buffers=True) + ema_model = ExponentialMovingAverage(model, + momentum=momentum, + update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) @@ -168,9 +168,9 @@ def test_ema_update_buffers(self): assert_allclose(p_target, p_ema) def test_momentum_annealing_ema(self): - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), + torch.nn.Linear(5, 10)) # Test invalid gamma with self.assertRaisesRegex(AssertionError, 'gamma must be greater than 0'): @@ -180,8 +180,10 @@ def test_momentum_annealing_ema(self): momentum = 0.1 gamma = 4 - ema_model = MomentumAnnealingEMA( - model, gamma=gamma, momentum=momentum, update_buffers=True) + ema_model = MomentumAnnealingEMA(model, + gamma=gamma, + momentum=momentum, + update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) @@ -216,19 +218,18 @@ def test_momentum_annealing_ema(self): def test_momentum_annealing_ema_with_interval(self): # Test EMA with momentum annealing and interval - model = torch.nn.Sequential( - torch.nn.Conv2d(1, 5, kernel_size=3), - torch.nn.BatchNorm2d(5, momentum=0.3), torch.nn.Linear(5, 10)) + model = torch.nn.Sequential(torch.nn.Conv2d(1, 5, kernel_size=3), + torch.nn.BatchNorm2d(5, momentum=0.3), + torch.nn.Linear(5, 10)) momentum = 0.1 gamma = 4 interval = 3 - ema_model = MomentumAnnealingEMA( - model, - gamma=gamma, - momentum=momentum, - interval=interval, - update_buffers=True) + ema_model = MomentumAnnealingEMA(model, + gamma=gamma, + momentum=momentum, + interval=interval, + update_buffers=True) averaged_params = [ torch.zeros_like(param) for param in itertools.chain(model.parameters(), model.buffers()) diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 8dc23eec86..484c95ec71 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -94,10 +94,9 @@ def test_parse_losses(self): ] losses = dict(loss_cls=loss_cls, loss_list=loss_list) target_parsed_losses = torch.tensor(6, dtype=torch.float32) - targe_log_vars = dict( - loss=torch.tensor(6, dtype=torch.float32), - loss_cls=torch.tensor(1, dtype=torch.float32), - loss_list=torch.tensor(5, dtype=torch.float32)) + targe_log_vars = dict(loss=torch.tensor(6, dtype=torch.float32), + loss_cls=torch.tensor(1, dtype=torch.float32), + loss_list=torch.tensor(5, dtype=torch.float32)) parse_losses, log_vars = model.parse_losses(losses) assert_allclose(parse_losses, target_parsed_losses) for key in log_vars: diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index c409260a50..e429db032c 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -97,12 +97,11 @@ def test_init(self): assert_allclose(data_processor.pad_value, torch.tensor(0)) # Initiate model with bgr2rgb, mean, std .etc.. - data_processor = ImgDataPreprocessor( - bgr_to_rgb=True, - mean=[0, 0, 0], - std=[255, 255, 255], - pad_size_divisor=16, - pad_value=10) + data_processor = ImgDataPreprocessor(bgr_to_rgb=True, + mean=[0, 0, 0], + std=[255, 255, 255], + pad_size_divisor=16, + pad_value=10) self.assertTrue(data_processor._enable_normalize) self.assertTrue(data_processor._channel_conversion, True) assert_allclose(data_processor.mean, @@ -122,15 +121,15 @@ def test_init(self): ImgDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True) with self.assertRaisesRegex(AssertionError, 'mean and std should be'): - ImgDataPreprocessor( - bgr_to_rgb=True, - mean=None, - std=[255, 255, 255], - pad_size_divisor=16, - pad_value=10) - - data_processor = ImgDataPreprocessor( - bgr_to_rgb=True, pad_size_divisor=16, pad_value=10) + ImgDataPreprocessor(bgr_to_rgb=True, + mean=None, + std=[255, 255, 255], + pad_size_divisor=16, + pad_value=10) + + data_processor = ImgDataPreprocessor(bgr_to_rgb=True, + pad_size_divisor=16, + pad_value=10) self.assertFalse(data_processor._enable_normalize) def test_forward(self): @@ -147,10 +146,9 @@ def test_forward(self): data_sample1 = InstanceData(bboxes=torch.randn(5, 4)) data_sample2 = InstanceData(bboxes=torch.randn(5, 4)) - data = dict( - inputs=[inputs1.clone(), inputs2.clone()], - data_sample=[data_sample1.clone(), - data_sample2.clone()]) + data = dict(inputs=[inputs1.clone(), inputs2.clone()], + data_sample=[data_sample1.clone(), + data_sample2.clone()]) std = torch.tensor([1, 2, 3]).view(-1, 1, 1) target_inputs1 = (inputs1.clone()[[2, 1, 0], ...] - 127.5) / std @@ -193,26 +191,27 @@ def test_forward(self): assert_allclose(data_sample.bboxes, target_data_sample.bboxes) # Test gray image with 3 dim mean will raise error - data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) - data = dict( - inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) + data_preprocessor = ImgDataPreprocessor(mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5)) + data = dict(inputs=[torch.ones(10, 10), + torch.ones(10, 10)], + data_sample=None) with self.assertRaisesRegex(AssertionError, 'If the mean has 3 values'): data_preprocessor(data) - data = dict( - inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) + data = dict(inputs=[torch.ones(10, 10), + torch.ones(10, 10)], + data_sample=None) with self.assertRaisesRegex(AssertionError, 'If the mean has 3 values'): data_preprocessor(data) # Test stacked batch inputs and batch data samples - data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), - std=(127.5, 127.5, 127.5), - rgb_to_bgr=True, - pad_size_divisor=16) + data_preprocessor = ImgDataPreprocessor(mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5), + rgb_to_bgr=True, + pad_size_divisor=16) _batch_inputs = torch.randn(2, 3, 10, 10) _batch_labels = [torch.randn(1), torch.randn(1)] data = dict(inputs=_batch_inputs, data_sample=_batch_labels) @@ -226,8 +225,8 @@ def test_forward(self): assert_allclose(target_batch_inputs, inputs) # Test batch inputs without convert channel order and pad - data_preprocessor = ImgDataPreprocessor( - mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) + data_preprocessor = ImgDataPreprocessor(mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5)) _batch_inputs = torch.randn(2, 3, 10, 10) _batch_labels = [torch.randn(1), torch.randn(1)] data = dict(inputs=_batch_inputs, data_sample=_batch_labels) @@ -239,8 +238,8 @@ def test_forward(self): assert_allclose(target_batch_inputs, inputs) # Test empty `data_sample` - data = dict( - inputs=[inputs1.clone(), inputs2.clone()], data_sample=None) + data = dict(inputs=[inputs1.clone(), inputs2.clone()], + data_sample=None) output = data_preprocessor(data, True) inputs, data_samples = output['inputs'], output['data_sample'] self.assertIsNone(data_samples) diff --git a/tests/test_model/test_base_module.py b/tests/test_model/test_base_module.py index 1401eed298..bf9489aa76 100644 --- a/tests/test_model/test_base_module.py +++ b/tests/test_model/test_base_module.py @@ -97,20 +97,27 @@ class TestBaseModule(TestCase): def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.BaseModule = BaseModule() - self.model_cfg = dict( - type='FooModel', - init_cfg=[ - dict(type='Constant', val=1, bias=2, layer='Linear'), - dict(type='Constant', val=3, bias=4, layer='Conv1d'), - dict(type='Constant', val=5, bias=6, layer='Conv2d') - ], - component1=dict(type='FooConv1d'), - component2=dict(type='FooConv2d'), - component3=dict(type='FooLinear'), - component4=dict( - type='FooLinearConv1d', - linear=dict(type='FooLinear'), - conv1d=dict(type='FooConv1d'))) + self.model_cfg = dict(type='FooModel', + init_cfg=[ + dict(type='Constant', + val=1, + bias=2, + layer='Linear'), + dict(type='Constant', + val=3, + bias=4, + layer='Conv1d'), + dict(type='Constant', + val=5, + bias=6, + layer='Conv2d') + ], + component1=dict(type='FooConv1d'), + component2=dict(type='FooConv2d'), + component3=dict(type='FooLinear'), + component4=dict(type='FooLinearConv1d', + linear=dict(type='FooLinear'), + conv1d=dict(type='FooConv1d'))) self.model = build_from_cfg(self.model_cfg, FOOMODELS) self.logger = MMLogger.get_instance(self._testMethodName) @@ -212,8 +219,8 @@ def __init__(self, torch.save(self.model.state_dict(), checkpoint_path) model_cfg = copy.deepcopy(self.model_cfg) model_cfg['type'] = 'PratrainedModel' - model_cfg['init_cfg'] = dict( - type='Pretrained', checkpoint=checkpoint_path) + model_cfg['init_cfg'] = dict(type='Pretrained', + checkpoint=checkpoint_path) model = FOOMODELS.build(model_cfg) ori_layer_weight = model.linear.linear.weight.clone() ori_layer_bias = model.linear.linear.bias.clone() @@ -280,8 +287,8 @@ def test_dump_init_info(self): model1.init_weights() assert len(os.listdir(dump_dir)) == 0 log_path = os.path.join(dump_dir, 'out.log') - MMLogger.get_instance( - 'logger2', log_file=log_path) # add logger with FileHandler + MMLogger.get_instance('logger2', + log_file=log_path) # add logger with FileHandler model2 = build_from_cfg(self.model_cfg, FOOMODELS) model2.init_weights() assert len(os.listdir(dump_dir)) == 1 @@ -297,14 +304,16 @@ class TestModuleList(TestCase): def test_modulelist_weight_init(self): models_cfg = [ - dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + dict(type='FooConv1d', + init_cfg=dict(type='Constant', + layer='Conv1d', + val=0., + bias=1.)), + dict(type='FooConv2d', + init_cfg=dict(type='Constant', + layer='Conv2d', + val=2., + bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] modellist = ModuleList(layers) @@ -323,10 +332,11 @@ def test_modulelist_weight_init(self): torch.full(modellist[1].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] - modellist = ModuleList( - layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + modellist = ModuleList(layers, + init_cfg=dict(type='Constant', + layer=['Conv1d', 'Conv2d'], + val=4., + bias=5.)) modellist.init_weights() self.assertTrue( torch.equal(modellist[0].conv1d.weight, @@ -346,14 +356,16 @@ class TestModuleDict(TestCase): def test_moduledict_weight_init(self): models_cfg = dict( - foo_conv_1d=dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - foo_conv_2d=dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + foo_conv_1d=dict(type='FooConv1d', + init_cfg=dict(type='Constant', + layer='Conv1d', + val=0., + bias=1.)), + foo_conv_2d=dict(type='FooConv2d', + init_cfg=dict(type='Constant', + layer='Conv2d', + val=2., + bias=3.)), ) layers = { name: build_from_cfg(cfg, COMPONENTS) @@ -382,10 +394,11 @@ def test_moduledict_weight_init(self): name: build_from_cfg(cfg, COMPONENTS) for name, cfg in models_cfg.items() } - modeldict = ModuleDict( - layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + modeldict = ModuleDict(layers, + init_cfg=dict(type='Constant', + layer=['Conv1d', 'Conv2d'], + val=4., + bias=5.)) modeldict.init_weights() self.assertTrue( torch.equal( @@ -409,14 +422,16 @@ class TestSequential(TestCase): def test_sequential_model_weight_init(self): seq_model_cfg = [ - dict( - type='FooConv1d', - init_cfg=dict( - type='Constant', layer='Conv1d', val=0., bias=1.)), - dict( - type='FooConv2d', - init_cfg=dict( - type='Constant', layer='Conv2d', val=2., bias=3.)), + dict(type='FooConv1d', + init_cfg=dict(type='Constant', + layer='Conv1d', + val=0., + bias=1.)), + dict(type='FooConv2d', + init_cfg=dict(type='Constant', + layer='Conv2d', + val=2., + bias=3.)), ] layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] seq_model = Sequential(*layers) @@ -435,10 +450,11 @@ def test_sequential_model_weight_init(self): torch.full(seq_model[1].conv2d.bias.shape, 3.))) # inner init_cfg has higher priority layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] - seq_model = Sequential( - *layers, - init_cfg=dict( - type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + seq_model = Sequential(*layers, + init_cfg=dict(type='Constant', + layer=['Conv1d', 'Conv2d'], + val=4., + bias=5.)) seq_model.init_weights() self.assertTrue( torch.equal(seq_model[0].conv1d.weight, diff --git a/tests/test_model/test_efficient_conv_bn_eval.py b/tests/test_model/test_efficient_conv_bn_eval.py index eb91a6d090..e8cee21f89 100644 --- a/tests/test_model/test_efficient_conv_bn_eval.py +++ b/tests/test_model/test_efficient_conv_bn_eval.py @@ -46,9 +46,8 @@ def forward(self, x): return x -@unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('1.8'), - reason='torch.fx needs Pytorch 1.8 or higher') +@unittest.skipIf(digit_version(TORCH_VERSION) < digit_version('1.8'), + reason='torch.fx needs Pytorch 1.8 or higher') class TestEfficientConvBNEval(TestCase): """Test the turn_on_efficient_conv_bn_eval function.""" diff --git a/tests/test_model/test_model_utils.py b/tests/test_model/test_model_utils.py index a08ff67d77..203e6000e4 100644 --- a/tests/test_model/test_model_utils.py +++ b/tests/test_model/test_model_utils.py @@ -25,8 +25,8 @@ def add_module(self, name, module): raise ValueError() -@pytest.mark.skipif( - torch.__version__ == 'parrots', reason='not supported in parrots now') +@pytest.mark.skipif(torch.__version__ == 'parrots', + reason='not supported in parrots now') def test_revert_syncbn(): # conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.SyncBatchNorm(8)) @@ -40,8 +40,8 @@ def test_revert_syncbn(): revert_sync_batchnorm(conv) -@pytest.mark.skipif( - torch.__version__ == 'parrots', reason='not supported in parrots now') +@pytest.mark.skipif(torch.__version__ == 'parrots', + reason='not supported in parrots now') def test_convert_syncbn(): # conv = ConvModule(3, 8, 2, norm_cfg=dict(type='SyncBN')) conv = nn.Sequential(nn.Conv2d(3, 8, 2), nn.BatchNorm2d(8)) diff --git a/tests/test_model/test_test_aug_time.py b/tests/test_model/test_test_aug_time.py index d2b8c97190..62f44bb1cc 100644 --- a/tests/test_model/test_test_aug_time.py +++ b/tests/test_model/test_test_aug_time.py @@ -79,10 +79,12 @@ def test_test_step(self): ] tuple_dataset = [([1, 2], [3, 4]) for _ in range(10)] - dict_dataloader = DataLoader( - dict_dataset, batch_size=2, collate_fn=pseudo_collate) - tuple_dataloader = DataLoader( - tuple_dataset, batch_size=2, collate_fn=pseudo_collate) + dict_dataloader = DataLoader(dict_dataset, + batch_size=2, + collate_fn=pseudo_collate) + tuple_dataloader = DataLoader(tuple_dataset, + batch_size=2, + collate_fn=pseudo_collate) for data in dict_dataloader: result = tta_model.test_step(data) @@ -103,8 +105,8 @@ def test_init(self): def test_with_runner(self): cfg = copy.deepcopy(self.epoch_based_cfg) - cfg.model = dict( - type='ToyTestTimeAugModel', module=dict(type='ToyModel')) + cfg.model = dict(type='ToyTestTimeAugModel', + module=dict(type='ToyModel')) cfg.test_dataloader.dataset = dict(type='ToyDatasetTTA') cfg.test_dataloader.dataset['pipeline'] = dict(type='ToyTTAPipeline') runner = self.build_runner(cfg) diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index ea657acac1..31f7beeea0 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -79,8 +79,8 @@ def setUp(self): super().setUp() self._spawn_processes() - @unittest.skipIf( - not torch.cuda.is_available(), reason='cuda should be available') + @unittest.skipIf(not torch.cuda.is_available(), + reason='cuda should be available') def test_train_step(self): self._init_dist_env(self.rank, self.world_size) # Mixed precision training and gradient asynchronous should be valid at @@ -88,8 +88,8 @@ def test_train_step(self): model = ToyModel().cuda() ddp_model = MMDistributedDataParallel(module=model) optimizer = SGD(ddp_model.parameters(), lr=0) - optim_wrapper = AmpOptimWrapper( - optimizer=optimizer, accumulative_counts=3) + optim_wrapper = AmpOptimWrapper(optimizer=optimizer, + accumulative_counts=3) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss'] @@ -113,11 +113,11 @@ def test_train_step(self): self.assertIsNone(grad) # Test enable detect_anomalous_params. - ddp_model = MMDistributedDataParallel( - module=model, detect_anomalous_params=True) + ddp_model = MMDistributedDataParallel(module=model, + detect_anomalous_params=True) optimizer = SGD(ddp_model.parameters(), lr=0) - optim_wrapper = AmpOptimWrapper( - optimizer=optimizer, accumulative_counts=3) + optim_wrapper = AmpOptimWrapper(optimizer=optimizer, + accumulative_counts=3) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss'] @@ -148,12 +148,13 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29510' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) -@unittest.skipIf( - not torch.cuda.is_available(), reason='cuda should be available') +@unittest.skipIf(not torch.cuda.is_available(), + reason='cuda should be available') class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): def test_init(self): @@ -178,8 +179,8 @@ def test_train_step(self): optimizer2 = SGD(model.conv1.parameters(), lr=0.2) optim_wrapper1 = OptimWrapper(optimizer1, 1) optim_wrapper2 = OptimWrapper(optimizer2, 1) - optim_wrapper_dict = OptimWrapperDict( - optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2) + optim_wrapper_dict = OptimWrapperDict(optim_wrapper1=optim_wrapper1, + optim_wrapper2=optim_wrapper2) inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 data = dict(inputs=inputs, data_sample=None) # Automatically sync grads of `optim_wrapper1` since @@ -215,15 +216,15 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29515' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) -@unittest.skipIf( - torch.cuda.device_count() < 2, reason='need 2 gpu to test fsdp') -@unittest.skipIf( - digit_version(TORCH_VERSION) < digit_version('2.0.0'), - reason='fsdp needs Pytorch 2.0.0 or higher') +@unittest.skipIf(torch.cuda.device_count() < 2, + reason='need 2 gpu to test fsdp') +@unittest.skipIf(digit_version(TORCH_VERSION) < digit_version('2.0.0'), + reason='fsdp needs Pytorch 2.0.0 or higher') class TestMMFullyShardedDataParallel(MultiProcessTestCase): def _init_dist_env(self, rank, world_size): @@ -234,8 +235,9 @@ def _init_dist_env(self, rank, world_size): num_gpus = torch.cuda.device_count() torch.cuda.set_device(rank % num_gpus) - torch_dist.init_process_group( - backend='nccl', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='nccl', + rank=rank, + world_size=world_size) def setUp(self) -> None: super().setUp() @@ -266,8 +268,8 @@ def wrap_policy(module, recurse=True, *args, **kwargs): return True return isinstance(module, nn.Conv2d) - fsdp_model = MMFullyShardedDataParallel( - module=model.cuda(), auto_wrap_policy=wrap_policy) + fsdp_model = MMFullyShardedDataParallel(module=model.cuda(), + auto_wrap_policy=wrap_policy) optimizer = SGD(fsdp_model.parameters(), lr=0.1) optim_wrapper = OptimWrapper(optimizer, accumulative_counts=1) inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 diff --git a/tests/test_optim/test_optimizer/test_optimizer.py b/tests/test_optim/test_optimizer/test_optimizer.py index 113aacd6c8..cbebdd9b49 100644 --- a/tests/test_optim/test_optimizer/test_optimizer.py +++ b/tests/test_optim/test_optimizer/test_optimizer.py @@ -26,8 +26,8 @@ MMCV_FULL_AVAILABLE = mmcv_full_available() if not MMCV_FULL_AVAILABLE: - sys.modules['mmcv.ops'] = MagicMock( - DeformConv2d=dict, ModulatedDeformConv2d=dict) + sys.modules['mmcv.ops'] = MagicMock(DeformConv2d=dict, + ModulatedDeformConv2d=dict) def has_dadaptation() -> bool: @@ -73,8 +73,10 @@ def __init__(self): self.sub = SubModel() if MMCV_FULL_AVAILABLE: from mmcv.ops import DeformConv2dPack - self.dcn = DeformConv2dPack( - 3, 4, kernel_size=3, deformable_groups=1) + self.dcn = DeformConv2dPack(3, + 4, + kernel_size=3, + deformable_groups=1) class ExampleDuplicateModel(nn.Module): @@ -90,8 +92,10 @@ def __init__(self): self.conv3[0] = self.conv1[0] if MMCV_FULL_AVAILABLE: from mmcv.ops import DeformConv2dPack - self.dcn = DeformConv2dPack( - 3, 4, kernel_size=3, deformable_groups=1) + self.dcn = DeformConv2dPack(3, + 4, + kernel_size=3, + deformable_groups=1) def forward(self, x): return x @@ -271,23 +275,19 @@ def test_transformers_optimizers(self): def test_build_optimizer(self): # test build function without ``constructor`` and ``paramwise_cfg`` - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # test build optimizer without type in optim_wrapper_cfg - optim_wrapper_cfg = dict( - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_wrapper = build_optim_wrapper(self.model, optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, OptimWrapper) self._check_default_optimizer(optim_wrapper.optimizer, self.model) @@ -310,24 +310,20 @@ def test_build_optimizer(self): lambda: build_optim_wrapper(self.model, optim_wrapper_cfg)) def test_build_default_optimizer_constructor(self): - optim_wrapper = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) - optim_constructor_cfg = dict( - type='DefaultOptimWrapperConstructor', - optim_wrapper_cfg=optim_wrapper, - paramwise_cfg=paramwise_cfg) + optim_wrapper = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3) + optim_constructor_cfg = dict(type='DefaultOptimWrapperConstructor', + optim_wrapper_cfg=optim_wrapper, + paramwise_cfg=paramwise_cfg) optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( optim_constructor_cfg) optim_wrapper = optim_constructor(self.model) @@ -335,13 +331,11 @@ def test_build_default_optimizer_constructor(self): **paramwise_cfg) def test_build_custom_optimizer_constructor(self): - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) @OPTIM_WRAPPER_CONSTRUCTORS.register_module() class MyOptimizerConstructor(DefaultOptimWrapperConstructor): @@ -363,10 +357,9 @@ def __call__(self, model): return build_from_cfg(self.optimizer_cfg, OPTIMIZERS) paramwise_cfg = dict(conv1_lr_mult=5) - optim_constructor_cfg = dict( - type='MyOptimizerConstructor', - optim_wrapper_cfg=optim_wrapper_cfg, - paramwise_cfg=paramwise_cfg) + optim_constructor_cfg = dict(type='MyOptimizerConstructor', + optim_wrapper_cfg=optim_wrapper_cfg, + paramwise_cfg=paramwise_cfg) optim_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( optim_constructor_cfg) optimizer = optim_constructor(self.model) @@ -394,9 +387,9 @@ def test_default_optimizer_constructor(self): with self.assertRaises(TypeError): # paramwise_cfg must be a dict or None - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict(lr=0.0001, weight_decay=None)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(lr=0.0001, + weight_decay=None)) paramwise_cfg = ['error'] optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -405,29 +398,28 @@ def test_default_optimizer_constructor(self): with self.assertRaises(ValueError): # bias_decay_mult/norm_decay_mult is specified but weight_decay # is None - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict(lr=0.0001, weight_decay=None)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(lr=0.0001, + weight_decay=None)) paramwise_cfg = dict(bias_decay_mult=1, norm_decay_mult=1) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_constructor(self.model) # basic config with ExampleModel - optimizer_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optimizer_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_constructor = DefaultOptimWrapperConstructor(optimizer_cfg) optim_wrapper = optim_constructor(self.model) self._check_default_optimizer(optim_wrapper.optimizer, self.model) # Support building custom optimizers - CUSTOM_OPTIMIZERS = Registry( - 'custom optimizer', scope='custom optimizer', parent=OPTIMIZERS) + CUSTOM_OPTIMIZERS = Registry('custom optimizer', + scope='custom optimizer', + parent=OPTIMIZERS) class CustomOptimizer(torch.optim.SGD): @@ -444,93 +436,84 @@ def __init__(self, model_params, *args, **kwargs): def test_default_optimizer_constructor_with_model_wrapper(self): # basic config with pseudo data parallel model = PseudoDataParallel() - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = None optim_constructor = DefaultOptimWrapperConstructor(optim_wrapper_cfg) optim_wrapper = optim_constructor(model) - self._check_default_optimizer( - optim_wrapper.optimizer, model, prefix='module.') + self._check_default_optimizer(optim_wrapper.optimizer, + model, + prefix='module.') # paramwise_cfg with pseudo data parallel model = PseudoDataParallel() - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) - self._check_sgd_optimizer( - optim_wrapper.optimizer, model, prefix='module.', **paramwise_cfg) + self._check_sgd_optimizer(optim_wrapper.optimizer, + model, + prefix='module.', + **paramwise_cfg) # basic config with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(ExampleModel()) - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = None optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg) optim_wrapper = optim_constructor(model) - self._check_default_optimizer( - optim_wrapper.optimizer, model, prefix='module.') + self._check_default_optimizer(optim_wrapper.optimizer, + model, + prefix='module.') # paramwise_cfg with DataParallel if torch.cuda.is_available(): model = torch.nn.DataParallel(self.model) - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(model) - self._check_sgd_optimizer( - optim_wrapper.optimizer, - model, - prefix='module.', - **paramwise_cfg) + self._check_sgd_optimizer(optim_wrapper.optimizer, + model, + prefix='module.', + **paramwise_cfg) def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): # Empty paramwise_cfg with ExampleModel - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict() optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -541,13 +524,11 @@ def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): model = ExampleModel() for param in model.parameters(): param.requires_grad = False - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) paramwise_cfg = dict() optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -556,20 +537,17 @@ def test_default_optimizer_constructor_with_empty_paramwise_cfg(self): def test_default_optimizer_constructor_with_paramwise_cfg(self): # paramwise_cfg with ExampleModel - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) optim_wrapper = optim_constructor(self.model) @@ -578,19 +556,16 @@ def test_default_optimizer_constructor_with_paramwise_cfg(self): def test_default_optimizer_constructor_no_grad(self): # paramwise_cfg with ExampleModel and no grad - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1) self.model.conv1.requires_grad_(False) optim_constructor = DefaultOptimWrapperConstructor( @@ -606,18 +581,15 @@ def test_default_optimizer_constructor_no_grad(self): def test_default_optimizer_constructor_bypass_duplicate(self): # paramwise_cfg with bypass_duplicate option model = ExampleDuplicateModel() - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1) with self.assertRaisesRegex( ValueError, @@ -626,14 +598,13 @@ def test_default_optimizer_constructor_bypass_duplicate(self): optim_wrapper_cfg, paramwise_cfg) optim_constructor(model) - paramwise_cfg = dict( - bias_lr_mult=2, - bias_decay_mult=0.5, - norm_decay_mult=0, - dwconv_decay_mult=0.1, - dcn_offset_lr_mult=0.1, - flat_decay_mult=0.3, - bypass_duplicate=True) + paramwise_cfg = dict(bias_lr_mult=2, + bias_decay_mult=0.5, + norm_decay_mult=0, + dwconv_decay_mult=0.1, + dcn_offset_lr_mult=0.1, + flat_decay_mult=0.3, + bypass_duplicate=True) optim_constructor = DefaultOptimWrapperConstructor( optim_wrapper_cfg, paramwise_cfg) @@ -663,21 +634,18 @@ def test_default_optimizer_constructor_bypass_duplicate(self): def test_default_optimizer_constructor_custom_key(self): # test DefaultOptimWrapperConstructor with custom_keys and # ExampleModel - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) - paramwise_cfg = dict( - custom_keys={ - 'param1': dict(lr_mult=10), - 'sub': dict(lr_mult=0.1, decay_mult=0), - 'sub.gn': dict(lr_mult=0.01), - 'non_exist_key': dict(lr_mult=0.0) - }, - norm_decay_mult=0.5) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) + paramwise_cfg = dict(custom_keys={ + 'param1': dict(lr_mult=10), + 'sub': dict(lr_mult=0.1, decay_mult=0), + 'sub.gn': dict(lr_mult=0.01), + 'non_exist_key': dict(lr_mult=0.0) + }, + norm_decay_mult=0.5) with self.assertRaises(TypeError): # custom_keys should be a dict @@ -689,8 +657,8 @@ def test_default_optimizer_constructor_custom_key(self): with self.assertRaises(ValueError): # if 'decay_mult' is specified in custom_keys, weight_decay # should be specified - optim_wrapper_cfg_ = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)) + optim_wrapper_cfg_ = dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01)) paramwise_cfg_ = dict( custom_keys={'.backbone': dict(decay_mult=0.5)}) optim_constructor = DefaultOptimWrapperConstructor( @@ -760,10 +728,10 @@ def test_default_optimizer_constructor_custom_key(self): # test DefaultOptimWrapperConstructor with custom_keys and # ExampleModel 2 - optim_wrapper_cfg = dict( - type='OptimWrapper', - optimizer=dict( - type='SGD', lr=self.base_lr, momentum=self.momentum)) + optim_wrapper_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=self.base_lr, + momentum=self.momentum)) paramwise_cfg = dict(custom_keys={'param1': dict(lr_mult=10)}) optim_constructor = DefaultOptimWrapperConstructor( @@ -849,24 +817,21 @@ def test_zero_redundancy_optimizer(self): self.base_wd = 0.9 # test build function - optim_wrapper_cfg = dict( - optimizer=dict( - type='ZeroRedundancyOptimizer', - optimizer_type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optim_wrapper_cfg = dict(optimizer=dict(type='ZeroRedundancyOptimizer', + optimizer_type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, model) # test build optimizer without ``optimizer_type`` with self.assertRaises(TypeError): optim_wrapper_cfg = dict( - optimizer=dict( - type='ZeroRedundancyOptimizer', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum)) + optimizer=dict(type='ZeroRedundancyOptimizer', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum)) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) @unittest.skipIf( @@ -885,14 +850,12 @@ def test_zero_redundancy_optimizer_with_paramwise_cfg(self): 'conv1': dict(lr_mult=0.0, decay_mult=0.0), 'conv2': dict(lr_mult=1.0, decay_mult=2.0) }) - optim_wrapper_cfg = dict( - optimizer=dict( - type='ZeroRedundancyOptimizer', - optimizer_type='SGD', - lr=self.base_lr, - weight_decay=self.base_wd, - momentum=self.momentum), - paramwise_cfg=paramwise_cfg) + optim_wrapper_cfg = dict(optimizer=dict(type='ZeroRedundancyOptimizer', + optimizer_type='SGD', + lr=self.base_lr, + weight_decay=self.base_wd, + momentum=self.momentum), + paramwise_cfg=paramwise_cfg) optim_wrapper = build_optim_wrapper(model, optim_wrapper_cfg) self._check_default_optimizer(optim_wrapper.optimizer, model) @@ -901,5 +864,6 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29510' os.environ['RANK'] = str(rank) - torch.distributed.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch.distributed.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index ef1db241dd..12723686d3 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -176,8 +176,8 @@ def test_ger_lr(self): optim_wrapper = OptimWrapper(optim) self.assertEqual(optim_wrapper.get_lr(), dict(lr=[0.1])) model = ToyModel() - optimizer_cfg = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.1)) + optimizer_cfg = dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.1)) paramwise_cfg = dict(custom_keys={'conv1.weight': dict(lr_mult=0.1)}) optim_constructor = DefaultOptimWrapperConstructor( optimizer_cfg, paramwise_cfg) @@ -226,8 +226,8 @@ def test_step(self): @unittest.skipIf(True, reason='Solved in the future') def test_clip_grads(self): # Test `clip_grad` with `clip_norm_` - optim_wrapper = OptimWrapper( - self.optimizer, clip_grad=dict(max_norm=35)) + optim_wrapper = OptimWrapper(self.optimizer, + clip_grad=dict(max_norm=35)) loss = self.model(torch.Tensor(1, 1, 1, 1)) loss.backward() optim_wrapper._clip_grad() @@ -236,8 +236,9 @@ def test_clip_grads(self): self.message_hub._log_scalars.clear() # Test `clip_grad` with `clip_value_` - optim_wrapper = OptimWrapper( - self.optimizer, clip_grad=dict(type='value', clip_value=0.5)) + optim_wrapper = OptimWrapper(self.optimizer, + clip_grad=dict(type='value', + clip_value=0.5)) loss = self.model(torch.Tensor(1, 1, 1, 1)) loss.backward() optim_wrapper._clip_grad() @@ -300,8 +301,9 @@ def _init_dist_env(self, rank, world_size): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29515' os.environ['RANK'] = str(rank) - torch_dist.init_process_group( - backend='gloo', rank=rank, world_size=world_size) + torch_dist.init_process_group(backend='gloo', + rank=rank, + world_size=world_size) # TODO Test the real interface after add testing tool function which can # test the function or method is read called. @@ -328,8 +330,9 @@ def setUp(self) -> None: reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_init(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): pass @@ -339,8 +342,9 @@ def test_init(self): 'https://www.github.com/nvidia/apex') def test_step(self): optimizer = MagicMock(spec=Optimizer) - apex_optim_wrapper = ApexOptimWrapper( - optimizer=optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.backward(loss) @@ -351,8 +355,9 @@ def test_step(self): reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_backward(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.backward(loss) @@ -362,8 +367,9 @@ def test_backward(self): reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_state_dict(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): loss = self.model(torch.Tensor(1, 1, 1, 1).cuda()) apex_optim_wrapper.update_params(loss) @@ -380,8 +386,9 @@ def test_state_dict(self): reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_load_state_dict(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): # Test load from optimizer optimizer = SGD(self.model.parameters(), lr=0.1) @@ -403,8 +410,9 @@ def test_load_state_dict(self): reason='`apex` is not available, Please install apex from ' 'https://www.github.com/nvidia/apex') def test_optim_context(self): - apex_optim_wrapper = ApexOptimWrapper( - optimizer=self.optimizer, opt_level='O1', loss_scale=1) + apex_optim_wrapper = ApexOptimWrapper(optimizer=self.optimizer, + opt_level='O1', + loss_scale=1) with apex_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) @@ -426,24 +434,25 @@ def test_init(self): self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) # Test with dynamic. - amp_optim_wrapper = AmpOptimWrapper( - 'dynamic', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper('dynamic', + optimizer=self.optimizer) self.assertIsNone(amp_optim_wrapper._scale_update_param) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) # Test with dtype float16 - amp_optim_wrapper = AmpOptimWrapper( - dtype='float16', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dtype='float16', + optimizer=self.optimizer) self.assertIs(amp_optim_wrapper.cast_dtype, torch.float16) # Test with dtype bfloat16 - amp_optim_wrapper = AmpOptimWrapper( - dtype='bfloat16', optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dtype='bfloat16', + optimizer=self.optimizer) self.assertIs(amp_optim_wrapper.cast_dtype, torch.bfloat16) # Test with dict loss_scale. - amp_optim_wrapper = AmpOptimWrapper( - dict(init_scale=1, growth_factor=2), optimizer=self.optimizer) + amp_optim_wrapper = AmpOptimWrapper(dict(init_scale=1, + growth_factor=2), + optimizer=self.optimizer) self.assertIsInstance(amp_optim_wrapper.loss_scaler, GradScaler) self.assertIsNone(amp_optim_wrapper._scale_update_param) with self.assertRaisesRegex(TypeError, @@ -455,8 +464,8 @@ def test_init(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_step(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): @@ -478,14 +487,14 @@ def test_step(self, dtype): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_backward(self, dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): raise unittest.SkipTest('bfloat16 not supported by device') - amp_optim_wrapper = AmpOptimWrapper( - optimizer=self.optimizer, dtype=dtype) + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer, + dtype=dtype) loss_scaler = MagicMock() scale_return = MagicMock() scale_fn = MagicMock(return_value=scale_return) @@ -539,14 +548,14 @@ def test_load_state_dict(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_optim_context(self, dtype, target_dtype): - if dtype is not None and (digit_version(TORCH_VERSION) < - digit_version('1.10.0')): + if dtype is not None and (digit_version(TORCH_VERSION) + < digit_version('1.10.0')): raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): raise unittest.SkipTest('bfloat16 not supported by device') - amp_optim_wrapper = AmpOptimWrapper( - optimizer=self.optimizer, dtype=dtype) + amp_optim_wrapper = AmpOptimWrapper(optimizer=self.optimizer, + dtype=dtype) with amp_optim_wrapper.optim_context(self.model): x = torch.randn(1, 1, 1, 1).cuda() y = nn.Conv2d(1, 1, 1).cuda()(x) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py index 3925a33ac9..990cbad757 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper_dict.py @@ -18,8 +18,8 @@ def setUp(self) -> None: self.optim2 = SGD(self.model2.parameters(), lr=0.2, momentum=0.9) self.optim_wrapper1 = OptimWrapper(self.optim1) self.optim_wrapper2 = OptimWrapper(self.optim2) - self.optimizers_wrappers = dict( - optim1=self.optim_wrapper1, optim2=self.optim_wrapper2) + self.optimizers_wrappers = dict(optim1=self.optim_wrapper1, + optim2=self.optim_wrapper2) def test_init(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) @@ -111,8 +111,8 @@ def test_load_state_dict(self): optim_wrapper_load2 = OptimWrapper(optim2) optim_wrapper_dict_save = OptimWrapperDict(**self.optimizers_wrappers) - optim_wrapper_dict_load = OptimWrapperDict( - optim1=optim_wrapper_load1, optim2=optim_wrapper_load2) + optim_wrapper_dict_load = OptimWrapperDict(optim1=optim_wrapper_load1, + optim2=optim_wrapper_load2) state_dict = optim_wrapper_dict_save.state_dict() optim_wrapper_dict_load.load_state_dict(state_dict) @@ -121,21 +121,18 @@ def test_load_state_dict(self): def test_items(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.items()), - list(self.optimizers_wrappers.items())) + self.assertListEqual(list(optim_wrapper_dict.items()), + list(self.optimizers_wrappers.items())) def test_values(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.values()), - list(self.optimizers_wrappers.values())) + self.assertListEqual(list(optim_wrapper_dict.values()), + list(self.optimizers_wrappers.values())) def test_keys(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) - self.assertListEqual( - list(optim_wrapper_dict.keys()), - list(self.optimizers_wrappers.keys())) + self.assertListEqual(list(optim_wrapper_dict.keys()), + list(self.optimizers_wrappers.keys())) def test_getitem(self): optim_wrapper_dict = OptimWrapperDict(**self.optimizers_wrappers) diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index 22787e4709..4d3380b3cf 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -118,8 +118,10 @@ def call_sch_before_optim(): group['initial_lr'] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepLR( - self.optimizer, gamma=0.1, step_size=3, last_step=10) + scheduler = StepLR(self.optimizer, + gamma=0.1, + step_size=3, + last_step=10) scheduler.step() self.optimizer.step() @@ -179,17 +181,16 @@ def test_effective_interval(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) + single_targets = [0.05] * begin + [ + x * 0.05 for x in interpolation + ] + [0.05] * (epochs - iters - begin) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearLR( - self.optimizer, - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + scheduler = LinearLR(self.optimizer, + start_factor=start_factor, + begin=begin, + end=begin + iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def _test_scheduler_value(self, @@ -233,8 +234,10 @@ def test_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepLR( - self.optimizer, gamma=0.1, step_size=3, verbose=True) + scheduler = StepLR(self.optimizer, + gamma=0.1, + step_size=3, + verbose=True) self._test_scheduler_value(scheduler, targets, epochs) def test_multi_step_scheduler(self): @@ -248,8 +251,9 @@ def test_multi_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]) + scheduler = MultiStepLR(self.optimizer, + gamma=0.1, + milestones=[2, 5, 9]) self._test_scheduler_value(scheduler, targets, epochs) def test_constant_scheduler(self): @@ -287,13 +291,14 @@ def test_linear_scheduler(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) + single_targets = [x * 0.05 + for x in interpolation] + [0.05] * (epochs - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearLR( - self.optimizer, start_factor=start_factor, end=iters + 1) + scheduler = LinearLR(self.optimizer, + start_factor=start_factor, + end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler(self): @@ -320,8 +325,10 @@ def test_cos_anneal_scheduler(self): self._test_scheduler_value(scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingLR( - self.optimizer, begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingLR(self.optimizer, + begin=5, + end=100, + eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -332,32 +339,30 @@ def test_poly_scheduler(self): targets_layer1 = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets_layer2 = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] - scheduler = PolyLR( - self.optimizer, power=power, eta_min=min_lr, end=iters + 1) + scheduler = PolyLR(self.optimizer, + power=power, + eta_min=min_lr, + end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): - CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + CosineRestartLR(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) with self.assertRaises(AssertionError): - CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + CosineRestartLR(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) single_targets = [ 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, 0.0086372, 0.0023872, 0.0023872 @@ -365,11 +370,10 @@ def test_cosine_restart_scheduler(self): targets = [ single_targets, [t * self.layer2_mult for t in single_targets] ] - scheduler = CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) + scheduler = CosineRestartLR(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) self._test_scheduler_value(scheduler, targets, epochs=10) def test_reduce_on_plateau_scheduler(self): @@ -429,8 +433,10 @@ def _test_value(epochs, targets, metrics_list, monitor, rule, factor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - scheduler, targets, epochs=epochs, step_kwargs=metrics_list) + self._test_scheduler_value(scheduler, + targets, + epochs=epochs, + step_kwargs=metrics_list) # reset the state of optimizers self.optimizer = optim.SGD([{ @@ -559,9 +565,8 @@ def test_step_scheduler_state_dict(self): def test_multi_step_scheduler_state_dict(self): self._check_scheduler_state_dict( lambda: MultiStepLR( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]), - lambda: MultiStepLR( - self.optimizer, gamma=0.01, milestones=[1, 4, 6])) + self.optimizer, gamma=0.1, milestones=[2, 5, 9]), lambda: + MultiStepLR(self.optimizer, gamma=0.01, milestones=[1, 4, 6])) def test_exp_scheduler_state_dict(self): self._check_scheduler_state_dict( @@ -593,52 +598,50 @@ def test_poly_scheduler_state_dict(self): def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: CosineRestartLR( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), - lambda: CosineRestartLR( - self.optimizer, - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), + lambda: CosineRestartLR(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartLR(self.optimizer, + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), epochs=10) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] self._check_scheduler_state_dict( - lambda: ReduceOnPlateauLR( - self.optimizer, - monitor='loss', - rule='less', - factor=0.01, - patience=5, - threshold=1e-4, - threshold_rule='rel', - cooldown=0, - min_value=0.0, - eps=1e-8), - lambda: ReduceOnPlateauLR( - self.optimizer, - monitor='loss_foo', - rule='greater', - factor=0.05, - patience=10, - threshold=1e-5, - threshold_rule='abs', - cooldown=5, - min_value=0.1, - eps=1e-9), + lambda: ReduceOnPlateauLR(self.optimizer, + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauLR(self.optimizer, + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), epochs=epochs, step_kwargs=metrics_list) def test_step_scheduler_convert_iterbased(self): # invalid epoch_length with self.assertRaises(AssertionError): - scheduler = StepLR.build_iter_from_epoch( - self.optimizer, gamma=0.1, step_size=2, epoch_length=-1) + scheduler = StepLR.build_iter_from_epoch(self.optimizer, + gamma=0.1, + step_size=2, + epoch_length=-1) # lr = 0.05 if epoch < 2 # lr = 0.005 if 2 <= epoch < 4 @@ -648,10 +651,14 @@ def test_step_scheduler_convert_iterbased(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepLR.build_iter_from_epoch( - self.optimizer, gamma=0.1, step_size=2, epoch_length=epoch_length) - self._test_scheduler_value( - scheduler, targets, epochs * epoch_length, param_name='lr') + scheduler = StepLR.build_iter_from_epoch(self.optimizer, + gamma=0.1, + step_size=2, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, + targets, + epochs * epoch_length, + param_name='lr') def test_multi_step_scheduler_convert_iterbased(self): # lr = 0.05 if epoch < 2 @@ -684,8 +691,10 @@ def test_constant_scheduler_convert_iterbased(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = ConstantLR.build_iter_from_epoch( - self.optimizer, factor=1.0 / 2, end=5, epoch_length=epoch_length) + scheduler = ConstantLR.build_iter_from_epoch(self.optimizer, + factor=1.0 / 2, + end=5, + epoch_length=epoch_length) self._test_scheduler_value(scheduler, targets, epochs * epoch_length) def test_linear_scheduler_convert_iterbased(self): @@ -698,16 +707,15 @@ def test_linear_scheduler_convert_iterbased(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs * epoch_length - iters) + single_targets = [x * 0.05 for x in interpolation + ] + [0.05] * (epochs * epoch_length - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearLR.build_iter_from_epoch( - self.optimizer, - start_factor=start_factor, - end=end, - epoch_length=epoch_length) + scheduler = LinearLR.build_iter_from_epoch(self.optimizer, + start_factor=start_factor, + end=end, + epoch_length=epoch_length) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler_convert_iterbased(self): @@ -755,20 +763,17 @@ def test_poly_scheduler_convert_iterbased(self): targets_layer1 = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets_layer2 = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] - scheduler = PolyLR.build_iter_from_epoch( - self.optimizer, - power=power, - eta_min=min_lr, - end=end, - epoch_length=epoch_length) + scheduler = PolyLR.build_iter_from_epoch(self.optimizer, + power=power, + eta_min=min_lr, + end=end, + epoch_length=epoch_length) self._test_scheduler_value(scheduler, targets, epochs=10) def test_multi_scheduler_without_overlap_linear_multi_step(self): @@ -779,10 +784,15 @@ def test_multi_scheduler_without_overlap_linear_multi_step(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearLR( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[3, 6], begin=5, end=12) + scheduler1 = LinearLR(self.optimizer, + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepLR(self.optimizer, + gamma=0.1, + milestones=[3, 6], + begin=5, + end=12) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_without_overlap_exp_cosine(self): @@ -800,8 +810,11 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingLR( - self.optimizer, T_max=5, eta_min=eta_min, begin=5, end=10) + scheduler2 = CosineAnnealingLR(self.optimizer, + T_max=5, + eta_min=eta_min, + begin=5, + end=10) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) @@ -813,10 +826,13 @@ def test_multi_scheduler_with_overlap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearLR( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepLR( - self.optimizer, gamma=0.1, milestones=[3, 6, 9]) + scheduler1 = LinearLR(self.optimizer, + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepLR(self.optimizer, + gamma=0.1, + milestones=[3, 6, 9]) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_gap(self): @@ -836,32 +852,33 @@ def test_multi_scheduler_with_gap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingLR( - self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15) + scheduler2 = CosineAnnealingLR(self.optimizer, + T_max=5, + eta_min=eta_min, + begin=10, + end=15) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_onecycle_lr(self): # test linear annealing target = [1., 13., 25., 21.5, 18., 14.5, 11., 7.5, 4., 0.5] - scheduler = OneCycleLR( - self.optimizer, - eta_max=25, - final_div_factor=2, - total_steps=10, - anneal_strategy='linear') + scheduler = OneCycleLR(self.optimizer, + eta_max=25, + final_div_factor=2, + total_steps=10, + anneal_strategy='linear') self._test_scheduler_value(scheduler, [target], 10) # test linear annealing three phase target = [1., 9., 17., 25., 17., 9., 1., 0.75, 0.5, 0.25] - scheduler = OneCycleLR( - self.optimizer, - eta_max=25, - div_factor=25, - total_steps=10, - anneal_strategy='linear', - pct_start=0.4, - final_div_factor=4, - three_phase=True) + scheduler = OneCycleLR(self.optimizer, + eta_max=25, + div_factor=25, + total_steps=10, + anneal_strategy='linear', + pct_start=0.4, + final_div_factor=4, + three_phase=True) self._test_scheduler_value(scheduler, [target], 10) # test cosine annealing @@ -878,6 +895,8 @@ def annealing_cos(start, end, pct): annealing_cos(25, 0.5, 5 / 7.0), annealing_cos(25, 0.5, 6 / 7.0), 0.5 ] - scheduler = OneCycleLR( - self.optimizer, eta_max=25, final_div_factor=2, total_steps=10) + scheduler = OneCycleLR(self.optimizer, + eta_max=25, + final_div_factor=2, + total_steps=10) self._test_scheduler_value(scheduler, [target], 10) diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index 60a9713ee2..171d8f0977 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -104,8 +104,9 @@ def test_resume(self): if epoch == 4: break scheduler.step() - scheduler2 = ExponentialMomentum( - self.optimizer, gamma=0.9, last_step=4) + scheduler2 = ExponentialMomentum(self.optimizer, + gamma=0.9, + last_step=4) for epoch in range(6): results.append(self.optimizer.param_groups[0]['momentum']) scheduler2.step() @@ -136,8 +137,10 @@ def call_sch_before_optim(): group['initial_momentum'] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepMomentum( - self.optimizer, gamma=0.1, step_size=3, last_step=10) + scheduler = StepMomentum(self.optimizer, + gamma=0.1, + step_size=3, + last_step=10) scheduler.step() self.optimizer.step() @@ -182,8 +185,11 @@ def test_effective_interval(self): # check invalid begin end with self.assertRaisesRegex(ValueError, 'end should be larger than begin'): - StepMomentum( - self.optimizer, gamma=0.1, step_size=3, begin=10, end=5) + StepMomentum(self.optimizer, + gamma=0.1, + step_size=3, + begin=10, + end=5) # momentum = 0.05 if epoch == 0 # momentum = 0.025 if epoch == 1 @@ -198,17 +204,16 @@ def test_effective_interval(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) + single_targets = [0.05] * begin + [ + x * 0.05 for x in interpolation + ] + [0.05] * (epochs - iters - begin) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearMomentum( - self.optimizer, - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + scheduler = LinearMomentum(self.optimizer, + start_factor=start_factor, + begin=begin, + end=begin + iters + 1) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) def _test_scheduler_value(self, @@ -261,12 +266,16 @@ def test_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepMomentum( - self.optimizer, gamma=0.1, step_size=3, verbose=True) + scheduler = StepMomentum(self.optimizer, + gamma=0.1, + step_size=3, + verbose=True) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = StepMomentum( - self.optimizer_with_betas, gamma=0.1, step_size=3, verbose=True) + scheduler = StepMomentum(self.optimizer_with_betas, + gamma=0.1, + step_size=3, + verbose=True) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) @@ -281,12 +290,14 @@ def test_multi_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[2, 5, 9]) + scheduler = MultiStepMomentum(self.optimizer, + gamma=0.1, + milestones=[2, 5, 9]) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = MultiStepMomentum( - self.optimizer_with_betas, gamma=0.1, milestones=[2, 5, 9]) + scheduler = MultiStepMomentum(self.optimizer_with_betas, + gamma=0.1, + milestones=[2, 5, 9]) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) @@ -305,8 +316,9 @@ def test_constant_scheduler(self): scheduler = ConstantMomentum(self.optimizer, factor=1.0 / 2, end=5) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = ConstantMomentum( - self.optimizer_with_betas, factor=1.0 / 2, end=5) + scheduler = ConstantMomentum(self.optimizer_with_betas, + factor=1.0 / 2, + end=5) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) @@ -330,19 +342,19 @@ def test_linear_scheduler(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) + single_targets = [x * 0.05 + for x in interpolation] + [0.05] * (epochs - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearMomentum( - self.optimizer, start_factor=start_factor, end=iters + 1) + scheduler = LinearMomentum(self.optimizer, + start_factor=start_factor, + end=iters + 1) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = LinearMomentum( - self.optimizer_with_betas, - start_factor=start_factor, - end=iters + 1) + scheduler = LinearMomentum(self.optimizer_with_betas, + start_factor=start_factor, + end=iters + 1) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) @@ -370,18 +382,22 @@ def test_cos_anneal_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = CosineAnnealingMomentum( - self.optimizer, T_max=t, eta_min=eta_min) + scheduler = CosineAnnealingMomentum(self.optimizer, + T_max=t, + eta_min=eta_min) self._test_scheduler_value(self.optimizer, scheduler, targets, epochs) - scheduler = CosineAnnealingMomentum( - self.optimizer_with_betas, T_max=t, eta_min=eta_min) + scheduler = CosineAnnealingMomentum(self.optimizer_with_betas, + T_max=t, + eta_min=eta_min) self._test_scheduler_value(self.optimizer_with_betas, scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingMomentum( - self.optimizer, begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingMomentum(self.optimizer, + begin=5, + end=100, + eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -392,41 +408,42 @@ def test_poly_scheduler(self): layer1_targets = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) layer2_targets = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [layer1_targets, layer2_targets] - scheduler = PolyMomentum( - self.optimizer, power=power, eta_min=min_lr, end=iters + 1) - self._test_scheduler_value( - self.optimizer, scheduler, targets, epochs=10) - - scheduler = PolyMomentum( - self.optimizer_with_betas, - power=power, - eta_min=min_lr, - end=iters + 1) - self._test_scheduler_value( - self.optimizer_with_betas, scheduler, targets, epochs=10) + scheduler = PolyMomentum(self.optimizer, + power=power, + eta_min=min_lr, + end=iters + 1) + self._test_scheduler_value(self.optimizer, + scheduler, + targets, + epochs=10) + + scheduler = PolyMomentum(self.optimizer_with_betas, + power=power, + eta_min=min_lr, + end=iters + 1) + self._test_scheduler_value(self.optimizer_with_betas, + scheduler, + targets, + epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): - CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + CosineRestartMomentum(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) with self.assertRaises(AssertionError): - CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + CosineRestartMomentum(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) single_targets = [ 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, 0.0086372, 0.0023872, 0.0023872 @@ -434,21 +451,23 @@ def test_cosine_restart_scheduler(self): targets = [ single_targets, [t * self.layer2_mult for t in single_targets] ] - scheduler = CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) - self._test_scheduler_value( - self.optimizer, scheduler, targets, epochs=10) - - scheduler = CosineRestartMomentum( - self.optimizer_with_betas, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) - self._test_scheduler_value( - self.optimizer_with_betas, scheduler, targets, epochs=10) + scheduler = CosineRestartMomentum(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) + self._test_scheduler_value(self.optimizer, + scheduler, + targets, + epochs=10) + + scheduler = CosineRestartMomentum(self.optimizer_with_betas, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) + self._test_scheduler_value(self.optimizer_with_betas, + scheduler, + targets, + epochs=10) def test_reduce_on_plateau_scheduler(self): # inherit _ParamScheduler but not call super().__init__(), @@ -474,8 +493,8 @@ def test_reduce_on_plateau_scheduler(self): ReduceOnPlateauMomentum(self.optimizer, factor=2.0) ReduceOnPlateauMomentum(self.optimizer, min_value=[0.1, 0.1]) with self.assertRaises(ValueError): - ReduceOnPlateauMomentum( - self.optimizer, min_value=[0.1, 0.1, 0.1, 0.1]) + ReduceOnPlateauMomentum(self.optimizer, + min_value=[0.1, 0.1, 0.1, 0.1]) with self.assertRaises(ValueError): ReduceOnPlateauMomentum(self.optimizer, threshold=-1.0) with self.assertRaises(ValueError): @@ -512,12 +531,11 @@ def _test_value(epochs, targets, metrics_list, optimizer, monitor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - optimizer, - scheduler, - targets, - epochs=epochs, - step_kwargs=metrics_list) + self._test_scheduler_value(optimizer, + scheduler, + targets, + epochs=epochs, + step_kwargs=metrics_list) # reset the state of optimizers self.optimizer = optim.SGD([{ @@ -700,44 +718,40 @@ def test_poly_scheduler_state_dict(self): def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: CosineRestartMomentum( - self.optimizer, - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), - lambda: CosineRestartMomentum( - self.optimizer, - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), + lambda: CosineRestartMomentum(self.optimizer, + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartMomentum(self.optimizer, + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), epochs=10) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] self._check_scheduler_state_dict( - lambda: ReduceOnPlateauMomentum( - self.optimizer, - monitor='loss', - rule='less', - factor=0.01, - patience=5, - threshold=1e-4, - threshold_rule='rel', - cooldown=0, - min_value=0.0, - eps=1e-8), - lambda: ReduceOnPlateauMomentum( - self.optimizer, - monitor='loss_foo', - rule='greater', - factor=0.05, - patience=10, - threshold=1e-5, - threshold_rule='abs', - cooldown=5, - min_value=0.1, - eps=1e-9), + lambda: ReduceOnPlateauMomentum(self.optimizer, + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauMomentum(self.optimizer, + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), epochs=epochs, step_kwargs=metrics_list) @@ -749,10 +763,15 @@ def test_multi_scheduler_without_overlap_linear_multi_step(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearMomentum( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[3, 6], begin=5, end=12) + scheduler1 = LinearMomentum(self.optimizer, + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepMomentum(self.optimizer, + gamma=0.1, + milestones=[3, 6], + begin=5, + end=12) self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) @@ -760,8 +779,10 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): # use Exp in the first 5 epochs and then use Cosine epochs = 10 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialMomentum( - self.optimizer, gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialMomentum(self.optimizer, + gamma=0.9, + begin=0, + end=5) eta_min = 1e-10 single_targets2 = [ @@ -772,8 +793,11 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingMomentum( - self.optimizer, T_max=5, eta_min=eta_min, begin=5, end=10) + scheduler2 = CosineAnnealingMomentum(self.optimizer, + T_max=5, + eta_min=eta_min, + begin=5, + end=10) self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) @@ -786,10 +810,13 @@ def test_multi_scheduler_with_overlap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearMomentum( - self.optimizer, start_factor=1 / 2, begin=0, end=5) - scheduler2 = MultiStepMomentum( - self.optimizer, gamma=0.1, milestones=[3, 6, 9]) + scheduler1 = LinearMomentum(self.optimizer, + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepMomentum(self.optimizer, + gamma=0.1, + milestones=[3, 6, 9]) self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) @@ -798,8 +825,10 @@ def test_multi_scheduler_with_gap(self): # no scheduler in the middle 5 epochs epochs = 15 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialMomentum( - self.optimizer, gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialMomentum(self.optimizer, + gamma=0.9, + begin=0, + end=5) eta_min = 1e-10 single_targets2 = [ @@ -811,8 +840,11 @@ def test_multi_scheduler_with_gap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingMomentum( - self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15) + scheduler2 = CosineAnnealingMomentum(self.optimizer, + T_max=5, + eta_min=eta_min, + begin=10, + end=15) self._test_scheduler_value(self.optimizer, [scheduler1, scheduler2], targets, epochs) diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index a13072dc6e..f9e39596b6 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -68,13 +68,15 @@ def test_base_scheduler_step(self): def test_invalid_optimizer(self): with self.assertRaisesRegex(TypeError, 'should be an Optimizer'): - StepParamScheduler( - 'invalid_optimizer', step_size=1, param_name='lr') + StepParamScheduler('invalid_optimizer', + step_size=1, + param_name='lr') def test_overwrite_optimzer_step(self): # raise warning if the counter in optimizer.step() is overwritten - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + scheduler = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9) def overwrite_fun(): pass @@ -88,18 +90,18 @@ def test_resume(self): # test invalid case: optimizer and scheduler are not both resumed with self.assertRaisesRegex(KeyError, "param 'initial_lr' is not specified"): - StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - last_step=10) + StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3, + last_step=10) # test manually resume with ``last_step`` instead of load_state_dict epochs = 10 targets = [0.05 * (0.9**x) for x in range(epochs)] - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + scheduler = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9) results = [] for epoch in range(5): @@ -111,8 +113,10 @@ def test_resume(self): if epoch == 4: break scheduler.step() - scheduler2 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, last_step=4) + scheduler2 = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9, + last_step=4) for epoch in range(6): results.append(self.optimizer.param_groups[0]['lr']) scheduler2.step() @@ -130,8 +134,10 @@ def test_scheduler_before_optim_warning(self): """Warns if scheduler is used before optimizer.""" def call_sch_before_optim(): - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, step_size=3) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3) scheduler.step() self.optimizer.step() @@ -144,12 +150,11 @@ def call_sch_before_optim(): group['initial_lr'] = 0.01 def call_sch_before_optim_resume(): - scheduler = StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - last_step=10) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3, + last_step=10) scheduler.step() self.optimizer.step() @@ -163,8 +168,10 @@ def test_get_last_value(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', step_size=3, gamma=0.1) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + step_size=3, + gamma=0.1) for epoch in range(epochs): result = scheduler.get_last_value() if isinstance(scheduler.optimizer, OptimWrapper) \ @@ -184,8 +191,10 @@ def test_get_last_value(self): def test_scheduler_step_count(self): iteration = 10 - scheduler = StepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, step_size=3) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3) self.assertEqual(scheduler.last_step, 0) target = [i + 1 for i in range(iteration)] step_counts = [] @@ -199,13 +208,12 @@ def test_effective_interval(self): # check invalid begin end with self.assertRaisesRegex(ValueError, 'end should be larger than begin'): - StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - begin=10, - end=5) + StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3, + begin=10, + end=5) # lr = 0.05 if epoch == 0 # lr = 0.025 if epoch == 1 @@ -220,24 +228,24 @@ def test_effective_interval(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [0.05] * begin + [x * 0.05 - for x in interpolation] + [0.05] * ( - epochs - iters - begin) + single_targets = [0.05] * begin + [ + x * 0.05 for x in interpolation + ] + [0.05] * (epochs - iters - begin) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=start_factor, - begin=begin, - end=begin + iters + 1) + scheduler = LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=start_factor, + begin=begin, + end=begin + iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def test_param_name(self): with self.assertRaises(KeyError): - StepParamScheduler( - self.optimizer, param_name='invalid_name', step_size=10) + StepParamScheduler(self.optimizer, + param_name='invalid_name', + step_size=10) def _test_scheduler_value(self, schedulers, @@ -280,12 +288,11 @@ def test_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - step_size=3, - verbose=True) + scheduler = StepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + step_size=3, + verbose=True) self._test_scheduler_value(scheduler, targets, epochs) # momentum = 0.01 if epoch < 2 @@ -295,10 +302,14 @@ def test_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = StepParamScheduler( - self.optimizer, param_name='momentum', gamma=0.1, step_size=2) - self._test_scheduler_value( - scheduler, targets, epochs, param_name='momentum') + scheduler = StepParamScheduler(self.optimizer, + param_name='momentum', + gamma=0.1, + step_size=2) + self._test_scheduler_value(scheduler, + targets, + epochs, + param_name='momentum') def test_multi_step_scheduler(self): # lr = 0.05 if epoch < 2 @@ -311,8 +322,10 @@ def test_multi_step_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = MultiStepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, milestones=[2, 5, 9]) + scheduler = MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[2, 5, 9]) self._test_scheduler_value(scheduler, targets, epochs) def test_constant_scheduler(self): @@ -327,23 +340,33 @@ def test_constant_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = ConstantParamScheduler( - self.optimizer, param_name='lr', factor=1.0 / 2, end=5) + scheduler = ConstantParamScheduler(self.optimizer, + param_name='lr', + factor=1.0 / 2, + end=5) self._test_scheduler_value(scheduler, targets, epochs) def test_linear_scheduler(self): with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', start_factor=10, end=900) + LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=10, + end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', start_factor=-1, end=900) + LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=-1, + end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', end_factor=1.001, end=900) + LinearParamScheduler(self.optimizer, + param_name='lr', + end_factor=1.001, + end=900) with self.assertRaises(ValueError): - LinearParamScheduler( - self.optimizer, param_name='lr', end_factor=-0.00001, end=900) + LinearParamScheduler(self.optimizer, + param_name='lr', + end_factor=-0.00001, + end=900) # lr = 0.025 if epoch == 0 # lr = 0.03125 if epoch == 1 # lr = 0.0375 if epoch == 2 @@ -355,16 +378,15 @@ def test_linear_scheduler(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs - iters) + single_targets = [x * 0.05 + for x in interpolation] + [0.05] * (epochs - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=start_factor, - end=iters + 1) + scheduler = LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=start_factor, + end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs) def test_exp_scheduler(self): @@ -373,18 +395,18 @@ def test_exp_scheduler(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9) + scheduler = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9) self._test_scheduler_value(scheduler, targets, epochs) def test_cos_anneal_scheduler(self): with self.assertRaises(AssertionError): - CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=10, - eta_min=0, - eta_min_ratio=0.1) + CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=10, + eta_min=0, + eta_min_ratio=0.1) epochs = 12 t = 10 eta_min = 5e-3 @@ -397,8 +419,10 @@ def test_cos_anneal_scheduler(self): for x in range(epochs) ] targets = [targets1, targets2] - scheduler = CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', T_max=t, eta_min=eta_min) + scheduler = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=t, + eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs) # Test `eta_min_ratio` @@ -413,16 +437,18 @@ def test_cos_anneal_scheduler(self): (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs) ] targets = [targets1, targets2] - scheduler = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=t, - eta_min_ratio=eta_min_ratio) + scheduler = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=t, + eta_min_ratio=eta_min_ratio) self._test_scheduler_value(scheduler, targets, epochs) # Test default `T_max` - scheduler = CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', begin=5, end=100, eta_min=eta_min) + scheduler = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + begin=5, + end=100, + eta_min=eta_min) self.assertEqual(scheduler.T_max, 100 - 5) def test_poly_scheduler(self): @@ -433,38 +459,33 @@ def test_poly_scheduler(self): targets_layer1 = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets_layer2 = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] - scheduler = PolyParamScheduler( - self.optimizer, - param_name='lr', - power=power, - eta_min=min_lr, - end=iters + 1) + scheduler = PolyParamScheduler(self.optimizer, + param_name='lr', + power=power, + eta_min=min_lr, + end=iters + 1) self._test_scheduler_value(scheduler, targets, epochs=10) def test_cosine_restart_scheduler(self): with self.assertRaises(AssertionError): - CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0, - eta_min_ratio=0.1) + CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0, + eta_min_ratio=0.1) with self.assertRaises(AssertionError): - CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5, 0.0], - eta_min=0) + CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5, 0.0], + eta_min=0) single_targets = [ 0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271, 0.0086372, 0.0023872, 0.0023872 @@ -474,12 +495,11 @@ def test_cosine_restart_scheduler(self): ] # Test with non-zero eta-min. - scheduler = CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0) + scheduler = CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0) self._test_scheduler_value(scheduler, targets, epochs=10) epochs = 10 @@ -494,12 +514,11 @@ def test_cosine_restart_scheduler(self): for x in range(epochs) ] targets = [targets1, targets2] - scheduler = CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[t], - restart_weights=[1], - eta_min=eta_min) + scheduler = CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[t], + restart_weights=[1], + eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs=10) def test_reduce_on_plateau_scheduler(self): @@ -510,34 +529,41 @@ def test_reduce_on_plateau_scheduler(self): with self.assertRaises(TypeError): ReduceOnPlateauParamScheduler('invalid_optimizer', param_name='lr') with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', begin=10, end=5) + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + begin=10, + end=5) with self.assertRaises(AssertionError): ReduceOnPlateauParamScheduler(self.optimizer, 'lr', by_epoch=False) for last_step in (1.5, -2): with self.assertRaises(AssertionError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', last_step=last_step) + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + last_step=last_step) with self.assertRaises(ValueError): ReduceOnPlateauParamScheduler(self.optimizer, 'lr', factor=2.0) - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', min_value=[0.1, 0.1]) + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + min_value=[0.1, 0.1]) with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', min_value=[0.1, 0.1, 0.1, 0.1]) + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + min_value=[0.1, 0.1, 0.1, 0.1]) with self.assertRaises(ValueError): ReduceOnPlateauParamScheduler(self.optimizer, 'lr', threshold=-1.0) with self.assertRaises(ValueError): ReduceOnPlateauParamScheduler(self.optimizer, 'lr', rule='foo') with self.assertRaises(ValueError): - ReduceOnPlateauParamScheduler( - self.optimizer, 'lr', threshold_rule='foo') + ReduceOnPlateauParamScheduler(self.optimizer, + 'lr', + threshold_rule='foo') # Test error in step method - scheduler = ReduceOnPlateauParamScheduler( - self.optimizer, param_name='lr', monitor='loss') + scheduler = ReduceOnPlateauParamScheduler(self.optimizer, + param_name='lr', + monitor='loss') assert scheduler.step() is None with self.assertRaises(TypeError): @@ -566,8 +592,10 @@ def _test_value(epochs, targets, metrics_list, monitor, rule, factor, cooldown=cooldown, min_value=min_value, ) - self._test_scheduler_value( - scheduler, targets, epochs=epochs, step_kwargs=metrics_list) + self._test_scheduler_value(scheduler, + targets, + epochs=epochs, + step_kwargs=metrics_list) # reset the state of optimizers self.optimizer = optim.SGD( @@ -703,15 +731,14 @@ def test_step_scheduler_state_dict(self): def test_multi_step_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - milestones=[2, 5, 9]), lambda: MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.01, - milestones=[1, 4, 6])) + lambda: MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[2, 5, 9]), + lambda: MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.01, + milestones=[1, 4, 6])) def test_exp_scheduler_state_dict(self): self._check_scheduler_state_dict( @@ -723,27 +750,24 @@ def test_exp_scheduler_state_dict(self): def test_cosine_scheduler_state_dict(self): epochs = 10 eta_min = 1e-10 - self._check_scheduler_state_dict( - lambda: CosineAnnealingParamScheduler( - self.optimizer, param_name='lr', T_max=epochs, eta_min=eta_min - ), - lambda: CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=epochs // 2, - eta_min=eta_min / 2), - epochs=epochs) + self._check_scheduler_state_dict(lambda: CosineAnnealingParamScheduler( + self.optimizer, param_name='lr', T_max=epochs, eta_min=eta_min), + lambda: CosineAnnealingParamScheduler( + self.optimizer, + param_name='lr', + T_max=epochs // 2, + eta_min=eta_min / 2), + epochs=epochs) def test_linear_scheduler_state_dict(self): epochs = 10 self._check_scheduler_state_dict( lambda: LinearParamScheduler( self.optimizer, param_name='lr', start_factor=1 / 3), - lambda: LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=0, - end_factor=0.3), + lambda: LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=0, + end_factor=0.3), epochs=epochs) def test_poly_scheduler_state_dict(self): @@ -756,48 +780,44 @@ def test_poly_scheduler_state_dict(self): def test_cosine_restart_scheduler_state_dict(self): self._check_scheduler_state_dict( - lambda: CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 5], - restart_weights=[1, 0.5], - eta_min=0), - lambda: CosineRestartParamScheduler( - self.optimizer, - param_name='lr', - periods=[4, 6], - restart_weights=[1, 0.5], - eta_min=0), + lambda: CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 5], + restart_weights=[1, 0.5], + eta_min=0), + lambda: CosineRestartParamScheduler(self.optimizer, + param_name='lr', + periods=[4, 6], + restart_weights=[1, 0.5], + eta_min=0), epochs=10) def test_reduce_on_plateau_scheduler_state_dict(self): epochs = 10 metrics_list = [dict(metrics=dict(loss=1.0)) for _ in range(epochs)] self._check_scheduler_state_dict( - lambda: ReduceOnPlateauParamScheduler( - self.optimizer, - param_name='lr', - monitor='loss', - rule='less', - factor=0.01, - patience=5, - threshold=1e-4, - threshold_rule='rel', - cooldown=0, - min_value=0.0, - eps=1e-8), - lambda: ReduceOnPlateauParamScheduler( - self.optimizer, - param_name='lr', - monitor='loss_foo', - rule='greater', - factor=0.05, - patience=10, - threshold=1e-5, - threshold_rule='abs', - cooldown=5, - min_value=0.1, - eps=1e-9), + lambda: ReduceOnPlateauParamScheduler(self.optimizer, + param_name='lr', + monitor='loss', + rule='less', + factor=0.01, + patience=5, + threshold=1e-4, + threshold_rule='rel', + cooldown=0, + min_value=0.0, + eps=1e-8), + lambda: ReduceOnPlateauParamScheduler(self.optimizer, + param_name='lr', + monitor='loss_foo', + rule='greater', + factor=0.05, + patience=10, + threshold=1e-5, + threshold_rule='abs', + cooldown=5, + min_value=0.1, + eps=1e-9), epochs=epochs, step_kwargs=metrics_list) @@ -825,8 +845,10 @@ def test_step_scheduler_convert_iterbased(self): gamma=0.1, step_size=2, epoch_length=epoch_length) - self._test_scheduler_value( - scheduler, targets, epochs * epoch_length, param_name='momentum') + self._test_scheduler_value(scheduler, + targets, + epochs * epoch_length, + param_name='momentum') def test_multi_step_scheduler_convert_iterbased(self): # lr = 0.05 if epoch < 2 @@ -878,8 +900,8 @@ def test_linear_scheduler_convert_iterbased(self): interpolation = [ start_factor + i * (1 - start_factor) / iters for i in range(iters) ] - single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( - epochs * epoch_length - iters) + single_targets = [x * 0.05 for x in interpolation + ] + [0.05] * (epochs * epoch_length - iters) targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] @@ -940,13 +962,11 @@ def test_poly_scheduler_convert_iterbased(self): targets_layer1 = [ min_lr + (0.05 - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets_layer2 = [ min_lr + (0.05 * self.layer2_mult - min_lr) * (1 - i / iters)**power for i in range(iters) - ] + [min_lr] * ( - epochs - iters) + ] + [min_lr] * (epochs - iters) targets = [targets_layer1, targets_layer2] scheduler = PolyParamScheduler.build_iter_from_epoch( self.optimizer, @@ -965,27 +985,28 @@ def test_multi_scheduler_without_overlap_linear_multi_step(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=1 / 2, - begin=0, - end=5) - scheduler2 = MultiStepParamScheduler( - self.optimizer, - param_name='lr', - gamma=0.1, - milestones=[3, 6], - begin=5, - end=12) + scheduler1 = LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[3, 6], + begin=5, + end=12) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_without_overlap_exp_cosine(self): # use Exp in the first 5 epochs and then use Cosine epochs = 10 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9, + begin=0, + end=5) eta_min = 1e-10 single_targets2 = [ @@ -996,13 +1017,12 @@ def test_multi_scheduler_without_overlap_exp_cosine(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=5, - eta_min=eta_min, - begin=5, - end=10) + scheduler2 = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=5, + eta_min=eta_min, + begin=5, + end=10) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) @@ -1014,14 +1034,15 @@ def test_multi_scheduler_with_overlap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler1 = LinearParamScheduler( - self.optimizer, - param_name='lr', - start_factor=1 / 2, - begin=0, - end=5) - scheduler2 = MultiStepParamScheduler( - self.optimizer, param_name='lr', gamma=0.1, milestones=[3, 6, 9]) + scheduler1 = LinearParamScheduler(self.optimizer, + param_name='lr', + start_factor=1 / 2, + begin=0, + end=5) + scheduler2 = MultiStepParamScheduler(self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[3, 6, 9]) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_multi_scheduler_with_gap(self): @@ -1029,8 +1050,11 @@ def test_multi_scheduler_with_gap(self): # no scheduler in the middle 5 epochs epochs = 15 single_targets1 = [0.05 * (0.9**x) for x in range(5)] - scheduler1 = ExponentialParamScheduler( - self.optimizer, param_name='lr', gamma=0.9, begin=0, end=5) + scheduler1 = ExponentialParamScheduler(self.optimizer, + param_name='lr', + gamma=0.9, + begin=0, + end=5) eta_min = 1e-10 single_targets2 = [ @@ -1042,32 +1066,33 @@ def test_multi_scheduler_with_gap(self): targets = [ single_targets, [x * self.layer2_mult for x in single_targets] ] - scheduler2 = CosineAnnealingParamScheduler( - self.optimizer, - param_name='lr', - T_max=5, - eta_min=eta_min, - begin=10, - end=15) + scheduler2 = CosineAnnealingParamScheduler(self.optimizer, + param_name='lr', + T_max=5, + eta_min=eta_min, + begin=10, + end=15) self._test_scheduler_value([scheduler1, scheduler2], targets, epochs) def test_onecycle_scheduler(self): # test invalid total steps with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, param_name='lr', total_steps=-1) + OneCycleParamScheduler(self.optimizer, + param_name='lr', + total_steps=-1) # test invalid pct_start with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, param_name='lr', total_steps=10, pct_start=-1) + OneCycleParamScheduler(self.optimizer, + param_name='lr', + total_steps=10, + pct_start=-1) # test invalid anneal_strategy with self.assertRaises(ValueError): - OneCycleParamScheduler( - self.optimizer, - param_name='lr', - total_steps=10, - anneal_strategy='a') + OneCycleParamScheduler(self.optimizer, + param_name='lr', + total_steps=10, + anneal_strategy='a') class TestParameterSchedulerOptimWrapper(TestParameterScheduler): diff --git a/tests/test_registry/test_build_functions.py b/tests/test_registry/test_build_functions.py index 80094ae107..b570b89fa8 100644 --- a/tests/test_registry/test_build_functions.py +++ b/tests/test_registry/test_build_functions.py @@ -90,19 +90,22 @@ def __init__(self, depth, stages=4): # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(stages=4))) # "type" defined using default_args cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(type='ResNet'))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 @@ -197,24 +200,22 @@ def test_build_scheduler_from_cfg(): from torch.optim import SGD model = nn.Conv2d(1, 1, 1) optimizer = SGD(model.parameters(), lr=0.1) - cfg = dict( - type='LinearParamScheduler', - optimizer=optimizer, - param_name='lr', - begin=0, - end=100) + cfg = dict(type='LinearParamScheduler', + optimizer=optimizer, + param_name='lr', + begin=0, + end=100) scheduler = PARAM_SCHEDULERS.build(cfg) assert scheduler.begin == 0 assert scheduler.end == 100 - cfg = dict( - type='LinearParamScheduler', - convert_to_iter_based=True, - optimizer=optimizer, - param_name='lr', - begin=0, - end=100, - epoch_length=10) + cfg = dict(type='LinearParamScheduler', + convert_to_iter_based=True, + optimizer=optimizer, + param_name='lr', + begin=0, + end=100, + epoch_length=10) scheduler = PARAM_SCHEDULERS.build(cfg) assert scheduler.begin == 0 diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index eb99b3dc8e..f339d37420 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -134,10 +134,9 @@ class BritishShorthair: # test `module` parameter, which is either None or a class # when the `register_module`` is called as a method rather than a # decorator, which must be a class - with pytest.raises( - TypeError, - match='module must be Callable,' - " but got "): + with pytest.raises(TypeError, + match='module must be Callable,' + " but got "): CATS.register_module(module='string') class SphynxCat: @@ -183,15 +182,17 @@ def _build_registry(self): registries.append(DOGS) HOUNDS = Registry('hounds', parent=DOGS, scope='hound') registries.append(HOUNDS) - LITTLE_HOUNDS = Registry( - 'little hounds', parent=HOUNDS, scope='little_hound') + LITTLE_HOUNDS = Registry('little hounds', + parent=HOUNDS, + scope='little_hound') registries.append(LITTLE_HOUNDS) MID_HOUNDS = Registry('mid hounds', parent=HOUNDS, scope='mid_hound') registries.append(MID_HOUNDS) SAMOYEDS = Registry('samoyeds', parent=DOGS, scope='samoyed') registries.append(SAMOYEDS) - LITTLE_SAMOYEDS = Registry( - 'little samoyeds', parent=SAMOYEDS, scope='little_samoyed') + LITTLE_SAMOYEDS = Registry('little samoyeds', + parent=SAMOYEDS, + scope='little_samoyed') registries.append(LITTLE_SAMOYEDS) return registries @@ -408,14 +409,14 @@ class Beagle: # test `default_scope` # switch the current registry to another registry - DefaultScope.get_instance( - f'test-{time.time()}', scope_name='mid_hound') + DefaultScope.get_instance(f'test-{time.time()}', + scope_name='mid_hound') dog = LITTLE_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) # `default_scope` can not be found - DefaultScope.get_instance( - f'test2-{time.time()}', scope_name='scope-not-found') + DefaultScope.get_instance(f'test2-{time.time()}', + scope_name='scope-not-found') dog = MID_HOUNDS.build(b_cfg) assert isinstance(dog, Beagle) @@ -431,20 +432,18 @@ class YourSamoyed: pass s_cfg = cfg_type( - dict( - _scope_='samoyed', - type='MySamoyed', - friend=dict(type='hound.BloodHound'))) + dict(_scope_='samoyed', + type='MySamoyed', + friend=dict(type='hound.BloodHound'))) dog = DOGS.build(s_cfg) assert isinstance(dog, MySamoyed) assert isinstance(dog.friend, BloodHound) assert DefaultScope.get_current_instance().scope_name != 'samoyed' s_cfg = cfg_type( - dict( - _scope_='samoyed', - type='MySamoyed', - friend=dict(type='YourSamoyed'))) + dict(_scope_='samoyed', + type='MySamoyed', + friend=dict(type='YourSamoyed'))) dog = DOGS.build(s_cfg) assert isinstance(dog, MySamoyed) assert isinstance(dog.friend, YourSamoyed) @@ -456,9 +455,9 @@ class YourSamoyed: lambda_cfg = cfg_type(dict(type='lambda_dog', name='unknown')) assert DOGS.build(lambda_cfg) == 'unknown' - DOGS.register_module( - name='patial dog', - module=functools.partial(lambda_dog, name='patial')) + DOGS.register_module(name='patial dog', + module=functools.partial(lambda_dog, + name='patial')) unknown_cfg = cfg_type(dict(type='patial dog')) assert DOGS.build(unknown_cfg) == 'patial' @@ -474,8 +473,8 @@ def test_switch_scope_and_registry(self): # | | | # HOUNDS (hound) SAMOYEDS (samoyed) CHIHUAHUA (chihuahua) - DefaultScope.get_instance( - f'scope_{time.time()}', scope_name='chihuahua') + DefaultScope.get_instance(f'scope_{time.time()}', + scope_name='chihuahua') assert DefaultScope.get_current_instance().scope_name == 'chihuahua' # Test switch scope and get target registry. @@ -597,19 +596,22 @@ def __init__(self, depth, stages=4): # cfg or default_args should contain the key "type" with pytest.raises(KeyError, match='must contain the key "type"'): cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(stages=4))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(stages=4))) # "type" defined using default_args cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type='ResNet'))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(type='ResNet'))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 cfg = cfg_type(dict(depth=50)) - model = build_from_cfg( - cfg, BACKBONES, default_args=cfg_type(dict(type=ResNet))) + model = build_from_cfg(cfg, + BACKBONES, + default_args=cfg_type(dict(type=ResNet))) assert isinstance(model, ResNet) assert model.depth == 50 and model.stages == 4 diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index 4655a4c5da..844dd4d80b 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -211,14 +211,16 @@ def __init__(self): # add prefix torch.save(model.state_dict(), checkpoint_path) - state_dict = load_checkpoint( - pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')]) + state_dict = load_checkpoint(pmodel, + checkpoint_path, + revise_keys=[(r'^', 'backbone.')]) for key in pmodel.backbone.state_dict().keys(): assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key]) # strip prefix torch.save(pmodel.state_dict(), checkpoint_path) - state_dict = load_checkpoint( - model, checkpoint_path, revise_keys=[(r'^backbone\.', '')]) + state_dict = load_checkpoint(model, + checkpoint_path, + revise_keys=[(r'^backbone\.', '')]) for key in state_dict.keys(): key_stripped = re.sub(r'^backbone\.', '', key) @@ -366,17 +368,19 @@ def test_save_checkpoint(tmp_path): save_checkpoint(model.state_dict(), filename) filename = str(tmp_path / 'checkpoint2.pth') - checkpoint = dict( - model=model.state_dict(), optimizer=optimizer.state_dict()) + checkpoint = dict(model=model.state_dict(), + optimizer=optimizer.state_dict()) save_checkpoint(checkpoint, filename) filename = str(tmp_path / 'checkpoint3.pth') - save_checkpoint( - model.state_dict(), filename, backend_args={'backend': 'local'}) + save_checkpoint(model.state_dict(), + filename, + backend_args={'backend': 'local'}) filename = str(tmp_path / 'checkpoint4.pth') - save_checkpoint( - model.state_dict(), filename, file_client_args={'backend': 'disk'}) + save_checkpoint(model.state_dict(), + filename, + file_client_args={'backend': 'disk'}) # 2. save to petrel oss with patch.object(PetrelBackend, 'put') as mock_method: @@ -386,10 +390,9 @@ def test_save_checkpoint(tmp_path): with patch.object(PetrelBackend, 'put') as mock_method: filename = 's3://path//of/your/checkpoint2.pth' - save_checkpoint( - model.state_dict(), - filename, - file_client_args={'backend': 'petrel'}) + save_checkpoint(model.state_dict(), + filename, + file_client_args={'backend': 'petrel'}) mock_method.assert_called() diff --git a/tests/test_runner/test_log_processor.py b/tests/test_runner/test_log_processor.py index d7fae5722a..b48b218c9e 100644 --- a/tests/test_runner/test_log_processor.py +++ b/tests/test_runner/test_log_processor.py @@ -16,8 +16,9 @@ class TestLogProcessor(RunnerTestCase): def test_init(self): - log_processor = LogProcessor( - window_size=10, by_epoch=True, custom_cfg=None) + log_processor = LogProcessor(window_size=10, + by_epoch=True, + custom_cfg=None) assert log_processor.by_epoch assert log_processor.window_size == 10 assert log_processor.custom_cfg == [] @@ -81,8 +82,8 @@ def test_parse_windows_size(self): # yapf: enable def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy): # Prepare LoggerHook - log_processor = LogProcessor( - by_epoch=by_epoch, log_with_hierarchy=log_with_hierarchy) + log_processor = LogProcessor(by_epoch=by_epoch, + log_with_hierarchy=log_with_hierarchy) log_processor._get_max_memory = MagicMock(return_value='100') eta = 40 self.runner.message_hub.update_info('eta', eta) @@ -157,15 +158,15 @@ def test_get_log_after_iter(self, by_epoch, mode, log_with_hierarchy): [False, 'val', False], [True, 'test', True], [False, 'test', False])) def test_log_val(self, by_epoch, mode, log_with_hierarchy): # Prepare LoggerHook - log_processor = LogProcessor( - by_epoch=by_epoch, log_with_hierarchy=log_with_hierarchy) + log_processor = LogProcessor(by_epoch=by_epoch, + log_with_hierarchy=log_with_hierarchy) # Prepare validation information. scalar_logs = dict(accuracy=0.9, data_time=1.0) - non_scalar_logs = dict( - recall={ - 'cat': 1, - 'dog': 0 - }, cm=torch.tensor([1, 2, 3])) + non_scalar_logs = dict(recall={ + 'cat': 1, + 'dog': 0 + }, + cm=torch.tensor([1, 2, 3])) log_processor._collect_scalars = MagicMock(return_value=scalar_logs) log_processor._collect_non_scalars = MagicMock( return_value=non_scalar_logs) @@ -207,8 +208,9 @@ def test_collect_scalars(self): 'val/metric': history_metric_buffer } self.runner.message_hub._log_scalars = log_scalars - tag = log_processor._collect_scalars( - copy.deepcopy(custom_cfg), self.runner, mode='train') + tag = log_processor._collect_scalars(copy.deepcopy(custom_cfg), + self.runner, + mode='train') # Training key in tag. assert list(tag.keys()) == ['time', 'loss_cls', 'time_max'] # Test statistics lr with `current`, loss and time with 'mean' @@ -217,17 +219,17 @@ def test_collect_scalars(self): assert tag['loss_cls'] == loss_cls_scalars[-10:].mean() # Validation key in tag - tag = log_processor._collect_scalars( - copy.deepcopy(custom_cfg), self.runner, mode='val') + tag = log_processor._collect_scalars(copy.deepcopy(custom_cfg), + self.runner, + mode='val') assert list(tag.keys()) == ['metric'] assert tag['metric'] == metric_scalars[-1] # reserve_prefix=True - tag = log_processor._collect_scalars( - copy.deepcopy(custom_cfg), - self.runner, - mode='train', - reserve_prefix=True) + tag = log_processor._collect_scalars(copy.deepcopy(custom_cfg), + self.runner, + mode='train', + reserve_prefix=True) assert list( tag.keys()) == ['train/time', 'train/loss_cls', 'train/time_max'] # Test statistics lr with `current`, loss and time with 'mean' @@ -315,31 +317,27 @@ def setUp(self): def test_with_runner(self): cfg = self.epoch_based_cfg.copy() - cfg.log_processor = dict( - custom_cfg=[ - dict( - data_src='time', - window_size='epoch', - log_name='iter_time', - method_name='mean') - ], - log_with_hierarchy=True) + cfg.log_processor = dict(custom_cfg=[ + dict(data_src='time', + window_size='epoch', + log_name='iter_time', + method_name='mean') + ], + log_with_hierarchy=True) runner = self.build_runner(cfg) runner.train() runner.val() runner.test() cfg = self.iter_based_cfg.copy() - cfg.log_processor = dict( - by_epoch=False, - custom_cfg=[ - dict( - data_src='time', - window_size=100, - log_name='iter_time', - method_name='mean') - ], - log_with_hierarchy=True) + cfg.log_processor = dict(by_epoch=False, + custom_cfg=[ + dict(data_src='time', + window_size=100, + log_name='iter_time', + method_name='mean') + ], + log_with_hierarchy=True) runner = self.build_runner(cfg) runner.train() runner.val() diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index e7668054bb..b4801710c1 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -389,40 +389,42 @@ def setUp(self): epoch_based_cfg = dict( model=dict(type='ToyModel'), work_dir=self.temp_dir, - train_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=3, - num_workers=0), - val_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), - test_dataloader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=False), - batch_size=3, - num_workers=0), + train_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), + test_dataloader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', + shuffle=False), + batch_size=3, + num_workers=0), auto_scale_lr=dict(base_batch_size=16, enable=False), - optim_wrapper=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), + optim_wrapper=dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01)), param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), val_evaluator=dict(type='ToyMetric1'), test_evaluator=dict(type='ToyMetric1'), - train_cfg=dict( - by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), + train_cfg=dict(by_epoch=True, + max_epochs=3, + val_interval=1, + val_begin=1), val_cfg=dict(), test_cfg=dict(), custom_hooks=[], - default_hooks=dict( - runtime_info=dict(type='RuntimeInfoHook'), - timer=dict(type='IterTimerHook'), - logger=dict(type='LoggerHook'), - param_scheduler=dict(type='ParamSchedulerHook'), - checkpoint=dict( - type='CheckpointHook', interval=1, by_epoch=True), - sampler_seed=dict(type='DistSamplerSeedHook')), + default_hooks=dict(runtime_info=dict(type='RuntimeInfoHook'), + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook'), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', + interval=1, + by_epoch=True), + sampler_seed=dict(type='DistSamplerSeedHook')), data_preprocessor=None, launcher='none', env_cfg=dict(dist_cfg=dict(backend='nccl')), @@ -620,23 +622,25 @@ def test_init(self): train_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn) val_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn) test_dataloader = DataLoader(ToyDataset(), collate_fn=collate_fn) - runner = Runner( - model=model, - work_dir=self.temp_dir, - train_cfg=dict( - by_epoch=True, max_epochs=3, val_interval=1, val_begin=1), - train_dataloader=train_dataloader, - optim_wrapper=optim_wrapper, - param_scheduler=MultiStepLR(optim_wrapper, milestones=[1, 2]), - val_cfg=dict(), - val_dataloader=val_dataloader, - val_evaluator=[ToyMetric1()], - test_cfg=dict(), - test_dataloader=test_dataloader, - test_evaluator=[ToyMetric1()], - default_hooks=dict(param_scheduler=toy_hook), - custom_hooks=[toy_hook2], - experiment_name='test_init14') + runner = Runner(model=model, + work_dir=self.temp_dir, + train_cfg=dict(by_epoch=True, + max_epochs=3, + val_interval=1, + val_begin=1), + train_dataloader=train_dataloader, + optim_wrapper=optim_wrapper, + param_scheduler=MultiStepLR(optim_wrapper, + milestones=[1, 2]), + val_cfg=dict(), + val_dataloader=val_dataloader, + val_evaluator=[ToyMetric1()], + test_cfg=dict(), + test_dataloader=test_dataloader, + test_evaluator=[ToyMetric1()], + default_hooks=dict(param_scheduler=toy_hook), + custom_hooks=[toy_hook2], + experiment_name='test_init14') runner.train() runner.test() @@ -693,8 +697,8 @@ def test_init(self): # 6.6 Test initializing with `_ParameterScheduler`. optimizer = SGD(nn.Linear(1, 1).parameters(), lr=0.1) - cfg.param_scheduler = MultiStepLR( - milestones=[1, 2], optimizer=optimizer) + cfg.param_scheduler = MultiStepLR(milestones=[1, 2], + optimizer=optimizer) cfg.experiment_name = 'test_init22' Runner(**cfg) @@ -706,9 +710,10 @@ def test_init(self): Runner(**cfg) # 6.8 Test initializing with 2 `_ParameterScheduler` for 2 optimizers. - cfg.param_scheduler = dict( - linear1=MultiStepLR(milestones=[1, 2], optimizer=optimizer), - linear2=MultiStepLR(milestones=[1, 2], optimizer=optimizer)) + cfg.param_scheduler = dict(linear1=MultiStepLR(milestones=[1, 2], + optimizer=optimizer), + linear2=MultiStepLR(milestones=[1, 2], + optimizer=optimizer)) cfg.experiment_name = 'test_init24' Runner(**cfg) @@ -747,9 +752,8 @@ def test_dump_config(self): temp_config_file = tempfile.NamedTemporaryFile( dir=temp_config_dir, suffix='.py', delete=False) temp_config_file.close() - file_cfg = Config( - self.epoch_based_cfg._cfg_dict, - filename=temp_config_file.name) + file_cfg = Config(self.epoch_based_cfg._cfg_dict, + filename=temp_config_file.name) file_cfg.experiment_name = f'test_dump2{idx}' runner = Runner.from_cfg(cfg=file_cfg) assert osp.exists( @@ -813,9 +817,8 @@ def test_build_visualizer(self): runner.visualizer.instance_name) # input is a Visualizer object - self.assertEqual( - id(runner.build_visualizer(runner.visualizer)), - id(runner.visualizer)) + self.assertEqual(id(runner.build_visualizer(runner.visualizer)), + id(runner.visualizer)) # input is a dict visualizer_cfg = dict(type='Visualizer', name='test_build_visualizer2') @@ -835,8 +838,9 @@ def test_build_visualizer(self): runner.build_visualizer('invalid-type') def test_default_scope(self): - TOY_SCHEDULERS = Registry( - 'parameter scheduler', parent=PARAM_SCHEDULERS, scope='toy') + TOY_SCHEDULERS = Registry('parameter scheduler', + parent=PARAM_SCHEDULERS, + scope='toy') @TOY_SCHEDULERS.register_module(force=True) class ToyScheduler(MultiStepLR): @@ -844,8 +848,8 @@ class ToyScheduler(MultiStepLR): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.epoch_based_cfg.param_scheduler = dict( - type='ToyScheduler', milestones=[1, 2]) + self.epoch_based_cfg.param_scheduler = dict(type='ToyScheduler', + milestones=[1, 2]) self.epoch_based_cfg.default_scope = 'toy' cfg = copy.deepcopy(self.epoch_based_cfg) @@ -1019,20 +1023,21 @@ def test_build_optim_wrapper(self): # "constructor" are not in optimizer optimizer1 = SGD(runner.model.linear1.parameters(), lr=0.01) optim_wrapper1 = OptimWrapper(optimizer1) - optim_wrapper2 = dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.01)) + optim_wrapper2 = dict(type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.01)) optim_cfg = dict(key1=optim_wrapper1, key2=optim_wrapper2) with self.assertRaisesRegex(ValueError, 'each item mush be an optimizer object'): runner.build_optim_wrapper(optim_cfg) # 2.3 input is a dict which contains multiple configs - optim_wrapper_cfg = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + optim_wrapper_cfg = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, OptimWrapperDict) self.assertIsInstance(optim_wrapper['linear1'].optimizer, SGD) @@ -1049,8 +1054,9 @@ def test_build_optim_wrapper(self): # Specify the type of optimizer wrapper model = nn.Linear(1, 1) optimizer = SGD(model.parameters(), lr=0.1) - optim_wrapper_cfg = dict( - optimizer=optimizer, type='ToyOptimWrapper', accumulative_counts=2) + optim_wrapper_cfg = dict(optimizer=optimizer, + type='ToyOptimWrapper', + accumulative_counts=2) optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) self.assertIsInstance(optim_wrapper, ToyOptimWrapper) self.assertIs(optim_wrapper.optimizer, optimizer) @@ -1065,10 +1071,10 @@ def test_build_param_scheduler(self): # `build_param_scheduler` cfg = dict(type='MultiStepLR', milestones=[1, 2]) runner.optim_wrapper = dict( - key1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - key2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), + key1=dict(type='OptimWrapper', optimizer=dict(type='SGD', + lr=0.01)), + key2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.02)), ) with self.assertRaisesRegex(AssertionError, 'should be called before'): runner.build_param_scheduler(cfg) @@ -1129,12 +1135,11 @@ def test_build_param_scheduler(self): self.assertEqual(len(param_schedulers['key2']), 2) # 4. test multiple optimizers and multiple parameter shceduers - cfg = dict( - key1=dict(type='MultiStepLR', milestones=[1, 2]), - key2=[ - dict(type='MultiStepLR', milestones=[1, 2]), - dict(type='StepLR', step_size=1) - ]) + cfg = dict(key1=dict(type='MultiStepLR', milestones=[1, 2]), + key2=[ + dict(type='MultiStepLR', milestones=[1, 2]), + dict(type='StepLR', step_size=1) + ]) param_schedulers = runner.build_param_scheduler(cfg) self.assertIsInstance(param_schedulers, dict) self.assertEqual(len(param_schedulers), 2) @@ -1146,16 +1151,16 @@ def test_build_param_scheduler(self): dict(type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01))) # 5.1 train loop should be built before converting scheduler - cfg = dict( - type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True) + cfg = dict(type='MultiStepLR', + milestones=[1, 2], + convert_to_iter_based=True) # 5.2 convert epoch-based to iter-based scheduler - cfg = dict( - type='MultiStepLR', - milestones=[1, 2], - begin=1, - end=7, - convert_to_iter_based=True) + cfg = dict(type='MultiStepLR', + milestones=[1, 2], + begin=1, + end=7, + convert_to_iter_based=True) runner._train_loop = runner.build_train_loop(runner.train_loop) param_schedulers = runner.build_param_scheduler(cfg) self.assertFalse(param_schedulers[0].by_epoch) @@ -1170,11 +1175,10 @@ def test_build_param_scheduler(self): # runner.max_epochs = 3 self.assertEqual(param_schedulers[0].end, 3) - cfg = dict( - type='MultiStepLR', - milestones=[1, 2], - begin=1, - convert_to_iter_based=True) + cfg = dict(type='MultiStepLR', + milestones=[1, 2], + begin=1, + convert_to_iter_based=True) param_schedulers = runner.build_param_scheduler(cfg) self.assertFalse(param_schedulers[0].by_epoch) self.assertEqual(param_schedulers[0].begin, 4) @@ -1217,12 +1221,11 @@ def test_build_evaluator(self): self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu') # test build a customize evaluator - evaluator = dict( - type='ToyEvaluator', - metrics=[ - dict(type='ToyMetric1', collect_device='cpu'), - dict(type='ToyMetric2', collect_device='gpu') - ]) + evaluator = dict(type='ToyEvaluator', + metrics=[ + dict(type='ToyMetric1', collect_device='cpu'), + dict(type='ToyMetric2', collect_device='gpu') + ]) _evaluator = runner.build_evaluator(evaluator) self.assertIsInstance(runner.build_evaluator(evaluator), ToyEvaluator) self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu') @@ -1237,11 +1240,10 @@ def test_build_dataloader(self): cfg.experiment_name = 'test_build_dataloader' runner = Runner.from_cfg(cfg) - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - batch_size=1, - num_workers=0) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=1, + num_workers=0) seed = np.random.randint(2**31) dataloader = runner.build_dataloader(cfg, seed=seed) self.assertIsInstance(dataloader, DataLoader) @@ -1250,28 +1252,27 @@ def test_build_dataloader(self): self.assertEqual(dataloader.sampler.seed, seed) # diff_rank_seed is True - dataloader = runner.build_dataloader( - cfg, seed=seed, diff_rank_seed=True) + dataloader = runner.build_dataloader(cfg, + seed=seed, + diff_rank_seed=True) self.assertNotEqual(dataloader.sampler.seed, seed) # custom worker_init_fn - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), - batch_size=1, - num_workers=2) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2) dataloader = runner.build_dataloader(cfg) self.assertIs(dataloader.worker_init_fn.func, custom_worker_init) # collate_fn is a dict - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), - batch_size=1, - num_workers=2, - collate_fn=dict(type='pseudo_collate')) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn=dict(type='pseudo_collate')) dataloader = runner.build_dataloader(cfg) self.assertIsInstance(dataloader.collate_fn, partial) @@ -1279,36 +1280,33 @@ def test_build_dataloader(self): def custom_collate(data_batch): return data_batch - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), - batch_size=1, - num_workers=2, - collate_fn=custom_collate) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn=custom_collate) dataloader = runner.build_dataloader(cfg) self.assertIs(dataloader.collate_fn, custom_collate) # collate_fn is a invalid value with self.assertRaisesRegex( TypeError, 'collate_fn should be a dict or callable object'): - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - worker_init_fn=dict(type='custom_worker_init'), - batch_size=1, - num_workers=2, - collate_fn='collate_fn') + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + worker_init_fn=dict(type='custom_worker_init'), + batch_size=1, + num_workers=2, + collate_fn='collate_fn') dataloader = runner.build_dataloader(cfg) self.assertIsInstance(dataloader.collate_fn, partial) # num_batch_per_epoch is not None - cfg = dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='DefaultSampler', shuffle=True), - collate_fn=dict(type='default_collate'), - batch_size=3, - num_workers=2, - num_batch_per_epoch=2) + cfg = dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + batch_size=3, + num_workers=2, + num_batch_per_epoch=2) dataloader = runner.build_dataloader(cfg) self.assertEqual(len(dataloader.dataset), 6) @@ -1432,8 +1430,8 @@ def test_build_log_processor(self): self.assertIsInstance(log_processor, LogProcessor) # input is a LogProcessor object - self.assertEqual( - id(runner.build_log_processor(log_processor)), id(log_processor)) + self.assertEqual(id(runner.build_log_processor(log_processor)), + id(log_processor)) # test custom validation log_processor cfg = dict(type='CustomLogProcessor') @@ -1525,8 +1523,10 @@ def before_val_iter(self, runner, batch_idx, data_batch=None): cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train3' cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.train_cfg = dict( - by_epoch=False, max_iters=12, val_interval=4, val_begin=4) + cfg.train_cfg = dict(by_epoch=False, + max_iters=12, + val_interval=4, + val_begin=4) runner = Runner.from_cfg(cfg) runner.train() @@ -1562,11 +1562,13 @@ def before_val_iter(self, runner, batch_idx, data_batch=None): cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train4' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) + cfg.train_dataloader.sampler = dict(type='DefaultSampler', + shuffle=True) cfg.custom_hooks = [dict(type='TestIterHook', priority=50)] - cfg.train_cfg = dict( - by_epoch=False, max_iters=12, val_interval=4, val_begin=4) + cfg.train_cfg = dict(by_epoch=False, + max_iters=12, + val_interval=4, + val_begin=4) runner = Runner.from_cfg(cfg) # Warning should be raised since the sampler is not InfiniteSampler. with self.assertLogs(MMLogger.get_current_instance(), level='WARNING'): @@ -1610,16 +1612,15 @@ def before_train_iter(self, runner, batch_idx, data_batch=None): cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train5' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) + cfg.train_dataloader.sampler = dict(type='DefaultSampler', + shuffle=True) cfg.custom_hooks = [ dict(type='TestIterDynamicIntervalHook', priority=50) ] - cfg.train_cfg = dict( - by_epoch=False, - max_iters=max_iters, - val_interval=interval, - dynamic_intervals=dynamic_intervals) + cfg.train_cfg = dict(by_epoch=False, + max_iters=max_iters, + val_interval=interval, + dynamic_intervals=dynamic_intervals) runner = Runner.from_cfg(cfg) runner.train() for result, target, in zip(iter_results, iter_targets): @@ -1647,16 +1648,15 @@ def before_train_epoch(self, runner): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_train6' - cfg.train_dataloader.sampler = dict( - type='DefaultSampler', shuffle=True) + cfg.train_dataloader.sampler = dict(type='DefaultSampler', + shuffle=True) cfg.custom_hooks = [ dict(type='TestEpochDynamicIntervalHook', priority=50) ] - cfg.train_cfg = dict( - by_epoch=True, - max_epochs=max_epochs, - val_interval=interval, - dynamic_intervals=dynamic_intervals) + cfg.train_cfg = dict(by_epoch=True, + max_epochs=max_epochs, + val_interval=interval, + dynamic_intervals=dynamic_intervals) runner = Runner.from_cfg(cfg) runner.train() for result, target, in zip(epoch_results, epoch_targets): @@ -1687,12 +1687,13 @@ def init_weights(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_train8' cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2]) - cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + cfg.optim_wrapper = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') cfg.model = dict(type='ToyGANModel') runner = runner.from_cfg(cfg) runner.train() @@ -1701,12 +1702,13 @@ def init_weights(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_train8.1.1' cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2]) - cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + cfg.optim_wrapper = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') cfg.model = dict(type='ToyGANModel') runner = runner.from_cfg(cfg) runner.train() @@ -1759,8 +1761,8 @@ def init_weights(self): # 10.3 Test build dataloader with custom worker_init function cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_train10.3' - cfg.train_dataloader.update( - worker_init_fn=dict(type='custom_worker_init')) + cfg.train_dataloader.update(worker_init_fn=dict( + type='custom_worker_init')) runner = Runner.from_cfg(cfg) runner.train() @@ -1835,9 +1837,8 @@ def train_step(self, *args, **kwargs): runner.train() self.assertEqual(runner.iter, 3 * 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, + reason='torch.compile is not valid, please install PyTorch>=2.0.0') def test_train_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) @@ -1947,9 +1948,8 @@ def after_val_iter(self, runner.val() self.assertEqual(val_result, 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, + reason='torch.compile is not valid, please install PyTorch>=2.0.0') def test_val_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) @@ -2052,9 +2052,8 @@ def after_test_iter(self, runner.test() self.assertEqual(test_result, 2) - @skipIf( - SKIP_TEST_COMPILE, - reason='torch.compile is not valid, please install PyTorch>=2.0.0') + @skipIf(SKIP_TEST_COMPILE, + reason='torch.compile is not valid, please install PyTorch>=2.0.0') def test_test_with_compile(self): # 1. test with simple configuration cfg = copy.deepcopy(self.epoch_based_cfg) @@ -2088,8 +2087,8 @@ def test_register_hook(self): self.assertEqual(len(runner._hooks), 1) self.assertTrue(isinstance(runner._hooks[0], IterTimerHook)) # default priority of `IterTimerHook` is 'NORMAL' - self.assertEqual( - get_priority(runner._hooks[0].priority), get_priority('NORMAL')) + self.assertEqual(get_priority(runner._hooks[0].priority), + get_priority('NORMAL')) runner._hooks = [] # 1.2.1 `hook` is a dict and contains `priority` field @@ -2098,9 +2097,8 @@ def test_register_hook(self): runner.register_hook(timer_cfg) self.assertEqual(len(runner._hooks), 1) self.assertTrue(isinstance(runner._hooks[0], IterTimerHook)) - self.assertEqual( - get_priority(runner._hooks[0].priority), - get_priority('BELOW_NORMAL')) + self.assertEqual(get_priority(runner._hooks[0].priority), + get_priority('BELOW_NORMAL')) # 1.3 `hook` is a hook object runtime_info_hook = RuntimeInfoHook() @@ -2110,8 +2108,8 @@ def test_register_hook(self): # `IterTimerHook`, so the first item of `_hooks` should be # `runtime_info_hook` self.assertTrue(isinstance(runner._hooks[0], RuntimeInfoHook)) - self.assertEqual( - get_priority(runner._hooks[0].priority), get_priority('VERY_HIGH')) + self.assertEqual(get_priority(runner._hooks[0].priority), + get_priority('VERY_HIGH')) # 2. test `priority` parameter # `priority` argument is not None and it will be set as priority of @@ -2120,16 +2118,16 @@ def test_register_hook(self): runner.register_hook(param_scheduler_cfg, priority='VERY_LOW') self.assertEqual(len(runner._hooks), 3) self.assertTrue(isinstance(runner._hooks[2], ParamSchedulerHook)) - self.assertEqual( - get_priority(runner._hooks[2].priority), get_priority('VERY_LOW')) + self.assertEqual(get_priority(runner._hooks[2].priority), + get_priority('VERY_LOW')) # `priority` is Priority logger_cfg = dict(type='LoggerHook', priority='BELOW_NORMAL') runner.register_hook(logger_cfg, priority=Priority.VERY_LOW) self.assertEqual(len(runner._hooks), 4) self.assertTrue(isinstance(runner._hooks[3], LoggerHook)) - self.assertEqual( - get_priority(runner._hooks[3].priority), get_priority('VERY_LOW')) + self.assertEqual(get_priority(runner._hooks[3].priority), + get_priority('VERY_LOW')) def test_default_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -2189,8 +2187,9 @@ class CustomTrainLoop2(IterBasedTrainLoop): def __init__(self, runner, dataloader, max_iters, warmup_loader, max_warmup_iters): - super().__init__( - runner=runner, dataloader=dataloader, max_iters=max_iters) + super().__init__(runner=runner, + dataloader=dataloader, + max_iters=max_iters) self.warmup_loader = self.runner.build_dataloader( warmup_loader) self.max_warmup_iters = max_warmup_iters @@ -2213,13 +2212,13 @@ def run(self): self.runner.call_hook('after_train') def warmup_iter(self, data_batch): - self.runner.call_hook( - 'before_warmup_iter', data_batch=data_batch) + self.runner.call_hook('before_warmup_iter', + data_batch=data_batch) train_logs = self.runner.model.train_step( data_batch, self.runner.optim_wrapper) self.runner.message_hub.update_info('train_logs', train_logs) - self.runner.call_hook( - 'after_warmup_iter', data_batch=data_batch) + self.runner.call_hook('after_warmup_iter', + data_batch=data_batch) before_warmup_iter_results = [] after_warmup_iter_results = [] @@ -2237,11 +2236,11 @@ def after_warmup_iter(self, runner, data_batch=None, outputs=None): self.iter_based_cfg.train_cfg = dict( type='CustomTrainLoop2', max_iters=10, - warmup_loader=dict( - dataset=dict(type='ToyDataset'), - sampler=dict(type='InfiniteSampler', shuffle=True), - batch_size=1, - num_workers=0), + warmup_loader=dict(dataset=dict(type='ToyDataset'), + sampler=dict(type='InfiniteSampler', + shuffle=True), + batch_size=1, + num_workers=0), max_warmup_iters=5) self.iter_based_cfg.custom_hooks = [ dict(type='TestWarmupHook', priority=50) @@ -2304,8 +2303,8 @@ def test_checkpoint(self): # 1.3.1 test `resume` cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint3' - cfg.optim_wrapper = dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)) + cfg.optim_wrapper = dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.2)) cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) runner.resume(path) @@ -2380,12 +2379,13 @@ def test_checkpoint(self): # 1.6 multiple optimizers cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint6' - cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + cfg.optim_wrapper = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') cfg.model = dict(type='ToyGANModel') # disable OptimizerHook because it only works with one optimizer runner = Runner.from_cfg(cfg) @@ -2400,12 +2400,13 @@ def test_checkpoint(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint7' - cfg.optim_wrapper = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.2)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)), - constructor='ToyMultipleOptimizerConstructor') + cfg.optim_wrapper = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', + lr=0.2)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', + lr=0.03)), + constructor='ToyMultipleOptimizerConstructor') cfg.model = dict(type='ToyGANModel') cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3]) runner = Runner.from_cfg(cfg) @@ -2518,12 +2519,11 @@ def test_checkpoint(self): # 2.7.1 test `resume` 2 optimizers and 1 scheduler list. path = osp.join(self.temp_dir, 'epoch_3.pth') - optim_cfg = dict( - linear1=dict( - type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), - linear2=dict( - type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)), - constructor='ToyMultipleOptimizerConstructor') + optim_cfg = dict(linear1=dict(type='OptimWrapper', + optimizer=dict(type='SGD', lr=0.01)), + linear2=dict(type='OptimWrapper', + optimizer=dict(type='Adam', lr=0.02)), + constructor='ToyMultipleOptimizerConstructor') cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint14' cfg.optim_wrapper = optim_cfg @@ -2546,9 +2546,11 @@ def test_checkpoint(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_checkpoint16' cfg.optim_wrapper = optim_cfg - cfg.param_scheduler = dict( - linear1=dict(type='MultiStepLR', milestones=[1, 2, 3]), - linear2=dict(type='StepLR', gamma=0.1, step_size=3)) + cfg.param_scheduler = dict(linear1=dict(type='MultiStepLR', + milestones=[1, 2, 3]), + linear2=dict(type='StepLR', + gamma=0.1, + step_size=3)) cfg.model = dict(type='ToyGANModel') resumed_cfg = copy.deepcopy(cfg) runner = Runner.from_cfg(cfg) diff --git a/tests/test_strategies/test_fsdp.py b/tests/test_strategies/test_fsdp.py index 64b900d2f8..545651b5da 100644 --- a/tests/test_strategies/test_fsdp.py +++ b/tests/test_strategies/test_fsdp.py @@ -59,33 +59,29 @@ def test_init(self): strategy = FSDPStrategy(state_dict_cfg='full') self._assert_full(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.LOCAL_STATE_DICT)) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.LOCAL_STATE_DICT)) self._assert_local(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig(), - optim_state_dict_config=FullOptimStateDictConfig(), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(), + optim_state_dict_config=FullOptimStateDictConfig(), + )) self._assert_full(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type='FULL_STATE_DICT', - state_dict_config=dict(type='FullStateDictConfig'), - optim_state_dict_config=dict(type='FullOptimStateDictConfig'), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type='FULL_STATE_DICT', + state_dict_config=dict(type='FullStateDictConfig'), + optim_state_dict_config=dict(type='FullOptimStateDictConfig'), + )) self._assert_full(strategy) - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=dict(type=FullStateDictConfig), - optim_state_dict_config=dict(type=FullOptimStateDictConfig), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + )) self._assert_full(strategy) with self.assertRaises(ValueError): @@ -97,33 +93,28 @@ def test_init(self): # state_dict_type must be a str or a enumerate of StateDictType with self.assertRaises(TypeError): - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=[], - state_dict_config=dict(type=FullStateDictConfig), - optim_state_dict_config=dict( - type=FullOptimStateDictConfig), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=[], + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + )) # state_dict_config should be a dict or a subclass of StateDictConfig with self.assertRaises(TypeError): - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=[], - optim_state_dict_config=dict( - type=FullOptimStateDictConfig), - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=[], + optim_state_dict_config=dict(type=FullOptimStateDictConfig), + )) # optim_state_dict_config should be a dict or a subclass of # OptimStateDictConfig with self.assertRaises(TypeError): - strategy = FSDPStrategy( - state_dict_cfg=dict( - state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=dict(type=FullStateDictConfig), - optim_state_dict_config=[], - )) + strategy = FSDPStrategy(state_dict_cfg=dict( + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=dict(type=FullStateDictConfig), + optim_state_dict_config=[], + )) def run_strategy(self): # Strategy can run with the built model, optimizer and schedulers. @@ -168,8 +159,8 @@ def run_strategy(self): # optimizer with multiple param_groups can be reconstructed. model = ToyModel() - strategy = FSDPStrategy( - model_wrapper=dict(auto_wrap_policy=linear_wrap_policy)) + strategy = FSDPStrategy(model_wrapper=dict( + auto_wrap_policy=linear_wrap_policy)) param_groups = [] for param in model.parameters(): param_groups.append(dict(params=[param], lr=0.1)) @@ -204,10 +195,9 @@ def _worker(cls, rank, func): self.tearDown() def test_run_strategy(self): - start_processes( - TestStrategy._worker, - args=('run_strategy', ), - nprocs=self.world_size) + start_processes(TestStrategy._worker, + args=('run_strategy', ), + nprocs=self.world_size) def test_build_model(self): ... diff --git a/tests/test_structures/test_data_element.py b/tests/test_structures/test_data_element.py index 1cb7cd1745..d1c1e9afb9 100644 --- a/tests/test_structures/test_data_element.py +++ b/tests/test_structures/test_data_element.py @@ -29,8 +29,9 @@ def gt_instances(self): @gt_instances.setter def gt_instances(self, value): - self.set_field( - value=value, name='_gt_instances', dtype=BaseDataElement) + self.set_field(value=value, + name='_gt_instances', + dtype=BaseDataElement) @gt_instances.deleter def gt_instances(self): @@ -42,8 +43,9 @@ def pred_instances(self): @pred_instances.setter def pred_instances(self, value): - self.set_field( - value=value, name='_pred_instances', dtype=BaseDataElement) + self.set_field(value=value, + name='_pred_instances', + dtype=BaseDataElement) @pred_instances.deleter def pred_instances(self): @@ -53,13 +55,13 @@ def pred_instances(self): class TestBaseDataElement(TestCase): def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) - gt_instances = BaseDataElement( - bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) - pred_instances = BaseDataElement( - bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + metainfo = dict(img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), + random.randint(400, 600))) + gt_instances = BaseDataElement(bboxes=torch.rand((5, 4)), + labels=torch.rand((5, ))) + pred_instances = BaseDataElement(bboxes=torch.rand((5, 4)), + scores=torch.rand((5, ))) data = dict(gt_instances=gt_instances, pred_instances=pred_instances) return metainfo, data @@ -232,8 +234,8 @@ def test_set_data(self): def test_update(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) - proposals = BaseDataElement( - bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))) + proposals = BaseDataElement(bboxes=torch.rand((5, 4)), + scores=torch.rand((5, ))) new_instances = BaseDataElement(proposals=proposals) instances.update(new_instances) self.check_key_value(instances, metainfo, @@ -267,8 +269,8 @@ def test_delete_modify(self): del instances.gt_instances del instances.img_id - assert not self.is_equal( - instances.pop('pred_instances', None), data['pred_instances']) + assert not self.is_equal(instances.pop('pred_instances', None), + data['pred_instances']) with self.assertRaises(AttributeError): del instances.pred_instances @@ -293,8 +295,8 @@ def test_delete_modify(self): with self.assertRaises(AttributeError): del instances._data_fields - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='GPU is required!') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='GPU is required!') def test_cuda(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) @@ -338,8 +340,9 @@ def test_detach(self): def test_repr(self): metainfo = dict(img_shape=(800, 1196, 3)) - gt_instances = BaseDataElement( - metainfo=metainfo, det_labels=torch.LongTensor([0, 1, 2, 3])) + gt_instances = BaseDataElement(metainfo=metainfo, + det_labels=torch.LongTensor( + [0, 1, 2, 3])) sample = BaseDataElement(metainfo=metainfo, gt_instances=gt_instances) address = hex(id(sample)) address_gt_instances = hex(id(sample.gt_instances)) diff --git a/tests/test_structures/test_instance_data.py b/tests/test_structures/test_instance_data.py index fe4a1b2603..20009741ef 100644 --- a/tests/test_structures/test_instance_data.py +++ b/tests/test_structures/test_instance_data.py @@ -73,9 +73,9 @@ def __repr__(self): class TestInstanceData(TestCase): def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) + metainfo = dict(img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), + random.randint(400, 600))) instances_infos = [1] * 5 bboxes = torch.rand((5, 4)) labels = np.random.rand(5) @@ -83,15 +83,14 @@ def setup_data(self): ids = (1, 2, 3, 4, 5) name_ids = '12345' polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist()) - instance_data = InstanceData( - metainfo=metainfo, - bboxes=bboxes, - labels=labels, - polygons=polygons, - kps=kps, - ids=ids, - name_ids=name_ids, - instances_infos=instances_infos) + instance_data = InstanceData(metainfo=metainfo, + bboxes=bboxes, + labels=labels, + polygons=polygons, + kps=kps, + ids=ids, + name_ids=name_ids, + instances_infos=instances_infos) return instance_data def test_set_data(self): @@ -189,8 +188,8 @@ def test_cat(self): assert len(cat_instance_data) == 10 # All inputs must be InstanceData - instance_data_2 = BaseDataElement( - bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))) + instance_data_2 = BaseDataElement(bboxes=torch.rand((5, 4)), + labels=torch.rand((5, ))) with self.assertRaises(AssertionError): InstanceData.cat([instance_data_1, instance_data_2]) @@ -208,11 +207,10 @@ def test_cat(self): instance_data_1.polygons = TmpObjectWithoutCat( np.arange(25).reshape((5, -1)).tolist()) instance_data_2 = instance_data_1.clone() - with pytest.raises( - ValueError, - match=('The type of `polygons` is ' - f'`{type(instance_data_1.polygons)}` ' - 'which has no attribute of `cat`')): + with pytest.raises(ValueError, + match=('The type of `polygons` is ' + f'`{type(instance_data_1.polygons)}` ' + 'which has no attribute of `cat`')): cat_instance_data = InstanceData.cat( [instance_data_1, instance_data_2]) diff --git a/tests/test_structures/test_label_data.py b/tests/test_structures/test_label_data.py index 8c73bca767..7cb771019f 100644 --- a/tests/test_structures/test_label_data.py +++ b/tests/test_structures/test_label_data.py @@ -21,10 +21,11 @@ def test_label_to_onehot(self): # item'max bigger than num_classes with self.assertRaises(AssertionError): - LabelData.label_to_onehot( - torch.tensor([11], dtype=torch.int64), num_classes) - onehot = LabelData.label_to_onehot( - label=torch.tensor([], dtype=torch.int64), num_classes=num_classes) + LabelData.label_to_onehot(torch.tensor([11], dtype=torch.int64), + num_classes) + onehot = LabelData.label_to_onehot(label=torch.tensor( + [], dtype=torch.int64), + num_classes=num_classes) assert (onehot == torch.zeros((num_classes, ), dtype=torch.int64)).all() @@ -50,8 +51,8 @@ def test_onehot_to_label(self): assert label == item assert label.device == item.device - @pytest.mark.skipif( - not torch.cuda.is_available(), reason='GPU is required!') + @pytest.mark.skipif(not torch.cuda.is_available(), + reason='GPU is required!') def test_cuda(self): item = torch.arange(0, 9).cuda() onehot = LabelData.label_to_onehot(item, num_classes=10) diff --git a/tests/test_structures/test_pixel_data.py b/tests/test_structures/test_pixel_data.py index 1ca80373af..34fcc249b8 100644 --- a/tests/test_structures/test_pixel_data.py +++ b/tests/test_structures/test_pixel_data.py @@ -12,9 +12,9 @@ class TestPixelData(TestCase): def setup_data(self): - metainfo = dict( - img_id=random.randint(0, 100), - img_shape=(random.randint(400, 600), random.randint(400, 600))) + metainfo = dict(img_id=random.randint(0, 100), + img_shape=(random.randint(400, 600), + random.randint(400, 600))) image = np.random.randint(0, 255, (4, 20, 40)) featmap = torch.randint(0, 255, (10, 20, 40)) pixel_data = PixelData(metainfo=metainfo, image=image, featmap=featmap) diff --git a/tests/test_testing/test_runner_test_case.py b/tests/test_testing/test_runner_test_case.py index 5d41c03531..be93e74ee6 100644 --- a/tests/test_testing/test_runner_test_case.py +++ b/tests/test_testing/test_runner_test_case.py @@ -46,8 +46,8 @@ def test_experiment_name(self): def test_init_dist(self): self.setup_dist_env() - self.assertEqual( - str(self.dist_cfg['MASTER_PORT']), os.environ['MASTER_PORT']) + self.assertEqual(str(self.dist_cfg['MASTER_PORT']), + os.environ['MASTER_PORT']) self.assertEqual(self.dist_cfg['MASTER_ADDR'], os.environ['MASTER_ADDR']) self.assertEqual(self.dist_cfg['RANK'], os.environ['RANK']) diff --git a/tests/test_utils/test_dl_utils/test_setup_env.py b/tests/test_utils/test_dl_utils/test_setup_env.py index 9ca98b4311..74c6881233 100644 --- a/tests/test_utils/test_dl_utils/test_setup_env.py +++ b/tests/test_utils/test_dl_utils/test_setup_env.py @@ -38,8 +38,9 @@ def test_setup_multi_processes(): assert os.getenv('OMP_NUM_THREADS') == '4' # test manually set opencv threads and mp start method - config = dict( - mp_start_method='spawn', opencv_num_threads=4, distributed=True) + config = dict(mp_start_method='spawn', + opencv_num_threads=4, + distributed=True) set_multi_processing(**config) assert cv2.getNumThreads() == 4 assert mp.get_start_method() == 'spawn' diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 7c43d04853..71af2e7c6a 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -158,8 +158,8 @@ def test_import_modules_from_strings(): with pytest.raises(ImportError): import_modules_from_strings('_not_implemented_module') with pytest.warns(UserWarning): - imported = import_modules_from_strings( - '_not_implemented_module', allow_failed_imports=True) + imported = import_modules_from_strings('_not_implemented_module', + allow_failed_imports=True) assert imported is None with pytest.warns(UserWarning): imported = import_modules_from_strings(['os.path', '_not_implemented'], diff --git a/tests/test_utils/test_package_utils.py b/tests/test_utils/test_package_utils.py index bed91b6c18..e271e9d314 100644 --- a/tests/test_utils/test_package_utils.py +++ b/tests/test_utils/test_package_utils.py @@ -2,9 +2,13 @@ import os.path as osp import sys -import pkg_resources import pytest +try: + from importlib.metadata import PackageNotFoundError +except ImportError: + from importlib_metadata import PackageNotFoundError # type: ignore[import-untyped, no-redef, import-not-found] # noqa: E501 + from mmengine.utils import get_installed_path, is_installed @@ -33,5 +37,5 @@ def test_get_install_path(): assert get_installed_path('optim') == osp.join(PYTHONPATH, 'optim') sys.path.pop() - with pytest.raises(pkg_resources.DistributionNotFound): + with pytest.raises(PackageNotFoundError): get_installed_path('unknown') diff --git a/tests/test_utils/test_progressbar.py b/tests/test_utils/test_progressbar.py index 0636e25e1d..c2635f2d6c 100644 --- a/tests/test_utils/test_progressbar.py +++ b/tests/test_utils/test_progressbar.py @@ -23,8 +23,9 @@ def test_start(self): prog_bar = mmengine.ProgressBar(bar_width=bar_width, file=out) assert out.getvalue() == 'completed: 0, elapsed: 0s' reset_string_io(out) - prog_bar = mmengine.ProgressBar( - bar_width=bar_width, start=False, file=out) + prog_bar = mmengine.ProgressBar(bar_width=bar_width, + start=False, + file=out) assert out.getvalue() == '' reset_string_io(out) prog_bar.start() @@ -34,16 +35,17 @@ def test_start(self): prog_bar = mmengine.ProgressBar(10, bar_width=bar_width, file=out) assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' reset_string_io(out) - prog_bar = mmengine.ProgressBar( - 10, bar_width=bar_width, start=False, file=out) + prog_bar = mmengine.ProgressBar(10, + bar_width=bar_width, + start=False, + file=out) assert out.getvalue() == '' reset_string_io(out) prog_bar.start() assert out.getvalue() == f'[{" " * bar_width}] 0/10, elapsed: 0s, ETA:' - @skipIf( - platform.system() != 'Linux', - reason='Only test `TestProgressBar.test_update` in Linux') + @skipIf(platform.system() != 'Linux', + reason='Only test `TestProgressBar.test_update` in Linux') def test_update(self): out = StringIO() bar_width = 20 @@ -62,9 +64,8 @@ def test_update(self): assert out.getvalue() == f'\r[{">" * 2 + " " * 18}] 1/10, 1.0 ' \ 'task/s, elapsed: 1s, ETA: 9s' - @skipIf( - platform.system() != 'Linux', - reason='Only test `TestProgressBar.test_adaptive_length` in Linux') + @skipIf(platform.system() != 'Linux', + reason='Only test `TestProgressBar.test_adaptive_length` in Linux') def test_adaptive_length(self): with patch.dict('os.environ', {'COLUMNS': '80'}): out = StringIO() @@ -108,13 +109,16 @@ def test_track_progress(): assert ret == [1, 2, 3] # tasks is an iterable object - ret = mmengine.track_progress( - return_itself, ((i for i in [1, 2, 3]), 3), bar_width=3, file=out) + ret = mmengine.track_progress(return_itself, ((i for i in [1, 2, 3]), 3), + bar_width=3, + file=out) assert ret == [1, 2, 3] # tasks is a range object - ret = mmengine.track_progress( - return_itself, range(1, 4), bar_width=3, file=out) + ret = mmengine.track_progress(return_itself, + range(1, 4), + bar_width=3, + file=out) assert ret == [1, 2, 3] @@ -143,19 +147,24 @@ def test_track_iter_progress(): def test_track_parallel_progress(): # tasks is a list out = StringIO() - ret = mmengine.track_parallel_progress( - return_itself, [1, 2, 3, 4], 2, bar_width=4, file=out) + ret = mmengine.track_parallel_progress(return_itself, [1, 2, 3, 4], + 2, + bar_width=4, + file=out) assert ret == [1, 2, 3, 4] # tasks is an iterable object - ret = mmengine.track_parallel_progress( - return_itself, ((i for i in [1, 2, 3, 4]), 4), - 2, - bar_width=4, - file=out) + ret = mmengine.track_parallel_progress(return_itself, + ((i for i in [1, 2, 3, 4]), 4), + 2, + bar_width=4, + file=out) assert ret == [1, 2, 3, 4] # tasks is a range object - ret = mmengine.track_parallel_progress( - return_itself, range(1, 5), 2, bar_width=4, file=out) + ret = mmengine.track_parallel_progress(return_itself, + range(1, 5), + 2, + bar_width=4, + file=out) assert ret == [1, 2, 3, 4] diff --git a/tests/test_utils/test_timer.py b/tests/test_utils/test_timer.py index 570f7ea380..de83d17527 100644 --- a/tests/test_utils/test_timer.py +++ b/tests/test_utils/test_timer.py @@ -7,8 +7,8 @@ import mmengine -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != 'Linux', + reason='Only test `Timer` in linux!') def test_timer_init(): timer = mmengine.Timer(start=False) assert not timer.is_running @@ -18,8 +18,8 @@ def test_timer_init(): assert timer.is_running -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != 'Linux', + reason='Only test `Timer` in linux!') def test_timer_run(): timer = mmengine.Timer() time.sleep(1) @@ -36,8 +36,8 @@ def test_timer_run(): timer.since_last_check() -@pytest.mark.skipif( - platform.system() != 'Linux', reason='Only test `Timer` in linux!') +@pytest.mark.skipif(platform.system() != 'Linux', + reason='Only test `Timer` in linux!') def test_timer_context(capsys): with mmengine.Timer(): time.sleep(1) diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py index c991462ef9..b04e24a7fd 100644 --- a/tests/test_visualizer/test_vis_backend.py +++ b/tests/test_visualizer/test_vis_backend.py @@ -156,8 +156,9 @@ def test_add_scalar(self): tensorboard_vis_backend.add_scalar('map', np.array(9), step=0) tensorboard_vis_backend.add_scalar('map', np.array(95), step=1) tensorboard_vis_backend.add_scalar('map', np.array([9])[0], step=0) - tensorboard_vis_backend.add_scalar( - 'map', np.array([95])[0], step=1) + tensorboard_vis_backend.add_scalar('map', + np.array([95])[0], + step=1) assert len(record) == 0 # test with tensor tensorboard_vis_backend.add_scalar('map', torch.tensor(0.9), step=0) @@ -266,8 +267,8 @@ def test_define_metric_cfg(self): wandb_vis_backend = WandbVisBackend( 'temp_dir', define_metric_cfg=define_metric_cfg) wandb_vis_backend._init_env() - wandb_vis_backend._wandb.define_metric.assert_any_call( - 'test3', summary='max') + wandb_vis_backend._wandb.define_metric.assert_any_call('test3', + summary='max') shutil.rmtree('temp_dir') @@ -284,11 +285,11 @@ def test_experiment(self): def test_create_experiment(self): with patch('mlflow.create_experiment') as mock_create_experiment: - MLflowVisBackend( - 'temp_dir', exp_name='test', - artifact_location='foo')._init_env() - mock_create_experiment.assert_any_call( - 'test', artifact_location='foo') + MLflowVisBackend('temp_dir', + exp_name='test', + artifact_location='foo')._init_env() + mock_create_experiment.assert_any_call('test', + artifact_location='foo') def test_add_config(self): cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) @@ -366,8 +367,8 @@ def test_close(self): clearml_vis_backend.close() -@pytest.mark.skipif( - not is_installed('neptune'), reason='Neptune is not installed.') +@pytest.mark.skipif(not is_installed('neptune'), + reason='Neptune is not installed.') class TestNeptuneVisBackend: def test_init(self): @@ -457,9 +458,8 @@ def test_close(self): shutil.rmtree('temp_dir') -@pytest.mark.skipif( - platform.system() == 'Windows', - reason='Aim does not support Windows for now.') +@pytest.mark.skipif(platform.system() == 'Windows', + reason='Aim does not support Windows for now.') class TestAimVisBackend: def test_init(self): diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index e4ababc637..f7d9a06f1d 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -57,8 +57,8 @@ def setUp(self): TestCase calls functions in this order: setUp() -> testMethod() -> tearDown() -> cleanUp() """ - self.image = np.random.randint( - 0, 256, size=(10, 10, 3)).astype('uint8') + self.image = np.random.randint(0, 256, + size=(10, 10, 3)).astype('uint8') self.vis_backend_cfg = [ dict(type='MockVisBackend', name='mock1'), dict(type='MockVisBackend', name='mock2') @@ -72,35 +72,33 @@ def test_init(self): visualizer = Visualizer( vis_backends=copy.deepcopy(self.vis_backend_cfg)) - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') assert isinstance(visualizer.get_backend('mock1'), MockVisBackend) assert len(visualizer._vis_backends) == 2 # The name fields cannot be the same with pytest.raises(RuntimeError): - Visualizer( - vis_backends=[ - dict(type='MockVisBackend'), - dict(type='MockVisBackend') - ], - save_dir='temp_dir') + Visualizer(vis_backends=[ + dict(type='MockVisBackend'), + dict(type='MockVisBackend') + ], + save_dir='temp_dir') with pytest.raises(RuntimeError): - Visualizer( - vis_backends=[ - dict(type='MockVisBackend', name='mock1'), - dict(type='MockVisBackend', name='mock1') - ], - save_dir='temp_dir') + Visualizer(vis_backends=[ + dict(type='MockVisBackend', name='mock1'), + dict(type='MockVisBackend', name='mock1') + ], + save_dir='temp_dir') # test global init instance_name = 'visualizer' + str(time.time()) - visualizer = Visualizer.get_instance( - instance_name, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer.get_instance(instance_name, + vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') assert len(visualizer._vis_backends) == 2 visualizer_any = Visualizer.get_instance(instance_name) assert visualizer_any == visualizer @@ -120,9 +118,10 @@ def __init__(self, save_dir: str) -> None: VISBACKENDS.module_dict.pop('CustomLocalVisBackend') - visualizer = Visualizer.get_instance( - 'test_save_dir', - vis_backends=dict(type='CustomLocalVisBackend', save_dir='tmp')) + visualizer = Visualizer.get_instance('test_save_dir', + vis_backends=dict( + type='CustomLocalVisBackend', + save_dir='tmp')) visualizer = Visualizer.get_instance( 'test_save_dir', vis_backends=[CustomLocalVisBackend('tmp')]) @@ -148,8 +147,10 @@ def test_draw_bboxes(self): # valid bbox visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2])) bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]]) - visualizer.draw_bboxes( - bboxes, alpha=0.5, edge_colors=(255, 0, 0), line_styles='-') + visualizer.draw_bboxes(bboxes, + alpha=0.5, + edge_colors=(255, 0, 0), + line_styles='-') bboxes = bboxes.numpy() visualizer.draw_bboxes(bboxes) @@ -159,10 +160,9 @@ def test_draw_bboxes(self): visualizer.draw_bboxes(torch.tensor([5, 1, 2, 2])) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The bbox is out of bounds,' - ' the drawn bbox may not be in the image'): + with pytest.warns(UserWarning, + match='Warning: The bbox is out of bounds,' + ' the drawn bbox may not be in the image'): visualizer.draw_bboxes(torch.tensor([1, 1, 20, 2])) # test incorrect bbox format @@ -170,10 +170,10 @@ def test_draw_bboxes(self): visualizer.draw_bboxes([1, 1, 2, 2]) def test_close(self): - visualizer = Visualizer( - image=self.image, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(image=self.image, + vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._close is False @@ -189,36 +189,33 @@ def test_draw_points(self): with pytest.raises(AssertionError): visualizer.draw_points(positions=np.array([1, 2, 3], dtype=object)) # test color - visualizer.draw_points( - positions=torch.tensor([[1, 1], [3, 3]]), - colors=['g', (255, 255, 0)]) - visualizer.draw_points( - positions=torch.tensor([[1, 1], [3, 3]]), - colors=['g', (255, 255, 0)], - marker='.', - sizes=[1, 5]) + visualizer.draw_points(positions=torch.tensor([[1, 1], [3, 3]]), + colors=['g', (255, 255, 0)]) + visualizer.draw_points(positions=torch.tensor([[1, 1], [3, 3]]), + colors=['g', (255, 255, 0)], + marker='.', + sizes=[1, 5]) def test_draw_texts(self): visualizer = Visualizer(image=self.image) # only support tensor and numpy - visualizer.draw_texts( - 'text1', positions=torch.tensor([5, 5]), colors=(0, 255, 0)) + visualizer.draw_texts('text1', + positions=torch.tensor([5, 5]), + colors=(0, 255, 0)) visualizer.draw_texts(['text1', 'text2'], positions=torch.tensor([[5, 5], [3, 3]]), colors=[(255, 0, 0), (255, 0, 0)]) visualizer.draw_texts('text1', positions=np.array([5, 5])) visualizer.draw_texts(['text1', 'text2'], positions=np.array([[5, 5], [3, 3]])) - visualizer.draw_texts( - 'text1', - positions=torch.tensor([5, 5]), - bboxes=dict(facecolor='r', alpha=0.6)) + visualizer.draw_texts('text1', + positions=torch.tensor([5, 5]), + bboxes=dict(facecolor='r', alpha=0.6)) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The text is out of bounds,' - ' the drawn text may not be in the image'): + with pytest.warns(UserWarning, + match='Warning: The text is out of bounds,' + ' the drawn text may not be in the image'): visualizer.draw_texts('text1', positions=torch.tensor([15, 5])) # test incorrect format @@ -230,8 +227,8 @@ def test_draw_texts(self): visualizer.draw_texts(['text1', 'text2'], positions=torch.tensor([5, 5])) with pytest.raises(AssertionError): - visualizer.draw_texts( - 'text1', positions=torch.tensor([[5, 5], [3, 3]])) + visualizer.draw_texts('text1', + positions=torch.tensor([[5, 5], [3, 3]])) with pytest.raises(AssertionError): visualizer.draw_texts(['text1', 'test2'], positions=torch.tensor([[5, 5], [3, 3]]), @@ -259,24 +256,21 @@ def test_draw_lines(self): visualizer = Visualizer(image=self.image) # only support tensor and numpy - visualizer.draw_lines( - x_datas=torch.tensor([1, 5]), y_datas=torch.tensor([2, 6])) - visualizer.draw_lines( - x_datas=np.array([[1, 5], [2, 4]]), - y_datas=np.array([[2, 6], [4, 7]])) - visualizer.draw_lines( - x_datas=np.array([[1, 5], [2, 4]]), - y_datas=np.array([[2, 6], [4, 7]]), - colors='r', - line_styles=['-', '-.'], - line_widths=[1, 2]) + visualizer.draw_lines(x_datas=torch.tensor([1, 5]), + y_datas=torch.tensor([2, 6])) + visualizer.draw_lines(x_datas=np.array([[1, 5], [2, 4]]), + y_datas=np.array([[2, 6], [4, 7]])) + visualizer.draw_lines(x_datas=np.array([[1, 5], [2, 4]]), + y_datas=np.array([[2, 6], [4, 7]]), + colors='r', + line_styles=['-', '-.'], + line_widths=[1, 2]) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The line is out of bounds,' - ' the drawn line may not be in the image'): - visualizer.draw_lines( - x_datas=torch.tensor([12, 5]), y_datas=torch.tensor([2, 6])) + with pytest.warns(UserWarning, + match='Warning: The line is out of bounds,' + ' the drawn line may not be in the image'): + visualizer.draw_lines(x_datas=torch.tensor([12, 5]), + y_datas=torch.tensor([2, 6])) # test incorrect format with pytest.raises(TypeError): @@ -286,9 +280,8 @@ def test_draw_lines(self): # test length mismatch with pytest.raises(AssertionError): - visualizer.draw_lines( - x_datas=torch.tensor([1, 5]), - y_datas=torch.tensor([[2, 6], [4, 7]])) + visualizer.draw_lines(x_datas=torch.tensor([1, 5]), + y_datas=torch.tensor([[2, 6], [4, 7]])) def test_draw_circles(self): visualizer = Visualizer(image=self.image) @@ -296,33 +289,30 @@ def test_draw_circles(self): # only support tensor and numpy visualizer.draw_circles(torch.tensor([1, 5]), torch.tensor([1])) visualizer.draw_circles(np.array([1, 5]), np.array([1])) - visualizer.draw_circles( - torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2])) + visualizer.draw_circles(torch.tensor([[1, 5], [2, 6]]), + radius=torch.tensor([1, 2])) # test face_colors - visualizer.draw_circles( - torch.tensor([[1, 5], [2, 6]]), - radius=torch.tensor([1, 2]), - face_colors=(255, 0, 0), - edge_colors=(255, 0, 0)) + visualizer.draw_circles(torch.tensor([[1, 5], [2, 6]]), + radius=torch.tensor([1, 2]), + face_colors=(255, 0, 0), + edge_colors=(255, 0, 0)) # test config - visualizer.draw_circles( - torch.tensor([[1, 5], [2, 6]]), - radius=torch.tensor([1, 2]), - edge_colors=['g', 'r'], - line_styles=['-', '-.'], - line_widths=[1, 2]) + visualizer.draw_circles(torch.tensor([[1, 5], [2, 6]]), + radius=torch.tensor([1, 2]), + edge_colors=['g', 'r'], + line_styles=['-', '-.'], + line_widths=[1, 2]) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The circle is out of bounds,' - ' the drawn circle may not be in the image'): - visualizer.draw_circles( - torch.tensor([12, 5]), radius=torch.tensor([1])) - visualizer.draw_circles( - torch.tensor([1, 5]), radius=torch.tensor([10])) + with pytest.warns(UserWarning, + match='Warning: The circle is out of bounds,' + ' the drawn circle may not be in the image'): + visualizer.draw_circles(torch.tensor([12, 5]), + radius=torch.tensor([1])) + visualizer.draw_circles(torch.tensor([1, 5]), + radius=torch.tensor([10])) # test incorrect format with pytest.raises(TypeError): @@ -332,8 +322,8 @@ def test_draw_circles(self): # test length mismatch with pytest.raises(AssertionError): - visualizer.draw_circles( - torch.tensor([[1, 5]]), radius=torch.tensor([1, 2])) + visualizer.draw_circles(torch.tensor([[1, 5]]), + radius=torch.tensor([1, 2])) def test_draw_polygons(self): visualizer = Visualizer(image=self.image) @@ -344,27 +334,24 @@ def test_draw_polygons(self): np.array([[1, 1], [2, 2], [3, 4]]), torch.tensor([[1, 1], [2, 2], [3, 4]]) ]) - visualizer.draw_polygons( - polygons=[ - np.array([[1, 1], [2, 2], [3, 4]]), - torch.tensor([[1, 1], [2, 2], [3, 4]]) - ], - face_colors=(255, 0, 0), - edge_colors=(255, 0, 0)) - visualizer.draw_polygons( - polygons=[ - np.array([[1, 1], [2, 2], [3, 4]]), - torch.tensor([[1, 1], [2, 2], [3, 4]]) - ], - edge_colors=['r', 'g'], - line_styles='-', - line_widths=[2, 1]) + visualizer.draw_polygons(polygons=[ + np.array([[1, 1], [2, 2], [3, 4]]), + torch.tensor([[1, 1], [2, 2], [3, 4]]) + ], + face_colors=(255, 0, 0), + edge_colors=(255, 0, 0)) + visualizer.draw_polygons(polygons=[ + np.array([[1, 1], [2, 2], [3, 4]]), + torch.tensor([[1, 1], [2, 2], [3, 4]]) + ], + edge_colors=['r', 'g'], + line_styles='-', + line_widths=[2, 1]) # test out of bounds - with pytest.warns( - UserWarning, - match='Warning: The polygon is out of bounds,' - ' the drawn polygon may not be in the image'): + with pytest.warns(UserWarning, + match='Warning: The polygon is out of bounds,' + ' the drawn polygon may not be in the image'): visualizer.draw_polygons(torch.tensor([[1, 1], [2, 2], [16, 4]])) def test_draw_binary_masks(self): @@ -388,8 +375,8 @@ def test_draw_binary_masks(self): # test color dim with pytest.raises(AssertionError): - visualizer.draw_binary_masks( - binary_mask, colors=np.array([1, 22, 4, 45])) + visualizer.draw_binary_masks(binary_mask, + colors=np.array([1, 22, 4, 45])) binary_mask = np.random.randint(0, 2, size=(10, 10)) with pytest.raises(AssertionError): visualizer.draw_binary_masks(binary_mask) @@ -399,15 +386,14 @@ def test_draw_featmap(self): image = np.random.randint(0, 256, size=(3, 3, 3), dtype='uint8') # must be Tensor - with pytest.raises( - AssertionError, - match='`featmap` should be torch.Tensor, but got ' - ""): + with pytest.raises(AssertionError, + match='`featmap` should be torch.Tensor, but got ' + ""): visualizer.draw_featmap(np.ones((3, 3, 3))) # test tensor format - with pytest.raises( - AssertionError, match='Input dimension must be 3, but got 4'): + with pytest.raises(AssertionError, + match='Input dimension must be 3, but got 4'): visualizer.draw_featmap(torch.randn(1, 1, 3, 3)) # test overlaid_image shape @@ -415,29 +401,29 @@ def test_draw_featmap(self): visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image) # test resize_shape - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), resize_shape=(6, 7)) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), + resize_shape=(6, 7)) assert featmap.shape[:2] == (6, 7) - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7)) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), + overlaid_image=image, + resize_shape=(6, 7)) assert featmap.shape[:2] == (6, 7) # test channel_reduction parameter # mode only supports 'squeeze_mean' and 'select_max' with pytest.raises(AssertionError): - visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='xx') + visualizer.draw_featmap(torch.randn(2, 3, 3), + channel_reduction='xx') - featmap = visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='squeeze_mean') + featmap = visualizer.draw_featmap(torch.randn(2, 3, 3), + channel_reduction='squeeze_mean') assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(2, 3, 3), channel_reduction='select_max') + featmap = visualizer.draw_featmap(torch.randn(2, 3, 3), + channel_reduction='select_max') assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(2, 4, 3), - overlaid_image=image, - channel_reduction='select_max') + featmap = visualizer.draw_featmap(torch.randn(2, 4, 3), + overlaid_image=image, + channel_reduction='select_max') assert featmap.shape[:2] == (3, 3) # test topk parameter @@ -448,53 +434,54 @@ def test_draw_featmap(self): 'dimension you input is 6, you can use the ' 'channel_reduction parameter or set topk ' 'greater than 0 to solve the error'): - visualizer.draw_featmap( - torch.randn(6, 3, 3), channel_reduction=None, topk=0) + visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction=None, + topk=0) - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), channel_reduction='select_max', topk=10) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction='select_max', + topk=10) assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(1, 4, 3), channel_reduction=None, topk=-1) + featmap = visualizer.draw_featmap(torch.randn(1, 4, 3), + channel_reduction=None, + topk=-1) assert featmap.shape[:2] == (4, 3) - featmap = visualizer.draw_featmap( - torch.randn(3, 4, 3), - overlaid_image=image, - channel_reduction=None, - topk=-1) + featmap = visualizer.draw_featmap(torch.randn(3, 4, 3), + overlaid_image=image, + channel_reduction=None, + topk=-1) assert featmap.shape[:2] == (3, 3) - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(2, 2)) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(2, 2)) assert featmap.shape[:2] == (6, 6) - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(1, 4)) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(1, 4)) assert featmap.shape[:2] == (3, 12) with pytest.raises( AssertionError, match='The product of row and col in the `arrangement` ' 'is less than topk, please set ' 'the `arrangement` correctly'): - visualizer.draw_featmap( - torch.randn(6, 3, 3), - channel_reduction=None, - topk=4, - arrangement=(1, 2)) + visualizer.draw_featmap(torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(1, 2)) # test gray - featmap = visualizer.draw_featmap( - torch.randn(6, 3, 3), - overlaid_image=np.random.randint( - 0, 256, size=(3, 3), dtype='uint8'), - channel_reduction=None, - topk=4, - arrangement=(2, 2)) + featmap = visualizer.draw_featmap(torch.randn(6, 3, 3), + overlaid_image=np.random.randint( + 0, + 256, + size=(3, 3), + dtype='uint8'), + channel_reduction=None, + topk=4, + arrangement=(2, 2)) assert featmap.shape[:2] == (6, 6) def test_chain_call(self): @@ -509,17 +496,17 @@ def test_chain_call(self): draw_binary_masks(binary_mask) def test_get_backend(self): - visualizer = Visualizer( - image=self.image, - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(image=self.image, + vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') for name in ['mock1', 'mock2']: assert isinstance(visualizer.get_backend(name), MockVisBackend) def test_add_config(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) visualizer.add_config(cfg) @@ -527,9 +514,9 @@ def test_add_config(self): assert visualizer.get_backend(name)._add_config is True def test_add_graph(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') class Model(nn.Module): @@ -546,26 +533,26 @@ def forward(self, x, y=None): def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') visualizer.add_image('img', image) for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._add_image is True def test_add_scalar(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') visualizer.add_scalar('map', 0.9, step=0) for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._add_scalar is True def test_add_scalars(self): - visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg), - save_dir='temp_dir') + visualizer = Visualizer(vis_backends=copy.deepcopy( + self.vis_backend_cfg), + save_dir='temp_dir') input_dict = {'map': 0.7, 'acc': 0.9} visualizer.add_scalars(input_dict) for name in ['mock1', 'mock2']: @@ -597,51 +584,44 @@ def test_show(self): patch('mmengine.visualization.visualizer.wait_continue', wait_continue): # test default backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='matplotlib') + visualizer.show(drawn_img=img, + win_name='test_show', + wait_time=0, + backend='matplotlib') assert hasattr(visualizer, 'manager') calls = [ - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' ') + call(visualizer.manager.canvas.figure, + timeout=0, + continue_key=' ') ] wait_continue.assert_has_calls(calls) # matplotlib backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='matplotlib') + visualizer.show(drawn_img=img, + win_name='test_show', + wait_time=0, + backend='matplotlib') assert hasattr(visualizer, 'manager') calls = [ - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' '), - call( - visualizer.manager.canvas.figure, - timeout=0, - continue_key=' ') + call(visualizer.manager.canvas.figure, + timeout=0, + continue_key=' '), + call(visualizer.manager.canvas.figure, + timeout=0, + continue_key=' ') ] wait_continue.assert_has_calls(calls) # cv2 backend - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='cv2') + visualizer.show(drawn_img=img, + win_name='test_show', + wait_time=0, + backend='cv2') cv2.imshow.assert_called_once_with(str(id(visualizer)), img) # unknown backend with pytest.raises(ValueError): - visualizer.show( - drawn_img=img, - win_name='test_show', - wait_time=0, - backend='unknown') + visualizer.show(drawn_img=img, + win_name='test_show', + wait_time=0, + backend='unknown') From e93af0a3f05ab08ce2641b79018d2fbfc0887f88 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 25 Oct 2025 18:20:33 +0000 Subject: [PATCH 26/35] [Fix] Fix config bug in python312 --- mmengine/config/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py index 81b58fb49a..bb15d689bd 100644 --- a/mmengine/config/utils.py +++ b/mmengine/config/utils.py @@ -175,6 +175,8 @@ def _is_builtin_module(module_name: str) -> bool: origin_path = getattr(spec, 'origin', None) if origin_path is None: return False + if origin_path == 'frozen': + return True origin_path = osp.abspath(origin_path) if ('site-package' in origin_path or 'dist-package' in origin_path or not origin_path.startswith( From 93afd47d8377afd648417d187051a0080da13715 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Sun, 26 Oct 2025 02:48:23 +0800 Subject: [PATCH 27/35] [Fix] Load checkpoint with `weights_only=Flase` (#1670) --- mmengine/runner/checkpoint.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 7cd323092f..20c8f9c814 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -412,7 +412,9 @@ def load_from_pavi(filename, map_location=None): with TemporaryDirectory() as tmp_dir: downloaded_file = osp.join(tmp_dir, model.name) model.download(downloaded_file) - checkpoint = torch.load(downloaded_file, map_location=map_location) + checkpoint = torch.load(downloaded_file, + map_location=map_location, + weights_only=False) return checkpoint @@ -435,7 +437,9 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): file_backend = get_file_backend(filename, backend_args={'backend': backend}) with io.BytesIO(file_backend.get(filename)) as buffer: - checkpoint = torch.load(buffer, map_location=map_location) + checkpoint = torch.load(buffer, + map_location=map_location, + weights_only=False) return checkpoint @@ -504,7 +508,9 @@ def load_from_openmmlab(filename, map_location=None): filename = osp.join(_get_mmengine_home(), model_url) if not osp.isfile(filename): raise FileNotFoundError(f'{filename} can not be found.') - checkpoint = torch.load(filename, map_location=map_location) + checkpoint = torch.load(filename, + map_location=map_location, + weights_only=False) return checkpoint From fed514af89d3b7e5d18279c360365ffb452e52f5 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 25 Oct 2025 20:32:31 +0000 Subject: [PATCH 28/35] [Fix] Fix unit test of checkpoint hook --- tests/test_hooks/test_checkpoint_hook.py | 32 +++++++++++++++++------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/test_hooks/test_checkpoint_hook.py b/tests/test_hooks/test_checkpoint_hook.py index eb7ac967cb..fa95d0b5ce 100644 --- a/tests/test_hooks/test_checkpoint_hook.py +++ b/tests/test_hooks/test_checkpoint_hook.py @@ -469,13 +469,17 @@ def test_with_runner(self, training_type): cfg = copy.deepcopy(common_cfg) runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertIn('optimizer', ckpt) cfg.default_hooks.checkpoint.save_optimizer = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertNotIn('optimizer', ckpt) # Test save_param_scheduler=False @@ -489,13 +493,17 @@ def test_with_runner(self, training_type): ] runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertIn('param_schedulers', ckpt) cfg.default_hooks.checkpoint.save_param_scheduler = False runner = self.build_runner(cfg) runner.train() - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertNotIn('param_schedulers', ckpt) self.clear_work_dir() @@ -543,7 +551,9 @@ def test_with_runner(self, training_type): self.assertFalse( osp.isfile(osp.join(cfg.work_dir, f'{training_type}_{i}.pth'))) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_11.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_11.pth'), + weights_only=False) self.assertEqual(ckpt['message_hub']['runtime_info']['keep_ckpt_ids'], [9, 10, 11]) @@ -584,9 +594,11 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=False) - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_5.pth'), + weights_only=False) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) @@ -613,11 +625,13 @@ def test_with_runner(self, training_type): runner.train() best_ckpt_path = osp.join(cfg.work_dir, f'best_test_acc_{training_type}_5.pth') - best_ckpt = torch.load(best_ckpt_path) + best_ckpt = torch.load(best_ckpt_path, weights_only=False) # if the current ckpt is the best, the interval will be ignored the # the ckpt will also be saved - ckpt = torch.load(osp.join(cfg.work_dir, f'{training_type}_5.pth')) + ckpt = torch.load( + osp.join(cfg.work_dir, f'{training_type}_5.pth'), + weights_only=False) self.assertEqual(best_ckpt_path, ckpt['message_hub']['runtime_info']['best_ckpt']) From 1e12eafd18e4d1f463c2eb4b0d5bee8e15010fec Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 25 Oct 2025 20:57:46 +0000 Subject: [PATCH 29/35] [Test] Fix unittest of file client --- tests/test_fileio/test_fileclient.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_fileio/test_fileclient.py b/tests/test_fileio/test_fileclient.py index 0cc87f3167..72eea97b88 100644 --- a/tests/test_fileio/test_fileclient.py +++ b/tests/test_fileio/test_fileclient.py @@ -16,8 +16,6 @@ from mmengine.utils import has_method sys.modules['ceph'] = MagicMock() -sys.modules['petrel_client'] = MagicMock() -sys.modules['petrel_client.client'] = MagicMock() sys.modules['mc'] = MagicMock() @@ -295,7 +293,9 @@ def test_disk_backend(self): osp.join('dir2', 'img.jpg'), 'text1.txt', 'text2.txt' } - @patch('petrel_client.client.Client', MockPetrelClient) + @patch.dict( + sys.modules, + {'petrel_client': MagicMock(**{'client.Client': MockPetrelClient})}) @pytest.mark.parametrize('backend,prefix', [('petrel', None), (None, 's3')]) def test_petrel_backend(self, backend, prefix): From e8a92b825874bd627a7da589cfc09c6326d81a9d Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 25 Oct 2025 21:46:52 +0000 Subject: [PATCH 30/35] [Fix] Fix unittest of empty cache hook --- tests/test_hooks/test_empty_cache_hook.py | 29 +++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/test_hooks/test_empty_cache_hook.py b/tests/test_hooks/test_empty_cache_hook.py index 7e722a0e77..02b0e9970e 100644 --- a/tests/test_hooks/test_empty_cache_hook.py +++ b/tests/test_hooks/test_empty_cache_hook.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy from unittest.mock import patch import pytest @@ -13,7 +14,7 @@ class TestEmptyCacheHook(RunnerTestCase): reason='cuda should be available') def test_with_runner(self): with patch('torch.cuda.empty_cache') as mock_empty_cache: - cfg = self.epoch_based_cfg + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [dict(type='EmptyCacheHook')] cfg.train_cfg.val_interval = 1e6 # disable validation during training # noqa: E501 runner = self.build_runner(cfg) @@ -24,12 +25,14 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after each epoch: # runner.train: `max_epochs` times. + # runner.val: last epoch will always trigger validation (BC caused by `e258c848`) # noqa: E501 # runner.val: `1` time. # runner.test: `1` time. - target_called_times = runner.max_epochs + 2 + target_called_times = runner.max_epochs + 3 self.assertEqual(mock_empty_cache.call_count, target_called_times) - + # with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [dict(type='EmptyCacheHook', before_epoch=True)] runner = self.build_runner(cfg) @@ -39,13 +42,15 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after/before each epoch: # runner.train: `max_epochs*2` times. - # runner.val: `1*2` times. + # runner.val: (max_epochs + 1)*2 times, last epoch will always trigger validation (BC caused by `e258c848`) # noqa: E501 # runner.test: `1*2` times. - target_called_times = runner.max_epochs * 2 + 4 + target_called_times = runner.max_epochs * 2 + (runner.max_epochs + + 1) * 2 + 1 * 2 self.assertEqual(mock_empty_cache.call_count, target_called_times) with patch('torch.cuda.empty_cache') as mock_empty_cache: + cfg = deepcopy(self.epoch_based_cfg) cfg.custom_hooks = [ dict(type='EmptyCacheHook', after_iter=True, before_epoch=True) ] @@ -57,13 +62,13 @@ def test_with_runner(self): # Call `torch.cuda.empty_cache` after/before each epoch, # after each iteration: - # runner.train: `max_epochs*2 + len(dataloader)*max_epochs` times. # noqa: E501 - # runner.val: `1*2 + len(val_dataloader)` times. - # runner.test: `1*2 + len(val_dataloader)` times. + # runner.train: max_epochs * (2 + len(train_dataloader)) times. + # runner.val: (max_epochs + 1(interval) + 1(last)) * (2 + len(val_dataloader)) times # noqa: E501 + # runner.test: 1 * (2 + len(test_dataloader)) times target_called_times = \ - runner.max_epochs * 2 + 4 + \ - len(runner.train_dataloader) * runner.max_epochs + \ - len(runner.val_dataloader) + \ - len(runner.test_dataloader) + runner.max_epochs * (2 + len(runner.train_dataloader)) + \ + (runner.max_epochs + 1) * (2 + len(runner.val_dataloader)) + \ + 1 * (2 + len(runner.test_dataloader)) + self.assertEqual(mock_empty_cache.call_count, target_called_times) From 74f989b8a3949329c3399b5c0f8cc2bbca9a764a Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 25 Oct 2025 21:55:43 +0000 Subject: [PATCH 31/35] [Test] Fix unittest of EMAHook --- tests/test_hooks/test_ema_hook.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index 398c7f1672..3c54d40ae2 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -230,7 +230,8 @@ def test_with_runner(self): self.assertTrue( isinstance(ema_hook.ema_model, ExponentialMovingAverage)) - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'), + weights_only=False) self.assertTrue('ema_state_dict' in checkpoint) self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8) @@ -245,7 +246,8 @@ def test_with_runner(self): runner.test() # Test load checkpoint without ema_state_dict - checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth')) + checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'), + weights_only=False) checkpoint.pop('ema_state_dict') torch.save(checkpoint, osp.join(self.temp_dir.name, 'without_ema_state_dict.pth')) @@ -274,7 +276,8 @@ def test_with_runner(self): runner = self.build_runner(cfg) runner.train() state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'), - map_location='cpu') + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) @@ -287,12 +290,14 @@ def test_with_runner(self): runner = self.build_runner(cfg) runner.train() state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'), - map_location='cpu') + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) for k, v in state_dict['state_dict'].items(): assert_allclose(v, state_dict['ema_state_dict']['module.' + k]) state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'), - map_location='cpu') + map_location='cpu', + weights_only=False) self.assertIn('ema_state_dict', state_dict) def _test_swap_parameters(self, func_name, *args, **kwargs): From c2d973953cb9a173e99e11d7dc20250d72e4380c Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 25 Oct 2025 23:18:56 +0000 Subject: [PATCH 32/35] [Test] Fix ut of runner --- tests/test_runner/test_checkpoint.py | 1 + tests/test_runner/test_runner.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index 844dd4d80b..51efc7db2f 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -356,6 +356,7 @@ def load_from_abc(filename, map_location): assert loader.__name__ == 'load_from_abc' +@patch.dict(sys.modules, {'petrel_client': MagicMock()}) def test_save_checkpoint(tmp_path): model = Model() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index b4801710c1..6d0f3da0f1 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -2271,7 +2271,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertEqual(ckpt['meta']['epoch'], 3) self.assertEqual(ckpt['meta']['iter'], 12) self.assertEqual(ckpt['meta']['experiment_name'], @@ -2445,7 +2445,7 @@ def test_checkpoint(self): self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertEqual(ckpt['meta']['epoch'], 0) self.assertEqual(ckpt['meta']['iter'], 12) assert isinstance(ckpt['optimizer'], dict) @@ -2456,7 +2456,7 @@ def test_checkpoint(self): self.assertEqual(message_hub.get_info('iter'), 11) # 2.1.2 check class attribute _statistic_methods can be saved HistoryBuffer._statistics_methods.clear() - ckpt = torch.load(path) + ckpt = torch.load(path, weights_only=False) self.assertIn('min', HistoryBuffer._statistics_methods) # 2.2 test `load_checkpoint` From ddf6c837de05d0974d3ea8574b211a7a74bc3f66 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 25 Oct 2025 23:19:24 +0000 Subject: [PATCH 33/35] [Test] Fix test of strategy --- mmengine/_strategy/fsdp.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mmengine/_strategy/fsdp.py b/mmengine/_strategy/fsdp.py index b3fe48a6c0..1a1cec07c4 100644 --- a/mmengine/_strategy/fsdp.py +++ b/mmengine/_strategy/fsdp.py @@ -408,7 +408,9 @@ def load_optim_state_dict(self, state_dict: dict) -> None: ``optimizer.state_dict()`` """ optim_state_dict = FSDP.optim_state_dict_to_load( - state_dict, self.model, self.optim_wrapper.optimizer) + optim_state_dict=state_dict, + model=self.model, + optim=self.optim_wrapper.optimizer) self.optim_wrapper.load_state_dict(optim_state_dict) def _init_state_dict_cfg(self, state_dict_cfg: Union[str, dict]) -> None: @@ -539,7 +541,9 @@ def build_optim_wrapper( # Force to load the converted optim_state_dict in full mode. with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): optim_state_dict = FSDP.optim_state_dict_to_load( - optim_state_dict, model, new_optimizer) + optim_state_dict=optim_state_dict, + model=model, + optim=new_optimizer) new_optimizer.load_state_dict(optim_state_dict) optim_wrapper.optimizer = new_optimizer return optim_wrapper From 363492ac4ff31fd60d2c98f2efb6329e19d1a8f9 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sun, 26 Oct 2025 00:35:56 +0000 Subject: [PATCH 34/35] [Test] Fix sync buffer hook --- tests/test_hooks/test_sync_buffers_hook.py | 31 ++++++++-------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/tests/test_hooks/test_sync_buffers_hook.py b/tests/test_hooks/test_sync_buffers_hook.py index 8558f53985..71db44e38a 100644 --- a/tests/test_hooks/test_sync_buffers_hook.py +++ b/tests/test_hooks/test_sync_buffers_hook.py @@ -1,15 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os from unittest.mock import MagicMock import torch import torch.distributed as torch_dist import torch.nn as nn +from torch.testing._internal.common_distributed import DistributedTestBase from mmengine.dist import all_gather from mmengine.hooks import SyncBuffersHook from mmengine.registry import MODELS -from mmengine.testing._internal import MultiProcessTestCase from mmengine.testing.runner_test_case import RunnerTestCase, ToyModel @@ -23,22 +22,14 @@ def __init__(self, data_preprocessor=None): def init_weights(self): for buffer in self.buffers(): buffer.fill_( - torch.tensor(int(os.environ['RANK']), dtype=torch.float32)) + torch.tensor(torch_dist.get_rank(), dtype=torch.float32)) return super().init_weights() -class TestSyncBuffersHook(MultiProcessTestCase, RunnerTestCase): - - def setUp(self) -> None: - super().setUp() - self._spawn_processes() - - def prepare_subprocess(self): - MODELS.register_module(module=ToyModuleWithNorm, force=True) - super(MultiProcessTestCase, self).setUp() +class TestSyncBuffersHook(DistributedTestBase, RunnerTestCase): def test_sync_buffers_hook(self): - self.setup_dist_env() + self.create_pg('cuda') runner = MagicMock() runner.model = ToyModuleWithNorm() runner.model.init_weights() @@ -53,9 +44,12 @@ def test_sync_buffers_hook(self): for buffer in runner.model.buffers(): buffer1, buffer2 = all_gather(buffer) self.assertTrue(torch.allclose(buffer1, buffer2)) + torch_dist.destroy_process_group() def test_with_runner(self): - self.setup_dist_env() + MODELS.register_module(module=ToyModuleWithNorm, force=True) + self.create_pg('cuda') + RunnerTestCase.setUp(self) cfg = self.epoch_based_cfg cfg.model = dict(type='ToyModuleWithNorm') cfg.launch = 'pytorch' @@ -67,9 +61,6 @@ def test_with_runner(self): buffer1, buffer2 = all_gather(buffer) self.assertTrue(torch.allclose(buffer1, buffer2)) - def setup_dist_env(self): - super().setup_dist_env() - os.environ['RANK'] = str(self.rank) - torch_dist.init_process_group(backend='gloo', - rank=self.rank, - world_size=self.world_size) + @property + def world_size(self) -> int: + return 2 From 0a07aad4cc8de828c690174e5563752ffc3f818f Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sun, 26 Oct 2025 00:48:59 +0000 Subject: [PATCH 35/35] [Lint] Fix lint --- mmengine/optim/optimizer/builder.py | 9 ++++----- mmengine/utils/package_utils.py | 5 +++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/mmengine/optim/optimizer/builder.py b/mmengine/optim/optimizer/builder.py index 65ac3f378d..edb36a3c56 100644 --- a/mmengine/optim/optimizer/builder.py +++ b/mmengine/optim/optimizer/builder.py @@ -118,7 +118,7 @@ def register_sophia_optimizers() -> List[str]: Returns: List[str]: A list of registered optimizers' name. """ - optimizers: List[str] = [] + optimizers = [] # type: ignore try: import Sophia except ImportError: @@ -131,8 +131,7 @@ def register_sophia_optimizers() -> List[str]: try: OPTIMIZERS.register_module(module=_optim) except Exception as e: - warnings.warn( - f"Failed to import {_optim.__name__} for {e}") + warnings.warn(f'Failed to import {Sophia} for {e}') return optimizers @@ -165,7 +164,7 @@ def register_bitsandbytes_optimizers() -> List[str]: try: OPTIMIZERS.register_module(module=optim_cls, name=name) except Exception as e: - warnings.warn(f"Failed to import {optim_cls.__name__} for {e}") + warnings.warn(f'Failed to import {optim_cls.__name__} for {e}') dadaptation_optimizers.append(name) return dadaptation_optimizers @@ -183,7 +182,7 @@ def register_transformers_optimizers() -> List[str]: try: OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) except Exception as e: - warnings.warn(f"Failed to import {Adafactor.__name__} for {e}") + warnings.warn(f'Failed to import Adafactor for {e}') transformer_optimizers.append('Adafactor') return transformer_optimizers diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 65eef94bea..3d78194a5d 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -22,6 +22,11 @@ def is_installed(package: str) -> bool: # For Python 3.7, importlib_metadata backport can be used import importlib.util + import pkg_resources # type: ignore + + # refresh the pkg_resources + # more datails at https://github.com/pypa/setuptools/issues/373 + importlib.reload(pkg_resources) try: distribution(package) return True