diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3890607 --- /dev/null +++ b/.gitignore @@ -0,0 +1,199 @@ +unitr_pretrain.pth +logs/ +output/ + +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python \ No newline at end of file diff --git a/README.md b/README.md index 31e05c3..74e4d60 100644 --- a/README.md +++ b/README.md @@ -205,7 +205,7 @@ bash scripts/dist_train.sh 8 --cfg_file ./cfgs/nuscenes_models/unitr_map.yaml -- ## add lss cd tools -bash scripts/dist_train.sh 8 --cfg_file ./cfgs/nuscenes_models/unitr_map.yaml --sync_bn --eval_map --logger_iter_interval 1000 +bash scripts/dist_train.sh 8 --cfg_file ./cfgs/nuscenes_models/unitr_map+lss.yaml --sync_bn --eval_map --logger_iter_interval 1000 ``` ### Testing diff --git a/pcdet/datasets/processor/data_processor.py b/pcdet/datasets/processor/data_processor.py index 915c563..b904bc1 100644 --- a/pcdet/datasets/processor/data_processor.py +++ b/pcdet/datasets/processor/data_processor.py @@ -85,9 +85,15 @@ def mask_points_and_boxes_outside_range(self, data_dict=None, config=None): mask = common_utils.mask_points_by_range(data_dict['points'], self.point_cloud_range) data_dict['points'] = data_dict['points'][mask] + # 이전 프레임 포인트도 동일하게 처리 + if 'prev_points' in data_dict: + prev_mask = common_utils.mask_points_by_range(data_dict['prev_points'], self.point_cloud_range) + data_dict['prev_points'] = data_dict['prev_points'][prev_mask] + if data_dict.get('gt_boxes', None) is not None and config.REMOVE_OUTSIDE_BOXES and self.training: mask = box_utils.mask_boxes_outside_range_numpy( - data_dict['gt_boxes'], self.point_cloud_range, min_num_corners=config.get('min_num_corners', 1), + data_dict['gt_boxes'], self.point_cloud_range, + min_num_corners=config.get('min_num_corners', 1), use_center_to_filter=config.get('USE_CENTER_TO_FILTER', True) ) data_dict['gt_boxes'] = data_dict['gt_boxes'][mask] @@ -100,8 +106,13 @@ def shuffle_points(self, data_dict=None, config=None): if config.SHUFFLE_ENABLED[self.mode]: points = data_dict['points'] shuffle_idx = np.random.permutation(points.shape[0]) - points = points[shuffle_idx] - data_dict['points'] = points + data_dict['points'] = points[shuffle_idx] + + # 이전 프레임 포인트도 동일하게 셔플 + if 'prev_points' in data_dict: + prev_points = data_dict['prev_points'] + prev_shuffle_idx = np.random.permutation(prev_points.shape[0]) + data_dict['prev_points'] = prev_points[prev_shuffle_idx] return data_dict @@ -150,6 +161,12 @@ def transform_points_to_voxels(self, data_dict=None, config=None): ) points = data_dict['points'] + + # 이전 프레임의 포인트 클라우드를 결합 + if 'prev_points' in data_dict: + prev_points = data_dict['prev_points'] + points = np.concatenate((prev_points, points), axis=0) + voxel_output = self.voxel_generator.generate(points) voxels, coordinates, num_points = voxel_output @@ -210,6 +227,33 @@ def sample_points(self, data_dict=None, config=None): choice = np.concatenate((choice, extra_choice), axis=0) np.random.shuffle(choice) data_dict['points'] = points[choice] + + # 이전 프레임 포인트에 대한 샘플링 + if 'prev_points' in data_dict and len(data_dict['prev_points']) > 0: + prev_points = data_dict['prev_points'] + if num_points < len(prev_points): + prev_pts_depth = np.linalg.norm(prev_points[:, 0:3], axis=1) + prev_pts_near_flag = prev_pts_depth < 40.0 + prev_far_idxs_choice = np.where(prev_pts_near_flag == 0)[0] + prev_near_idxs = np.where(prev_pts_near_flag == 1)[0] + prev_choice = [] + if num_points > len(prev_far_idxs_choice): + prev_near_idxs_choice = np.random.choice(prev_near_idxs, + num_points - len(prev_far_idxs_choice), replace=False) + prev_choice = np.concatenate((prev_near_idxs_choice, prev_far_idxs_choice), axis=0) \ + if len(prev_far_idxs_choice) > 0 else prev_near_idxs_choice + else: + prev_choice = np.arange(0, len(prev_points), dtype=np.int32) + prev_choice = np.random.choice(prev_choice, num_points, replace=False) + np.random.shuffle(prev_choice) + else: + prev_choice = np.arange(0, len(prev_points), dtype=np.int32) + if num_points > len(prev_points): + prev_extra_choice = np.random.choice(prev_choice, num_points - len(prev_points), replace=False) + prev_choice = np.concatenate((prev_choice, prev_extra_choice), axis=0) + np.random.shuffle(prev_choice) + data_dict['prev_points'] = prev_points[prev_choice] + return data_dict def calculate_grid_size(self, data_dict=None, config=None): diff --git a/pcdet/models/mm_backbone/unitr.py b/pcdet/models/mm_backbone/unitr.py index ad2f90c..a0ea84e 100644 --- a/pcdet/models/mm_backbone/unitr.py +++ b/pcdet/models/mm_backbone/unitr.py @@ -10,6 +10,8 @@ from pcdet.ops.ingroup_inds.ingroup_inds_op import ingroup_inds get_inner_win_inds_cuda = ingroup_inds +from deformable_attention import DeformableAttention # lucidrains의 구현체 import + class UniTR(nn.Module): ''' UniTR: A Unified and Efficient Multi-Modal Transformer for Bird's-Eye-View Representation. @@ -281,7 +283,7 @@ def _image2lidar_preprocess(self, batch_dict, multi_feat, multi_pos_embed_list): batch_dict) image2lidar_coords_bzyx = torch.cat( [batch_dict['patch_coords'][:, :1].clone(), image2lidar_coords_zyx], dim=1) - image2lidar_coords_bzyx[:, 0] = image2lidar_coords_bzyx[:, 0] // N + image2lidar_coords_bzyx[:, 0] = torch.div(image2lidar_coords_bzyx[:, 0], N, rounding_mode='floor') image2lidar_batch_dict = {} image2lidar_batch_dict['voxel_features'] = multi_feat.clone() image2lidar_batch_dict['voxel_coords'] = torch.cat( @@ -310,7 +312,8 @@ def _lidar2image_preprocess(self, batch_dict, multi_feat, multi_pos_embed_list): lidar2image_coords_bzyx = torch.cat( [batch_dict['voxel_coords'][:, :1].clone(), lidar2image_coords_zyx], dim=1) multiview_coords = batch_dict['patch_coords'].clone() - multiview_coords[:, 0] = batch_dict['patch_coords'][:, 0] // N + multiview_coords[:, 0] = torch.div(batch_dict['patch_coords'][:, 0], N, rounding_mode='floor') + # multiview_coords[:, 0] = batch_dict['patch_coords'][:, 0] // N multiview_coords[:, 1] = batch_dict['patch_coords'][:, 0] % N multiview_coords[:, 2] += hw_shape[1] multiview_coords[:, 3] += hw_shape[0] @@ -340,6 +343,7 @@ def _reset_parameters(self): def _recover_image(self, pillar_features, coords, indices): pillar_features = getattr(self, f'out_norm{indices}')(pillar_features) batch_size = coords[:, 0].max().int().item() + 1 + print("batch_size: ", batch_size) batch_spatial_features = pillar_features.view( batch_size, self.patch_size[0], self.patch_size[1], -1).permute(0, 3, 1, 2).contiguous() return batch_spatial_features @@ -392,8 +396,13 @@ class UniTR_EncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", batch_first=True, mlp_dropout=0, dout=None, layer_cfg=dict()): super().__init__() - self.win_attn = SetAttention( - d_model, nhead, dropout, dim_feedforward, activation, batch_first, mlp_dropout, layer_cfg) + if layer_cfg.get('deformable', False): + self.win_attn = SetDeformableAttention( + d_model, nhead, dropout, dim_feedforward, activation, batch_first, mlp_dropout, layer_cfg) + else: + self.win_attn = SetAttention( + d_model, nhead, dropout, dim_feedforward, activation, batch_first, mlp_dropout, layer_cfg) + if dout is None: dout = d_model self.norm = nn.LayerNorm(dout) @@ -407,6 +416,102 @@ def forward(self, src, set_voxel_inds, set_voxel_masks, pos=None, voxel_num=0): src = self.norm(src) return src + +class SetDeformableAttention(nn.Module): + def __init__(self, d_model, nhead, dropout, dim_feedforward=2048, activation="relu", batch_first=True, mlp_dropout=0, layer_cfg=dict()): + super().__init__() + self.nhead = nhead + + # 기존 nn.MultiheadAttention 대신 deformable attention 사용 + self.self_attn = DeformableAttention( + dim=d_model, + heads=nhead, + dim_head=d_model // nhead, + dropout=dropout, + downsample_factor=1, + offset_kernel_size=3 # 커널 크기를 줄임 + ) + + # Feedforward network + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(mlp_dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + self.d_model = d_model + self.layer_cfg = layer_cfg + + # Layer normalization + use_bn = layer_cfg.get('use_bn', False) + if not use_bn: + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + # Split FFN (for Lidar and image data) + if layer_cfg.get('split_ffn', False): + self.lidar_linear1 = nn.Linear(d_model, dim_feedforward) + self.lidar_dropout = nn.Dropout(mlp_dropout) + self.lidar_linear2 = nn.Linear(dim_feedforward, d_model) + if not use_bn: + self.lidar_norm1 = nn.LayerNorm(d_model) + self.lidar_norm2 = nn.LayerNorm(d_model) + + self.dropout1 = nn.Identity() + self.dropout2 = nn.Identity() + + self.activation = nn.ReLU() if activation == "relu" else nn.GELU() + + def forward(self, src, pos=None, key_padding_mask=None, voxel_inds=None, voxel_num=0): + set_features = src[voxel_inds] # [win_num, 36, d_model] + if pos is not None: + set_pos = pos[voxel_inds] + else: + set_pos = None + + # Deformable attention query는 feature와 positional encoding을 더한 값 + query = set_features + set_pos if pos is not None else set_features + + # 추가: 입력을 4차원으로 변환 (conv2d를 위해) + query = query.permute(0, 2, 1).unsqueeze(3) # [batch_size, d_model, sequence_length, 1] + + # Deformable attention 수행 + src2 = self.self_attn(query) + + flatten_inds = voxel_inds.reshape(-1) + unique_flatten_inds, inverse = torch.unique( + flatten_inds, return_inverse=True) + perm = torch.arange(inverse.size( + 0), dtype=inverse.dtype, device=inverse.device) + inverse, perm = inverse.flip([0]), perm.flip([0]) + perm = inverse.new_empty( + unique_flatten_inds.size(0)).scatter_(0, inverse, perm) + src2 = src2.reshape(-1, self.d_model)[perm] + print("src2: ", src2.shape) + + # Split FFN 처리 (Lidar와 이미지 데이터 분리 처리) + if self.layer_cfg.get('split_ffn', False): + src = src + self.dropout1(src2) + lidar_norm = self.lidar_norm1(src[:voxel_num]) + image_norm = self.norm1(src[voxel_num:]) + src = torch.cat([lidar_norm, image_norm], dim=0) + + lidar_linear2 = self.lidar_linear2(self.lidar_dropout(self.activation(self.lidar_linear1(src[:voxel_num])))) + image_linear2 = self.linear2(self.dropout(self.activation(self.linear1(src[voxel_num:])))) + src2 = torch.cat([lidar_linear2, image_linear2], dim=0) + + src = src + self.dropout2(src2) + lidar_norm2 = self.lidar_norm2(src[:voxel_num]) + image_norm2 = self.norm2(src[voxel_num:]) + src = torch.cat([lidar_norm2, image_norm2], dim=0) + else: + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + + return src + + + class SetAttention(nn.Module): def __init__(self, d_model, nhead, dropout, dim_feedforward=2048, activation="relu", batch_first=True, mlp_dropout=0, layer_cfg=dict()): diff --git a/pcdet/models/model_utils/dsvt_utils.py b/pcdet/models/model_utils/dsvt_utils.py index a364052..fd2967e 100644 --- a/pcdet/models/model_utils/dsvt_utils.py +++ b/pcdet/models/model_utils/dsvt_utils.py @@ -60,9 +60,12 @@ def get_window_coors(coors, sparse_shape, window_shape, do_shift, shift_list=Non shifted_coors_y = coors[:, 2] + shift_y shifted_coors_z = coors[:, 1] + shift_z - win_coors_x = shifted_coors_x // win_shape_x - win_coors_y = shifted_coors_y // win_shape_y - win_coors_z = shifted_coors_z // win_shape_z + # win_coors_x = shifted_coors_x // win_shape_x + # win_coors_y = shifted_coors_y // win_shape_y + # win_coors_z = shifted_coors_z // win_shape_z + win_coors_x = torch.div(shifted_coors_x, win_shape_x, rounding_mode='floor') + win_coors_y = torch.div(shifted_coors_y, win_shape_y, rounding_mode='floor') + win_coors_z = torch.div(shifted_coors_z, win_shape_z, rounding_mode='floor') if len(window_shape) == 2: assert (win_coors_z == 0).all() @@ -97,9 +100,12 @@ def get_pooling_index(coors, sparse_shape, window_shape): coors_y = coors[:, 2] coors_z = coors[:, 1] - win_coors_x = coors_x // win_shape_x - win_coors_y = coors_y // win_shape_y - win_coors_z = coors_z // win_shape_z + # win_coors_x = coors_x // win_shape_x + # win_coors_y = coors_y // win_shape_y + # win_coors_z = coors_z // win_shape_z + win_coors_x = torch.div(coors_x, win_shape_x, rounding_mode='floor') + win_coors_y = torch.div(coors_y, win_shape_y, rounding_mode='floor') + win_coors_z = torch.div(coors_z, win_shape_z, rounding_mode='floor') batch_win_inds = coors[:, 0] * max_num_win_per_sample + \ win_coors_x * max_num_win_y * max_num_win_z + \ diff --git a/pcdet/models/model_utils/sst_utils.py b/pcdet/models/model_utils/sst_utils.py index 6674f4a..b03f497 100644 --- a/pcdet/models/model_utils/sst_utils.py +++ b/pcdet/models/model_utils/sst_utils.py @@ -230,9 +230,12 @@ def get_window_coors(coors, sparse_shape, window_shape, do_shift, shift_list=Non shifted_coors_y = coors[:, 2] + shift_y shifted_coors_z = coors[:, 1] + shift_z - win_coors_x = shifted_coors_x // win_shape_x - win_coors_y = shifted_coors_y // win_shape_y - win_coors_z = shifted_coors_z // win_shape_z + # win_coors_x = shifted_coors_x // win_shape_x + # win_coors_y = shifted_coors_y // win_shape_y + # win_coors_z = shifted_coors_z // win_shape_z + win_coors_x = torch.div(shifted_coors_x, win_shape_x, rounding_mode='floor') + win_coors_y = torch.div(shifted_coors_y, win_shape_y, rounding_mode='floor') + win_coors_z = torch.div(shifted_coors_z, win_shape_z, rounding_mode='floor') if len(window_shape) == 2: assert (win_coors_z == 0).all() diff --git a/requirements.txt b/requirements.txt index d9898c4..1abd5d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,6 @@ opencv-python pyquaternion spconv-cu113 av2 -kornia +kornia==0.6.9 torch_scatter -shapely==1.8.4 \ No newline at end of file +shapely==1.8.4 diff --git a/sbatch.sh b/sbatch.sh new file mode 100644 index 0000000..93d39ba --- /dev/null +++ b/sbatch.sh @@ -0,0 +1,53 @@ +#!/usr/bin/bash + +#SBATCH -J UniTR-Train-temporal_deformable +#SBATCH --gres=gpu:2 +#SBATCH --cpus-per-gpu=8 +#SBATCH --mem-per-gpu=32G +#SBATCH -p batch_ugrad +#SBATCH -t 3-0 +#SBATCH -o logs/slurm-%A.outs + +cat $0 +pwd +which python +hostname + +. /data/sw/spack/share/spack/setup-env.sh +spack find +spack load cuda@11.3.0 +nvcc -V + +# python -m pcdet.datasets.nuscenes.nuscenes_dataset --func create_nuscenes_infos \ +# --cfg_file tools/cfgs/dataset_configs/nuscenes_dataset.yaml \ +# --version v1.0-trainval \ +# --with_cam \ +# --with_cam_gt + +# multi-gpu training +# note that we don't use image pretrain in BEV Map Segmentation +cd tools + +## default +# bash scripts/dist_train.sh 2 --cfg_file ./cfgs/nuscenes_models/unitr_map.yaml --sync_bn --eval_map --logger_iter_interval 1000 \ + # --extra_tag default --ckpt ../output/cfgs/nuscenes_models/unitr_map/default/ckpt/latest_model.pth + +## temporal +# bash scripts/dist_train.sh 2 --cfg_file ./cfgs/nuscenes_models/unitr_map.yaml --sync_bn --eval_map --logger_iter_interval 1000 \ +# --extra_tag temporal --ckpt ../output/cfgs/nuscenes_models/unitr_map/temporal/ckpt/latest_model.pth + +## temporal deformable +bash scripts/dist_train.sh 2 --cfg_file ./cfgs/nuscenes_models/unitr_map.yaml --sync_bn --eval_map --use_amp --logger_iter_interval 1000 --extra_tag temporal_deformable + +## add lss +# bash scripts/dist_train.sh 2 --cfg_file ./cfgs/nuscenes_models/unitr_map+lss.yaml --sync_bn --eval_map --logger_iter_interval 1000 + +# # multi-gpu testing +# ## normal +# bash scripts/dist_test.sh 2 --cfg_file ./cfgs/nuscenes_models/unitr_map.yaml --ckpt --eval_map + +# ## add LSS +# bash scripts/dist_test.sh 2 --cfg_file ./cfgs/nuscenes_models/unitr_map+lss.yaml --ckpt --eval_map +# # NOTE: evaluation results will not be logged in *.log, only be printed in the teminal + +exit 0 diff --git a/tools/cfgs/nuscenes_models/unitr_map.yaml b/tools/cfgs/nuscenes_models/unitr_map.yaml index cafae93..cd8224a 100644 --- a/tools/cfgs/nuscenes_models/unitr_map.yaml +++ b/tools/cfgs/nuscenes_models/unitr_map.yaml @@ -86,7 +86,7 @@ MODEL: hybrid_factor: [1, 1, 1] # x, y, z shifts_list: [[[0, 0, 0], [15, 15, 0]]] input_image: True - + LIDAR_INPUT_LAYER: sparse_shape: [256, 256, 1] d_model: [128] @@ -102,7 +102,7 @@ MODEL: dropout: 0.0 activation: gelu checkpoint_blocks: [0,1,2,3] - layer_cfg: {'use_bn': False, 'split_ffn': True, 'split_residual': True} + layer_cfg: {'use_bn': False, 'split_ffn': True, 'split_residual': True, 'deformable': True} # fuse backbone config FUSE_BACKBONE: @@ -185,7 +185,7 @@ MODEL: OPTIMIZATION: - BATCH_SIZE_PER_GPU: 3 + BATCH_SIZE_PER_GPU: 1 NUM_EPOCHS: 20 OPTIMIZER: adam_onecycle diff --git a/tools/scripts/slurm_train_v2.sh b/tools/scripts/slurm_train_v2.sh index d5ec0ed..643eef2 100644 --- a/tools/scripts/slurm_train_v2.sh +++ b/tools/scripts/slurm_train_v2.sh @@ -2,13 +2,12 @@ set -x -PARTITION=$1 -JOB_NAME=$2 -GPUS=$3 -PY_ARGS=${@:4} +PARTITION=batch_ugrad +JOB_NAME="UniTR-Train" +GPUS=$1 +PY_ARGS=${@:2} -GPUS_PER_NODE=${GPUS_PER_NODE:-8} -CPUS_PER_TASK=${CPUS_PER_TASK:-40} +GPUS_PER_NODE=$1 SRUN_ARGS=${SRUN_ARGS:-""} while true @@ -24,7 +23,8 @@ echo $PORT srun -p ${PARTITION} \ --job-name=${JOB_NAME} \ --gres=gpu:${GPUS_PER_NODE} \ - --cpus-per-task=${CPUS_PER_TASK} \ + --cpus-per-gpu=8 \ + --mem-per-gpu=32G \ --kill-on-bad-exit=1 \ ${SRUN_ARGS} \ python -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} train.py --launcher pytorch --tcp_port ${PORT} ${PY_ARGS}