From 6bc3cf11c0daf700b23a7aa10e15ee40241cca75 Mon Sep 17 00:00:00 2001 From: sibasmarak Date: Wed, 13 Mar 2024 03:23:38 -0400 Subject: [PATCH 1/9] checkpoint names with wandb run --- .gitignore | 2 +- conda_env.yaml | 360 ++++++++++++------ examples/images/classification/README.md | 20 +- .../configs/checkpoint/default.yaml | 2 +- examples/images/classification/train.py | 30 +- examples/images/classification/train_utils.py | 30 -- examples/images/segmentation/README.md | 20 +- examples/images/segmentation/train.py | 32 +- examples/images/segmentation/train_utils.py | 20 - 9 files changed, 322 insertions(+), 194 deletions(-) diff --git a/.gitignore b/.gitignore index da67e6c..c88eeae 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,7 @@ tutorials/images/utils/sam_vit_h_4b8939.pth tutorials/images/understanding_continuous_canonicalization.ipynb tutorials/images/understanding_discrete_canonicalization_old.ipynb old_conda_env.yaml +conda_env_from_history.yaml # pyenv .python-version @@ -112,7 +113,6 @@ dmypy.json # Ignore scripts to run experiments in mila mila_scripts/ -escnn *__pycache__/ *_output/ wandb/ diff --git a/conda_env.yaml b/conda_env.yaml index 7765c29..63b71a8 100644 --- a/conda_env.yaml +++ b/conda_env.yaml @@ -1,130 +1,240 @@ name: equiadapt channels: + - pytorch + - nvidia + - conda-forge - defaults dependencies: - - bzip2 - - mpfr - - cuda-nvrtc==12.1.105=0 - - appdirs - - jpeg - - pytorch-cuda==12.1=ha16c6d3_5 - - libunistring - - sqlite - - blas==1.0=mkl - - lightning-utilities - - lerc - - pytorch - - fsspec - - gitdb - - ld_impl_linux-64 - - libtasn1 - - mkl - - torchtriton - - pycparser - - libcusolver==11.4.4.55=0 - - cuda-cupti==12.1.105=0 - - pathtools - - libcufile==1.8.1.2=0 - - python=3.10 - - tzdata - - typing_extensions - - sentry-sdk - - tqdm - - gitpython - - libidn2 - - numpy-base - - tbb - - libjpeg-turbo - - tk - - _libgcc_mutex - - openjpeg - - six - - cuda-libraries==12.1.0=0 - - colorama - - mkl_fft - - libwebp-base - - libnpp==12.0.2.50=0 - - libuuid - - smmap - - libffi - - cuda-opencl==12.3.101=0 - - charset-normalizer - - freetype - - cuda-nvtx==12.1.105=0 - - cuda-cudart==12.1.105=0 - - libgcc-ng - - packaging - - click - - mkl_random - - libnvjpeg==12.1.1.14=0 - - zstd - - intel-openmp - - libiconv - - ncurses - - libcurand==10.3.4.107=0 - - libcusparse==12.0.2.55=0 - - libstdcxx-ng - - openh264 - - zlib - - openssl - - lame - - lz4-c - - libcufft==11.0.2.4=0 - - llvm-openmp - - libtiff - - ffmpeg - - brotli-python - - xz - - libpng - - wandb - - mpc - - _openmp_mutex - - ca-certificates - - gnutls - - pytorch-mutex==1.0=cuda - - lcms2 - - nettle - - libcublas==12.1.0.26=0 - - libdeflate - - gmp - - libgomp - - cuda-runtime==12.1.0=0 - - readline - - kornia - - libnvjitlink==12.1.105=0 - - docker-pycreds - - yaml - - libprotobuf - - psutil - - ptyprocess - - executing - - jupyter_core - - prompt_toolkit - - python_abi - - debugpy - - traitlets - - ipython - - libsodium - - matplotlib-inline - - pygments - - stack_data - - jedi - - jupyter_client - - prompt-toolkit - - python-dateutil - - pure_eval - - pexpect - - exceptiongroup - - zeromq - - parso - - decorator - - wcwidth - - tornado - - platformdirs - - pyzmq - - comm - - nest-asyncio - - asttokens - - ipykernel - - lightning + - _libgcc_mutex=0.1 + - _openmp_mutex=5.1 + - appdirs=1.4.4 + - asttokens=2.0.5 + - blas=1.0 + - brotli-python=1.0.9 + - bzip2=1.0.8 + - ca-certificates=2023.12.12 + - certifi=2024.2.2 + - charset-normalizer=2.0.4 + - click=7.1.2 + - colorama=0.4.6 + - cuda-cudart=12.1.105 + - cuda-cupti=12.1.105 + - cuda-libraries=12.1.0 + - cuda-nvrtc=12.1.105 + - cuda-nvtx=12.1.105 + - cuda-opencl=12.3.101 + - cuda-runtime=12.1.0 + - decorator=5.1.1 + - docker-pycreds=0.4.0 + - executing=0.8.3 + - ffmpeg=4.3 + - filelock=3.13.1 + - freetype=2.12.1 + - fsspec=2024.2.0 + - gitdb=4.0.11 + - gitpython=3.1.41 + - gmp=6.2.1 + - gnutls=3.6.15 + - intel-openmp=2023.1.0 + - itsdangerous=2.0.1 + - jinja2=3.1.3 + - jpeg=9e + - jupyter_client=8.6.0 + - jupyter_core=5.5.0 + - kornia=0.7.0 + - lame=3.100 + - lcms2=2.12 + - ld_impl_linux-64=2.38 + - lerc=3.0 + - libcublas=12.1.0.26 + - libcufft=11.0.2.4 + - libcufile=1.8.1.2 + - libcurand=10.3.4.107 + - libcusolver=11.4.4.55 + - libcusparse=12.0.2.55 + - libdeflate=1.17 + - libffi=3.4.4 + - libgcc-ng=11.2.0 + - libgomp=11.2.0 + - libiconv=1.16 + - libidn2=2.3.4 + - libjpeg-turbo=2.0.0 + - libnpp=12.0.2.50 + - libnvjitlink=12.1.105 + - libnvjpeg=12.1.1.14 + - libpng=1.6.39 + - libprotobuf=3.20.3 + - libsodium=1.0.18 + - libstdcxx-ng=11.2.0 + - libtasn1=4.19.0 + - libtiff=4.5.1 + - libunistring=0.9.10 + - libuuid=1.41.5 + - libwebp-base=1.3.2 + - lightning-utilities=0.10.1 + - llvm-openmp=14.0.6 + - lz4-c=1.9.4 + - mkl=2023.1.0 + - mkl_fft=1.3.8 + - mkl_random=1.2.4 + - mpc=1.1.0 + - mpfr=4.0.2 + - mpmath=1.3.0 + - ncurses=6.4 + - nettle=3.7.3 + - numpy-base=1.26.3 + - openh264=2.1.1 + - openjpeg=2.4.0 + - openssl=3.0.13 + - packaging=23.2 + - parso=0.8.3 + - pathtools=0.1.2 + - pexpect=4.8.0 + - prompt_toolkit=3.0.43 + - ptyprocess=0.7.0 + - pure_eval=0.2.2 + - pycparser=2.21 + - python=3.10.13 + - python-dateutil=2.8.2 + - python-editor=1.0.4 + - python_abi=3.10 + - pytorch=2.2.0 + - pytorch-cuda=12.1 + - pytorch-mutex=1.0 + - readline=8.2 + - requests=2.31.0 + - sentry-sdk=1.40.1 + - six=1.16.0 + - smmap=5.0.0 + - sqlite=3.41.2 + - stack_data=0.2.0 + - sympy=1.12 + - tbb=2021.8.0 + - tk=8.6.12 + - torchtriton=2.2.0 + - tqdm=4.66.1 + - typing_extensions=4.9.0 + - tzdata=2023d + - wandb=0.16.3 + - wcwidth=0.2.5 + - xz=5.4.5 + - yaml=0.2.5 + - zeromq=4.3.5 + - zlib=1.2.13 + - zstd=1.5.5 + - pip: + - aiohttp==3.9.3 + - aiosignal==1.3.1 + - antlr4-python3-runtime==4.9.3 + - anyio==4.2.0 + - arrow==1.2.3 + - async-timeout==4.0.3 + - attrs==23.2.0 + - autograd==1.6.2 + - backoff==2.2.1 + - beautifulsoup4==4.12.2 + - blessed==1.20.0 + - boto3==1.29.1 + - botocore==1.32.1 + - brotli==1.0.9 + - cffi==1.16.0 + - comm==0.1.2 + - configparser==6.0.0 + - contourpy==1.2.0 + - croniter==1.3.7 + - cryptography==41.0.7 + - cycler==0.12.1 + - dateutils==0.6.12 + - debugpy==1.6.7 + - deepdiff==6.7.1 + - exceptiongroup==1.2.0 + - fastapi==0.103.0 + - fonttools==4.48.1 + - frozenlist==1.4.1 + - future==0.18.3 + - gmpy2==2.1.2 + - h11==0.14.0 + - hydra-core==1.3.2 + - idna==3.4 + - inquirer==3.1.4 + - ipykernel==6.28.0 + - ipython==8.20.0 + - isort==5.13.2 + - jedi==0.18.1 + - jmespath==1.0.1 + - joblib==1.3.2 + - jupyter-client==8.6.0 + - jupyter-core==5.5.0 + - kiwisolver==1.4.5 + - lightning==2.0.9.post0 + - lightning-cloud==0.5.57 + - markdown-it-py==2.2.0 + - markupsafe==2.1.3 + - matplotlib==3.8.2 + - matplotlib-inline==0.1.6 + - mdurl==0.1.0 + - mkl-fft==1.3.8 + - mkl-random==1.2.4 + - mkl-service==2.4.0 + - multidict==6.0.5 + - mypy==1.8.0 + - mypy-extensions==1.0.0 + - nest-asyncio==1.5.6 + - networkx==3.1 + - numpy==1.25.2 + - omegaconf==2.3.0 + - opencv-python==4.9.0.80 + - opencv-python-headless==4.9.0.80 + - ordered-set==4.1.0 + - orjson==3.9.10 + - pillow==10.2.0 + - pip==24.0 + - platformdirs==3.10.0 + - promise==2.3 + - prompt-toolkit==3.0.43 + - protobuf==3.20.3 + - psutil==5.9.0 + - pycocotools==2.0.7 + - pydantic==1.10.12 + - pygments==2.15.1 + - pyjwt==2.4.0 + - pymanopt==2.2.0 + - pyopenssl==23.2.0 + - pyparsing==3.1.1 + - pysocks==1.7.1 + - python-dotenv==1.0.1 + - python-multipart==0.0.6 + - pytorch-lightning==2.2.0.post0 + - pytz==2023.3.post1 + - pyyaml==6.0.1 + - pyzmq==25.1.2 + - readchar==4.0.5 + - rich==13.3.5 + - s3transfer==0.7.0 + - scipy==1.9.3 + - segment-anything==1.0 + - segmentation-mask-overlay==0.4.4 + - setproctitle==1.2.2 + - setuptools==69.1.1 + - shortuuid==1.0.11 + - sniffio==1.3.0 + - soupsieve==2.5 + - starlette==0.27.0 + - starsessions==1.3.0 + - subprocess32==3.5.4 + - tomli==2.0.1 + - torch==2.2.0 + - torchmetrics==1.3.1 + - torchvision==0.17.0 + - tornado==6.3.3 + - traitlets==5.7.1 + - triton==2.2.0 + - typing-extensions==4.9.0 + - urllib3==1.26.18 + - uvicorn==0.20.0 + - websocket-client==0.58.0 + - websockets==10.4 + - wheel==0.41.2 + - yarl==1.9.4 +prefix: equiadapt diff --git a/examples/images/classification/README.md b/examples/images/classification/README.md index 28608f3..1763886 100644 --- a/examples/images/classification/README.md +++ b/examples/images/classification/README.md @@ -11,7 +11,25 @@ python train.py canonicalization=group_equivariant experiment.training.loss.prio **Note**: You can also run the `train.py` as follows from the root directory of the project: ``` -python examples/images/classification/train.py canonicalization=group_equivariant dataset.dataset_name=rotated_mnist +python examples/images/classification/train.py canonicalization=group_equivariant +``` + +### For testing checkpoints +``` +python train.py experiment.run_mode=test dataset.dataset_name=stl10 \ +checkpoint.checkpoint_path=/path/of/checkpoint/dir checkpoint.checkpoint_name= + +``` + +**Note**: +The final checkpoint that will be loaded during evaluation as follows, hence ensure that the combination: +``` + model = ImageClassifierPipeline.load_from_checkpoint( + checkpoint_path=hyperparams.checkpoint.checkpoint_path + "/" + \ + hyperparams.checkpoint.checkpoint_name + ".ckpt", + hyperparams=hyperparams + ) + ``` ## Important Hyperparameters diff --git a/examples/images/classification/configs/checkpoint/default.yaml b/examples/images/classification/configs/checkpoint/default.yaml index 419f669..8dedf2a 100644 --- a/examples/images/classification/configs/checkpoint/default.yaml +++ b/examples/images/classification/configs/checkpoint/default.yaml @@ -1,3 +1,3 @@ checkpoint_path: ${oc.env:CHECKPOINT_PATH} # Path to save checkpoints -checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later +checkpoint_name: "" # Model checkpoint name, should be left empty for training and dynamically allocated later save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file diff --git a/examples/images/classification/train.py b/examples/images/classification/train.py index ea9eafd..74b4afc 100644 --- a/examples/images/classification/train.py +++ b/examples/images/classification/train.py @@ -12,12 +12,25 @@ def train_images(hyperparams: DictConfig): - hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type'] - hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' - hyperparams['dataset']['data_path'] = hyperparams['dataset']['data_path'] + "/" + hyperparams['dataset']['dataset_name'] - hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \ - hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \ - + "/" + hyperparams['prediction']['prediction_network_architecture'] + + if hyperparams['experiment']['run_mode'] == "test": + assert len(hyperparams['checkpoint']['checkpoint_name']) > 0, "checkpoint_name must be provided for test mode" + + existing_ckpt_path = hyperparams['checkpoint']['checkpoint_path'] + "/" + hyperparams['checkpoint']['checkpoint_name'] + ".ckpt" + existing_ckpt = torch.load(existing_ckpt_path) + conf = OmegaConf.create(existing_ckpt['hyper_parameters']['hyperparams']) + + hyperparams['canonicalization_type'] = conf['canonicalization_type'] + hyperparams['canonicalization'] = conf['canonicalization'] + hyperparams['prediction'] = conf['prediction'] + + else: + hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type'] + hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' + hyperparams['dataset']['data_path'] = hyperparams['dataset']['data_path'] + "/" + hyperparams['dataset']['dataset_name'] + hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \ + hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \ + + "/" + hyperparams['prediction']['prediction_network_architecture'] # set system environment variables for wandb if hyperparams['wandb']['use_wandb']: @@ -30,9 +43,12 @@ def train_images(hyperparams: DictConfig): os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] # initialize wandb - wandb.init(config=OmegaConf.to_container(hyperparams, resolve=True), entity=hyperparams['wandb']['wandb_entity'], project=hyperparams['wandb']['wandb_project'], dir=hyperparams['wandb']['wandb_dir']) + wandb_run = wandb.init(config=OmegaConf.to_container(hyperparams, resolve=True), entity=hyperparams['wandb']['wandb_entity'], project=hyperparams['wandb']['wandb_project'], dir=hyperparams['wandb']['wandb_dir']) wandb_logger = WandbLogger(project=hyperparams['wandb']['wandb_project'], log_model="all") + if not hyperparams['experiment']['run_mode'] == "test": + hyperparams['checkpoint']['checkpoint_name'] = wandb_run.id + "_" + wandb_run.name + "_" + wandb_run.sweep_id + "_" + wandb_run.group + # set seed pl.seed_everything(hyperparams.experiment.seed) diff --git a/examples/images/classification/train_utils.py b/examples/images/classification/train_utils.py index 3e4d1b4..49aa57b 100644 --- a/examples/images/classification/train_utils.py +++ b/examples/images/classification/train_utils.py @@ -19,9 +19,6 @@ def get_model_data_and_callbacks(hyperparams : DictConfig): # get image data image_data = get_image_data(hyperparams.dataset) - # checkpoint name - hyperparams.checkpoint.checkpoint_name = get_checkpoint_name(hyperparams) - # checkpoint callbacks callbacks = get_callbacks(hyperparams) @@ -83,33 +80,6 @@ def get_callbacks(hyperparams: DictConfig): mode="max") return [checkpoint_callback, early_stop_metric_callback] - -def get_recursive_hyperparams_identifier(hyperparams: DictConfig): - # get the identifier for the canonicalization network hyperparameters - # recursively go through the dictionary and get the values and concatenate them - identifier = "" - for key, value in hyperparams.items(): - if isinstance(value, DictConfig): - identifier += f"{get_recursive_hyperparams_identifier(value)}" - # special manipulation for the keys (to avoid exceeding OS limit for file names) - elif key not in ["canonicalization_type", "beta", "input_crop_ratio"]: - if key == "network_type": - identifier += f"_net_type_{value}_" - elif key == "out_vector_size": - identifier += f"_out_vec_{value}_" - elif key in ["kernel_size", "resize_shape", "group_type", "artifact_err_wt"]: - identifier += f"_{key.split('_')[0]}_{value}_" - elif key in ["num_layers", "out_channels", "num_rotations"]: - identifier += f"_{key.split('_')[-1]}_{value}_" - else: - identifier += f"_{key}_{value}_" - return identifier - -def get_checkpoint_name(hyperparams : DictConfig): - return f"{get_recursive_hyperparams_identifier(hyperparams.canonicalization)}".lstrip("_") + \ - f"_loss_wts_{int(hyperparams.experiment.training.loss.task_weight)}_{int(hyperparams.experiment.training.loss.prior_weight)}_{int(hyperparams.experiment.training.loss.group_contrast_weight)}" + \ - f"_lrs_{hyperparams.experiment.training.prediction_lr}_{hyperparams.experiment.training.canonicalization_lr}" + \ - f"_seed_{hyperparams.experiment.seed}" def get_image_data(dataset_hyperparams: DictConfig): diff --git a/examples/images/segmentation/README.md b/examples/images/segmentation/README.md index f9fe68a..0fbdbff 100644 --- a/examples/images/segmentation/README.md +++ b/examples/images/segmentation/README.md @@ -19,11 +19,29 @@ experiment.training.loss.prior_weight=0 python examples/images/segmentation/train.py canonicalization=group_equivariant ``` +### For testing checkpoints +``` +python train.py experiment.run_mode=test dataset.dataset_name=stl10 \ +checkpoint.checkpoint_path=/path/of/checkpoint/dir checkpoint.checkpoint_name= + +``` + +**Note**: +The final checkpoint that will be loaded during evaluation as follows, hence ensure that the combination of `checkpoint.checkpoint_path` and `checkpoint.checkpoint_name` is correct: +``` + model = ImageSegmentationPipeline.load_from_checkpoint( + checkpoint_path=hyperparams.checkpoint.checkpoint_path + "/" + \ + hyperparams.checkpoint.checkpoint_name + ".ckpt", + hyperparams=hyperparams + ) + +``` + ## Important Hyperparameters We use `hydra` and `OmegaConf` to setup experiments and parse configs. All the config files are available in [`/configs`](configs), along with the meaning of the hyperparameters in each yaml file. Below, we highlight some important details: - Choose canonicalization type from [`here`](configs/canonicalization) and set with `canonicalizaton=group_equivariant` - Canonicalization network architecture and relevant hyperparameters are detailed within canonicalization configs -- Dataset settings can be found [`here`](configs/dataset) and set with `dataset.dataset_name=coco` +- Dataset settings can be found [`here`](configs/dataset) and set with `dataset.dataset_name=coco` (Ensure the annotations are placed in `root-dir/dataset-name/annotations` and images in `root-dir/dataset-name/`) - Experiment settings can be found [`here`](configs/experiment) and set with `experiment.inference.num_rotations=8` - Prediction architecture settings can be found [`here`](configs/prediction) and set with `prediction.prediction_network_architecture=maskrcnn` - Wandb logging settings can can be found [`here`](configs/wandb) and set with `wandb.use_wandb=1` \ No newline at end of file diff --git a/examples/images/segmentation/train.py b/examples/images/segmentation/train.py index 30ed622..c672923 100644 --- a/examples/images/segmentation/train.py +++ b/examples/images/segmentation/train.py @@ -12,13 +12,26 @@ def train_images(hyperparams: DictConfig): - hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type'] - hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' - hyperparams['dataset']['root_dir'] = hyperparams['dataset']['root_dir'] + "/" + hyperparams['dataset']['dataset_name'] - hyperparams['dataset']['ann_dir'] = hyperparams['dataset']['root_dir'] + "/" + "annotations" - hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \ - hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \ - + "/" + hyperparams['prediction']['prediction_network_architecture'] + + if hyperparams['experiment']['run_mode'] == "test": + assert len(hyperparams['checkpoint']['checkpoint_name']) > 0, "checkpoint_name must be provided for test mode" + + existing_ckpt_path = hyperparams['checkpoint']['checkpoint_path'] + "/" + hyperparams['checkpoint']['checkpoint_name'] + ".ckpt" + existing_ckpt = torch.load(existing_ckpt_path) + conf = OmegaConf.create(existing_ckpt['hyper_parameters']['hyperparams']) + + hyperparams['canonicalization_type'] = conf['canonicalization_type'] + hyperparams['canonicalization'] = conf['canonicalization'] + hyperparams['prediction'] = conf['prediction'] + + else: + hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type'] + hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' + hyperparams['dataset']['root_dir'] = hyperparams['dataset']['root_dir'] + "/" + hyperparams['dataset']['dataset_name'] + hyperparams['dataset']['ann_dir'] = hyperparams['dataset']['root_dir'] + "/" + "annotations" + hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \ + hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \ + + "/" + hyperparams['prediction']['prediction_network_architecture'] # set system environment variables for wandb if hyperparams['wandb']['use_wandb']: @@ -31,9 +44,12 @@ def train_images(hyperparams: DictConfig): os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] # initialize wandb - wandb.init(config=OmegaConf.to_container(hyperparams, resolve=True), entity=hyperparams['wandb']['wandb_entity'], project=hyperparams['wandb']['wandb_project'], dir=hyperparams['wandb']['wandb_dir']) + wandb_run = wandb.init(config=OmegaConf.to_container(hyperparams, resolve=True), entity=hyperparams['wandb']['wandb_entity'], project=hyperparams['wandb']['wandb_project'], dir=hyperparams['wandb']['wandb_dir']) wandb_logger = WandbLogger(project=hyperparams['wandb']['wandb_project'], log_model="all") + if not hyperparams['experiment']['run_mode'] == "test": + hyperparams['checkpoint']['checkpoint_name'] = wandb_run.id + "_" + wandb_run.name + "_" + wandb_run.sweep_id + "_" + wandb_run.group + # set seed pl.seed_everything(hyperparams.experiment.seed) diff --git a/examples/images/segmentation/train_utils.py b/examples/images/segmentation/train_utils.py index ab26704..c10105c 100644 --- a/examples/images/segmentation/train_utils.py +++ b/examples/images/segmentation/train_utils.py @@ -13,9 +13,6 @@ def get_model_data_and_callbacks(hyperparams : DictConfig): # get image data image_data = get_image_data(hyperparams.dataset) - # checkpoint name - hyperparams.checkpoint.checkpoint_name = get_checkpoint_name(hyperparams) - # checkpoint callbacks callbacks = get_callbacks(hyperparams) @@ -80,23 +77,6 @@ def get_callbacks(hyperparams: DictConfig): return [checkpoint_callback, early_stop_metric_callback] -def get_recursive_hyperparams_identifier(hyperparams: DictConfig): - # get the identifier for the canonicalization network hyperparameters - # recursively go through the dictionary and get the values and concatenate them - identifier = "" - for key, value in hyperparams.items(): - if isinstance(value, DictConfig): - identifier += f"_{get_recursive_hyperparams_identifier(value)}_" - else: - identifier += f"_{key}_{value}_" - return identifier - -def get_checkpoint_name(hyperparams : DictConfig): - - return f"{get_recursive_hyperparams_identifier(hyperparams.canonicalization)}".lstrip("_") + \ - f"__epochs_{hyperparams.experiment.training.num_epochs}_" + f"__seed_{hyperparams.experiment.seed}" - - def get_image_data(dataset_hyperparams: DictConfig): dataset_classes = { From fbf51162f3bc920e74f91c317e88fa92039cbc10 Mon Sep 17 00:00:00 2001 From: sibasmarak Date: Wed, 13 Mar 2024 04:47:44 -0400 Subject: [PATCH 2/9] fixed isort and removed mypy.ini --- .mypy.ini | 6 ---- .pre-commit-config.yaml | 32 +++++++++---------- AUTHORS.md | 3 +- README.md | 6 +--- conda_env.yaml | 2 +- docs/conf.py | 2 +- equiadapt/__init__.py | 2 +- equiadapt/common/__init__.py | 9 ++---- equiadapt/common/basecanonicalization.py | 2 +- equiadapt/images/__init__.py | 7 ++-- equiadapt/images/canonicalization/__init__.py | 4 +-- .../canonicalization/continuous_group.py | 4 +-- .../images/canonicalization/discrete_group.py | 4 +-- .../canonicalization_networks/__init__.py | 13 ++++---- .../custom_equivariant_networks.py | 8 +++-- .../custom_group_equivariant_layers.py | 5 +-- .../custom_nonequivariant_networks.py | 3 +- equiadapt/images/utils.py | 3 +- equiadapt/pointcloud/__init__.py | 4 +-- .../pointcloud/canonicalization/__init__.py | 1 - .../canonicalization/continuous_group.py | 5 +-- .../canonicalization_networks/__init__.py | 7 ++-- .../equivariant_networks.py | 8 +++-- .../vector_neuron_layers.py | 3 +- examples/images/common/utils.py | 2 +- .../images/segmentation/prepare/coco_data.py | 3 +- examples/pointcloud/classification/model.py | 10 +++--- .../pointcloud/classification/model_utils.py | 3 +- examples/pointcloud/classification/prepare.py | 11 +++---- examples/pointcloud/classification/train.py | 12 +++---- .../pointcloud/classification/train_utils.py | 7 ++-- examples/pointcloud/common/networks.py | 2 +- examples/pointcloud/common/utils.py | 5 +-- .../pointcloud/part_segmentation/model.py | 10 +++--- .../part_segmentation/model_utils.py | 1 + .../pointcloud/part_segmentation/prepare.py | 10 +++--- .../pointcloud/part_segmentation/train.py | 12 +++---- .../part_segmentation/train_utils.py | 7 ++-- setup.cfg | 4 +-- .../canonicalization/test_continuous_group.py | 4 +-- 40 files changed, 118 insertions(+), 128 deletions(-) delete mode 100644 .mypy.ini diff --git a/.mypy.ini b/.mypy.ini deleted file mode 100644 index 50c4add..0000000 --- a/.mypy.ini +++ /dev/null @@ -1,6 +0,0 @@ -[mypy] -ignore_missing_imports = True - -[mypy-equiadapt.*] -ignore_missing_imports = False -disallow_untyped_defs = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a8cb9a8..ebf2494 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: ## If you want to avoid flake8 errors due to unused vars or imports: - repo: https://github.com/PyCQA/autoflake - rev: v2.1.1 + rev: v2.3.1 hooks: - id: autoflake args: [ @@ -34,10 +34,10 @@ repos: --remove-all-unused-imports, ] -# - repo: https://github.com/PyCQA/isort -# rev: 5.13.2 -# hooks: -# - id: isort +- repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort - repo: https://github.com/psf/black rev: 24.2.0 @@ -52,11 +52,11 @@ repos: # - id: blacken-docs # additional_dependencies: [black] -- repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-docstrings] +# - repo: https://github.com/PyCQA/flake8 +# rev: 7.0.0 +# hooks: +# - id: flake8 +# additional_dependencies: [flake8-docstrings] ## Check for misspells in documentation files: # - repo: https://github.com/codespell-project/codespell @@ -64,9 +64,9 @@ repos: # hooks: # - id: codespell -# Check for type errors with mypy: -- repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.8.0' - hooks: - - id: mypy - args: [--disallow-untyped-defs, --ignore-missing-imports] +# # Check for type errors with mypy: +# - repo: https://github.com/pre-commit/mirrors-mypy +# rev: 'v1.9.0' +# hooks: +# - id: mypy +# args: [--disallow-untyped-defs, --ignore-missing-imports] diff --git a/AUTHORS.md b/AUTHORS.md index 17eddad..4817c36 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -1,3 +1,4 @@ # Contributors -* Arnab Mondal [arnab.mondal@mila.quebec](mailto:arnab.mondal@mila.quebec)s +* Arnab Mondal [arnab.mondal@mila.quebec](mailto:arnab.mondal@mila.quebec) +* Siba Smarak Panigrahi [siba-smarak.panigrahi@mila.quebec](mailto:siba-smarak.panigrahi@mila.quebec) diff --git a/README.md b/README.md index 1fa1100..9716bf2 100644 --- a/README.md +++ b/README.md @@ -180,10 +180,6 @@ If you find this library or the associated papers useful, please cite the follow } ``` -# Contributing - -This repository is a work in progress. We are actively working on improving the codebase and adding more features. If you are interested in contributing, please raise an issue or submit a pull request. We will be happy to help you get started. - # Contact For questions related to this code, please raise an issue and you can mail us at: @@ -195,7 +191,7 @@ For questions related to this code, please raise an issue and you can mail us at You can check out the [contributor's guide](CONTRIBUTING.md). -This project uses `pre-commit`_, you can install it before making any +This project uses `pre-commit`, you can install it before making any changes:: pip install pre-commit diff --git a/conda_env.yaml b/conda_env.yaml index 63b71a8..6491c94 100644 --- a/conda_env.yaml +++ b/conda_env.yaml @@ -219,6 +219,7 @@ dependencies: - shortuuid==1.0.11 - sniffio==1.3.0 - soupsieve==2.5 + - sphinx==7.2.6 - starlette==0.27.0 - starsessions==1.3.0 - subprocess32==3.5.4 @@ -237,4 +238,3 @@ dependencies: - wheel==0.41.2 - yarl==1.9.4 prefix: equiadapt - diff --git a/docs/conf.py b/docs/conf.py index 758fcb3..03fb545 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -8,8 +8,8 @@ # serve to show the default. import os -import sys import shutil +import sys # -- Path setup -------------------------------------------------------------- diff --git a/equiadapt/__init__.py b/equiadapt/__init__.py index 7e91706..a8a72ee 100644 --- a/equiadapt/__init__.py +++ b/equiadapt/__init__.py @@ -14,9 +14,9 @@ DiscreteGroupImageCanonicalization, ESCNNEquivariantNetwork, ESCNNSteerableNetwork, - ESCNNWRNEquivariantNetwork, ESCNNWideBasic, ESCNNWideBottleneck, + ESCNNWRNEquivariantNetwork, GroupEquivariantImageCanonicalization, OptimizedGroupEquivariantImageCanonicalization, OptimizedSteerableImageCanonicalization, diff --git a/equiadapt/common/__init__.py b/equiadapt/common/__init__.py index c123f0d..c1418e5 100644 --- a/equiadapt/common/__init__.py +++ b/equiadapt/common/__init__.py @@ -1,16 +1,11 @@ -from equiadapt.common import basecanonicalization -from equiadapt.common import utils - +from equiadapt.common import basecanonicalization, utils from equiadapt.common.basecanonicalization import ( BaseCanonicalization, ContinuousGroupCanonicalization, DiscreteGroupCanonicalization, IdentityCanonicalization, ) -from equiadapt.common.utils import ( - LieParameterization, - gram_schmidt, -) +from equiadapt.common.utils import LieParameterization, gram_schmidt __all__ = [ "BaseCanonicalization", diff --git a/equiadapt/common/basecanonicalization.py b/equiadapt/common/basecanonicalization.py index c860067..baf3f98 100644 --- a/equiadapt/common/basecanonicalization.py +++ b/equiadapt/common/basecanonicalization.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Union, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import torch diff --git a/equiadapt/images/__init__.py b/equiadapt/images/__init__.py index 1fbc483..b670541 100644 --- a/equiadapt/images/__init__.py +++ b/equiadapt/images/__init__.py @@ -1,7 +1,4 @@ -from equiadapt.images import canonicalization -from equiadapt.images import canonicalization_networks -from equiadapt.images import utils - +from equiadapt.images import canonicalization, canonicalization_networks, utils from equiadapt.images.canonicalization import ( ContinuousGroupImageCanonicalization, DiscreteGroupImageCanonicalization, @@ -17,9 +14,9 @@ CustomEquivariantNetwork, ESCNNEquivariantNetwork, ESCNNSteerableNetwork, - ESCNNWRNEquivariantNetwork, ESCNNWideBasic, ESCNNWideBottleneck, + ESCNNWRNEquivariantNetwork, ResNet18Network, RotationEquivariantConv, RotationEquivariantConvLift, diff --git a/equiadapt/images/canonicalization/__init__.py b/equiadapt/images/canonicalization/__init__.py index 7e40a89..2348770 100644 --- a/equiadapt/images/canonicalization/__init__.py +++ b/equiadapt/images/canonicalization/__init__.py @@ -1,6 +1,4 @@ -from equiadapt.images.canonicalization import continuous_group -from equiadapt.images.canonicalization import discrete_group - +from equiadapt.images.canonicalization import continuous_group, discrete_group from equiadapt.images.canonicalization.continuous_group import ( ContinuousGroupImageCanonicalization, OptimizedSteerableImageCanonicalization, diff --git a/equiadapt/images/canonicalization/continuous_group.py b/equiadapt/images/canonicalization/continuous_group.py index 03d6816..bc43461 100644 --- a/equiadapt/images/canonicalization/continuous_group.py +++ b/equiadapt/images/canonicalization/continuous_group.py @@ -1,9 +1,9 @@ import math -from omegaconf import DictConfig -from typing import Optional, Dict, Any, Union, Tuple, List +from typing import Any, Dict, List, Optional, Tuple, Union import kornia as K import torch +from omegaconf import DictConfig from torch.nn import functional as F from torchvision import transforms diff --git a/equiadapt/images/canonicalization/discrete_group.py b/equiadapt/images/canonicalization/discrete_group.py index 0d2ef00..67489d7 100644 --- a/equiadapt/images/canonicalization/discrete_group.py +++ b/equiadapt/images/canonicalization/discrete_group.py @@ -1,9 +1,9 @@ import math -from omegaconf import DictConfig -from typing import List, Tuple, Union, Optional, Any +from typing import Any, List, Optional, Tuple, Union import kornia as K import torch +from omegaconf import DictConfig from torch.nn import functional as F from torchvision import transforms diff --git a/equiadapt/images/canonicalization_networks/__init__.py b/equiadapt/images/canonicalization_networks/__init__.py index 5e94346..13b33d0 100644 --- a/equiadapt/images/canonicalization_networks/__init__.py +++ b/equiadapt/images/canonicalization_networks/__init__.py @@ -1,8 +1,9 @@ -from equiadapt.images.canonicalization_networks import custom_equivariant_networks -from equiadapt.images.canonicalization_networks import custom_group_equivariant_layers -from equiadapt.images.canonicalization_networks import custom_nonequivariant_networks -from equiadapt.images.canonicalization_networks import escnn_networks - +from equiadapt.images.canonicalization_networks import ( + custom_equivariant_networks, + custom_group_equivariant_layers, + custom_nonequivariant_networks, + escnn_networks, +) from equiadapt.images.canonicalization_networks.custom_equivariant_networks import ( CustomEquivariantNetwork, ) @@ -19,9 +20,9 @@ from equiadapt.images.canonicalization_networks.escnn_networks import ( ESCNNEquivariantNetwork, ESCNNSteerableNetwork, - ESCNNWRNEquivariantNetwork, ESCNNWideBasic, ESCNNWideBottleneck, + ESCNNWRNEquivariantNetwork, ) __all__ = [ diff --git a/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py b/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py index a58bc3b..4af1a26 100644 --- a/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py +++ b/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py @@ -1,12 +1,14 @@ +from typing import Tuple + import torch import torch.nn as nn + from .custom_group_equivariant_layers import ( - RotationEquivariantConvLift, RotationEquivariantConv, - RotoReflectionEquivariantConvLift, + RotationEquivariantConvLift, RotoReflectionEquivariantConv, + RotoReflectionEquivariantConvLift, ) -from typing import Tuple class CustomEquivariantNetwork(nn.Module): diff --git a/equiadapt/images/canonicalization_networks/custom_group_equivariant_layers.py b/equiadapt/images/canonicalization_networks/custom_group_equivariant_layers.py index e806348..7d0798e 100644 --- a/equiadapt/images/canonicalization_networks/custom_group_equivariant_layers.py +++ b/equiadapt/images/canonicalization_networks/custom_group_equivariant_layers.py @@ -1,8 +1,9 @@ +import math + +import kornia as K import torch import torch.nn as nn import torch.nn.functional as F -import kornia as K -import math class RotationEquivariantConvLift(nn.Module): diff --git a/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py b/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py index 8cffb20..4fd2080 100644 --- a/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py +++ b/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py @@ -1,6 +1,7 @@ from typing import List -import torch, torchvision +import torch +import torchvision from torch import nn diff --git a/equiadapt/images/utils.py b/equiadapt/images/utils.py index 3683424..0994c5f 100644 --- a/equiadapt/images/utils.py +++ b/equiadapt/images/utils.py @@ -1,7 +1,8 @@ +from typing import List, Tuple + import kornia as K import torch from torchvision import transforms -from typing import List, Tuple def roll_by_gather(feature_map: torch.Tensor, shifts: torch.Tensor) -> torch.Tensor: diff --git a/equiadapt/pointcloud/__init__.py b/equiadapt/pointcloud/__init__.py index 8d9dc3d..34b2988 100644 --- a/equiadapt/pointcloud/__init__.py +++ b/equiadapt/pointcloud/__init__.py @@ -1,6 +1,4 @@ -from equiadapt.pointcloud import canonicalization -from equiadapt.pointcloud import canonicalization_networks - +from equiadapt.pointcloud import canonicalization, canonicalization_networks from equiadapt.pointcloud.canonicalization import ( ContinuousGroupPointcloudCanonicalization, EquivariantPointcloudCanonicalization, diff --git a/equiadapt/pointcloud/canonicalization/__init__.py b/equiadapt/pointcloud/canonicalization/__init__.py index 4d2aa4d..e16eea1 100644 --- a/equiadapt/pointcloud/canonicalization/__init__.py +++ b/equiadapt/pointcloud/canonicalization/__init__.py @@ -1,5 +1,4 @@ from equiadapt.pointcloud.canonicalization import continuous_group - from equiadapt.pointcloud.canonicalization.continuous_group import ( ContinuousGroupPointcloudCanonicalization, EquivariantPointcloudCanonicalization, diff --git a/equiadapt/pointcloud/canonicalization/continuous_group.py b/equiadapt/pointcloud/canonicalization/continuous_group.py index fb6146d..2299bf9 100644 --- a/equiadapt/pointcloud/canonicalization/continuous_group.py +++ b/equiadapt/pointcloud/canonicalization/continuous_group.py @@ -1,12 +1,13 @@ # Note that for now we have only implemented canonicalizatin for rotation in the pointcloud setting. # This is meant to be a proof of concept and we are happy to receive contribution to extend this to other group actions. -from omegaconf import DictConfig +from typing import Any, List, Optional, Tuple, Union + import torch +from omegaconf import DictConfig from equiadapt.common.basecanonicalization import ContinuousGroupCanonicalization from equiadapt.common.utils import gram_schmidt -from typing import Any, List, Tuple, Union, Optional class ContinuousGroupPointcloudCanonicalization(ContinuousGroupCanonicalization): diff --git a/equiadapt/pointcloud/canonicalization_networks/__init__.py b/equiadapt/pointcloud/canonicalization_networks/__init__.py index e2cae21..cf888f1 100644 --- a/equiadapt/pointcloud/canonicalization_networks/__init__.py +++ b/equiadapt/pointcloud/canonicalization_networks/__init__.py @@ -1,6 +1,7 @@ -from equiadapt.pointcloud.canonicalization_networks import equivariant_networks -from equiadapt.pointcloud.canonicalization_networks import vector_neuron_layers - +from equiadapt.pointcloud.canonicalization_networks import ( + equivariant_networks, + vector_neuron_layers, +) from equiadapt.pointcloud.canonicalization_networks.equivariant_networks import ( VNSmall, get_graph_feature_cross, diff --git a/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py b/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py index 0b147b2..54cb0c0 100644 --- a/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py +++ b/equiadapt/pointcloud/canonicalization_networks/equivariant_networks.py @@ -1,13 +1,15 @@ +from typing import Optional + import torch import torch.nn as nn +from omegaconf import DictConfig + from equiadapt.pointcloud.canonicalization_networks.vector_neuron_layers import ( + VNBatchNorm, VNLinearLeakyReLU, VNMaxPool, - VNBatchNorm, mean_pool, ) -from omegaconf import DictConfig -from typing import Optional def knn(x: torch.Tensor, k: int) -> torch.Tensor: diff --git a/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py b/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py index 4561c01..f42878b 100644 --- a/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py +++ b/equiadapt/pointcloud/canonicalization_networks/vector_neuron_layers.py @@ -2,9 +2,10 @@ # Taken from Vector Neurons: A General Framework for SO(3)-Equivariant Networks (https://arxiv.org/abs/2104.12229) paper and # their codebase https://github.com/FlyingGiraffe/vnn +from typing import Tuple + import torch import torch.nn as nn -from typing import Tuple EPS = 1e-6 diff --git a/examples/images/common/utils.py b/examples/images/common/utils.py index d48781b..b41aec6 100644 --- a/examples/images/common/utils.py +++ b/examples/images/common/utils.py @@ -12,11 +12,11 @@ ) from equiadapt.images.canonicalization_networks import ( ConvNetwork, - ResNet18Network, CustomEquivariantNetwork, ESCNNEquivariantNetwork, ESCNNSteerableNetwork, ESCNNWRNEquivariantNetwork, + ResNet18Network, ) diff --git a/examples/images/segmentation/prepare/coco_data.py b/examples/images/segmentation/prepare/coco_data.py index 8470aff..5c40e60 100644 --- a/examples/images/segmentation/prepare/coco_data.py +++ b/examples/images/segmentation/prepare/coco_data.py @@ -1,7 +1,6 @@ import os import numpy as np -import examples.images.segmentation.prepare.vision_transforms as T import pytorch_lightning as pl import torch import torchvision.transforms as transforms @@ -10,6 +9,8 @@ from segment_anything.utils.transforms import ResizeLongestSide from torch.utils.data import DataLoader, Dataset +import examples.images.segmentation.prepare.vision_transforms as T + class ResizeAndPad: def __init__(self, target_size): diff --git a/examples/pointcloud/classification/model.py b/examples/pointcloud/classification/model.py index c5d4729..d91033a 100644 --- a/examples/pointcloud/classification/model.py +++ b/examples/pointcloud/classification/model.py @@ -1,11 +1,12 @@ +import numpy as np +import pytorch_lightning as pl +import sklearn.metrics as metrics import torch import torch.nn.functional as F -import pytorch_lightning as pl -from pytorch3d.transforms import RotateAxisAngle, Rotate, random_rotations +from model_utils import get_prediction_network +from pytorch3d.transforms import Rotate, RotateAxisAngle, random_rotations from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR -import numpy as np -from model_utils import get_prediction_network from examples.pointcloud.common.utils import ( get_canonicalization_network, get_canonicalizer, @@ -13,7 +14,6 @@ random_scale_point_cloud, random_shift_point_cloud, ) -import sklearn.metrics as metrics class PointcloudClassificationPipeline(pl.LightningModule): diff --git a/examples/pointcloud/classification/model_utils.py b/examples/pointcloud/classification/model_utils.py index 7e40a72..af7ca64 100644 --- a/examples/pointcloud/classification/model_utils.py +++ b/examples/pointcloud/classification/model_utils.py @@ -1,5 +1,6 @@ from omegaconf import DictConfig -from examples.pointcloud.common.networks import PointNet, DGCNN + +from examples.pointcloud.common.networks import DGCNN, PointNet def get_prediction_network( diff --git a/examples/pointcloud/classification/prepare.py b/examples/pointcloud/classification/prepare.py index 09ed697..f477486 100644 --- a/examples/pointcloud/classification/prepare.py +++ b/examples/pointcloud/classification/prepare.py @@ -1,12 +1,11 @@ -import numpy as np -import warnings +import glob import os -from torch.utils.data import Dataset, DataLoader - +import warnings -import pytorch_lightning as pl import h5py -import glob +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset warnings.filterwarnings("ignore") diff --git a/examples/pointcloud/classification/train.py b/examples/pointcloud/classification/train.py index 4ecc762..2a6ac9b 100644 --- a/examples/pointcloud/classification/train.py +++ b/examples/pointcloud/classification/train.py @@ -1,22 +1,22 @@ import os + import hydra import omegaconf +import pytorch_lightning as pl import torch -import wandb - from omegaconf import DictConfig, OmegaConf -import pytorch_lightning as pl -from pytorch_lightning.loggers import WandbLogger - from prepare import ModelNetDataModule +from pytorch_lightning.loggers import WandbLogger from train_utils import ( - get_model_pipeline, get_callbacks, get_checkpoint_name, + get_model_pipeline, get_trainer, load_envs, ) +import wandb + def train_pointcloud(hyperparams: DictConfig): hyperparams["canonicalization_type"] = hyperparams["canonicalization"][ diff --git a/examples/pointcloud/classification/train_utils.py b/examples/pointcloud/classification/train_utils.py index 99045cf..27b9ead 100644 --- a/examples/pointcloud/classification/train_utils.py +++ b/examples/pointcloud/classification/train_utils.py @@ -1,11 +1,10 @@ -import dotenv -from omegaconf import DictConfig from typing import Optional +import dotenv import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping - from model import PointcloudClassificationPipeline +from omegaconf import DictConfig +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint def get_model_pipeline(hyperparams: DictConfig): diff --git a/examples/pointcloud/common/networks.py b/examples/pointcloud/common/networks.py index acb58b4..045bc18 100644 --- a/examples/pointcloud/common/networks.py +++ b/examples/pointcloud/common/networks.py @@ -1,8 +1,8 @@ -from omegaconf import DictConfig import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init +from omegaconf import DictConfig def knn(x, k): diff --git a/examples/pointcloud/common/utils.py b/examples/pointcloud/common/utils.py index 7c47184..292028f 100644 --- a/examples/pointcloud/common/utils.py +++ b/examples/pointcloud/common/utils.py @@ -1,11 +1,12 @@ +import numpy as np import torch from omegaconf import DictConfig -from equiadapt.pointcloud.canonicalization_networks import VNSmall + from equiadapt.common.basecanonicalization import IdentityCanonicalization from equiadapt.pointcloud.canonicalization.continuous_group import ( EquivariantPointcloudCanonicalization, ) -import numpy as np +from equiadapt.pointcloud.canonicalization_networks import VNSmall def get_canonicalization_network( diff --git a/examples/pointcloud/part_segmentation/model.py b/examples/pointcloud/part_segmentation/model.py index 9907d12..5205b32 100644 --- a/examples/pointcloud/part_segmentation/model.py +++ b/examples/pointcloud/part_segmentation/model.py @@ -1,11 +1,12 @@ +import numpy as np +import pytorch_lightning as pl +import sklearn.metrics as metrics import torch import torch.nn.functional as F -import pytorch_lightning as pl -from pytorch3d.transforms import RotateAxisAngle, Rotate, random_rotations +from model_utils import get_prediction_network +from pytorch3d.transforms import Rotate, RotateAxisAngle, random_rotations from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR -import numpy as np -from model_utils import get_prediction_network from examples.pointcloud.common.utils import ( get_canonicalization_network, get_canonicalizer, @@ -13,7 +14,6 @@ random_scale_point_cloud, random_shift_point_cloud, ) -import sklearn.metrics as metrics class_choices = [ "airplane", diff --git a/examples/pointcloud/part_segmentation/model_utils.py b/examples/pointcloud/part_segmentation/model_utils.py index 40ec7e0..39931c3 100644 --- a/examples/pointcloud/part_segmentation/model_utils.py +++ b/examples/pointcloud/part_segmentation/model_utils.py @@ -1,4 +1,5 @@ from omegaconf import DictConfig + from examples.pointcloud.common.networks import DGCNN_partseg diff --git a/examples/pointcloud/part_segmentation/prepare.py b/examples/pointcloud/part_segmentation/prepare.py index db6a6dd..be23276 100644 --- a/examples/pointcloud/part_segmentation/prepare.py +++ b/examples/pointcloud/part_segmentation/prepare.py @@ -1,11 +1,11 @@ -import numpy as np +import glob +import os import warnings -from torch.utils.data import Dataset, DataLoader -import pytorch_lightning as pl import h5py -import os -import glob +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset warnings.filterwarnings("ignore") diff --git a/examples/pointcloud/part_segmentation/train.py b/examples/pointcloud/part_segmentation/train.py index c9e67df..4baad98 100644 --- a/examples/pointcloud/part_segmentation/train.py +++ b/examples/pointcloud/part_segmentation/train.py @@ -1,22 +1,22 @@ import os + import hydra import omegaconf +import pytorch_lightning as pl import torch -import wandb - from omegaconf import DictConfig, OmegaConf -import pytorch_lightning as pl -from pytorch_lightning.loggers import WandbLogger - from prepare import ShapeNetDataModule +from pytorch_lightning.loggers import WandbLogger from train_utils import ( - get_model_pipeline, get_callbacks, get_checkpoint_name, + get_model_pipeline, get_trainer, load_envs, ) +import wandb + def train_pointcloud(hyperparams: DictConfig): hyperparams["canonicalization_type"] = hyperparams["canonicalization"][ diff --git a/examples/pointcloud/part_segmentation/train_utils.py b/examples/pointcloud/part_segmentation/train_utils.py index f3c360c..0e7dfff 100644 --- a/examples/pointcloud/part_segmentation/train_utils.py +++ b/examples/pointcloud/part_segmentation/train_utils.py @@ -1,11 +1,10 @@ -import dotenv -from omegaconf import DictConfig from typing import Optional +import dotenv import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping - from model import PointcloudClassificationPipeline +from omegaconf import DictConfig +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.strategies import DDPStrategy diff --git a/setup.cfg b/setup.cfg index 6f245d2..5d51b6f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,8 +6,8 @@ [metadata] name = equiadapt description = Library that provides metrics to assess representation quality -author = Arnab Mondal -author_email = arnab.mondal@mila.quebec +author = Arnab Mondal, Siba Smarak Panigrahi +author_email = arnab.mondal@mila.quebec, siba-smarak.panigrahi@mila.quebec license = MIT license_files = LICENSE long_description = file: README.md diff --git a/tests/images/canonicalization/test_continuous_group.py b/tests/images/canonicalization/test_continuous_group.py index 9f22429..2dc8a29 100644 --- a/tests/images/canonicalization/test_continuous_group.py +++ b/tests/images/canonicalization/test_continuous_group.py @@ -4,9 +4,9 @@ import torch from omegaconf import DictConfig -from equiadapt import ( +from equiadapt import ( # Update with your actual import path ContinuousGroupImageCanonicalization, -) # Update with your actual import path +) @pytest.fixture From c43c7db93a1b424d244553fb6b3c189767b138df Mon Sep 17 00:00:00 2001 From: sibasmarak Date: Wed, 13 Mar 2024 04:52:09 -0400 Subject: [PATCH 3/9] mypy commit --- .pre-commit-config.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ebf2494..6c86418 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -64,9 +64,9 @@ repos: # hooks: # - id: codespell -# # Check for type errors with mypy: -# - repo: https://github.com/pre-commit/mirrors-mypy -# rev: 'v1.9.0' -# hooks: -# - id: mypy -# args: [--disallow-untyped-defs, --ignore-missing-imports] +# Check for type errors with mypy: +- repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.9.0' + hooks: + - id: mypy + args: [--disallow-untyped-defs, --ignore-missing-imports] From 0fca34589407b141eaf616fc117dc128ad7e73a6 Mon Sep 17 00:00:00 2001 From: sibasmarak Date: Wed, 13 Mar 2024 08:20:38 -0400 Subject: [PATCH 4/9] dummy docs --- .pre-commit-config.yaml | 10 +- equiadapt/common/basecanonicalization.py | 223 ++++++++++++++++++++--- 2 files changed, 203 insertions(+), 30 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6c86418..9be26db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,11 +52,11 @@ repos: # - id: blacken-docs # additional_dependencies: [black] -# - repo: https://github.com/PyCQA/flake8 -# rev: 7.0.0 -# hooks: -# - id: flake8 -# additional_dependencies: [flake8-docstrings] +- repo: https://github.com/PyCQA/flake8 + rev: 7.0.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-docstrings] ## Check for misspells in documentation files: # - repo: https://github.com/codespell-project/codespell diff --git a/equiadapt/common/basecanonicalization.py b/equiadapt/common/basecanonicalization.py index baf3f98..44c24e8 100644 --- a/equiadapt/common/basecanonicalization.py +++ b/equiadapt/common/basecanonicalization.py @@ -7,6 +7,14 @@ class BaseCanonicalization(torch.nn.Module): + """ + This is the base class for canonicalization. + + This class is used as a base for all canonicalization methods. + Subclasses should implement the canonicalize method to define the specific canonicalization process. + + """ + def __init__(self, canonicalization_network: torch.nn.Module): super().__init__() self.canonicalization_network = canonicalization_network @@ -16,8 +24,7 @@ def forward( self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ - Forward method for the canonicalization which takes the input data and - returns the canonicalized version of the data + Forward method for the canonicalization which takes the input data and returns the canonicalized version of the data Args: x: input data @@ -27,8 +34,8 @@ def forward( Returns: canonicalized_x: canonicalized version of the input data - """ + """ # call the canonicalize method to obtain canonicalized version of the input data return self.canonicalize(x, targets, **kwargs) @@ -37,6 +44,7 @@ def canonicalize( ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ This method takes an input data with, optionally, targets that need to be canonicalized + Args: x: input data targets: (optional) additional targets that need to be canonicalized, @@ -52,8 +60,7 @@ def invert_canonicalization( self, x_canonicalized_out: torch.Tensor, **kwargs: Any ) -> torch.Tensor: """ - This method takes the output of the canonicalized data - and returns the output for the original data orientation + This method takes the output of the canonicalized data and returns the output for the original data orientation Args: canonicalized_outputs: output of the prediction network for canonicalized data @@ -67,12 +74,47 @@ def invert_canonicalization( class IdentityCanonicalization(BaseCanonicalization): + """ + This class represents an identity canonicalization method. + + Identity canonicalization is a no-op; it doesn't change the input data. It's useful as a default or placeholder + when no other canonicalization method is specified. + + Attributes: + canonicalization_network (torch.nn.Module): The network used for canonicalization. Defaults to torch.nn.Identity. + + Methods: + __init__: Initializes the IdentityCanonicalization instance. + canonicalize: Canonicalizes the input data. In this class, it returns the input data unchanged. + """ + def __init__(self, canonicalization_network: torch.nn.Module = torch.nn.Identity()): + """ + Initializes the IdentityCanonicalization instance. + + Args: + canonicalization_network (torch.nn.Module, optional): The network used for canonicalization. Defaults to torch.nn.Identity. + """ super().__init__(canonicalization_network) def canonicalize( self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: + """ + Canonicalize the input data. + + This method takes the input data and returns it unchanged, along with the targets if provided. + It's a no-op in the IdentityCanonicalization class. + + Args: + x: The input data. + targets: (Optional) Additional targets that need to be canonicalized. + **kwargs: Additional arguments. + + Returns: + A tuple containing the unchanged input data and targets if targets are provided, + otherwise just the unchanged input data. + """ if targets: return x, targets return x @@ -80,22 +122,78 @@ def canonicalize( def invert_canonicalization( self, x_canonicalized_out: torch.Tensor, **kwargs: Any ) -> torch.Tensor: + """ + Inverts the canonicalization. + + For the IdentityCanonicalization class, this is a no-op and returns the input unchanged. + + Args: + x_canonicalized_out (torch.Tensor): The canonicalized output. + **kwargs: Additional arguments. + + Returns: + torch.Tensor: The unchanged x_canonicalized_out. + """ return x_canonicalized_out def get_prior_regularization_loss(self) -> torch.Tensor: + """ + Gets the prior regularization loss. + + For the IdentityCanonicalization class, this is always 0. + + Returns: + torch.Tensor: A tensor containing the value 0. + """ return torch.tensor(0.0) def get_identity_metric(self) -> torch.Tensor: + """ + Gets the identity metric. + + For the IdentityCanonicalization class, this is always 1. + + Returns: + torch.Tensor: A tensor containing the value 1. + """ return torch.tensor(1.0) class DiscreteGroupCanonicalization(BaseCanonicalization): + """ + This class represents a discrete group canonicalization method. + + Discrete group canonicalization is a method that transforms the input data into a canonical form using a discrete group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for discrete group canonicalization. + + Attributes: + canonicalization_network (torch.nn.Module): The network used for canonicalization. + beta (float): A parameter for the softmax function. Defaults to 1.0. + gradient_trick (str): The method used for backpropagation through the discrete operation. Defaults to "straight_through". + + Methods: + __init__: Initializes the DiscreteGroupCanonicalization instance. + groupactivations_to_groupelementonehot: Converts group activations to one-hot encoded group elements in a differentiable manner. + canonicalize: Canonicalizes the input data. + invert_canonicalization: Inverts the canonicalization. + get_prior_regularization_loss: Gets the prior regularization loss. + get_identity_metric: Gets the identity metric. + + """ + def __init__( self, canonicalization_network: torch.nn.Module, beta: float = 1.0, gradient_trick: str = "straight_through", ): + """ + Initializes the DiscreteGroupCanonicalization instance. + + Args: + canonicalization_network (torch.nn.Module): The network used for canonicalization. + beta (float, optional): A parameter for the softmax function. Defaults to 1.0. + gradient_trick (str, optional): The method used for backpropagation through the discrete operation. Defaults to "straight_through". + """ super().__init__(canonicalization_network) self.beta = beta self.gradient_trick = gradient_trick @@ -104,14 +202,13 @@ def groupactivations_to_groupelementonehot( self, group_activations: torch.Tensor ) -> torch.Tensor: """ - This method takes the activations for each group element as input and - returns the group element in a differentiable manner + Converts group activations to one-hot encoded group elements in a differentiable manner. Args: - group_activations: activations for each group element + group_activations (torch.Tensor): The activations for each group element. Returns: - group_element_onehot: one hot encoding of the group element + torch.Tensor: The one-hot encoding of the group elements. """ group_activations_one_hot = torch.nn.functional.one_hot( torch.argmax(group_activations, dim=-1), self.num_group @@ -142,10 +239,16 @@ def canonicalize( self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ - This method takes an input data and - returns its canonicalized version and - a dictionary containing the information - about the canonicalization + Canonicalizes the input data. + + Args: + x (torch.Tensor): The input data. + targets (List, optional): Additional targets that need to be canonicalized. + **kwargs: Additional arguments. + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, List]]: The canonicalized input data and targets. + Simultaneously, it updates a dictionary containing the information about the canonicalization. """ raise NotImplementedError() @@ -153,12 +256,24 @@ def invert_canonicalization( self, x_canonicalized_out: torch.Tensor, **kwargs: Any ) -> torch.Tensor: """ - This method takes the output of the canonicalized data - and returns the output for the original data orientation + Inverts the canonicalization. + + Args: + x_canonicalized_out (torch.Tensor): The canonicalized output. + **kwargs: Additional arguments. + + Returns: + torch.Tensor: The output for the original data orientation. """ raise NotImplementedError() def get_prior_regularization_loss(self) -> torch.Tensor: + """ + Gets the prior regularization loss. + + Returns: + torch.Tensor: The prior regularization loss. + """ group_activations = self.canonicalization_info_dict["group_activations"] dataset_prior = torch.zeros((group_activations.shape[0],), dtype=torch.long).to( self.device @@ -166,12 +281,43 @@ def get_prior_regularization_loss(self) -> torch.Tensor: return torch.nn.CrossEntropyLoss()(group_activations, dataset_prior) def get_identity_metric(self) -> torch.Tensor: + """ + Gets the identity metric. + + Returns: + torch.Tensor: The identity metric. + """ group_activations = self.canonicalization_info_dict["group_activations"] return (group_activations.argmax(dim=-1) == 0).float().mean() class ContinuousGroupCanonicalization(BaseCanonicalization): + """ + This class represents a continuous group canonicalization method. + + Continuous group canonicalization is a method that transforms the input data into a canonical form using a continuous group. This class is a subclass of the BaseCanonicalization class and overrides its methods to provide the functionality for continuous group canonicalization. + + Attributes: + canonicalization_network (torch.nn.Module): The network used for canonicalization. + beta (float): A parameter for the softmax function. Defaults to 1.0. + + Methods: + __init__: Initializes the ContinuousGroupCanonicalization instance. + canonicalizationnetworkout_to_groupelement: Converts the output of the canonicalization network to a group element in a differentiable manner. + canonicalize: Canonicalizes the input data. + invert_canonicalization: Inverts the canonicalization. + get_prior_regularization_loss: Gets the prior regularization loss. + get_identity_metric: Gets the identity metric. + """ + def __init__(self, canonicalization_network: torch.nn.Module, beta: float = 1.0): + """ + Initializes the ContinuousGroupCanonicalization instance. + + Args: + canonicalization_network (torch.nn.Module): The network used for canonicalization. + beta (float, optional): A parameter for the softmax function. Defaults to 1.0. + """ super().__init__(canonicalization_network) self.beta = beta @@ -179,14 +325,13 @@ def canonicalizationnetworkout_to_groupelement( self, group_activations: torch.Tensor ) -> torch.Tensor: """ - This method takes the as input and - returns the group element in a differentiable manner + Converts the output of the canonicalization network to a group element in a differentiable manner. Args: - group_activations: activations for each group element + group_activations (torch.Tensor): The activations for each group element. Returns: - group_element: group element + torch.Tensor: The group element. """ raise NotImplementedError() @@ -194,10 +339,16 @@ def canonicalize( self, x: torch.Tensor, targets: Optional[List] = None, **kwargs: Any ) -> Union[torch.Tensor, Tuple[torch.Tensor, List]]: """ - This method takes an input data and - returns its canonicalized version and - a dictionary containing the information - about the canonicalization + Canonicalizes the input data. + + Args: + x (torch.Tensor): The input data. + targets (List, optional): Additional targets that need to be canonicalized. + **kwargs: Additional arguments. + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, List]]: The canonicalized input data and targets. + Simultaneously, it updates a dictionary containing the information about the canonicalization. """ raise NotImplementedError() @@ -205,12 +356,26 @@ def invert_canonicalization( self, x_canonicalized_out: torch.Tensor, **kwargs: Any ) -> torch.Tensor: """ - This method takes the output of the canonicalized data - and returns the output for the original data orientation + Inverts the canonicalization. + + Args: + x_canonicalized_out (torch.Tensor): The canonicalized output. + **kwargs: Additional arguments. + + Returns: + torch.Tensor: The output for the original data orientation. """ raise NotImplementedError() def get_prior_regularization_loss(self) -> torch.Tensor: + """ + Gets the prior regularization loss. + + The prior regularization loss is calculated as the mean squared error between the group element matrix representation and the identity matrix. + + Returns: + torch.Tensor: The prior regularization loss. + """ group_elements_rep = self.canonicalization_info_dict[ "group_element_matrix_representation" ] # shape: (batch_size, group_rep_dim, group_rep_dim) @@ -223,6 +388,14 @@ def get_prior_regularization_loss(self) -> torch.Tensor: return torch.nn.MSELoss()(group_elements_rep, dataset_prior) def get_identity_metric(self) -> torch.Tensor: + """ + Gets the identity metric. + + The identity metric is calculated as 1 minus the mean of the mean squared error between the group element matrix representation and the identity matrix. + + Returns: + torch.Tensor: The identity metric. + """ group_elements_rep = self.canonicalization_info_dict[ "group_element_matrix_representation" ] From 35ef33dc410dafa078fedc7596f3c17c5d9b64f7 Mon Sep 17 00:00:00 2001 From: sibasmarak Date: Wed, 13 Mar 2024 08:28:01 -0400 Subject: [PATCH 5/9] docs for equiadapt/common/utils.py --- equiadapt/common/utils.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/equiadapt/common/utils.py b/equiadapt/common/utils.py index 030ce53..c095690 100644 --- a/equiadapt/common/utils.py +++ b/equiadapt/common/utils.py @@ -1,5 +1,23 @@ import torch +""" +This module contains utility functions and classes that are used for operations on Lie groups. + +The module includes a function for the Gram-Schmidt process, which is used to orthogonalize a set of vectors. This function is implemented in a batch-wise manner, meaning it can process multiple sets of vectors at once. + +The module also includes a class for parameterizing Lie groups and their representations. +This class supports several types of Lie groups, including the special orthogonal group (SO(n)), +the special Euclidean group (SE(n)), the orthogonal group (O(n)), and the Euclidean group (E(n)). +The class provides methods for generating the basis of the Lie group, as well as for computing the +group representation given a set of parameters. + +Functions: + gram_schmidt(vectors: torch.Tensor) -> torch.Tensor + +Classes: + LieParameterization +""" + def gram_schmidt(vectors: torch.Tensor) -> torch.Tensor: """ @@ -26,7 +44,8 @@ def gram_schmidt(vectors: torch.Tensor) -> torch.Tensor: class LieParameterization(torch.nn.Module): - """A class for parameterizing Lie groups and their representations for a single block. + """ + A class for parameterizing Lie groups and their representations for a single block. Args: group_type (str): The type of Lie group (e.g., 'SOn', 'SEn', 'On', 'En'). @@ -43,7 +62,8 @@ def __init__(self, group_type: str, group_dim: int): self.group_dim = group_dim def get_son_bases(self) -> torch.Tensor: - """Generates the basis of the Lie group of SOn. + """ + Generates the basis of the Lie group of SOn. Returns: torch.Tensor: The son basis of shape (num_params, group_dim, group_dim). @@ -62,7 +82,8 @@ def get_son_bases(self) -> torch.Tensor: return son_bases def get_son_rep(self, params: torch.Tensor) -> torch.Tensor: - """Computes the representation for SOn group. + """ + Computes the representation for SOn group. Args: params (torch.Tensor): Input parameters of shape (batch_size, param_dim). @@ -104,7 +125,8 @@ def get_on_rep( return on_rep def get_sen_rep(self, params: torch.Tensor) -> torch.Tensor: - """Computes the representation for SEn group. + """ + Computes the representation for SEn group. Args: params (torch.Tensor): Input parameters of shape (batch_size, param_dim). @@ -129,14 +151,6 @@ def get_sen_rep(self, params: torch.Tensor) -> torch.Tensor: def get_en_rep( self, params: torch.Tensor, reflect_indicators: torch.Tensor ) -> torch.Tensor: - """Computes the representation for E(n) group. - - Args: - params (torch.Tensor): Input parameters of shape (batch_size, param_dim). - - Returns: - torch.Tensor: The representation of shape (batch_size, rep_dim, rep_dim). - """ """Computes the representation for E(n) group, including both rotations and translations. Args: From 39f086b0f12ad7ef86aa59ad44e790f7f3025cd5 Mon Sep 17 00:00:00 2001 From: "D. Benesch" <34680344+danibene@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:36:32 -0400 Subject: [PATCH 6/9] Exclude examples from flake8 checks --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 5d51b6f..8e5f3e3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -105,6 +105,7 @@ exclude = dist .eggs docs/conf.py + examples [pyscaffold] # PyScaffold's parameters when the project was created. From 723c2fca7a8da403ee83a5f0599da212e1dbe247 Mon Sep 17 00:00:00 2001 From: "D. Benesch" <34680344+danibene@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:38:46 -0400 Subject: [PATCH 7/9] have mypy exclude examples --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9f81bba..d42880c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,4 +9,4 @@ build-backend = "setuptools.build_meta" version_scheme = "no-guess-dev" [tool.mypy] -exclude = ['docs'] +exclude = ['docs', 'examples'] From c271901a05b250b734b859b228aef1c68b0dddec Mon Sep 17 00:00:00 2001 From: "D. Benesch" <34680344+danibene@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:53:41 -0400 Subject: [PATCH 8/9] check if specifying dir in different way fixes exclude --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 8e5f3e3..2f95df6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -105,7 +105,7 @@ exclude = dist .eggs docs/conf.py - examples + examples/* [pyscaffold] # PyScaffold's parameters when the project was created. From 4c15e03149e1c4b596fbc357ca645278427b0d29 Mon Sep 17 00:00:00 2001 From: "D. Benesch" <34680344+danibene@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:59:35 -0400 Subject: [PATCH 9/9] try moving mypy exclude to setup.cfg --- setup.cfg | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.cfg b/setup.cfg index 2f95df6..a5e78e6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -107,6 +107,9 @@ exclude = docs/conf.py examples/* +[mypy] +exclude = docs|examples + [pyscaffold] # PyScaffold's parameters when the project was created. # This will be used when updating. Do not change!