diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..103964b --- /dev/null +++ b/.coveragerc @@ -0,0 +1,28 @@ +# .coveragerc to control coverage.py +[run] +branch = True +source = equiadapt +# omit = bad_file.py + +[paths] +source = + equiadapt/ + */site-packages/ + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b95f5ab --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,107 @@ +# GitHub Actions configuration **EXAMPLE**, +# MODIFY IT ACCORDING TO YOUR NEEDS! +# Reference: https://docs.github.com/en/actions + +name: tests + +on: + push: + # Avoid using all the resources/limits available by checking only + # relevant branches and tags. Other branches can be checked via PRs. + branches: [main] + tags: ['v[0-9]*', '[0-9]+.[0-9]+*'] # Match tags that resemble a version + pull_request: # Run in every PR + workflow_dispatch: # Allow manually triggering the workflow + schedule: + # Run roughly every 15 days at 00:00 UTC + # (useful to check if updates on dependencies break the package) + - cron: '0 0 1,16 * *' + +permissions: + contents: read + +concurrency: + group: >- + ${{ github.workflow }}-${{ github.ref_type }}- + ${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + +jobs: + prepare: + runs-on: ubuntu-latest + outputs: + wheel-distribution: ${{ steps.wheel-distribution.outputs.path }} + steps: + - uses: actions/checkout@v3 + with: {fetch-depth: 0} # deep clone for setuptools-scm + - uses: actions/setup-python@v4 + id: setup-python + with: {python-version: "3.10"} + - name: Run static analysis and format checkers + run: pipx run pre-commit run --all-files --show-diff-on-failure + - name: Build package distribution files + run: >- + pipx run --python '${{ steps.setup-python.outputs.python-path }}' + tox -e clean,build + - name: Record the path of wheel distribution + id: wheel-distribution + run: echo "path=$(ls dist/*.whl)" >> $GITHUB_OUTPUT + - name: Store the distribution files for use in other stages + # `tests` and `publish` will use the same pre-built distributions, + # so we make sure to release the exact same package that was tested + uses: actions/upload-artifact@v3 + with: + name: python-distribution-files + path: dist/ + retention-days: 1 + + test: + needs: prepare + strategy: + matrix: + python: + - "3.7" # oldest Python supported by PSF + - "3.10" # newest Python that is stable + platform: + - ubuntu-latest + # - macos-latest + # - windows-latest + runs-on: ${{ matrix.platform }} + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + id: setup-python + with: + python-version: ${{ matrix.python }} + - name: Retrieve pre-built distribution files + uses: actions/download-artifact@v3 + with: {name: python-distribution-files, path: dist/} + - name: Run tests + run: >- + pipx run --python '${{ steps.setup-python.outputs.python-path }}' + tox --installpkg '${{ needs.prepare.outputs.wheel-distribution }}' + -- -rFEx --durations 10 --color yes # pytest args + - name: Generate coverage report + run: pipx run coverage lcov -o coverage.lcov + + publish: + if: ${{ github.event_name == 'push' && contains(github.ref, 'refs/tags/') }} + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: {python-version: "3.10"} + - name: Retrieve pre-built distribution files + uses: actions/download-artifact@v3 + with: {name: python-distribution-files, path: dist/} + - name: Publish Package + env: + # TODO: Set your PYPI_TOKEN as a secret using GitHub UI + # - https://pypi.org/help/#apitoken + # - https://docs.github.com/en/actions/security-guides/encrypted-secrets + TWINE_REPOSITORY: pypi + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: pipx run tox -e publish diff --git a/.gitignore b/.gitignore index f3f797a..f202880 100644 --- a/.gitignore +++ b/.gitignore @@ -104,11 +104,15 @@ dmypy.json # Ignore .vscode in all folders **/.vscode -# Ignore scripts to run experiments in mila +# Ignore scripts to run experiments in mila mila_scripts/ escnn *__pycache__/ rotmnist_sweep_output/ cifar10_sweep_output/ -wandb/ \ No newline at end of file +wandb/ + +# Docs +docs/api/ +docs/_build/ diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..777924e --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,3 @@ +[settings] +profile = black +known_first_party = equiadapt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..d1bbf82 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,74 @@ +exclude: '^docs/conf.py' + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: trailing-whitespace + - id: check-added-large-files + - id: check-ast +# - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows + +## If you want to automatically "modernize" your Python code: +# - repo: https://github.com/asottile/pyupgrade +# rev: v3.7.0 +# hooks: +# - id: pyupgrade +# args: ['--py37-plus'] + +## If you want to avoid flake8 errors due to unused vars or imports: +# - repo: https://github.com/PyCQA/autoflake +# rev: v2.1.1 +# hooks: +# - id: autoflake +# args: [ +# --in-place, +# --remove-all-unused-imports, +# --remove-unused-variables, +# ] + +- repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + +- repo: https://github.com/psf/black + rev: 24.2.0 + hooks: + - id: black + language_version: python3 + +## If like to embrace black styles even in the docs: +# - repo: https://github.com/asottile/blacken-docs +# rev: v1.13.0 +# hooks: +# - id: blacken-docs +# additional_dependencies: [black] + +# - repo: https://github.com/PyCQA/flake8 +# rev: 7.0.0 +# hooks: +# - id: flake8 + ## You can add flake8 plugins via `additional_dependencies`: + # additional_dependencies: [flake8-bugbear] + +## Check for misspells in documentation files: +# - repo: https://github.com/codespell-project/codespell +# rev: v2.2.5 +# hooks: +# - id: codespell + +## Check for type errors with mypy: +# - repo: https://github.com/pre-commit/mirrors-mypy +# rev: 'v1.8.0' +# hooks: +# - id: mypy +# args: [--disallow-untyped-defs, --ignore-missing-imports] diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..a2bcab3 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,27 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# Build documentation with MkDocs +#mkdocs: +# configuration: mkdocs.yml + +# Optionally build your docs in additional formats such as PDF +formats: + - pdf + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +python: + install: + - requirements: docs/requirements.txt + - {path: ., method: pip} diff --git a/AUTHORS.md b/AUTHORS.md new file mode 100644 index 0000000..17eddad --- /dev/null +++ b/AUTHORS.md @@ -0,0 +1,3 @@ +# Contributors + +* Arnab Mondal [arnab.mondal@mila.quebec](mailto:arnab.mondal@mila.quebec)s diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..dd0325b --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,3 @@ +# Changelog + +## Version 0.1 (development) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..93ebb5e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,371 @@ +```{todo} THIS IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! + + The document assumes you are using a source repository service that promotes a + contribution model similar to [GitHub's fork and pull request workflow]. + While this is true for the majority of services (like GitHub, GitLab, + BitBucket), it might not be the case for private repositories (e.g., when + using Gerrit). + + Also notice that the code examples might refer to GitHub URLs or the text + might use GitHub specific terminology (e.g., *Pull Request* instead of *Merge + Request*). + + Please make sure to check the document having these assumptions in mind + and update things accordingly. +``` + +```{todo} Provide the correct links/replacements at the bottom of the document. +``` + +```{todo} You might want to have a look on [PyScaffold's contributor's guide], + + especially if your project is open source. The text should be very similar to + this template, but there are a few extra contents that you might decide to + also include, like mentioning labels of your issue tracker or automated + releases. +``` + +# Contributing + +Welcome to `equiadapt` contributor's guide. + +This document focuses on getting any potential contributor familiarized with +the development processes, but [other kinds of contributions] are also appreciated. + +If you are new to using [git] or have never collaborated in a project previously, +please have a look at [contribution-guide.org]. Other resources are also +listed in the excellent [guide created by FreeCodeCamp] [^contrib1]. + +Please notice, all users and contributors are expected to be **open, +considerate, reasonable, and respectful**. When in doubt, +[Python Software Foundation's Code of Conduct] is a good reference in terms of +behavior guidelines. + +## Issue Reports + +If you experience bugs or general issues with `equiadapt`, please have a look +on the [issue tracker]. +If you don't see anything useful there, please feel free to fire an issue report. + +:::{tip} +Please don't forget to include the closed issues in your search. +Sometimes a solution was already reported, and the problem is considered +**solved**. +::: + +New issue reports should include information about your programming environment +(e.g., operating system, Python version) and steps to reproduce the problem. +Please try also to simplify the reproduction steps to a very minimal example +that still illustrates the problem you are facing. By removing other factors, +you help us to identify the root cause of the issue. + +## Documentation Improvements + +You can help improve `equiadapt` docs by making them more readable and coherent, or +by adding missing information and correcting mistakes. + +`equiadapt` documentation uses [Sphinx] as its main documentation compiler. +This means that the docs are kept in the same repository as the project code, and +that any documentation update is done in the same way was a code contribution. + +```{todo} Don't forget to mention which markup language you are using. + + e.g., [reStructuredText] or [CommonMark] with [MyST] extensions. +``` + +```{todo} If your project is hosted on GitHub, you can also mention the following tip: + + :::{tip} + Please notice that the [GitHub web interface] provides a quick way of + propose changes in `equiadapt`'s files. While this mechanism can + be tricky for normal code contributions, it works perfectly fine for + contributing to the docs, and can be quite handy. + + If you are interested in trying this method out, please navigate to + the `docs` folder in the source [repository], find which file you + would like to propose changes and click in the little pencil icon at the + top, to open [GitHub's code editor]. Once you finish editing the file, + please write a message in the form at the bottom of the page describing + which changes have you made and what are the motivations behind them and + submit your proposal. + ::: +``` + +When working on documentation changes in your local machine, you can +compile them using [tox] : + +``` +tox -e docs +``` + +and use Python's built-in web server for a preview in your web browser +(`http://localhost:8000`): + +``` +python3 -m http.server --directory 'docs/_build/html' +``` + +## Code Contributions + +```{todo} Please include a reference or explanation about the internals of the project. + + An architecture description, design principles or at least a summary of the + main concepts will make it easy for potential contributors to get started + quickly. +``` + +### Submit an issue + +Before you work on any non-trivial code contribution it's best to first create +a report in the [issue tracker] to start a discussion on the subject. +This often provides additional considerations and avoids unnecessary work. + +### Create an environment + +Before you start coding, we recommend creating an isolated [virtual environment] +to avoid any problems with your installed Python packages. +This can easily be done via either [virtualenv]: + +``` +virtualenv +source /bin/activate +``` + +or [Miniconda]: + +``` +conda create -n equiadapt python=3 six virtualenv pytest pytest-cov +conda activate equiadapt +``` + +### Clone the repository + +1. Create an user account on GitHub if you do not already have one. + +2. Fork the project [repository]: click on the *Fork* button near the top of the + page. This creates a copy of the code under your account on GitHub. + +3. Clone this copy to your local disk: + + ``` + git clone git@github.com:YourLogin/equiadapt.git + cd equiadapt + ``` + +4. You should run: + + ``` + pip install -U pip setuptools -e . + ``` + + to be able to import the package under development in the Python REPL. + + ```{todo} if you are not using pre-commit, please remove the following item: + ``` + +5. Install [pre-commit]: + + ``` + pip install pre-commit + pre-commit install + ``` + + `equiadapt` comes with a lot of hooks configured to automatically help the + developer to check the code being written. + +### Implement your changes + +1. Create a branch to hold your changes: + + ``` + git checkout -b my-feature + ``` + + and start making changes. Never work on the main branch! + +2. Start your work on this branch. Don't forget to add [docstrings] to new + functions, modules and classes, especially if they are part of public APIs. + +3. Add yourself to the list of contributors in `AUTHORS.rst`. + +4. When you’re done editing, do: + + ``` + git add + git commit + ``` + + to record your changes in [git]. + + ```{todo} if you are not using pre-commit, please remove the following item: + ``` + + Please make sure to see the validation messages from [pre-commit] and fix + any eventual issues. + This should automatically use [flake8]/[black] to check/fix the code style + in a way that is compatible with the project. + + :::{important} + Don't forget to add unit tests and documentation in case your + contribution adds an additional feature and is not just a bugfix. + + Moreover, writing a [descriptive commit message] is highly recommended. + In case of doubt, you can check the commit history with: + + ``` + git log --graph --decorate --pretty=oneline --abbrev-commit --all + ``` + + to look for recurring communication patterns. + ::: + +5. Please check that your changes don't break any unit tests with: + + ``` + tox + ``` + + (after having installed [tox] with `pip install tox` or `pipx`). + + You can also use [tox] to run several other pre-configured tasks in the + repository. Try `tox -av` to see a list of the available checks. + +### Submit your contribution + +1. If everything works fine, push your local branch to the remote server with: + + ``` + git push -u origin my-feature + ``` + +2. Go to the web page of your fork and click "Create pull request" + to send your changes for review. + + ```{todo} if you are using GitHub, you can uncomment the following paragraph + + Find more detailed information in [creating a PR]. You might also want to open + the PR as a draft first and mark it as ready for review after the feedbacks + from the continuous integration (CI) system or any required fixes. + + ``` + +### Troubleshooting + +The following tips can be used when facing problems to build or test the +package: + +1. Make sure to fetch all the tags from the upstream [repository]. + The command `git describe --abbrev=0 --tags` should return the version you + are expecting. If you are trying to run CI scripts in a fork repository, + make sure to push all the tags. + You can also try to remove all the egg files or the complete egg folder, i.e., + `.eggs`, as well as the `*.egg-info` folders in the `src` folder or + potentially in the root of your project. + +2. Sometimes [tox] misses out when new dependencies are added, especially to + `setup.cfg` and `docs/requirements.txt`. If you find any problems with + missing dependencies when running a command with [tox], try to recreate the + `tox` environment using the `-r` flag. For example, instead of: + + ``` + tox -e docs + ``` + + Try running: + + ``` + tox -r -e docs + ``` + +3. Make sure to have a reliable [tox] installation that uses the correct + Python version (e.g., 3.7+). When in doubt you can run: + + ``` + tox --version + # OR + which tox + ``` + + If you have trouble and are seeing weird errors upon running [tox], you can + also try to create a dedicated [virtual environment] with a [tox] binary + freshly installed. For example: + + ``` + virtualenv .venv + source .venv/bin/activate + .venv/bin/pip install tox + .venv/bin/tox -e all + ``` + +4. [Pytest can drop you] in an interactive session in the case an error occurs. + In order to do that you need to pass a `--pdb` option (for example by + running `tox -- -k --pdb`). + You can also setup breakpoints manually instead of using the `--pdb` option. + +## Maintainer tasks + +### Releases + +```{todo} This section assumes you are using PyPI to publicly release your package. + + If instead you are using a different/private package index, please update + the instructions accordingly. +``` + +If you are part of the group of maintainers and have correct user permissions +on [PyPI], the following steps can be used to release a new version for +`equiadapt`: + +1. Make sure all unit tests are successful. +2. Tag the current commit on the main branch with a release tag, e.g., `v1.2.3`. +3. Push the new tag to the upstream [repository], + e.g., `git push upstream v1.2.3` +4. Clean up the `dist` and `build` folders with `tox -e clean` + (or `rm -rf dist build`) + to avoid confusion with old builds and Sphinx docs. +5. Run `tox -e build` and check that the files in `dist` have + the correct version (no `.dirty` or [git] hash) according to the [git] tag. + Also check the sizes of the distributions, if they are too big (e.g., > + 500KB), unwanted clutter may have been accidentally included. +6. Run `tox -e publish -- --repository pypi` and check that everything was + uploaded to [PyPI] correctly. + +[^contrib1]: Even though, these resources focus on open source projects and + communities, the general ideas behind collaborating with other developers + to collectively create software are general and can be applied to all sorts + of environments, including private companies and proprietary code bases. + + +[black]: https://pypi.org/project/black/ +[commonmark]: https://commonmark.org/ +[contribution-guide.org]: http://www.contribution-guide.org/ +[creating a pr]: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request +[descriptive commit message]: https://chris.beams.io/posts/git-commit +[docstrings]: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html +[first-contributions tutorial]: https://github.com/firstcontributions/first-contributions +[flake8]: https://flake8.pycqa.org/en/stable/ +[git]: https://git-scm.com +[github web interface]: https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/editing-files-in-your-repository +[github's code editor]: https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/editing-files-in-your-repository +[github's fork and pull request workflow]: https://guides.github.com/activities/forking/ +[guide created by freecodecamp]: https://github.com/freecodecamp/how-to-contribute-to-open-source +[miniconda]: https://docs.conda.io/en/latest/miniconda.html +[myst]: https://myst-parser.readthedocs.io/en/latest/syntax/syntax.html +[other kinds of contributions]: https://opensource.guide/how-to-contribute +[pre-commit]: https://pre-commit.com/ +[pypi]: https://pypi.org/ +[pyscaffold's contributor's guide]: https://pyscaffold.org/en/stable/contributing.html +[pytest can drop you]: https://docs.pytest.org/en/stable/usage.html#dropping-to-pdb-python-debugger-at-the-start-of-a-test +[python software foundation's code of conduct]: https://www.python.org/psf/conduct/ +[restructuredtext]: https://www.sphinx-doc.org/en/master/usage/restructuredtext/ +[sphinx]: https://www.sphinx-doc.org/en/master/ +[tox]: https://tox.readthedocs.io/en/stable/ +[virtual environment]: https://realpython.com/python-virtual-environments-a-primer/ +[virtualenv]: https://virtualenv.pypa.io/en/stable/ + + +```{todo} Please review and change the following definitions: +``` + +[repository]: https://github.com//equiadapt +[issue tracker]: https://github.com//equiadapt/issues diff --git a/README.md b/README.md index fad71b8..a9fe130 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Library to make any existing neural network architecture equivariant # Setup instructions -### Setup Conda environment +### Setup Conda environment To create a conda environment with the necessary packages: @@ -14,7 +14,7 @@ pip install -e . #### For Python 3.10 -Currently, everything works in Python 3.8. +Currently, everything works in Python 3.8. But to use Python 3.10, you need to remove `py3nj` from the `escnn` package requirements and install `escnn` from GitHub manually. ``` @@ -23,7 +23,7 @@ cd escnn (and go to setup.py and remove py3nj from the requirements) pip install -e . ``` -### Setup Hydra +### Setup Hydra - Create a `.env` file in the root of the project with the following content: ``` export HYDRA_JOBS="/path/to/your/hydra/jobs/directory" @@ -31,11 +31,11 @@ pip install -e . export WANDB_CACHE_DIR="/path/to/your/wandb/cache/directory" export DATA_PATH="/path/to/your/data/directory" export CHECKPOINT_PATH="/path/to/your/checkpoint/directory" - ``` + ``` # Running Instructions -For image classification: [here](/examples/images/classification/README.md) +For image classification: [here](/examples/images/classification/README.md) For image segmentation: [here](/examples/images/segmentation/README.md) @@ -43,7 +43,7 @@ For image segmentation: [here](/examples/images/segmentation/README.md) For more insights on this library refer to our original paper on the idea: [Equivariance with Learned Canonicalization Function (ICML 2023)](https://proceedings.mlr.press/v202/kaba23a.html) and how to extend it to make any existing large pre-trained model equivariant: [Equivariant Adaptation of Large Pretrained Models (NeurIPS 2023)](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9d5856318032ef3630cb580f4e24f823-Abstract-Conference.html). -To learn more about this from a blog, check out: [How to make your foundation model equivariant](https://mila.quebec/en/article/how-to-make-your-foundation-model-equivariant/) +To learn more about this from a blog, check out: [How to make your foundation model equivariant](https://mila.quebec/en/article/how-to-make-your-foundation-model-equivariant/) # Citation If you find this library or the associated papers useful, please cite: @@ -70,7 +70,22 @@ If you find this library or the associated papers useful, please cite: # Contact -For questions related to this code, you can mail us at: +For questions related to this code, you can mail us at: ```arnab.mondal@mila.quebec``` ```siba-smarak.panigrahi@mila.quebec``` -```kabaseko@mila.quebec``` \ No newline at end of file +```kabaseko@mila.quebec``` + +# Contributing + +You can check out the [contributor's guide](CONTRIBUTING.md). + +This project uses `pre-commit`_, you can install it before making any +changes:: + + pip install pre-commit + cd equiadapt + pre-commit install + +It is a good idea to update the hooks to the latest version:: + + pre-commit autoupdate diff --git a/conda_env.yaml b/conda_env.yaml index 176e904..4dcc926 100644 --- a/conda_env.yaml +++ b/conda_env.yaml @@ -155,4 +155,3 @@ dependencies: - urllib3==1.26.18 - wheel==0.41.2 - yarl==1.9.4 - diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..31655dd --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,29 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build +AUTODOCDIR = api + +# User-friendly check for sphinx-build +ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) +$(error "The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from https://sphinx-doc.org/") +endif + +.PHONY: help clean Makefile + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + rm -rf $(BUILDDIR)/* $(AUTODOCDIR) + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/.gitignore b/docs/_static/.gitignore new file mode 100644 index 0000000..3c96363 --- /dev/null +++ b/docs/_static/.gitignore @@ -0,0 +1 @@ +# Empty directory diff --git a/docs/authors.md b/docs/authors.md new file mode 100644 index 0000000..ced47d0 --- /dev/null +++ b/docs/authors.md @@ -0,0 +1,4 @@ +```{include} ../AUTHORS.md +:relative-docs: docs/ +:relative-images: +``` diff --git a/docs/changelog.md b/docs/changelog.md new file mode 100644 index 0000000..6e2f0fb --- /dev/null +++ b/docs/changelog.md @@ -0,0 +1,4 @@ +```{include} ../CHANGELOG.md +:relative-docs: docs/ +:relative-images: +``` diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..7249aa9 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,304 @@ +# This file is execfile()d with the current directory set to its containing dir. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import os +import sys +import shutil + +# -- Path setup -------------------------------------------------------------- + +__location__ = os.path.dirname(__file__) + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +sys.path.insert(0, os.path.join(__location__, "../src")) + +# -- Run sphinx-apidoc ------------------------------------------------------- +# This hack is necessary since RTD does not issue `sphinx-apidoc` before running +# `sphinx-build -b html . _build/html`. See Issue: +# https://github.com/readthedocs/readthedocs.org/issues/1139 +# DON'T FORGET: Check the box "Install your project inside a virtualenv using +# setup.py install" in the RTD Advanced Settings. +# Additionally it helps us to avoid running apidoc manually + +try: # for Sphinx >= 1.7 + from sphinx.ext import apidoc +except ImportError: + from sphinx import apidoc + +output_dir = os.path.join(__location__, "api") +module_dir = os.path.join(__location__, "../equiadapt") +try: + shutil.rmtree(output_dir) +except FileNotFoundError: + pass + +try: + import sphinx + + cmd_line = f"sphinx-apidoc --implicit-namespaces -f -o {output_dir} {module_dir}" + + args = cmd_line.split(" ") + if tuple(sphinx.__version__.split(".")) >= ("1", "7"): + # This is a rudimentary parse_version to avoid external dependencies + args = args[1:] + + apidoc.main(args) +except Exception as e: + print("Running `sphinx-apidoc` failed!\n{}".format(e)) + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.autosummary", + "sphinx.ext.viewcode", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.ifconfig", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + + +# Enable markdown +extensions.append("myst_parser") + +# Configure MyST-Parser +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", + "linkify", + "replacements", + "smartquotes", + "substitution", + "tasklist", +] + +# The suffix of source filenames. +source_suffix = [".rst", ".md"] + +# The encoding of source files. +# source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = "index" + +# General information about the project. +project = "equiadapt" +copyright = "2024, Danielle Benesch" + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# version: The short X.Y version. +# release: The full version, including alpha/beta/rc tags. +# If you don’t need the separation provided between version and release, +# just set them both to the same value. +try: + from equiadapt import __version__ as version +except ImportError: + version = "" + +if not version or version.lower() == "unknown": + version = os.getenv("READTHEDOCS_VERSION", "unknown") # automatically set by RTD + +release = version + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +# today = '' +# Else, today_fmt is used as the format for a strftime call. +# today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] + +# The reST default role (used for this markup: `text`) to use for all documents. +# default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +# add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +# add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +# show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# A list of ignored prefixes for module index sorting. +# modindex_common_prefix = [] + +# If true, keep warnings as "system message" paragraphs in the built documents. +# keep_warnings = False + +# If this is True, todo emits a warning for each TODO entries. The default is False. +todo_emit_warnings = True + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = "alabaster" + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +html_theme_options = { + "sidebar_width": "300px", + "page_width": "1200px" +} + +# Add any paths that contain custom themes here, relative to this directory. +# html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +# html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +# html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +# html_logo = "" + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +# html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +# html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +# html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +# html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +# html_additional_pages = {} + +# If false, no module index is generated. +# html_domain_indices = True + +# If false, no index is generated. +# html_use_index = True + +# If true, the index is split into individual pages for each letter. +# html_split_index = False + +# If true, links to the reST sources are added to the pages. +# html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +# html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +# html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +# html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +# html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = "equiadapt-doc" + + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ("letterpaper" or "a4paper"). + # "papersize": "letterpaper", + # The font size ("10pt", "11pt" or "12pt"). + # "pointsize": "10pt", + # Additional stuff for the LaTeX preamble. + # "preamble": "", +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ("index", "user_guide.tex", "equiadapt Documentation", "Danielle Benesch", "manual") +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +# latex_logo = "" + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +# latex_use_parts = False + +# If true, show page references after internal links. +# latex_show_pagerefs = False + +# If true, show URL addresses after external links. +# latex_show_urls = False + +# Documents to append as an appendix to all manuals. +# latex_appendices = [] + +# If false, no module index is generated. +# latex_domain_indices = True + +# -- External mapping -------------------------------------------------------- +python_version = ".".join(map(str, sys.version_info[0:2])) +intersphinx_mapping = { + "sphinx": ("https://www.sphinx-doc.org/en/master", None), + "python": ("https://docs.python.org/" + python_version, None), + "matplotlib": ("https://matplotlib.org", None), + "numpy": ("https://numpy.org/doc/stable", None), + "sklearn": ("https://scikit-learn.org/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "setuptools": ("https://setuptools.pypa.io/en/stable/", None), + "pyscaffold": ("https://pyscaffold.org/en/stable", None), +} + +print(f"loading configurations for {project} {version} ...", file=sys.stderr) \ No newline at end of file diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 0000000..fc1b213 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,4 @@ +```{include} ../CONTRIBUTING.md +:relative-docs: docs/ +:relative-images: +``` diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..e03ae46 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,39 @@ +# equiadapt + +Library that provides metrics to asses representation quality + + +## Note + +> This is the main page of your project's [Sphinx] documentation. It is +> formatted in [Markdown]. Add additional pages by creating md-files in +> `docs` or rst-files (formatted in [reStructuredText]) and adding links to +> them in the `Contents` section below. +> +> Please check [Sphinx] and [MyST] for more information +> about how to document your project and how to configure your preferences. + + +## Contents + +```{toctree} +:maxdepth: 2 + +Overview +Contributions & Help +License +Authors +Changelog +Module Reference +``` + +## Indices and tables + +* {ref}`genindex` +* {ref}`modindex` +* {ref}`search` + +[Sphinx]: http://www.sphinx-doc.org/ +[Markdown]: https://daringfireball.net/projects/markdown/ +[reStructuredText]: http://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html +[MyST]: https://myst-parser.readthedocs.io/en/latest/ diff --git a/docs/license.md b/docs/license.md new file mode 100644 index 0000000..22567b6 --- /dev/null +++ b/docs/license.md @@ -0,0 +1,5 @@ +# License + +```{literalinclude} ../LICENSE +:language: text +``` diff --git a/docs/readme.md b/docs/readme.md new file mode 100644 index 0000000..2cb706b --- /dev/null +++ b/docs/readme.md @@ -0,0 +1,4 @@ +```{include} ../README.md +:relative-docs: docs/ +:relative-images: +``` diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..0990c2a --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,6 @@ +# Requirements file for ReadTheDocs, check .readthedocs.yml. +# To build the module reference correctly, make sure every external package +# under `install_requires` in `setup.cfg` is also listed here! +# sphinx_rtd_theme +myst-parser[linkify] +sphinx>=3.2.1 diff --git a/equiadapt/common/basecanonicalization.py b/equiadapt/common/basecanonicalization.py index b22154e..0c6143c 100644 --- a/equiadapt/common/basecanonicalization.py +++ b/equiadapt/common/basecanonicalization.py @@ -1,188 +1,210 @@ -import torch from abc import ABC, abstractmethod +import torch # Base skeleton for the canonicalization class -# DiscreteGroupCanonicalization and ContinuousGroupCanonicalization +# DiscreteGroupCanonicalization and ContinuousGroupCanonicalization # will inherit from this class + class BaseCanonicalization(torch.nn.Module): def __init__(self, canonicalization_network: torch.nn.Module): super().__init__() self.canonicalization_network = canonicalization_network self.canonicalization_info_dict = {} - + def forward(self, x: torch.Tensor, **kwargs): """ Forward method for the canonicalization which takes the input data and returns the canonicalized version of the data - + Args: x: input data **kwargs: additional arguments - + Returns: canonicalized_x: canonicalized version of the input data """ - + return self.canonicalize(x, **kwargs) - def canonicalize(self, x: torch.Tensor, **kwargs): """ - This method takes an input data and + This method takes an input data and returns its canonicalized version and a dictionary containing the information about the canonicalization """ raise NotImplementedError() - def invert_canonicalization(self, x: torch.Tensor, **kwargs): """ - This method takes the output of the canonicalized data + This method takes the output of the canonicalized data and returns the output for the original data orientation """ raise NotImplementedError() - + class IdentityCanonicalization(BaseCanonicalization): def __init__(self, canonicalization_network: torch.nn.Module = torch.nn.Identity()): super().__init__(canonicalization_network) - + def canonicalize(self, x: torch.Tensor, **kwargs): return x - + def invert_canonicalization(self, x: torch.Tensor, **kwargs): return x - + def get_prior_regularization_loss(self): return torch.tensor(0.0) - + def get_identity_metric(self): return torch.tensor(1.0) - + + class DiscreteGroupCanonicalization(BaseCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, - beta: float = 1.0, - gradient_trick: str = 'straight_through'): + def __init__( + self, + canonicalization_network: torch.nn.Module, + beta: float = 1.0, + gradient_trick: str = 'straight_through', + ): super().__init__(canonicalization_network) self.beta = beta self.gradient_trick = gradient_trick - + def groupactivations_to_groupelementonehot(self, group_activations: torch.Tensor): """ This method takes the activations for each group element as input and returns the group element in a differentiable manner - + Args: group_activations: activations for each group element - + Returns: group_element_onehot: one hot encoding of the group element """ group_activations_one_hot = torch.nn.functional.one_hot( - torch.argmax(group_activations, dim=-1), self.num_group).float() - group_activations_soft = torch.nn.functional.softmax(self.beta * group_activations, dim=-1) + torch.argmax(group_activations, dim=-1), self.num_group + ).float() + group_activations_soft = torch.nn.functional.softmax( + self.beta * group_activations, dim=-1 + ) if self.gradient_trick == 'straight_through': - if self.training: - group_element_onehot = (group_activations_one_hot + group_activations_soft - group_activations_soft.detach()) + if self.training: + group_element_onehot = ( + group_activations_one_hot + + group_activations_soft + - group_activations_soft.detach() + ) else: group_element_onehot = group_activations_one_hot elif self.gradient_trick == 'gumbel_softmax': - group_element_onehot = torch.nn.functional.gumbel_softmax(group_activations, tau=1, hard=True) + group_element_onehot = torch.nn.functional.gumbel_softmax( + group_activations, tau=1, hard=True + ) else: - raise ValueError(f'Gradient trick {self.gradient_trick} not implemented') - + raise ValueError(f'Gradient trick {self.gradient_trick} not implemented') + # return the group element one hot encoding return group_element_onehot - + def canonicalize(self, x: torch.Tensor, **kwargs): """ - This method takes an input data and + This method takes an input data and returns its canonicalized version and a dictionary containing the information about the canonicalization """ raise NotImplementedError() - def invert_canonicalization(self, x: torch.Tensor, **kwargs): """ - This method takes the output of the canonicalized data + This method takes the output of the canonicalized data and returns the output for the original data orientation """ raise NotImplementedError() - - + def get_prior_regularization_loss(self): group_activations = self.canonicalization_info_dict['group_activations'] - dataset_prior = torch.zeros((group_activations.shape[0],), dtype=torch.long).to(self.device) + dataset_prior = torch.zeros((group_activations.shape[0],), dtype=torch.long).to( + self.device + ) return torch.nn.CrossEntropyLoss()(group_activations, dataset_prior) - - + def get_identity_metric(self): group_activations = self.canonicalization_info_dict['group_activations'] return (group_activations.argmax(dim=-1) == 0).float().mean() - class ContinuousGroupCanonicalization(BaseCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, - beta: float = 1.0, - gradient_trick: str = 'straight_through'): + def __init__( + self, + canonicalization_network: torch.nn.Module, + beta: float = 1.0, + gradient_trick: str = 'straight_through', + ): super().__init__(canonicalization_network) self.beta = beta self.gradient_trick = gradient_trick - - def canonicalizationnetworkout_to_groupelement(self, group_activations: torch.Tensor): + + def canonicalizationnetworkout_to_groupelement( + self, group_activations: torch.Tensor + ): """ This method takes the as input and returns the group element in a differentiable manner - + Args: group_activations: activations for each group element - + Returns: group_element: group element """ raise NotImplementedError() - + def canonicalize(self, x: torch.Tensor, **kwargs): """ - This method takes an input data and + This method takes an input data and returns its canonicalized version and a dictionary containing the information about the canonicalization """ raise NotImplementedError() - def invert_canonicalization(self, x: torch.Tensor, **kwargs): """ - This method takes the output of the canonicalized data + This method takes the output of the canonicalized data and returns the output for the original data orientation """ raise NotImplementedError() - - + def get_prior_regularization_loss(self): - group_elements_rep = self.canonicalization_info_dict['group_element_matrix_representation'] # shape: (batch_size, group_rep_dim, group_rep_dim) + group_elements_rep = self.canonicalization_info_dict[ + 'group_element_matrix_representation' + ] # shape: (batch_size, group_rep_dim, group_rep_dim) # Set the dataset prior to identity matrix of size group_rep_dim and repeat it for batch_size - dataset_prior = torch.eye(group_elements_rep.shape[-1]).repeat( - group_elements_rep.shape[0], 1, 1).to(self.device) + dataset_prior = ( + torch.eye(group_elements_rep.shape[-1]) + .repeat(group_elements_rep.shape[0], 1, 1) + .to(self.device) + ) return torch.nn.MSELoss()(group_elements_rep, dataset_prior) - + def get_identity_metric(self): - group_elements_rep = self.canonicalization_info_dict['group_element_matrix_representation'] - identity_element = torch.eye(group_elements_rep.shape[-1]).repeat( - group_elements_rep.shape[0], 1, 1).to(self.device) - return 1.0 - torch.nn.functional.mse_loss(group_elements_rep, identity_element).mean() - - - + group_elements_rep = self.canonicalization_info_dict[ + 'group_element_matrix_representation' + ] + identity_element = ( + torch.eye(group_elements_rep.shape[-1]) + .repeat(group_elements_rep.shape[0], 1, 1) + .to(self.device) + ) + return ( + 1.0 + - torch.nn.functional.mse_loss(group_elements_rep, identity_element).mean() + ) + # Idea for the user interface: @@ -191,7 +213,7 @@ def get_identity_metric(self): # example: canonicalization_network = ESCNNEquivariantNetwork(in_shape, out_channels, kernel_size, group_type='rotation', num_rotations=4, num_layers=3) # canonicalizer = GroupEquivariantImageCanonicalization(canonicalization_network, beta=1.0) # -# +# # 2. The user uses this wrapper with their code to canonicalize the input data # example: model = ResNet18() # x_canonized = canonicalizer(x) @@ -204,5 +226,3 @@ def get_identity_metric(self): # loss = criterion(model_out, y) # loss = canonicalizer.add_prior_regularizer(loss) # loss.backward() - - \ No newline at end of file diff --git a/equiadapt/common/utils.py b/equiadapt/common/utils.py index e22d3d0..feea3c5 100644 --- a/equiadapt/common/utils.py +++ b/equiadapt/common/utils.py @@ -1,5 +1,6 @@ import torch + def gram_schmidt(vectors): """ Applies the Gram-Schmidt process to orthogonalize a set of vectors in a batch-wise manner. @@ -17,12 +18,17 @@ def gram_schmidt(vectors): for i in range(1, n_vectors): for j in range(i): # Project vector i on vector j, then subtract this projection from vector i - projection = (torch.sum(orthogonal_vectors[:, i] * orthogonal_vectors[:, j], dim=1, keepdim=True) / - torch.sum(orthogonal_vectors[:, j] * orthogonal_vectors[:, j], dim=1, keepdim=True)) + projection = torch.sum( + orthogonal_vectors[:, i] * orthogonal_vectors[:, j], dim=1, keepdim=True + ) / torch.sum( + orthogonal_vectors[:, j] * orthogonal_vectors[:, j], dim=1, keepdim=True + ) orthogonal_vectors[:, i] -= projection * orthogonal_vectors[:, j] # Normalize the vectors after orthogonalization is complete to ensure numerical stability - orthogonal_vectors = orthogonal_vectors / torch.norm(orthogonal_vectors, dim=2, keepdim=True) + orthogonal_vectors = orthogonal_vectors / torch.norm( + orthogonal_vectors, dim=2, keepdim=True + ) return orthogonal_vectors @@ -52,7 +58,13 @@ def get_son_bases(self): """ num_son_bases = self.group_dim * (self.group_dim - 1) // 2 son_bases = torch.zeros((num_son_bases, self.group_dim, self.group_dim)) - for counter, (i, j) in enumerate([(i, j) for i in range(self.group_dim) for j in range(i + 1, self.group_dim)]): + for counter, (i, j) in enumerate( + [ + (i, j) + for i in range(self.group_dim) + for j in range(i + 1, self.group_dim) + ] + ): son_bases[counter, i, j] = 1 son_bases[counter, j, i] = -1 return son_bases @@ -69,7 +81,7 @@ def get_son_rep(self, params: torch.Tensor): son_bases = self.get_son_bases().to(params.device) A = torch.einsum('bs,sij->bij', params, son_bases) return torch.matrix_exp(A) - + def get_on_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): """ Computes the representation for O(n) group, optionally including reflections. @@ -82,15 +94,21 @@ def get_on_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): torch.Tensor: The representation of shape (batch_size, rep_dim, rep_dim). """ son_rep = self.get_son_rep(params) - + # This is a simplified and conceptual approach; actual reflection handling # would need to determine how to reflect (e.g., across which axis or plane) # and this might not directly apply as-is. identity_matrix = torch.eye(self.group_dim) - reflection_matrix = torch.diag_embed(torch.tensor([1] * (self.group_dim - 1) + [-1])) - on_rep = torch.matmul(son_rep, reflect_indicators * reflection_matrix + (1 - reflect_indicators) * identity_matrix) + reflection_matrix = torch.diag_embed( + torch.tensor([1] * (self.group_dim - 1) + [-1]) + ) + on_rep = torch.matmul( + son_rep, + reflect_indicators * reflection_matrix + + (1 - reflect_indicators) * identity_matrix, + ) return on_rep - + def get_sen_rep(self, params: torch.Tensor): """Computes the representation for SEn group. @@ -101,14 +119,19 @@ def get_sen_rep(self, params: torch.Tensor): torch.Tensor: The representation of shape (batch_size, rep_dim, rep_dim). """ son_param_dim = self.group_dim * (self.group_dim - 1) // 2 - rho = torch.zeros(params.shape[0], self.group_dim + 1, - self.group_dim + 1, device=params.device) - rho[:, :self.group_dim, :self.group_dim] = self.get_son_rep( - params[:, :son_param_dim].unsqueeze(0)).squeeze(0) - rho[:, :self.group_dim, self.group_dim] = params[:, son_param_dim:] + rho = torch.zeros( + params.shape[0], + self.group_dim + 1, + self.group_dim + 1, + device=params.device, + ) + rho[:, : self.group_dim, : self.group_dim] = self.get_son_rep( + params[:, :son_param_dim].unsqueeze(0) + ).squeeze(0) + rho[:, : self.group_dim, self.group_dim] = params[:, son_param_dim:] rho[:, self.group_dim, self.group_dim] = 1 return rho - + def get_en_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): """Computes the representation for E(n) group. @@ -134,19 +157,25 @@ def get_en_rep(self, params: torch.Tensor, reflect_indicators: torch.Tensor): # Separate rotation/reflection and translation parameters rotation_params = params[:, :rotation_param_dim] - translation_params = params[:, rotation_param_dim:rotation_param_dim + translation_param_dim] + translation_params = params[ + :, rotation_param_dim : rotation_param_dim + translation_param_dim + ] # Compute rotation/reflection representation rotoreflection_rep = self.get_on_rep(rotation_params, reflect_indicators) # Construct the E(n) representation matrix - en_rep = torch.zeros(params.shape[0], self.group_dim + 1, self.group_dim + 1, device=params.device) - en_rep[:, :self.group_dim, :self.group_dim] = rotoreflection_rep - en_rep[:, :self.group_dim, self.group_dim] = translation_params + en_rep = torch.zeros( + params.shape[0], + self.group_dim + 1, + self.group_dim + 1, + device=params.device, + ) + en_rep[:, : self.group_dim, : self.group_dim] = rotoreflection_rep + en_rep[:, : self.group_dim, self.group_dim] = translation_params en_rep[:, self.group_dim, self.group_dim] = 1 return en_rep - def get_group_rep(self, params): """Computes the representation for the specified Lie group. @@ -167,5 +196,3 @@ def get_group_rep(self, params): return self.get_en_rep(params) else: raise ValueError(f"Unsupported group type: {self.group_type}") - - diff --git a/equiadapt/images/canonicalization/continuous_group.py b/equiadapt/images/canonicalization/continuous_group.py index f1c0a49..c266c31 100644 --- a/equiadapt/images/canonicalization/continuous_group.py +++ b/equiadapt/images/canonicalization/continuous_group.py @@ -1,175 +1,212 @@ -import torch +import math + import kornia as K +import torch +from torch.nn import functional as F +from torchvision import transforms + from equiadapt.common.basecanonicalization import ContinuousGroupCanonicalization from equiadapt.common.utils import gram_schmidt from equiadapt.images.utils import get_action_on_image_features -from torchvision import transforms -import math -from torch.nn import functional as F + class ContinuousGroupImageCanonicalization(ContinuousGroupCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, - in_shape: tuple - ): + def __init__( + self, + canonicalization_network: torch.nn.Module, + canonicalization_hyperparams: dict, + in_shape: tuple, + ): super().__init__(canonicalization_network) - - assert len(in_shape) == 3, 'Input shape should be in the format (channels, height, width)' - + + assert ( + len(in_shape) == 3 + ), 'Input shape should be in the format (channels, height, width)' + # pad and crop the input image if it is not rotated MNIST - is_grayscale = (in_shape[0] == 1) - self.pad = torch.nn.Identity() if is_grayscale else transforms.Pad( - math.ceil(in_shape[-1] * 0.5), padding_mode='edge' + is_grayscale = in_shape[0] == 1 + self.pad = ( + torch.nn.Identity() + if is_grayscale + else transforms.Pad(math.ceil(in_shape[-1] * 0.5), padding_mode='edge') + ) + self.crop = ( + torch.nn.Identity() + if is_grayscale + else transforms.CenterCrop((in_shape[-2], in_shape[-1])) + ) + self.crop_canonization = ( + torch.nn.Identity() + if is_grayscale + else transforms.CenterCrop( + ( + math.ceil( + in_shape[-2] * canonicalization_hyperparams.input_crop_ratio + ), + math.ceil( + in_shape[-1] * canonicalization_hyperparams.input_crop_ratio + ), + ) + ) + ) + self.resize_canonization = ( + torch.nn.Identity() + if is_grayscale + else transforms.Resize(size=canonicalization_hyperparams.resize_shape) ) - self.crop = torch.nn.Identity() if is_grayscale else transforms.CenterCrop((in_shape[-2], in_shape[-1])) - self.crop_canonization = torch.nn.Identity() if is_grayscale else transforms.CenterCrop(( - math.ceil(in_shape[-2] * canonicalization_hyperparams.input_crop_ratio), - math.ceil(in_shape[-1] * canonicalization_hyperparams.input_crop_ratio) - )) - self.resize_canonization = torch.nn.Identity() if is_grayscale else transforms.Resize(size=canonicalization_hyperparams.resize_shape) self.group_info_dict = {} - + def get_groupelement(self, x: torch.Tensor): """ This method takes the input image and maps it to the group element - + Args: x: input image - + Returns: group_element: group element """ raise NotImplementedError('get_groupelement method is not implemented') - + def transformations_before_canonicalization_network_forward(self, x: torch.Tensor): """ - This method takes an image as input and - returns the pre-canonicalized image + This method takes an image as input and + returns the pre-canonicalized image """ x = self.crop_canonization(x) x = self.resize_canonization(x) return x - + def get_group_from_out_vectors(self, out_vectors: torch.Tensor): """ This method takes the output of the canonicalization network and returns the group element - + Args: out_vectors: output of the canonicalization network - + Returns: group_element_dict: group element group_element_representation: group element representation """ group_element_dict = {} - + if self.group_type == 'roto-reflection': # Apply Gram-Schmidt to get the rotation matrices/orthogonal frame from # a batch of two 2D vectors - rotoreflection_matrices = gram_schmidt(out_vectors) # (batch_size, 2, 2) - + rotoreflection_matrices = gram_schmidt(out_vectors) # (batch_size, 2, 2) + # Calculate the determinant to check for reflection - determinant = rotoreflection_matrices[:, 0, 0] * rotoreflection_matrices[:, 1, 1] - \ - rotoreflection_matrices[:, 0, 1] * rotoreflection_matrices[:, 1, 0] - + determinant = ( + rotoreflection_matrices[:, 0, 0] * rotoreflection_matrices[:, 1, 1] + - rotoreflection_matrices[:, 0, 1] * rotoreflection_matrices[:, 1, 0] + ) + reflect_indicator = (1 - determinant[:, None, None, None]) / 2 group_element_dict['reflection'] = reflect_indicator - + # Identify matrices with a reflection (negative determinant) reflection_indices = determinant < 0 # For matrices with a reflection, adjust to remove the reflection component # This example assumes flipping the sign of the second column as one way to adjust # Note: This method of adjustment is context-dependent and may vary based on your specific requirements - rotation_matrices = rotoreflection_matrices - rotation_matrices[reflection_indices, :, 1] *= -1 + rotation_matrices = rotoreflection_matrices + rotation_matrices[reflection_indices, :, 1] *= -1 else: # Pass the first vector to get the rotation matrix rotation_matrices = self.get_rotation_matrix_from_vector(out_vectors[:, 0]) - + group_element_dict['rotation'] = rotation_matrices - - return group_element_dict, rotoreflection_matrices if self.group_type == 'roto-reflection' else rotation_matrices - - + + return group_element_dict, ( + rotoreflection_matrices + if self.group_type == 'roto-reflection' + else rotation_matrices + ) + def canonicalize(self, x: torch.Tensor): """ - This method takes an image as input and - returns the canonicalized image - + This method takes an image as input and + returns the canonicalized image + Args: x: input image - + Returns: x_canonicalized: canonicalized image """ self.device = x.device - + # get the group element dictionary with keys as 'rotation' and 'reflection' - group_element_dict = self.get_groupelement(x) - + group_element_dict = self.get_groupelement(x) + rotation_matrices = group_element_dict['rotation'] rotation_matrices[:, [0, 1], [1, 0]] *= -1 - + if 'reflection' in group_element_dict: reflect_indicator = group_element_dict['reflection'] # Reflect the image conditionally x = (1 - reflect_indicator) * x + reflect_indicator * K.geometry.hflip(x) - - + # Apply padding before canonicalization x = self.pad(x) - + # Compute affine part for warp affine alpha, beta = rotation_matrices[:, 0, 0], rotation_matrices[:, 0, 1] cx, cy = x.shape[-2] // 2, x.shape[-1] // 2 - affine_part = torch.stack([(1 - alpha) * cx - beta * cy, beta * cx + (1 - alpha) * cy], dim=1) - + affine_part = torch.stack( + [(1 - alpha) * cx - beta * cy, beta * cx + (1 - alpha) * cy], dim=1 + ) + # Prepare affine matrices for warp affine, adjusting rotation matrix for Kornia compatibility - affine_matrices = torch.cat([rotation_matrices, affine_part.unsqueeze(-1)], dim=-1) - - # Apply warp affine, and then crop + affine_matrices = torch.cat( + [rotation_matrices, affine_part.unsqueeze(-1)], dim=-1 + ) + + # Apply warp affine, and then crop x = K.geometry.warp_affine(x, affine_matrices, dsize=(x.shape[-2], x.shape[-1])) x = self.crop(x) return x - - def invert_canonicalization(self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = 'vector'): + def invert_canonicalization( + self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = 'vector' + ): """ This method takes the output of canonicalized image as input and returns output of the original image - + """ - return get_action_on_image_features(feature_map = x_canonicalized_out, - group_info_dict = self.group_info_dict, - group_element_dict = self.canonicalization_info_dict['group_element'], - induced_rep_type = induced_rep_type) - + return get_action_on_image_features( + feature_map=x_canonicalized_out, + group_info_dict=self.group_info_dict, + group_element_dict=self.canonicalization_info_dict['group_element'], + induced_rep_type=induced_rep_type, + ) class SteerableImageCanonicalization(ContinuousGroupImageCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, - in_shape: tuple - ): - super().__init__(canonicalization_network, - canonicalization_hyperparams, - in_shape) + def __init__( + self, + canonicalization_network: torch.nn.Module, + canonicalization_hyperparams: dict, + in_shape: tuple, + ): + super().__init__( + canonicalization_network, canonicalization_hyperparams, in_shape + ) self.group_type = canonicalization_network.group_type - + def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): ''' This method takes the input vector and returns the rotation matrix - + Args: vectors: input vector - + Returns: rotation_matrices: rotation matrices ''' @@ -177,57 +214,62 @@ def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): v2 = torch.stack([-v1[:, 1], v1[:, 0]], dim=1) rotation_matrices = torch.stack([v1, v2], dim=1) return rotation_matrices - + def get_groupelement(self, x: torch.Tensor): """ This method takes the input image and maps it to the group element - + Args: x: input image - + Returns: group_element: group element """ - + group_element_dict = {} - + x = self.transformations_before_canonicalization_network_forward(x) - + # convert the group activations to one hot encoding of group element # this conversion is differentiable and will be used to select the group element out_vectors = self.canonicalization_network(x) - + # Check whether canonicalization_info_dict is already defined if not hasattr(self, 'canonicalization_info_dict'): self.canonicalization_info_dict = {} - group_element_dict, group_element_representation = self.get_group_from_out_vectors(out_vectors) - self.canonicalization_info_dict['group_element_matrix_representation'] = group_element_representation + group_element_dict, group_element_representation = ( + self.get_group_from_out_vectors(out_vectors) + ) + self.canonicalization_info_dict['group_element_matrix_representation'] = ( + group_element_representation + ) self.canonicalization_info_dict['group_element'] = group_element_dict - + return group_element_dict - + class OptimizedSteerableImageCanonicalization(ContinuousGroupImageCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, - in_shape: tuple - ): - super().__init__(canonicalization_network, - canonicalization_hyperparams, - in_shape) + def __init__( + self, + canonicalization_network: torch.nn.Module, + canonicalization_hyperparams: dict, + in_shape: tuple, + ): + super().__init__( + canonicalization_network, canonicalization_hyperparams, in_shape + ) self.group_type = canonicalization_hyperparams.group_type - + def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): ''' This method takes the input vector and returns the rotation matrix - + Args: vectors: input vector - + Returns: rotation_matrices: rotation matrices ''' @@ -235,7 +277,7 @@ def get_rotation_matrix_from_vector(self, vectors: torch.Tensor): v2 = torch.stack([-v1[:, 1], v1[:, 0]], dim=1) rotation_matrices = torch.stack([v1, v2], dim=1) return rotation_matrices - + def group_augment(self, x): """ Augmentation of the input images by applying random rotations and, @@ -256,11 +298,15 @@ def group_augment(self, x): # Create tensors for rotation matrices rotation_matrices = torch.zeros(batch_size, 2, 3, device=self.device) - rotation_matrices[:, :2, :2] = torch.stack((cos_a, -sin_a, sin_a, cos_a)).reshape(-1, 2, 2) + rotation_matrices[:, :2, :2] = torch.stack( + (cos_a, -sin_a, sin_a, cos_a) + ).reshape(-1, 2, 2) if self.group_type == 'roto-reflection': # Generate reflection indicators (horizontal flip) with 50% probability - reflect = torch.randint(0, 2, (batch_size,), device=self.device).float() * 2 - 1 + reflect = ( + torch.randint(0, 2, (batch_size,), device=self.device).float() * 2 - 1 + ) # Adjust the rotation matrix for reflections rotation_matrices[:, 0, 0] *= reflect @@ -274,63 +320,88 @@ def group_augment(self, x): # Return augmented images and the transformation matrices used return augmented_images, rotation_matrices[:, :, :2] - def get_groupelement(self, x: torch.Tensor): """ This method takes the input image and maps it to the group element - + Args: x: input image - + Returns: group_element: group element """ - + group_element_dict = {} - + batch_size = x.shape[0] - + # randomly sample generate some agmentations of the input image using rotation and reflection - - x_augmented, group_element_representations_augmented_gt = self.group_augment(x) # size (batch_size * group_size, in_channels, height, width) - - x_all = torch.cat([x, x_augmented], dim=0) # size (batch_size * 2, in_channels, height, width) - + + x_augmented, group_element_representations_augmented_gt = self.group_augment( + x + ) # size (batch_size * group_size, in_channels, height, width) + + x_all = torch.cat( + [x, x_augmented], dim=0 + ) # size (batch_size * 2, in_channels, height, width) + x_all = self.transformations_before_canonicalization_network_forward(x_all) - - out_vectors_all = self.canonicalization_network(x_all) # size (batch_size * 2, out_vector_size) - - out_vectors_all = out_vectors_all.reshape(2 * batch_size, -1, 2) # size (batch_size * 2, num_vectors, 2) - + + out_vectors_all = self.canonicalization_network( + x_all + ) # size (batch_size * 2, out_vector_size) + + out_vectors_all = out_vectors_all.reshape( + 2 * batch_size, -1, 2 + ) # size (batch_size * 2, num_vectors, 2) + out_vectors, out_vectors_augmented = out_vectors_all.chunk(2, dim=0) - + # Check whether canonicalization_info_dict is already defined if not hasattr(self, 'canonicalization_info_dict'): - self.canonicalization_info_dict = {} - - group_element_dict, group_element_representations = self.get_group_from_out_vectors(out_vectors) + self.canonicalization_info_dict = {} + + group_element_dict, group_element_representations = ( + self.get_group_from_out_vectors(out_vectors) + ) # Store the matrix representation of the group element for regularization and identity metric - self.canonicalization_info_dict['group_element_matrix_representation'] = group_element_representations + self.canonicalization_info_dict['group_element_matrix_representation'] = ( + group_element_representations + ) self.canonicalization_info_dict['group_element'] = group_element_dict - - _, group_element_representations_augmented = self.get_group_from_out_vectors(out_vectors_augmented) - self.canonicalization_info_dict['group_element_matrix_representation_augmented'] = \ - group_element_representations_augmented - self.canonicalization_info_dict['group_element_matrix_representation_augmented_gt'] = \ - group_element_representations_augmented_gt - + + _, group_element_representations_augmented = self.get_group_from_out_vectors( + out_vectors_augmented + ) + self.canonicalization_info_dict[ + 'group_element_matrix_representation_augmented' + ] = group_element_representations_augmented + self.canonicalization_info_dict[ + 'group_element_matrix_representation_augmented_gt' + ] = group_element_representations_augmented_gt + return group_element_dict - + def get_optimization_specific_loss(self): """ This method returns the optimization specific loss - + Returns: loss: optimization specific loss """ - group_element_representations_augmented, group_element_representations_augmented_gt = \ - self.canonicalization_info_dict['group_element_matrix_representation_augmented'], \ - self.canonicalization_info_dict['group_element_matrix_representation_augmented_gt'] - return F.mse_loss(group_element_representations_augmented, group_element_representations_augmented_gt) - \ No newline at end of file + ( + group_element_representations_augmented, + group_element_representations_augmented_gt, + ) = ( + self.canonicalization_info_dict[ + 'group_element_matrix_representation_augmented' + ], + self.canonicalization_info_dict[ + 'group_element_matrix_representation_augmented_gt' + ], + ) + return F.mse_loss( + group_element_representations_augmented, + group_element_representations_augmented_gt, + ) diff --git a/equiadapt/images/canonicalization/discrete_group.py b/equiadapt/images/canonicalization/discrete_group.py index 231b961..b376a1f 100644 --- a/equiadapt/images/canonicalization/discrete_group.py +++ b/equiadapt/images/canonicalization/discrete_group.py @@ -1,208 +1,276 @@ -import torch -import kornia as K -from equiadapt.common.basecanonicalization import DiscreteGroupCanonicalization -from equiadapt.images.utils import flip_boxes, flip_masks, get_action_on_image_features, rotate_boxes, rotate_masks -from torchvision import transforms import math + +import kornia as K +import torch from torch.nn import functional as F +from torchvision import transforms + +from equiadapt.common.basecanonicalization import DiscreteGroupCanonicalization +from equiadapt.images.utils import ( + flip_boxes, + flip_masks, + get_action_on_image_features, + rotate_boxes, + rotate_masks, +) + class DiscreteGroupImageCanonicalization(DiscreteGroupCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, - in_shape: tuple - ): + def __init__( + self, + canonicalization_network: torch.nn.Module, + canonicalization_hyperparams: dict, + in_shape: tuple, + ): super().__init__(canonicalization_network) - + self.beta = canonicalization_hyperparams.beta - - assert len(in_shape) == 3, 'Input shape should be in the format (channels, height, width)' - + + assert ( + len(in_shape) == 3 + ), 'Input shape should be in the format (channels, height, width)' + # DEfine all the image transformations here which are used during canonicalization # pad and crop the input image if it is not rotated MNIST - is_grayscale = (in_shape[0] == 1) - - self.pad = torch.nn.Identity() if is_grayscale else transforms.Pad( - math.ceil(in_shape[-2] * 0.4), padding_mode='edge' + is_grayscale = in_shape[0] == 1 + + self.pad = ( + torch.nn.Identity() + if is_grayscale + else transforms.Pad(math.ceil(in_shape[-2] * 0.4), padding_mode='edge') + ) + self.crop = ( + torch.nn.Identity() + if is_grayscale + else transforms.CenterCrop((in_shape[-2], in_shape[-1])) + ) + self.crop_canonization = ( + torch.nn.Identity() + if is_grayscale + else transforms.CenterCrop( + ( + math.ceil( + in_shape[-2] * canonicalization_hyperparams.input_crop_ratio + ), + math.ceil( + in_shape[-1] * canonicalization_hyperparams.input_crop_ratio + ), + ) + ) ) - self.crop = torch.nn.Identity() if is_grayscale else transforms.CenterCrop((in_shape[-2], in_shape[-1])) - self.crop_canonization = torch.nn.Identity() if is_grayscale else transforms.CenterCrop(( - math.ceil(in_shape[-2] * canonicalization_hyperparams.input_crop_ratio), - math.ceil(in_shape[-1] * canonicalization_hyperparams.input_crop_ratio) - )) - - self.resize_canonization = torch.nn.Identity() if is_grayscale else transforms.Resize(size=canonicalization_hyperparams.resize_shape) - + + self.resize_canonization = ( + torch.nn.Identity() + if is_grayscale + else transforms.Resize(size=canonicalization_hyperparams.resize_shape) + ) + def groupactivations_to_groupelement(self, group_activations: torch.Tensor): """ This method takes the activations for each group element as input and returns the group element - + Args: group_activations: activations for each group element - + Returns: group_element: group element """ - + # convert the group activations to one hot encoding of group element # this conversion is differentiable and will be used to select the group element - group_elements_one_hot = self.groupactivations_to_groupelementonehot(group_activations) - - angles = torch.linspace(0., 360., self.num_rotations+1)[:self.num_rotations].to(self.device) - group_elements_rot_comp = torch.cat([angles, angles], dim=0) if self.group_type == 'roto-reflection' else angles - + group_elements_one_hot = self.groupactivations_to_groupelementonehot( + group_activations + ) + + angles = torch.linspace(0.0, 360.0, self.num_rotations + 1)[ + : self.num_rotations + ].to(self.device) + group_elements_rot_comp = ( + torch.cat([angles, angles], dim=0) + if self.group_type == 'roto-reflection' + else angles + ) + group_element_dict = {} - - group_element_rot_comp = torch.sum(group_elements_one_hot * group_elements_rot_comp, dim=-1) + + group_element_rot_comp = torch.sum( + group_elements_one_hot * group_elements_rot_comp, dim=-1 + ) group_element_dict['rotation'] = group_element_rot_comp if self.group_type == 'roto-reflection': - reflect_identifier_vector = torch.cat([torch.zeros(self.num_rotations), - torch.ones(self.num_rotations)], dim=0).to(self.device) - group_element_reflect_comp = torch.sum(group_elements_one_hot * reflect_identifier_vector, dim=-1) + reflect_identifier_vector = torch.cat( + [torch.zeros(self.num_rotations), torch.ones(self.num_rotations)], dim=0 + ).to(self.device) + group_element_reflect_comp = torch.sum( + group_elements_one_hot * reflect_identifier_vector, dim=-1 + ) group_element_dict['reflection'] = group_element_reflect_comp - + return group_element_dict - + def get_group_activations(self, x: torch.Tensor): """ - This method takes an image as input and + This method takes an image as input and returns the group activations """ - raise NotImplementedError('get_group_activations is not implemented for' - 'the DiscreteGroupImageCanonicalization class') - - + raise NotImplementedError( + 'get_group_activations is not implemented for' + 'the DiscreteGroupImageCanonicalization class' + ) + def get_groupelement(self, x: torch.Tensor): """ This method takes the input image and maps it to the group element - + Args: x: input image - + Returns: group_element: group element """ group_activations = self.get_group_activations(x) group_element_dict = self.groupactivations_to_groupelement(group_activations) - + # Check whether canonicalization_info_dict is already defined if not hasattr(self, 'canonicalization_info_dict'): self.canonicalization_info_dict = {} self.canonicalization_info_dict['group_element'] = group_element_dict self.canonicalization_info_dict['group_activations'] = group_activations - + return group_element_dict - + def transformations_before_canonicalization_network_forward(self, x: torch.Tensor): """ - This method takes an image as input and - returns the pre-canonicalized image + This method takes an image as input and + returns the pre-canonicalized image """ x = self.crop_canonization(x) x = self.resize_canonization(x) return x - - + def canonicalize(self, x: torch.Tensor, targets: torch.Tensor = None): """ - This method takes an image as input and - returns the canonicalized image + This method takes an image as input and + returns the canonicalized image """ self.device = x.device group_element_dict = self.get_groupelement(x) - + x = self.pad(x) - + if 'reflection' in group_element_dict.keys(): - reflect_indicator = group_element_dict['reflection'][:,None,None,None] + reflect_indicator = group_element_dict['reflection'][:, None, None, None] x = (1 - reflect_indicator) * x + reflect_indicator * K.geometry.hflip(x) x = K.geometry.rotate(x, -group_element_dict['rotation']) - + x = self.crop(x) - + if targets: # canonicalize the targets (for instance segmentation, masks and boxes) image_width = x.shape[-1] - + if 'reflection' in group_element_dict.keys(): # flip masks and boxes for t in range(len(targets['boxes'])): targets[t]['boxes'] = flip_boxes(targets[t]['boxes'], image_width) targets[t]['masks'] = flip_masks(targets[t]['masks']) - + # rotate masks and boxes for t in range(len(targets['boxes'])): - targets[t]['boxes'] = rotate_boxes(targets[t]['boxes'], group_element_dict['rotation'], image_width) - targets[t]['masks'] = rotate_masks(targets[t]['masks'], -group_element_dict['rotation']) - + targets[t]['boxes'] = rotate_boxes( + targets[t]['boxes'], group_element_dict['rotation'], image_width + ) + targets[t]['masks'] = rotate_masks( + targets[t]['masks'], -group_element_dict['rotation'] + ) + return x, targets - + return x - - def invert_canonicalization(self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = 'regular'): + + def invert_canonicalization( + self, x_canonicalized_out: torch.Tensor, induced_rep_type: str = 'regular' + ): """ This method takes the output of canonicalized image as input and returns output of the original image """ - return get_action_on_image_features(feature_map = x_canonicalized_out, - group_info_dict = self.group_info_dict, - group_element_dict = self.canonicalization_info_dict['group_element'], - induced_rep_type = induced_rep_type) - - - + return get_action_on_image_features( + feature_map=x_canonicalized_out, + group_info_dict=self.group_info_dict, + group_element_dict=self.canonicalization_info_dict['group_element'], + induced_rep_type=induced_rep_type, + ) + class GroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, - in_shape: tuple - ): - super().__init__(canonicalization_network, - canonicalization_hyperparams, - in_shape) + def __init__( + self, + canonicalization_network: torch.nn.Module, + canonicalization_hyperparams: dict, + in_shape: tuple, + ): + super().__init__( + canonicalization_network, canonicalization_hyperparams, in_shape + ) self.group_type = canonicalization_network.group_type self.num_rotations = canonicalization_network.num_rotations - self.num_group = self.num_rotations if self.group_type == 'rotation' else 2 * self.num_rotations - self.group_info_dict = {'num_rotations': self.num_rotations, - 'num_group': self.num_group} - + self.num_group = ( + self.num_rotations + if self.group_type == 'rotation' + else 2 * self.num_rotations + ) + self.group_info_dict = { + 'num_rotations': self.num_rotations, + 'num_group': self.num_group, + } + def get_group_activations(self, x: torch.Tensor): """ - This method takes an image as input and + This method takes an image as input and returns the group activations """ x = self.transformations_before_canonicalization_network_forward(x) group_activations = self.canonicalization_network(x) return group_activations - - - -class OptimizedGroupEquivariantImageCanonicalization(DiscreteGroupImageCanonicalization): - def __init__(self, - canonicalization_network: torch.nn.Module, - canonicalization_hyperparams: dict, - in_shape: tuple - ): - super().__init__(canonicalization_network, - canonicalization_hyperparams, - in_shape) + + +class OptimizedGroupEquivariantImageCanonicalization( + DiscreteGroupImageCanonicalization +): + def __init__( + self, + canonicalization_network: torch.nn.Module, + canonicalization_hyperparams: dict, + in_shape: tuple, + ): + super().__init__( + canonicalization_network, canonicalization_hyperparams, in_shape + ) self.group_type = canonicalization_hyperparams.group_type self.num_rotations = canonicalization_hyperparams.num_rotations - self.num_group = self.num_rotations if self.group_type == 'rotation' else 2 * self.num_rotations + self.num_group = ( + self.num_rotations + if self.group_type == 'rotation' + else 2 * self.num_rotations + ) self.out_vector_size = canonicalization_network.out_vector_size self.reference_vector = torch.nn.Parameter( torch.randn(1, self.out_vector_size), requires_grad=False ) - self.group_info_dict = {'num_rotations': self.num_rotations, - 'num_group': self.num_group} - - def rotate_and_maybe_reflect(self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False): + self.group_info_dict = { + 'num_rotations': self.num_rotations, + 'num_group': self.num_group, + } + + def rotate_and_maybe_reflect( + self, x: torch.Tensor, degrees: torch.Tensor, reflect: bool = False + ): x_augmented_list = [] for degree in degrees: x_rot = self.pad(x) @@ -212,43 +280,46 @@ def rotate_and_maybe_reflect(self, x: torch.Tensor, degrees: torch.Tensor, refle x_rot = self.crop(x_rot) x_augmented_list.append(x_rot) return x_augmented_list - - - def group_augment(self, x : torch.Tensor): - + + def group_augment(self, x: torch.Tensor): + degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1].to(self.device) x_augmented_list = self.rotate_and_maybe_reflect(x, degrees) - + if self.group_type == 'roto-reflection': x_augmented_list += self.rotate_and_maybe_reflect(x, degrees, reflect=True) - + return torch.cat(x_augmented_list, dim=0) - def get_group_activations(self, x: torch.Tensor): """ - This method takes an image as input and + This method takes an image as input and returns the group activations """ - - x = self.transformations_before_canonicalization_network_forward(x) - x_augmented = self.group_augment(x) # size (batch_size * group_size, in_channels, height, width) - vector_out = self.canonicalization_network(x_augmented) # size (batch_size * group_size, reference_vector_size) + + x = self.transformations_before_canonicalization_network_forward(x) + x_augmented = self.group_augment( + x + ) # size (batch_size * group_size, in_channels, height, width) + vector_out = self.canonicalization_network( + x_augmented + ) # size (batch_size * group_size, reference_vector_size) self.canonicalization_info_dict = {'vector_out': vector_out} scalar_out = F.cosine_similarity( - self.reference_vector.repeat(vector_out.shape[0], 1), - vector_out - ) # size (batch_size * group_size, 1) - group_activations = scalar_out.reshape(self.num_group, -1).T # size (batch_size, group_size) + self.reference_vector.repeat(vector_out.shape[0], 1), vector_out + ) # size (batch_size * group_size, 1) + group_activations = scalar_out.reshape( + self.num_group, -1 + ).T # size (batch_size, group_size) return group_activations - - + def get_optimization_specific_loss(self): vectors = self.canonicalization_info_dict['vector_out'] - vectors = vectors.reshape(self.num_group, -1, self.out_vector_size).permute((1, 0, 2)) # (batch_size, group_size, vector_out_size) + vectors = vectors.reshape(self.num_group, -1, self.out_vector_size).permute( + (1, 0, 2) + ) # (batch_size, group_size, vector_out_size) distances = vectors @ vectors.permute((0, 2, 1)) - mask = 1.0 - torch.eye(self.num_group).to(self.device) # (group_size, group_size) + mask = 1.0 - torch.eye(self.num_group).to( + self.device + ) # (group_size, group_size) return torch.abs(distances * mask).mean() - - - \ No newline at end of file diff --git a/equiadapt/images/canonicalization_networks/__init__.py b/equiadapt/images/canonicalization_networks/__init__.py index a7441bb..44bf935 100644 --- a/equiadapt/images/canonicalization_networks/__init__.py +++ b/equiadapt/images/canonicalization_networks/__init__.py @@ -1,3 +1,3 @@ -from .escnn_networks import ESCNNEquivariantNetwork, ESCNNSteerableNetwork +from .custom_equivariant_networks import CustomEquivariantNetwork from .custom_nonequivariant_networks import ConvNetwork -from .custom_equivariant_networks import CustomEquivariantNetwork \ No newline at end of file +from .escnn_networks import ESCNNEquivariantNetwork, ESCNNSteerableNetwork diff --git a/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py b/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py index 4fbdaa8..3488ad1 100644 --- a/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py +++ b/equiadapt/images/canonicalization_networks/custom_equivariant_networks.py @@ -1,14 +1,27 @@ +import math + +import kornia as K import torch import torch.nn as nn import torch.nn.functional as F -import kornia as K -import math + class RotationEquivariantConvLift(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, num_rotations=4, stride=1, padding=0, bias=True, - device='cuda'): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_rotations=4, + stride=1, + padding=0, + bias=True, + device='cuda', + ): super().__init__() - self.weights = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size, kernel_size).to(device)) + self.weights = nn.Parameter( + torch.empty(out_channels, in_channels, kernel_size, kernel_size).to(device) + ) torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) if bias: self.bias = nn.Parameter(torch.empty(out_channels).to(device)) @@ -27,10 +40,16 @@ def get_rotated_weights(self, weights, num_rotations=4): weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1) rotated_weights = K.geometry.rotate( weights, - torch.linspace(0., 360., steps=num_rotations + 1, dtype=torch.float32)[:num_rotations].to(device), + torch.linspace(0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32)[ + :num_rotations + ].to(device), ) rotated_weights = rotated_weights.reshape( - self.num_rotations, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size + self.num_rotations, + self.out_channels, + self.in_channels, + self.kernel_size, + self.kernel_size, ).transpose(0, 1) return rotated_weights.flatten(0, 1) @@ -43,18 +62,31 @@ def forward(self, x): rotated_weights = self.get_rotated_weights(self.weights, self.num_rotations) # shape (out_channels * num_rotations, in_channels, kernel_size, kernel_size) x = F.conv2d(x, rotated_weights, stride=self.stride, padding=self.padding) - x = x.reshape(batch_size, self.out_channels, self.num_rotations, x.shape[2], x.shape[3]) + x = x.reshape( + batch_size, self.out_channels, self.num_rotations, x.shape[2], x.shape[3] + ) if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x class RotoReflectionEquivariantConvLift(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, num_rotations=4, stride=1, padding=0, bias=True, - device='cuda'): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_rotations=4, + stride=1, + padding=0, + bias=True, + device='cuda', + ): super().__init__() num_group_elements = 2 * num_rotations - self.weights = nn.Parameter(torch.empty(out_channels, in_channels, kernel_size, kernel_size).to(device)) + self.weights = nn.Parameter( + torch.empty(out_channels, in_channels, kernel_size, kernel_size).to(device) + ) torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) if bias: self.bias = nn.Parameter(torch.empty(out_channels).to(device)) @@ -74,12 +106,18 @@ def get_rotoreflected_weights(self, weights, num_rotations=4): weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1) rotated_weights = K.geometry.rotate( weights, - torch.linspace(0., 360., steps=num_rotations + 1, dtype=torch.float32)[:num_rotations].to(device), + torch.linspace(0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32)[ + :num_rotations + ].to(device), ) reflected_weights = K.geometry.hflip(rotated_weights) rotoreflected_weights = torch.cat([rotated_weights, reflected_weights], dim=0) rotoreflected_weights = rotoreflected_weights.reshape( - self.num_group_elements, self.out_channels, self.in_channels, self.kernel_size, self.kernel_size + self.num_group_elements, + self.out_channels, + self.in_channels, + self.kernel_size, + self.kernel_size, ).transpose(0, 1) return rotoreflected_weights.flatten(0, 1) @@ -89,18 +127,41 @@ def forward(self, x): :return: (batch_size, out_channels, num_group_elements, height, width) """ batch_size = x.shape[0] - rotoreflected_weights = self.get_rotoreflected_weights(self.weights, self.num_rotations) + rotoreflected_weights = self.get_rotoreflected_weights( + self.weights, self.num_rotations + ) # shape (out_channels * num_group_elements, in_channels, kernel_size, kernel_size) x = F.conv2d(x, rotoreflected_weights, stride=self.stride, padding=self.padding) - x = x.reshape(batch_size, self.out_channels, self.num_group_elements, x.shape[2], x.shape[3]) + x = x.reshape( + batch_size, + self.out_channels, + self.num_group_elements, + x.shape[2], + x.shape[3], + ) if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x + class RotationEquivariantConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, num_rotations=4, stride=1, padding=0, bias=True, device='cuda'): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_rotations=4, + stride=1, + padding=0, + bias=True, + device='cuda', + ): super().__init__() - self.weights = nn.Parameter(torch.empty(out_channels, in_channels, num_rotations, kernel_size, kernel_size).to(device)) + self.weights = nn.Parameter( + torch.empty( + out_channels, in_channels, num_rotations, kernel_size, kernel_size + ).to(device) + ) torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) if bias: self.bias = nn.Parameter(torch.empty(out_channels).to(device)) @@ -113,26 +174,45 @@ def __init__(self, in_channels, out_channels, kernel_size, num_rotations=4, stri self.padding = padding self.num_rotations = num_rotations self.kernel_size = kernel_size - indices = torch.arange(num_rotations).view((1, 1, num_rotations, 1, 1)).repeat( - num_rotations, out_channels * in_channels, 1, kernel_size, kernel_size + indices = ( + torch.arange(num_rotations) + .view((1, 1, num_rotations, 1, 1)) + .repeat( + num_rotations, out_channels * in_channels, 1, kernel_size, kernel_size + ) ) self.permute_indices_along_group = ( - (indices - torch.arange(num_rotations)[:, None, None, None, None]) % num_rotations + (indices - torch.arange(num_rotations)[:, None, None, None, None]) + % num_rotations ).to(device) - self.angle_list = torch.linspace(0., 360., steps=num_rotations + 1, dtype=torch.float32)[:num_rotations].to(device) + self.angle_list = torch.linspace( + 0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32 + )[:num_rotations].to(device) def get_rotated_permuted_weights(self, weights, num_rotations=4): device = weights.device weights = weights.flatten(0, 1).unsqueeze(0).repeat(num_rotations, 1, 1, 1, 1) permuted_weights = torch.gather(weights, 2, self.permute_indices_along_group) rotated_permuted_weights = K.geometry.rotate( - permuted_weights.flatten(1, 2), - self.angle_list, - ) - rotated_permuted_weights = rotated_permuted_weights.reshape( - self.num_rotations, self.out_channels, self.in_channels, self.num_rotations, self.kernel_size, self.kernel_size - ).transpose(0, 1).reshape( - self.out_channels * self.num_rotations, self.in_channels * self.num_rotations, self.kernel_size, self.kernel_size + permuted_weights.flatten(1, 2), + self.angle_list, + ) + rotated_permuted_weights = ( + rotated_permuted_weights.reshape( + self.num_rotations, + self.out_channels, + self.in_channels, + self.num_rotations, + self.kernel_size, + self.kernel_size, + ) + .transpose(0, 1) + .reshape( + self.out_channels * self.num_rotations, + self.in_channels * self.num_rotations, + self.kernel_size, + self.kernel_size, + ) ) return rotated_permuted_weights @@ -144,19 +224,40 @@ def forward(self, x): batch_size = x.shape[0] x = x.flatten(1, 2) # shape (batch_size, in_channels * num_rotations, height, width) - rotated_permuted_weights = self.get_rotated_permuted_weights(self.weights, self.num_rotations) + rotated_permuted_weights = self.get_rotated_permuted_weights( + self.weights, self.num_rotations + ) # shape (out_channels * num_rotations, in_channels * num_rotations, kernal_size, kernal_size) - x = F.conv2d(x, rotated_permuted_weights, stride=self.stride, padding=self.padding) - x = x.reshape(batch_size, self.out_channels, self.num_rotations, x.shape[2], x.shape[3]) + x = F.conv2d( + x, rotated_permuted_weights, stride=self.stride, padding=self.padding + ) + x = x.reshape( + batch_size, self.out_channels, self.num_rotations, x.shape[2], x.shape[3] + ) if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x + class RotoReflectionEquivariantConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, num_rotations=4, stride=1, padding=0, bias=True, device='cuda'): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + num_rotations=4, + stride=1, + padding=0, + bias=True, + device='cuda', + ): super().__init__() num_group_elements = 2 * num_rotations - self.weights = nn.Parameter(torch.empty(out_channels, in_channels, num_group_elements, kernel_size, kernel_size).to(device)) + self.weights = nn.Parameter( + torch.empty( + out_channels, in_channels, num_group_elements, kernel_size, kernel_size + ).to(device) + ) torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) if bias: self.bias = nn.Parameter(torch.empty(out_channels).to(device)) @@ -170,41 +271,80 @@ def __init__(self, in_channels, out_channels, kernel_size, num_rotations=4, stri self.num_rotations = num_rotations self.kernel_size = kernel_size self.num_group_elements = num_group_elements - indices = torch.arange(num_rotations).view((1, 1, num_rotations, 1, 1)).repeat( - num_rotations, out_channels * in_channels, 1, kernel_size, kernel_size + indices = ( + torch.arange(num_rotations) + .view((1, 1, num_rotations, 1, 1)) + .repeat( + num_rotations, out_channels * in_channels, 1, kernel_size, kernel_size + ) ) - self.permute_indices_along_group = (indices - torch.arange(num_rotations)[:, None, None, None, None]) % num_rotations - self.permute_indices_along_group_inverse = (indices + torch.arange(num_rotations)[:, None, None, None, None]) % num_rotations - self.permute_indices_upper_half = torch.cat([ - self.permute_indices_along_group, self.permute_indices_along_group_inverse + num_rotations - ], dim=2) - self.permute_indices_lower_half = torch.cat([ - self.permute_indices_along_group_inverse + num_rotations, self.permute_indices_along_group - ], dim=2) - self.permute_indices = torch.cat([ - self.permute_indices_upper_half, self.permute_indices_lower_half - ], dim=0).to(device) - self.angle_list = torch.cat([ - torch.linspace(0., 360., steps=num_rotations + 1, dtype=torch.float32)[:num_rotations], - torch.linspace(0., 360., steps=num_rotations + 1, dtype=torch.float32)[:num_rotations] - ]).to(device) + self.permute_indices_along_group = ( + indices - torch.arange(num_rotations)[:, None, None, None, None] + ) % num_rotations + self.permute_indices_along_group_inverse = ( + indices + torch.arange(num_rotations)[:, None, None, None, None] + ) % num_rotations + self.permute_indices_upper_half = torch.cat( + [ + self.permute_indices_along_group, + self.permute_indices_along_group_inverse + num_rotations, + ], + dim=2, + ) + self.permute_indices_lower_half = torch.cat( + [ + self.permute_indices_along_group_inverse + num_rotations, + self.permute_indices_along_group, + ], + dim=2, + ) + self.permute_indices = torch.cat( + [self.permute_indices_upper_half, self.permute_indices_lower_half], dim=0 + ).to(device) + self.angle_list = torch.cat( + [ + torch.linspace( + 0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32 + )[:num_rotations], + torch.linspace( + 0.0, 360.0, steps=num_rotations + 1, dtype=torch.float32 + )[:num_rotations], + ] + ).to(device) def get_rotoreflected_permuted_weights(self, weights, num_rotations=4): - weights = weights.flatten(0, 1).unsqueeze(0).repeat(self.num_group_elements, 1, 1, 1, 1) + weights = ( + weights.flatten(0, 1) + .unsqueeze(0) + .repeat(self.num_group_elements, 1, 1, 1, 1) + ) # shape (num_group_elements, out_channels * in_channels, num_group_elements, kernel_size, kernel_size) permuted_weights = torch.gather(weights, 2, self.permute_indices) rotated_permuted_weights = K.geometry.rotate( - permuted_weights.flatten(1, 2), - self.angle_list - ) - rotoreflected_permuted_weights = torch.cat([ - rotated_permuted_weights[:self.num_rotations], - K.geometry.hflip(rotated_permuted_weights[self.num_rotations:]) - ]) - rotoreflected_permuted_weights = rotoreflected_permuted_weights.reshape( - self.num_group_elements, self.out_channels, self.in_channels, self.num_group_elements, self.kernel_size, self.kernel_size - ).transpose(0, 1).reshape( - self.out_channels * self.num_group_elements, self.in_channels * self.num_group_elements, self.kernel_size, self.kernel_size + permuted_weights.flatten(1, 2), self.angle_list + ) + rotoreflected_permuted_weights = torch.cat( + [ + rotated_permuted_weights[: self.num_rotations], + K.geometry.hflip(rotated_permuted_weights[self.num_rotations :]), + ] + ) + rotoreflected_permuted_weights = ( + rotoreflected_permuted_weights.reshape( + self.num_group_elements, + self.out_channels, + self.in_channels, + self.num_group_elements, + self.kernel_size, + self.kernel_size, + ) + .transpose(0, 1) + .reshape( + self.out_channels * self.num_group_elements, + self.in_channels * self.num_group_elements, + self.kernel_size, + self.kernel_size, + ) ) return rotoreflected_permuted_weights @@ -216,36 +356,65 @@ def forward(self, x): batch_size = x.shape[0] x = x.flatten(1, 2) # shape (batch_size, in_channels * num_group_elements, height, width) - rotoreflected_permuted_weights = self.get_rotoreflected_permuted_weights(self.weights, self.num_rotations) + rotoreflected_permuted_weights = self.get_rotoreflected_permuted_weights( + self.weights, self.num_rotations + ) # shape (out_channels * num_group_elements, in_channels * num_group_elements, kernel_size, kernel_size) - x = F.conv2d(x, rotoreflected_permuted_weights, stride=self.stride, padding=self.padding) - x = x.reshape(batch_size, self.out_channels, self.num_group_elements, x.shape[2], x.shape[3]) + x = F.conv2d( + x, rotoreflected_permuted_weights, stride=self.stride, padding=self.padding + ) + x = x.reshape( + batch_size, + self.out_channels, + self.num_group_elements, + x.shape[2], + x.shape[3], + ) if self.bias is not None: x = x + self.bias[None, :, None, None, None] return x - + + class CustomEquivariantNetwork(nn.Module): - def __init__(self, - in_shape, - out_channels, - kernel_size, - group_type='rotation', - num_rotations=4, - num_layers=1, - device='cuda' if torch.cuda.is_available() else 'cpu'): + def __init__( + self, + in_shape, + out_channels, + kernel_size, + group_type='rotation', + num_rotations=4, + num_layers=1, + device='cuda' if torch.cuda.is_available() else 'cpu', + ): super().__init__() - + if group_type == 'rotation': - layer_list = [RotationEquivariantConvLift(in_shape[0], out_channels, kernel_size, num_rotations, device=device)] + layer_list = [ + RotationEquivariantConvLift( + in_shape[0], out_channels, kernel_size, num_rotations, device=device + ) + ] for i in range(num_layers - 1): layer_list.append(nn.ReLU()) - layer_list.append(RotationEquivariantConv(out_channels, out_channels, 1, num_rotations, device=device)) + layer_list.append( + RotationEquivariantConv( + out_channels, out_channels, 1, num_rotations, device=device + ) + ) self.eqv_network = nn.Sequential(*layer_list) elif group_type == 'roto-reflection': - layer_list = [RotoReflectionEquivariantConvLift(in_shape[0], out_channels, kernel_size, num_rotations, device=device)] + layer_list = [ + RotoReflectionEquivariantConvLift( + in_shape[0], out_channels, kernel_size, num_rotations, device=device + ) + ] for i in range(num_layers - 1): layer_list.append(nn.ReLU()) - layer_list.append(RotoReflectionEquivariantConv(out_channels, out_channels, 1, num_rotations, device=device)) + layer_list.append( + RotoReflectionEquivariantConv( + out_channels, out_channels, 1, num_rotations, device=device + ) + ) self.eqv_network = nn.Sequential(*layer_list) else: raise ValueError('group_type must be rotation or roto-reflection for now.') @@ -257,5 +426,5 @@ def forward(self, x): """ feature_map = self.eqv_network(x) group_activatiobs = torch.mean(feature_map, dim=(1, 3, 4)) - + return group_activatiobs diff --git a/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py b/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py index 5d47147..39c33f7 100644 --- a/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py +++ b/equiadapt/images/canonicalization_networks/custom_nonequivariant_networks.py @@ -1,13 +1,11 @@ import torch from torch import nn + class ConvNetwork(nn.Module): - def __init__(self, - in_shape, - out_channels, - kernel_size, - num_layers=2, - out_vector_size=128): + def __init__( + self, in_shape, out_channels, kernel_size, num_layers=2, out_vector_size=128 + ): super().__init__() in_channels = in_shape[0] @@ -16,7 +14,9 @@ def __init__(self, if i == 0: layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, 2)) elif i % 3 == 2: - layers.append(nn.Conv2d(out_channels, 2 * out_channels, kernel_size, 2, 1)) + layers.append( + nn.Conv2d(out_channels, 2 * out_channels, kernel_size, 2, 1) + ) out_channels *= 2 else: layers.append(nn.Conv2d(out_channels, out_channels, kernel_size, 2)) @@ -29,11 +29,11 @@ def __init__(self, # self.scalar_fc = nn.Linear(out_shape[1] * out_shape[2] * out_shape[3], 1) out_dim = out_shape[1] * out_shape[2] * out_shape[3] self.final_fc = nn.Sequential( - nn.BatchNorm1d(out_dim), - nn.Dropout1d(0.5), - nn.ReLU(), - nn.Linear(out_dim, out_vector_size) - ) + nn.BatchNorm1d(out_dim), + nn.Dropout1d(0.5), + nn.ReLU(), + nn.Linear(out_dim, out_vector_size), + ) self.out_vector_size = out_vector_size def forward(self, x): @@ -44,4 +44,4 @@ def forward(self, x): batch_size = x.shape[0] out = self.enc_network(x) out = out.reshape(batch_size, -1) - return self.final_fc(out) \ No newline at end of file + return self.final_fc(out) diff --git a/equiadapt/images/canonicalization_networks/escnn_networks.py b/equiadapt/images/canonicalization_networks/escnn_networks.py index 1e80dab..43badcd 100644 --- a/equiadapt/images/canonicalization_networks/escnn_networks.py +++ b/equiadapt/images/canonicalization_networks/escnn_networks.py @@ -1,15 +1,18 @@ -import torch import escnn +import torch from escnn import gspaces + class ESCNNEquivariantNetwork(torch.nn.Module): - def __init__(self, - in_shape, - out_channels, - kernel_size, - group_type='rotation', - num_rotations=4, - num_layers=1): + def __init__( + self, + in_shape, + out_channels, + kernel_size, + group_type='rotation', + num_rotations=4, + num_layers=1, + ): super().__init__() self.in_channels = in_shape[0] @@ -24,13 +27,17 @@ def __init__(self, self.gspace = gspaces.flipRot2dOnR2(num_rotations) else: raise ValueError('group_type must be rotation or roto-reflection for now.') - + # If the group is roto-reflection, then the number of group elements is twice the number of rotations - self.num_group_elements = num_rotations if group_type == 'rotation' else 2 * num_rotations + self.num_group_elements = ( + num_rotations if group_type == 'rotation' else 2 * num_rotations + ) - r1 = escnn.nn.FieldType(self.gspace, [self.gspace.trivial_repr] * self.in_channels) + r1 = escnn.nn.FieldType( + self.gspace, [self.gspace.trivial_repr] * self.in_channels + ) r2 = escnn.nn.FieldType(self.gspace, [self.gspace.regular_repr] * out_channels) - + self.in_type = r1 self.out_type = r2 @@ -41,18 +48,28 @@ def __init__(self, escnn.nn.PointwiseDropout(self.out_type, p=0.5), ) for _ in range(num_layers - 2): - self.eqv_network.append(escnn.nn.R2Conv(self.out_type, self.out_type, kernel_size),) - self.eqv_network.append(escnn.nn.InnerBatchNorm(self.out_type, momentum=0.9),) - self.eqv_network.append(escnn.nn.ReLU(self.out_type, inplace=True),) - self.eqv_network.append(escnn.nn.PointwiseDropout(self.out_type, p=0.5),) - - self.eqv_network.append(escnn.nn.R2Conv(self.out_type, self.out_type, kernel_size),) - + self.eqv_network.append( + escnn.nn.R2Conv(self.out_type, self.out_type, kernel_size), + ) + self.eqv_network.append( + escnn.nn.InnerBatchNorm(self.out_type, momentum=0.9), + ) + self.eqv_network.append( + escnn.nn.ReLU(self.out_type, inplace=True), + ) + self.eqv_network.append( + escnn.nn.PointwiseDropout(self.out_type, p=0.5), + ) + + self.eqv_network.append( + escnn.nn.R2Conv(self.out_type, self.out_type, kernel_size), + ) + def forward(self, x): """ - The forward takes an image as input and returns the activations of + The forward takes an image as input and returns the activations of each group element as output. - + x shape: (batch_size, in_channels, height, width) :return: (batch_size, group_size) """ @@ -61,24 +78,29 @@ def forward(self, x): feature_map = out.tensor feature_map = feature_map.reshape( - feature_map.shape[0], self.out_channels, self.num_group_elements, - feature_map.shape[-2], feature_map.shape[-1] + feature_map.shape[0], + self.out_channels, + self.num_group_elements, + feature_map.shape[-2], + feature_map.shape[-1], ) - + group_activations = torch.mean(feature_map, dim=(1, 3, 4)) return group_activations - + class ESCNNSteerableNetwork(torch.nn.Module): - def __init__(self, - in_shape: tuple, - out_channels: int, - kernel_size: int = 9, - group_type: str = 'rotation', - num_layers: int = 1): + def __init__( + self, + in_shape: tuple, + out_channels: int, + kernel_size: int = 9, + group_type: str = 'rotation', + num_layers: int = 1, + ): super().__init__() - + self.group_type = group_type assert group_type == 'rotation', 'group_type must be rotation for now.' # TODO: Add support for roto-reflection group @@ -87,8 +109,10 @@ def __init__(self, self.gspace = gspaces.rot2dOnR2(N=-1) # The input image is a scalar field, corresponding to the trivial representation - in_type = escnn.nn.FieldType(self.gspace, in_shape[0] * [self.gspace.trivial_repr]) - + in_type = escnn.nn.FieldType( + self.gspace, in_shape[0] * [self.gspace.trivial_repr] + ) + # Store the input type for wrapping the images into a geometric tensor during the forward pass self.input_type = in_type @@ -97,25 +121,43 @@ def __init__(self, # Dynamically add layers based on num_layers for _ in range(num_layers): - activation = escnn.nn.FourierELU(self.gspace, out_channels, irreps=[(f,) for f in range(0, 5)], N=16, inplace=True) - modules.append(escnn.nn.R2Conv(in_type, activation.in_type, kernel_size=kernel_size, padding=0, bias=False)) + activation = escnn.nn.FourierELU( + self.gspace, + out_channels, + irreps=[(f,) for f in range(0, 5)], + N=16, + inplace=True, + ) + modules.append( + escnn.nn.R2Conv( + in_type, + activation.in_type, + kernel_size=kernel_size, + padding=0, + bias=False, + ) + ) modules.append(escnn.nn.IIDBatchNorm2d(activation.in_type)) modules.append(activation) in_type = activation.out_type # Update in_type for the next layer # Define the output layer - out_type = escnn.nn.FieldType(self.gspace, [self.gspace.irrep(1), self.gspace.irrep(1)]) - modules.append(escnn.nn.R2Conv(in_type, out_type, kernel_size=kernel_size, padding=0, bias=False)) + out_type = escnn.nn.FieldType( + self.gspace, [self.gspace.irrep(1), self.gspace.irrep(1)] + ) + modules.append( + escnn.nn.R2Conv( + in_type, out_type, kernel_size=kernel_size, padding=0, bias=False + ) + ) # Combine all modules into a SequentialModule self.block = escnn.nn.SequentialModule(*modules) - def forward(self, x : torch.Tensor): + def forward(self, x: torch.Tensor): x = self.input_type(x) # Wrap input images into a geometric tensor x = self.block(x) x = x.tensor # Extract tensor from geometric tensor x = torch.mean(x, dim=(-1, -2)) # Average over spatial dimensions x = x.reshape(x.shape[0], 2, 2) # Reshape to get vector/vectors of dimension 2 return x - - diff --git a/equiadapt/images/utils.py b/equiadapt/images/utils.py index 75d8874..9fb0bef 100644 --- a/equiadapt/images/utils.py +++ b/equiadapt/images/utils.py @@ -1,20 +1,30 @@ import math -import torch + import kornia as K +import torch from torchvision import transforms + def roll_by_gather(feature_map: torch.Tensor, shifts: torch.Tensor): device = shifts.device # assumes 2D array batch, channel, group, x_dim, y_dim = feature_map.shape - arange1 = torch.arange(group).view((1, 1, group, 1, 1)).repeat((batch, channel, 1, x_dim, y_dim)).to(device) - arange2 = (arange1 - shifts[:, None, None,None,None].long()) % group + arange1 = ( + torch.arange(group) + .view((1, 1, group, 1, 1)) + .repeat((batch, channel, 1, x_dim, y_dim)) + .to(device) + ) + arange2 = (arange1 - shifts[:, None, None, None, None].long()) % group return torch.gather(feature_map, 2, arange2) -def get_action_on_image_features(feature_map: torch.Tensor, - group_info_dict: dict, - group_element_dict: dict, - induced_rep_type: str ='regular'): + +def get_action_on_image_features( + feature_map: torch.Tensor, + group_info_dict: dict, + group_element_dict: dict, + induced_rep_type: str = 'regular', +): """ This function takes the feature map and the action and returns the feature map after the action has been applied @@ -27,20 +37,24 @@ def get_action_on_image_features(feature_map: torch.Tensor, assert feature_map.shape[1] % num_group == 0 angles = group_element_dict['group']['rotation'] x_out = K.geometry.rotate(feature_map, angles) - + if 'reflection' in group_element_dict['group']: - reflect_indicator = group_element_dict['group']['reflection'] + reflect_indicator = group_element_dict['group']['reflection'] x_out_reflected = K.geometry.hflip(x_out) - x_out = x_out * reflect_indicator[:,None,None,None] + \ - x_out_reflected * (1 - reflect_indicator[:,None,None,None]) - + x_out = x_out * reflect_indicator[:, None, None, None] + x_out_reflected * ( + 1 - reflect_indicator[:, None, None, None] + ) + x_out = x_out.reshape(batch_size, C // num_group, num_group, H, W) - shift = angles / 360. * num_rotations + shift = angles / 360.0 * num_rotations if 'reflection' in group_element_dict['group']: - x_out = torch.cat([ - roll_by_gather(x_out[:,:,:num_rotations], shift), - roll_by_gather(x_out[:,:,num_rotations:], -shift) - ], dim=2) + x_out = torch.cat( + [ + roll_by_gather(x_out[:, :, :num_rotations], shift), + roll_by_gather(x_out[:, :, num_rotations:], -shift), + ], + dim=2, + ) else: x_out = roll_by_gather(x_out, shift) x_out = x_out.reshape(batch_size, -1, H, W) @@ -49,10 +63,11 @@ def get_action_on_image_features(feature_map: torch.Tensor, angles = group_element_dict['group'][0] x_out = K.geometry.rotate(feature_map, angles) if 'reflection' in group_element_dict['group']: - reflect_indicator = group_element_dict['group']['reflection'] + reflect_indicator = group_element_dict['group']['reflection'] x_out_reflected = K.geometry.hflip(x_out) - x_out = x_out * reflect_indicator[:,None,None,None] + \ - x_out_reflected * (1 - reflect_indicator[:,None,None,None]) + x_out = x_out * reflect_indicator[:, None, None, None] + x_out_reflected * ( + 1 - reflect_indicator[:, None, None, None] + ) return x_out elif induced_rep_type == 'vector': # TODO: Implement the action for vector representation @@ -60,16 +75,20 @@ def get_action_on_image_features(feature_map: torch.Tensor, else: raise ValueError('induced_rep_type must be regular, scalar or vector') + def flip_boxes(boxes, width): boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - return boxes + return boxes + def flip_masks(masks): return masks.flip(-1) - + + def rotate_masks(masks, angle): return transforms.functional.rotate(masks, angle) + def rotate_points(origin, point, angle): ox, oy = origin px, py = point @@ -78,6 +97,7 @@ def rotate_points(origin, point, angle): qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy) return qx, qy + def rotate_boxes(boxes, angle, width): # rotate points origin = [width / 2, width / 2] @@ -85,8 +105,12 @@ def rotate_boxes(boxes, angle, width): x_max_rot, y_max_rot = rotate_points(origin, boxes[:, 2:].T, torch.deg2rad(angle)) # rearrange the max and mins to get rotated boxes - x_min_rot, x_max_rot = torch.min(x_min_rot, x_max_rot), torch.max(x_min_rot, x_max_rot) - y_min_rot, y_max_rot = torch.min(y_min_rot, y_max_rot), torch.max(y_min_rot, y_max_rot) + x_min_rot, x_max_rot = torch.min(x_min_rot, x_max_rot), torch.max( + x_min_rot, x_max_rot + ) + y_min_rot, y_max_rot = torch.min(y_min_rot, y_max_rot), torch.max( + y_min_rot, y_max_rot + ) rotated_boxes = torch.stack([x_min_rot, y_min_rot, x_max_rot, y_max_rot], dim=-1) - return rotated_boxes \ No newline at end of file + return rotated_boxes diff --git a/examples/images/classification/README.md b/examples/images/classification/README.md index fdd92b1..054a480 100644 --- a/examples/images/classification/README.md +++ b/examples/images/classification/README.md @@ -6,11 +6,11 @@ python train.py canonicalization=group_equivariant experiment.training.loss.prior_weight=0 ``` ### For image classification (with prior regularization) -``` -python train.py canonicalization=group_equivariant +``` +python train.py canonicalization=group_equivariant ``` -**Note**: You can also run the `train.py` as follows from root directory of the project: +**Note**: You can also run the `train.py` as follows from root directory of the project: ``` python examples/images/classification/train.py canonicalization=group_equivariant ``` diff --git a/examples/images/classification/configs/canonicalization/group_equivariant.yaml b/examples/images/classification/configs/canonicalization/group_equivariant.yaml index 0bcda3b..731c152 100644 --- a/examples/images/classification/configs/canonicalization/group_equivariant.yaml +++ b/examples/images/classification/configs/canonicalization/group_equivariant.yaml @@ -8,4 +8,4 @@ network_hyperparams: num_rotations: 4 # Number of rotations for the canonization network beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization -resize_shape: 32 # Resize shape for the input \ No newline at end of file +resize_shape: 32 # Resize shape for the input diff --git a/examples/images/classification/configs/canonicalization/identity.yaml b/examples/images/classification/configs/canonicalization/identity.yaml index 1598d17..513e776 100644 --- a/examples/images/classification/configs/canonicalization/identity.yaml +++ b/examples/images/classification/configs/canonicalization/identity.yaml @@ -1 +1 @@ -canonicalization_type: identity \ No newline at end of file +canonicalization_type: identity diff --git a/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml b/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml index 12f60e4..93110c5 100644 --- a/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml +++ b/examples/images/classification/configs/canonicalization/opt_group_equivariant.yaml @@ -9,4 +9,4 @@ group_type: "rotation" # Type of group for the canonization network num_rotations: 4 # Number of rotations for the canonization network beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization -resize_shape: 32 # Resize shape for the input \ No newline at end of file +resize_shape: 32 # Resize shape for the input diff --git a/examples/images/classification/configs/canonicalization/opt_steerable.yaml b/examples/images/classification/configs/canonicalization/opt_steerable.yaml index 086cf49..47722db 100644 --- a/examples/images/classification/configs/canonicalization/opt_steerable.yaml +++ b/examples/images/classification/configs/canonicalization/opt_steerable.yaml @@ -6,4 +6,4 @@ network_hyperparams: num_layers: 3 # Number of layers in the canonization network out_vector_size: 4 # Dimension of the output vector group_type: "rotation" # Type of group for the canonization network -input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization \ No newline at end of file +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization diff --git a/examples/images/classification/configs/canonicalization/steerable.yaml b/examples/images/classification/configs/canonicalization/steerable.yaml index 9a6f5f2..629d274 100644 --- a/examples/images/classification/configs/canonicalization/steerable.yaml +++ b/examples/images/classification/configs/canonicalization/steerable.yaml @@ -5,4 +5,4 @@ network_hyperparams: out_channels: 16 # Number of output channels for the canonization network num_layers: 3 # Number of layers in the canonization network group_type: "rotation" # Type of group for the canonization network -input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization \ No newline at end of file +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization diff --git a/examples/images/classification/configs/checkpoint/default.yaml b/examples/images/classification/configs/checkpoint/default.yaml index 419f669..7398463 100644 --- a/examples/images/classification/configs/checkpoint/default.yaml +++ b/examples/images/classification/configs/checkpoint/default.yaml @@ -1,3 +1,3 @@ checkpoint_path: ${oc.env:CHECKPOINT_PATH} # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later -save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file +save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/dataset/default.yaml b/examples/images/classification/configs/dataset/default.yaml index 746b266..d8585cc 100644 --- a/examples/images/classification/configs/dataset/default.yaml +++ b/examples/images/classification/configs/dataset/default.yaml @@ -2,4 +2,4 @@ dataset_name: "cifar10" # Name of the dataset to use data_path: ${oc.env:DATA_PATH} # Path to the dataset augment: 1 # Whether to use data augmentation (1) or not (0) num_workers: 4 # Number of workers for data loading -batch_size: 128 # Number of samples per batch \ No newline at end of file +batch_size: 128 # Number of samples per batch diff --git a/examples/images/classification/configs/experiment/default.yaml b/examples/images/classification/configs/experiment/default.yaml index c571752..44b2d5d 100644 --- a/examples/images/classification/configs/experiment/default.yaml +++ b/examples/images/classification/configs/experiment/default.yaml @@ -1,5 +1,5 @@ run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune -seed: 0 # Seed for random number generation +seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -17,4 +17,4 @@ training: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference diff --git a/examples/images/classification/configs/original_configs/group_equivariant/cifar10.yaml b/examples/images/classification/configs/original_configs/group_equivariant/cifar10.yaml index d94f0bf..a209727 100644 --- a/examples/images/classification/configs/original_configs/group_equivariant/cifar10.yaml +++ b/examples/images/classification/configs/original_configs/group_equivariant/cifar10.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -53,4 +53,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/s/siba-smarak.panigrahi/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/original_configs/group_equivariant/rotmnist.yaml b/examples/images/classification/configs/original_configs/group_equivariant/rotmnist.yaml index 26b0414..afac33a 100644 --- a/examples/images/classification/configs/original_configs/group_equivariant/rotmnist.yaml +++ b/examples/images/classification/configs/original_configs/group_equivariant/rotmnist.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -53,4 +53,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/original_configs/opt_equivariant/cifar10.yaml b/examples/images/classification/configs/original_configs/opt_equivariant/cifar10.yaml index c611b84..9c7afe6 100644 --- a/examples/images/classification/configs/original_configs/opt_equivariant/cifar10.yaml +++ b/examples/images/classification/configs/original_configs/opt_equivariant/cifar10.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) num_nodes: 1 num_gpus: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -54,4 +54,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/original_configs/opt_equivariant/rotmnist.yaml b/examples/images/classification/configs/original_configs/opt_equivariant/rotmnist.yaml index 5168bd1..1dc06c3 100644 --- a/examples/images/classification/configs/original_configs/opt_equivariant/rotmnist.yaml +++ b/examples/images/classification/configs/original_configs/opt_equivariant/rotmnist.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) num_nodes: 1 num_gpus: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -54,4 +54,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/original_configs/steerable/cifar10.yaml b/examples/images/classification/configs/original_configs/steerable/cifar10.yaml index 36db32a..be68cf1 100644 --- a/examples/images/classification/configs/original_configs/steerable/cifar10.yaml +++ b/examples/images/classification/configs/original_configs/steerable/cifar10.yaml @@ -19,7 +19,7 @@ prediction: freeze_pretrained_encoder: 0 # Whether to freeze the pretrained encoder (1) or not (0) canonicalization: - network_type: 'escnn' # Options o canonization method 1) escnn + network_type: 'escnn' # Options o canonization method 1) escnn network_hyperparams: kernel_size: 3 # Kernel size for the canonization network out_channels: 16 # Number of output channels for the canonization network @@ -42,4 +42,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints deterministic: false # Whether to set deterministic mode (true) or not (false) - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/classification/configs/wandb_sweep.yaml b/examples/images/classification/configs/wandb_sweep.yaml index 56117e9..2e1273f 100644 --- a/examples/images/classification/configs/wandb_sweep.yaml +++ b/examples/images/classification/configs/wandb_sweep.yaml @@ -27,4 +27,4 @@ command: - ${env} - python3 - ${program} - - ${args_no_hyphens} \ No newline at end of file + - ${args_no_hyphens} diff --git a/examples/images/classification/inference_utils.py b/examples/images/classification/inference_utils.py index 51ff778..d0a39ef 100644 --- a/examples/images/classification/inference_utils.py +++ b/examples/images/classification/inference_utils.py @@ -1,46 +1,55 @@ -import torch, math -import wandb - -from typing import Union, Dict +import math +from typing import Dict, Union +import torch +import wandb from torchvision import transforms -def get_inference_method(canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, - num_classes: int, - inference_hyperparams: Union[Dict, wandb.Config], - in_shape: tuple = (3, 32, 32)): + +def get_inference_method( + canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, + num_classes: int, + inference_hyperparams: Union[Dict, wandb.Config], + in_shape: tuple = (3, 32, 32), +): if inference_hyperparams.method == 'vanilla': return VanillaInference(canonicalizer, prediction_network, num_classes) elif inference_hyperparams.method == 'group': return GroupInference( - canonicalizer, prediction_network, num_classes, - inference_hyperparams, in_shape + canonicalizer, + prediction_network, + num_classes, + inference_hyperparams, + in_shape, ) else: raise ValueError(f'{inference_hyperparams.method} is not implemented for now.') + class VanillaInference: - def __init__(self, - canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, - num_classes: int) -> None: + def __init__( + self, + canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, + num_classes: int, + ) -> None: self.canonicalizer = canonicalizer self.prediction_network = prediction_network self.num_classes = num_classes - + def forward(self, x): # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized = self.canonicalizer(x) - + # Forward pass through the prediction network as you'll normally do logits = self.prediction_network(x_canonicalized) return logits - + def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): # Forward pass through the prediction network - logits = self.forward(x) + logits = self.forward(x) preds = logits.argmax(dim=-1) # Calculate the accuracy @@ -48,46 +57,56 @@ def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): metrics = {"test/acc": acc} # Calculate accuracy per class - acc_per_class = [(preds[y == i] == y[y == i]).float().mean() for i in range(self.num_classes)] - + acc_per_class = [ + (preds[y == i] == y[y == i]).float().mean() for i in range(self.num_classes) + ] + # check if the accuracy per class is nan acc_per_class = [0.0 if math.isnan(acc) else acc for acc in acc_per_class] # Update metrics with accuracy per class - metrics.update({f'test/acc_class_{i}': max(acc, 0.0) for i, acc in enumerate(acc_per_class)}) + metrics.update( + { + f'test/acc_class_{i}': max(acc, 0.0) + for i, acc in enumerate(acc_per_class) + } + ) return metrics class GroupInference(VanillaInference): - def __init__(self, - canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, - num_classes: int, - inference_hyperparams: Union[Dict, wandb.Config], - in_shape: tuple = (3, 32, 32)): - + def __init__( + self, + canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, + num_classes: int, + inference_hyperparams: Union[Dict, wandb.Config], + in_shape: tuple = (3, 32, 32), + ): + super().__init__(canonicalizer, prediction_network, num_classes) self.group_type = inference_hyperparams.group_type self.num_rotations = inference_hyperparams.num_rotations - self.num_group_elements = self.num_rotations if self.group_type == 'rotation' else 2 * self.num_rotations - self.pad = transforms.Pad( - math.ceil(in_shape[-2] * 0.4), - padding_mode='edge' + self.num_group_elements = ( + self.num_rotations + if self.group_type == 'rotation' + else 2 * self.num_rotations ) + self.pad = transforms.Pad(math.ceil(in_shape[-2] * 0.4), padding_mode='edge') self.crop = transforms.CenterCrop((in_shape[-2], in_shape[-1])) def get_group_element_wise_logits(self, x: torch.Tensor): logits_dict = {} degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1] for rot, degree in enumerate(degrees): - + x_pad = self.pad(x) x_rot = transforms.functional.rotate(x_pad, int(degree)) x_rot = self.crop(x_rot) - + logits_dict[rot] = self.forward(x_rot) - + if self.group_type == 'roto-reflection': # Rotate the reflected images and get the logits for rot, degree in enumerate(degrees): @@ -100,17 +119,27 @@ def get_group_element_wise_logits(self, x: torch.Tensor): logits_dict[rot + len(degrees)] = self.forward(x_rotoreflect) return logits_dict - + def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): - + logits_dict = self.get_group_element_wise_logits(x) - + # Use list comprehension to calculate accuracy for each group element - acc_per_group_element = torch.tensor([(logits.argmax(dim=-1) == y).float().mean() for logits in logits_dict.values()]) + acc_per_group_element = torch.tensor( + [ + (logits.argmax(dim=-1) == y).float().mean() + for logits in logits_dict.values() + ] + ) metrics = {"test/group_acc": torch.mean(acc_per_group_element)} - metrics.update({f'test/acc_group_element_{i}': max(acc_per_group_element[i], 0.0) for i in range(self.num_group_elements)}) - + metrics.update( + { + f'test/acc_group_element_{i}': max(acc_per_group_element[i], 0.0) + for i in range(self.num_group_elements) + } + ) + preds = logits_dict[0].argmax(dim=-1) # Calculate the accuracy @@ -118,13 +147,19 @@ def get_inference_metrics(self, x: torch.Tensor, y: torch.Tensor): metrics.update({"test/acc": acc}) # Calculate accuracy per class - acc_per_class = [(preds[y == i] == y[y == i]).float().mean() for i in range(self.num_classes)] - + acc_per_class = [ + (preds[y == i] == y[y == i]).float().mean() for i in range(self.num_classes) + ] + # check if the accuracy per class is nan acc_per_class = [0.0 if math.isnan(acc) else acc for acc in acc_per_class] # Update metrics with accuracy per class - metrics.update({f'test/acc_class_{i}': max(acc, 0.0) for i, acc in enumerate(acc_per_class)}) + metrics.update( + { + f'test/acc_class_{i}': max(acc, 0.0) + for i, acc in enumerate(acc_per_class) + } + ) return metrics - \ No newline at end of file diff --git a/examples/images/classification/model.py b/examples/images/classification/model.py index d0a9a1c..3568233 100644 --- a/examples/images/classification/model.py +++ b/examples/images/classification/model.py @@ -1,19 +1,20 @@ -import torch import pytorch_lightning as pl -from torch.optim.lr_scheduler import MultiStepLR - -from omegaconf import DictConfig - +import torch +from common.utils import get_canonicalization_network, get_canonicalizer from inference_utils import get_inference_method from model_utils import get_dataset_specific_info, get_prediction_network -from common.utils import get_canonicalization_network, get_canonicalizer +from omegaconf import DictConfig +from torch.optim.lr_scheduler import MultiStepLR + # define the LightningModule class ImageClassifierPipeline(pl.LightningModule): def __init__(self, hyperparams: DictConfig): super().__init__() - - self.loss, self.image_shape, self.num_classes = get_dataset_specific_info(hyperparams.dataset.dataset_name) + + self.loss, self.image_shape, self.num_classes = get_dataset_specific_info( + hyperparams.dataset.dataset_name + ) self.prediction_network = get_prediction_network( architecture=hyperparams.prediction.prediction_network_architecture, @@ -21,166 +22,178 @@ def __init__(self, hyperparams: DictConfig): use_pretrained=hyperparams.prediction.use_pretrained, freeze_encoder=hyperparams.prediction.freeze_pretrained_encoder, input_shape=self.image_shape, - num_classes=self.num_classes + num_classes=self.num_classes, ) canonicalization_network = get_canonicalization_network( - hyperparams.canonicalization_type, + hyperparams.canonicalization_type, hyperparams.canonicalization, self.image_shape, ) - + self.canonicalizer = get_canonicalizer( hyperparams.canonicalization_type, canonicalization_network, hyperparams.canonicalization, - self.image_shape - ) - + self.image_shape, + ) + self.hyperparams = hyperparams - + self.inference_method = get_inference_method( self.canonicalizer, self.prediction_network, self.num_classes, hyperparams.experiment.inference, - self.image_shape + self.image_shape, ) - + self.max_epochs = hyperparams.experiment.training.num_epochs - - self.save_hyperparameters() + self.save_hyperparameters() def training_step(self, batch: torch.Tensor): x, y = batch batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape training_metrics = {} loss = 0.0 - + # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized = self.canonicalizer(x) - + # add group contrast loss while using optmization based canonicalization method if 'opt' in self.hyperparams.canonicalization_type: group_contrast_loss = self.canonicalizer.get_optimization_specific_loss() - loss += group_contrast_loss * self.hyperparams.experiment.training.loss.group_contrast_weight - training_metrics.update({"train/optimization_specific_loss": group_contrast_loss}) - - + loss += ( + group_contrast_loss + * self.hyperparams.experiment.training.loss.group_contrast_weight + ) + training_metrics.update( + {"train/optimization_specific_loss": group_contrast_loss} + ) + # calculate the task loss which is the cross-entropy loss for classification if self.hyperparams.experiment.training.loss.task_weight: # Forward pass through the prediction network as you'll normally do logits = self.prediction_network(x_canonicalized) - + task_loss = self.loss(logits, y) loss += self.hyperparams.experiment.training.loss.task_weight * task_loss - + # Get the predictions and calculate the accuracy preds = logits.argmax(dim=-1) acc = (preds == y).float().mean() - - training_metrics.update({ - "train/task_loss": task_loss, - "train/acc": acc - }) - + + training_metrics.update({"train/task_loss": task_loss, "train/acc": acc}) + # Add prior regularization loss if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: prior_loss = self.canonicalizer.get_prior_regularization_loss() loss += prior_loss * self.hyperparams.experiment.training.loss.prior_weight metric_identity = self.canonicalizer.get_identity_metric() - training_metrics.update({ - "train/prior_loss": prior_loss, - "train/identity_metric": metric_identity - }) - - training_metrics.update({ + training_metrics.update( + { + "train/prior_loss": prior_loss, + "train/identity_metric": metric_identity, + } + ) + + training_metrics.update( + { "train/loss": loss, - }) - + } + ) + # Log the training metrics self.log_dict(training_metrics, prog_bar=True) - + return {'loss': loss, 'acc': acc} - def validation_step(self, batch: torch.Tensor): x, y = batch - + batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape - + validation_metrics = {} - + # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized = self.canonicalizer(x) - + # Forward pass through the prediction network as you'll normally do logits = self.prediction_network(x_canonicalized) - # Get the predictions and calculate the accuracy + # Get the predictions and calculate the accuracy preds = logits.argmax(dim=-1) acc = (preds == y).float().mean() - - # Log the identity metric if the prior weight is non-zero + + # Log the identity metric if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: metric_identity = self.canonicalizer.get_identity_metric() - validation_metrics.update({ - "train/identity_metric": metric_identity - }) - - + validation_metrics.update({"train/identity_metric": metric_identity}) + # Logging to TensorBoard by default - validation_metrics.update({ - "val/acc": acc - }) - + validation_metrics.update({"val/acc": acc}) + self.log_dict(validation_metrics, prog_bar=True) return {'acc': acc} - def test_step(self, batch: torch.Tensor): x, y = batch batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape test_metrics = self.inference_method.get_inference_metrics(x, y) - + # Log the test metrics self.log_dict(test_metrics, prog_bar=True) - - return test_metrics - + + return test_metrics def configure_optimizers(self): - if 'resnet' in self.hyperparams.prediction.prediction_network_architecture and 'mnist' not in self.hyperparams.dataset.dataset_name: + if ( + 'resnet' in self.hyperparams.prediction.prediction_network_architecture + and 'mnist' not in self.hyperparams.dataset.dataset_name + ): print('using SGD optimizer') optimizer = torch.optim.SGD( [ - {'params': self.prediction_network.parameters(), 'lr': self.hyperparams.experiment.training.prediction_lr}, - {'params': self.canonicalizer.parameters(), 'lr': self.hyperparams.experiment.training.canonicalization_lr}, - ], + { + 'params': self.prediction_network.parameters(), + 'lr': self.hyperparams.experiment.training.prediction_lr, + }, + { + 'params': self.canonicalizer.parameters(), + 'lr': self.hyperparams.experiment.training.canonicalization_lr, + }, + ], momentum=0.9, weight_decay=5e-4, ) - + if self.max_epochs > 100: - milestones = [self.trainer.max_epochs // 6, self.trainer.max_epochs // 3, self.trainer.max_epochs // 2] + milestones = [ + self.trainer.max_epochs // 6, + self.trainer.max_epochs // 3, + self.trainer.max_epochs // 2, + ] else: - milestones = [self.trainer.max_epochs // 3, self.trainer.max_epochs // 2] # for small training epochs - + milestones = [ + self.trainer.max_epochs // 3, + self.trainer.max_epochs // 2, + ] # for small training epochs + scheduler_dict = { "scheduler": MultiStepLR( optimizer, @@ -192,8 +205,16 @@ def configure_optimizers(self): return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} else: print(f'using Adam optimizer') - optimizer = torch.optim.AdamW([ - {'params': self.prediction_network.parameters(), 'lr': self.hyperparams.experiment.training.prediction_lr}, - {'params': self.canonicalizer.parameters(), 'lr': self.hyperparams.experiment.training.canonicalization_lr}, - ]) - return optimizer \ No newline at end of file + optimizer = torch.optim.AdamW( + [ + { + 'params': self.prediction_network.parameters(), + 'lr': self.hyperparams.experiment.training.prediction_lr, + }, + { + 'params': self.canonicalizer.parameters(), + 'lr': self.hyperparams.experiment.training.canonicalization_lr, + }, + ] + ) + return optimizer diff --git a/examples/images/classification/model_utils.py b/examples/images/classification/model_utils.py index 5442d4d..7016353 100644 --- a/examples/images/classification/model_utils.py +++ b/examples/images/classification/model_utils.py @@ -1,13 +1,24 @@ import torch -import torchvision import torch.nn as nn - +import torchvision from omegaconf import DictConfig from equiadapt.common.basecanonicalization import IdentityCanonicalization -from equiadapt.images.canonicalization.discrete_group import GroupEquivariantImageCanonicalization, OptimizedGroupEquivariantImageCanonicalization -from equiadapt.images.canonicalization.continuous_group import SteerableImageCanonicalization, OptimizedSteerableImageCanonicalization -from equiadapt.images.canonicalization_networks import ESCNNEquivariantNetwork, ConvNetwork, CustomEquivariantNetwork, ESCNNSteerableNetwork +from equiadapt.images.canonicalization.continuous_group import ( + OptimizedSteerableImageCanonicalization, + SteerableImageCanonicalization, +) +from equiadapt.images.canonicalization.discrete_group import ( + GroupEquivariantImageCanonicalization, + OptimizedGroupEquivariantImageCanonicalization, +) +from equiadapt.images.canonicalization_networks import ( + ConvNetwork, + CustomEquivariantNetwork, + ESCNNEquivariantNetwork, + ESCNNSteerableNetwork, +) + class PredictionNetwork(nn.Module): def __init__(self, encoder: torch.nn.Module, feature_dim: int, num_classes: int): @@ -19,7 +30,8 @@ def forward(self, x): reps = self.encoder(x) reps = reps.view(x.shape[0], -1) return self.predictor(reps) - + + def get_dataset_specific_info(dataset_name): dataset_info = { 'rotated_mnist': (nn.CrossEntropyLoss(), (1, 28, 28), 10), @@ -35,43 +47,49 @@ def get_dataset_specific_info(dataset_name): raise ValueError('Dataset not implemented for now.') return dataset_info[dataset_name] - + def get_prediction_network( - architecture: str = 'resnet50', + architecture: str = 'resnet50', dataset_name: str = 'cifar10', use_pretrained: bool = False, freeze_encoder: bool = False, input_shape: tuple = (3, 32, 32), - num_classes: int = 10 + num_classes: int = 10, ): weights = 'DEFAULT' if use_pretrained else None model_dict = { 'resnet50': torchvision.models.resnet50, - 'vit': torchvision.models.vit_b_16 + 'vit': torchvision.models.vit_b_16, } if architecture not in model_dict: - raise ValueError(f'{architecture} is not implemented as prediction network for now.') + raise ValueError( + f'{architecture} is not implemented as prediction network for now.' + ) encoder = model_dict[architecture](weights=weights) - if architecture == 'resnet50' and dataset_name in ('cifar10', 'cifar100', 'rotated_mnist'): + if architecture == 'resnet50' and dataset_name in ( + 'cifar10', + 'cifar100', + 'rotated_mnist', + ): if input_shape[-2:] == [32, 32] or dataset_name == 'rotated_mnist': - encoder.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False) + encoder.conv1 = nn.Conv2d( + input_shape[0], 64, kernel_size=3, stride=1, padding=1, bias=False + ) encoder.maxpool = nn.Identity() - + if freeze_encoder: for param in encoder.parameters(): param.requires_grad = False if dataset_name != 'ImageNet': feature_dim = encoder.fc.in_features - encoder.fc = nn.Identity() + encoder.fc = nn.Identity() prediction_network = PredictionNetwork(encoder, feature_dim, num_classes) else: prediction_network = encoder - - - return prediction_network \ No newline at end of file + return prediction_network diff --git a/examples/images/classification/prepare/__init__.py b/examples/images/classification/prepare/__init__.py index 67886d5..be3a8f1 100644 --- a/examples/images/classification/prepare/__init__.py +++ b/examples/images/classification/prepare/__init__.py @@ -1,6 +1,6 @@ -from .rotated_mnist_data import RotatedMNISTDataModule -from .cifar_data import CIFAR10DataModule, CIFAR100DataModule -from .stl10_data import STL10DataModule from .celeba_data import CelebADataModule +from .cifar_data import CIFAR10DataModule, CIFAR100DataModule from .flowers102_data import Flowers102DataModule -from .imagenet_data import ImageNetDataModule \ No newline at end of file +from .imagenet_data import ImageNetDataModule +from .rotated_mnist_data import RotatedMNISTDataModule +from .stl10_data import STL10DataModule diff --git a/examples/images/classification/prepare/celeba_data.py b/examples/images/classification/prepare/celeba_data.py index c3c064d..9ea5b33 100644 --- a/examples/images/classification/prepare/celeba_data.py +++ b/examples/images/classification/prepare/celeba_data.py @@ -1,28 +1,26 @@ +import os +import random import pytorch_lightning as pl from torch.utils.data import DataLoader, random_split from torchvision import transforms from torchvision.datasets import CelebA -import os -import random class CelebADataModule(pl.LightningDataModule): def __init__(self, hyperparams, download=False): super().__init__() self.data_path = hyperparams.data_path self.hyperparams = hyperparams - + if hyperparams.augment == 1: self.train_transform = transforms.Compose( [ transforms.Resize(224), transforms.Pad(4), transforms.RandomCrop(224), - transforms.RandomRotation(5), transforms.RandomHorizontalFlip(), - transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] @@ -34,10 +32,8 @@ def __init__(self, hyperparams, download=False): transforms.Resize(224), transforms.Pad(4), transforms.RandomCrop(224), - - transforms.RandomRotation(180), # sampling from (-180, 180) + transforms.RandomRotation(180), # sampling from (-180, 180) transforms.RandomHorizontalFlip(), - transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] @@ -48,7 +44,6 @@ def __init__(self, hyperparams, download=False): transforms.Resize(224), transforms.Pad(4), transforms.RandomCrop(224), - transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] @@ -57,7 +52,6 @@ def __init__(self, hyperparams, download=False): [ transforms.Resize(224), transforms.RandomCrop(224), - transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] @@ -66,10 +60,28 @@ def __init__(self, hyperparams, download=False): def setup(self, stage=None): if stage == "fit" or stage is None: - self.train_dataset = CelebA(self.data_path, split='train', target_type='attr', transform=self.train_transform, download=True) - self.valid_dataset = CelebA(self.data_path, split='valid', target_type='attr', transform=self.test_transform, download=True) + self.train_dataset = CelebA( + self.data_path, + split='train', + target_type='attr', + transform=self.train_transform, + download=True, + ) + self.valid_dataset = CelebA( + self.data_path, + split='valid', + target_type='attr', + transform=self.test_transform, + download=True, + ) if stage == "test": - self.test_dataset = CelebA(self.data_path, split='test', target_type='attr', transform=self.test_transform, download=True) + self.test_dataset = CelebA( + self.data_path, + split='test', + target_type='attr', + transform=self.test_transform, + download=True, + ) print('Test dataset size: ', len(self.test_dataset)) def train_dataloader(self): @@ -97,4 +109,4 @@ def test_dataloader(self): shuffle=False, num_workers=self.hyperparams.num_workers, ) - return test_loader \ No newline at end of file + return test_loader diff --git a/examples/images/classification/prepare/cifar_data.py b/examples/images/classification/prepare/cifar_data.py index 595ce48..d9eb88f 100644 --- a/examples/images/classification/prepare/cifar_data.py +++ b/examples/images/classification/prepare/cifar_data.py @@ -1,11 +1,11 @@ +import os +import random import pytorch_lightning as pl from torch.utils.data import DataLoader, random_split from torchvision import transforms from torchvision.datasets import CIFAR10, CIFAR100 -import os -import random class CustomRotationTransform: """Rotate by one of the given angles.""" @@ -17,61 +17,75 @@ def __call__(self, x): angle = random.choice(self.angles) return transforms.functional.rotate(x, angle) + class CIFAR10DataModule(pl.LightningDataModule): def __init__(self, hyperparams, download=False): super().__init__() self.data_path = hyperparams.data_path self.hyperparams = hyperparams if hyperparams.augment == 1: - self.train_transform = transforms.Compose([ + self.train_transform = transforms.Compose( + [ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - transforms.RandomHorizontalFlip(), transforms.RandomRotation(5), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + ) elif hyperparams.augment == 2: # all augmentations - self.train_transform = transforms.Compose([ + self.train_transform = transforms.Compose( + [ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - transforms.RandomHorizontalFlip(), CustomRotationTransform([0, 45, 90, 135, 180, 225, 270, 315]), # transforms.RandomRotation(180), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + ) elif hyperparams.augment == 3: # autoaugment - self.train_transform = transforms.Compose([ + self.train_transform = transforms.Compose( + [ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - transforms.RandomHorizontalFlip(), - transforms.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10), - + transforms.AutoAugment( + policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10 + ), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + ) else: - self.train_transform = transforms.Compose([ + self.train_transform = transforms.Compose( + [ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) - self.test_transform = transforms.Compose([ + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + ) + self.test_transform = transforms.Compose( + [ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) + ] + ) os.makedirs(self.data_path, exist_ok=True) def setup(self, stage=None): @@ -82,10 +96,25 @@ def setup(self, stage=None): # print('Train dataset size: ', len(self.train_dataset)) # print('Valid dataset size: ', len(self.valid_dataset)) # Not a good strategy for splitting data but most papers use this - self.train_dataset = CIFAR10(self.data_path, train=True, transform=self.train_transform, download=True) - self.valid_dataset = CIFAR10(self.data_path, train=False, transform=self.test_transform, download=True) + self.train_dataset = CIFAR10( + self.data_path, + train=True, + transform=self.train_transform, + download=True, + ) + self.valid_dataset = CIFAR10( + self.data_path, + train=False, + transform=self.test_transform, + download=True, + ) if stage == "test": - self.test_dataset = CIFAR10(self.data_path, train=False, transform=self.test_transform, download=True) + self.test_dataset = CIFAR10( + self.data_path, + train=False, + transform=self.test_transform, + download=True, + ) print('Test dataset size: ', len(self.test_dataset)) def train_dataloader(self): @@ -115,61 +144,75 @@ def test_dataloader(self): ) return test_loader + class CIFAR100DataModule(pl.LightningDataModule): def __init__(self, hyperparams, download=False): super().__init__() self.data_path = hyperparams.data_path self.hyperparams = hyperparams if hyperparams.augment == 1: - self.train_transform = transforms.Compose([ + self.train_transform = transforms.Compose( + [ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - transforms.RandomHorizontalFlip(), transforms.RandomRotation(5), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + ) elif hyperparams.augment == 2: # all augmentations - self.train_transform = transforms.Compose([ + self.train_transform = transforms.Compose( + [ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - transforms.RandomHorizontalFlip(), CustomRotationTransform([0, 45, 90, 135, 180, 225, 270, 315]), # transforms.RandomRotation(180), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + ) elif hyperparams.augment == 3: # autoaugment - self.train_transform = transforms.Compose([ + self.train_transform = transforms.Compose( + [ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - transforms.RandomHorizontalFlip(), - transforms.AutoAugment(policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10), - + transforms.AutoAugment( + policy=transforms.autoaugment.AutoAugmentPolicy.CIFAR10 + ), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + ) else: - self.train_transform = transforms.Compose([ + self.train_transform = transforms.Compose( + [ transforms.RandomCrop(32, padding=4), transforms.Resize(224), - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) - self.test_transform = transforms.Compose([ + transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + ) + self.test_transform = transforms.Compose( + [ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), - ]) + ] + ) os.makedirs(self.data_path, exist_ok=True) def setup(self, stage=None): @@ -180,10 +223,25 @@ def setup(self, stage=None): # print('Train dataset size: ', len(self.train_dataset)) # print('Valid dataset size: ', len(self.valid_dataset)) # Not a good strategy for splitting data but most papers use this - self.train_dataset = CIFAR100(self.data_path, train=True, transform=self.train_transform, download=True) - self.valid_dataset = CIFAR100(self.data_path, train=False, transform=self.test_transform, download=True) + self.train_dataset = CIFAR100( + self.data_path, + train=True, + transform=self.train_transform, + download=True, + ) + self.valid_dataset = CIFAR100( + self.data_path, + train=False, + transform=self.test_transform, + download=True, + ) if stage == "test": - self.test_dataset = CIFAR100(self.data_path, train=False, transform=self.test_transform, download=True) + self.test_dataset = CIFAR100( + self.data_path, + train=False, + transform=self.test_transform, + download=True, + ) print('Test dataset size: ', len(self.test_dataset)) def train_dataloader(self): @@ -211,4 +269,4 @@ def test_dataloader(self): shuffle=False, num_workers=self.hyperparams.num_workers, ) - return test_loader \ No newline at end of file + return test_loader diff --git a/examples/images/classification/prepare/flowers102_data.py b/examples/images/classification/prepare/flowers102_data.py index 05e2c8e..141e7b3 100644 --- a/examples/images/classification/prepare/flowers102_data.py +++ b/examples/images/classification/prepare/flowers102_data.py @@ -1,11 +1,11 @@ +import os +import random import pytorch_lightning as pl from torch.utils.data import DataLoader, random_split from torchvision import transforms from torchvision.datasets import Flowers102 -import os -import random class Flowers102DataModule(pl.LightningDataModule): def __init__(self, hyperparams, download=False): @@ -33,10 +33,25 @@ def __init__(self, hyperparams, download=False): def setup(self, stage=None): if stage == "fit" or stage is None: - self.train_dataset = Flowers102(self.data_path, split='train', transform=self.train_transform, download=True) - self.valid_dataset = Flowers102(self.data_path, split='val', transform=self.test_transform, download=True) + self.train_dataset = Flowers102( + self.data_path, + split='train', + transform=self.train_transform, + download=True, + ) + self.valid_dataset = Flowers102( + self.data_path, + split='val', + transform=self.test_transform, + download=True, + ) if stage == "test": - self.test_dataset = Flowers102(self.data_path, split='test', transform=self.test_transform, download=True) + self.test_dataset = Flowers102( + self.data_path, + split='test', + transform=self.test_transform, + download=True, + ) print('Test dataset size: ', len(self.test_dataset)) def train_dataloader(self): @@ -64,4 +79,4 @@ def test_dataloader(self): shuffle=False, num_workers=self.hyperparams.num_workers, ) - return test_loader \ No newline at end of file + return test_loader diff --git a/examples/images/classification/prepare/imagenet_data.py b/examples/images/classification/prepare/imagenet_data.py index 3dc3572..1d893e9 100644 --- a/examples/images/classification/prepare/imagenet_data.py +++ b/examples/images/classification/prepare/imagenet_data.py @@ -1,17 +1,16 @@ +import os +import random +from typing import List import pytorch_lightning as pl - -import os import torch -import random import torchvision -from torch import nn -from typing import List -from PIL import Image, ImageOps import torchvision.transforms as transforms +from PIL import Image, ImageOps +from torch import nn +DEFAULT_CROP_RATIO = 224 / 256 -DEFAULT_CROP_RATIO = 224/256 class GaussianBlur(nn.Module): def __init__(self, p): @@ -26,6 +25,7 @@ def forward(self, img): else: return img + class Solarization(nn.Module): def __init__(self, p): super().__init__() @@ -37,6 +37,7 @@ def forward(self, img): else: return img + class CustomRotationTransform: """Rotate by one of the given angles.""" @@ -47,30 +48,43 @@ def __call__(self, x): angle = random.choice(self.angles) return transforms.functional.rotate(x, angle) + class Transform: def __init__(self, mode='train'): # these transformations are essential for reproducing the zero-shot performance if mode == 'train': - self.transform = transforms.Compose([ - transforms.RandomResizedCrop(224, interpolation=transforms.functional.InterpolationMode.BILINEAR), - transforms.RandomHorizontalFlip(0.5), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - ]) + self.transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + 224, + interpolation=transforms.functional.InterpolationMode.BILINEAR, + ), + transforms.RandomHorizontalFlip(0.5), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) elif mode == 'val': - self.transform = transforms.Compose([ - transforms.Resize(256, interpolation=transforms.functional.InterpolationMode.BILINEAR), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - ]) - + self.transform = transforms.Compose( + [ + transforms.Resize( + 256, + interpolation=transforms.functional.InterpolationMode.BILINEAR, + ), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) def __call__(self, x): return self.transform(x) + class ImageNetDataModule(pl.LightningDataModule): def __init__(self, hyperparams): super().__init__() @@ -95,8 +109,9 @@ def val_dataloader(self): def test_dataloader(self): return self.loaders['val'] - - def get_imagenet_pytorch_dataloaders(self, data_dir=None, batch_size=None, num_workers=None): + def get_imagenet_pytorch_dataloaders( + self, data_dir=None, batch_size=None, num_workers=None + ): paths = { 'train': data_dir + '/train', 'val': data_dir + '/val', @@ -109,18 +124,21 @@ def get_imagenet_pytorch_dataloaders(self, data_dir=None, batch_size=None, num_w drop_last = True if name == 'train' else False shuffle = True if name == 'train' else False loader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, num_workers=num_workers, - pin_memory=True, shuffle=shuffle, drop_last=drop_last + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + shuffle=shuffle, + drop_last=drop_last, ) loaders[name] = loader return loaders - def get_imagenet_pytorch_dataloaders_distributed(self, data_dir=None, batch_size=None, num_workers=None, world_size=None): - paths = { - 'train': data_dir + '/train', - 'val': data_dir + '/val' - } + def get_imagenet_pytorch_dataloaders_distributed( + self, data_dir=None, batch_size=None, num_workers=None, world_size=None + ): + paths = {'train': data_dir + '/train', 'val': data_dir + '/val'} loaders = {} @@ -130,8 +148,11 @@ def get_imagenet_pytorch_dataloaders_distributed(self, data_dir=None, batch_size assert batch_size % world_size == 0 per_device_batch_size = batch_size // world_size loader = torch.utils.data.DataLoader( - dataset, batch_size=per_device_batch_size, num_workers=num_workers, - pin_memory=True, sampler=sampler + dataset, + batch_size=per_device_batch_size, + num_workers=num_workers, + pin_memory=True, + sampler=sampler, ) loaders[name] = loader diff --git a/examples/images/classification/prepare/rotated_mnist_data.py b/examples/images/classification/prepare/rotated_mnist_data.py index a2e092a..9a3c7ca 100644 --- a/examples/images/classification/prepare/rotated_mnist_data.py +++ b/examples/images/classification/prepare/rotated_mnist_data.py @@ -1,11 +1,13 @@ -import os import argparse +import os import urllib.request as url_req import zipfile + import numpy as np +import pytorch_lightning as pl import torch from torch.utils.data import DataLoader, TensorDataset -import pytorch_lightning as pl + def obtain(dir_path): os.makedirs(dir_path, exist_ok=True) @@ -13,8 +15,10 @@ def obtain(dir_path): print('Downloading the dataset') ## Download the main zip file - url_req.urlretrieve('http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip', - os.path.join(dir_path, 'mnist_rotated.zip')) + url_req.urlretrieve( + 'http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip', + os.path.join(dir_path, 'mnist_rotated.zip'), + ) # Extract the zip file print('Extracting the dataset') @@ -31,8 +35,14 @@ def obtain(dir_path): test_file_path = os.path.join(dir_path, 'mnist_rotated_test.amat') # Rename train and test files - os.rename(os.path.join(dir_path, 'mnist_all_rotation_normalized_float_train_valid.amat'), train_file_path) - os.rename(os.path.join(dir_path, 'mnist_all_rotation_normalized_float_test.amat'), test_file_path) + os.rename( + os.path.join(dir_path, 'mnist_all_rotation_normalized_float_train_valid.amat'), + train_file_path, + ) + os.rename( + os.path.join(dir_path, 'mnist_all_rotation_normalized_float_test.amat'), + test_file_path, + ) # Split data in valid file and train file fp = open(train_file_path) @@ -67,6 +77,7 @@ def load_line(line): tokens = line.split() return np.array([float(i) for i in tokens[:-1]]), int(float(tokens[-1])) + def custom_load_data(file_path): fp = open(file_path) # Add the lines of the file into a list @@ -81,6 +92,7 @@ def custom_load_data(file_path): labels = torch.stack(label_list) return images, labels + def get_dataset(dir_path, split='train'): if split == 'train': file_path = os.path.join(dir_path, 'mnist_rotated_train.amat') @@ -139,6 +151,7 @@ def test_dataloader(self): ) return test_loader + # if __name__ == "__main__": # parser = argparse.ArgumentParser() # parser.add_argument( @@ -147,4 +160,4 @@ def test_dataloader(self): # help='Path to the dataset.' # ) # args = parser.parse_args() -# obtain(args.data_path) \ No newline at end of file +# obtain(args.data_path) diff --git a/examples/images/classification/prepare/stl10_data.py b/examples/images/classification/prepare/stl10_data.py index 7f16b57..13d9f77 100644 --- a/examples/images/classification/prepare/stl10_data.py +++ b/examples/images/classification/prepare/stl10_data.py @@ -1,11 +1,11 @@ +import os +import random import pytorch_lightning as pl from torch.utils.data import DataLoader, random_split from torchvision import transforms from torchvision.datasets import STL10 -import os -import random class CustomRotationTransform: """Rotate by one of the given angles.""" @@ -16,7 +16,8 @@ def __init__(self, angles): def __call__(self, x): angle = random.choice(self.angles) return transforms.functional.rotate(x, angle) - + + class STL10DataModule(pl.LightningDataModule): def __init__(self, hyperparams, download=False): super().__init__() @@ -28,10 +29,8 @@ def __init__(self, hyperparams, download=False): transforms.Pad(4), transforms.RandomCrop(96), transforms.Resize(224), - transforms.RandomRotation(5), transforms.RandomHorizontalFlip(), - transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] @@ -43,11 +42,9 @@ def __init__(self, hyperparams, download=False): transforms.Pad(4), transforms.RandomCrop(96), transforms.Resize(224), - CustomRotationTransform([0, 45, 90, 135, 180, 225, 270, 315]), # transforms.RandomRotation(180), transforms.RandomHorizontalFlip(), - transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] @@ -60,7 +57,6 @@ def __init__(self, hyperparams, download=False): transforms.RandomCrop(96), transforms.Resize(224), transforms.RandomHorizontalFlip(), - transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] @@ -76,10 +72,25 @@ def __init__(self, hyperparams, download=False): def setup(self, stage=None): if stage == "fit" or stage is None: - self.train_dataset = STL10(self.data_path, split='train', transform=self.train_transform, download=True) - self.valid_dataset = STL10(self.data_path, split='test', transform=self.test_transform, download=True) + self.train_dataset = STL10( + self.data_path, + split='train', + transform=self.train_transform, + download=True, + ) + self.valid_dataset = STL10( + self.data_path, + split='test', + transform=self.test_transform, + download=True, + ) if stage == "test": - self.test_dataset = STL10(self.data_path, split='test', transform=self.test_transform, download=True) + self.test_dataset = STL10( + self.data_path, + split='test', + transform=self.test_transform, + download=True, + ) print('Test dataset size: ', len(self.test_dataset)) def train_dataloader(self): @@ -107,4 +118,4 @@ def test_dataloader(self): shuffle=False, num_workers=self.hyperparams.num_workers, ) - return test_loader \ No newline at end of file + return test_loader diff --git a/examples/images/classification/train.py b/examples/images/classification/train.py index 8834e1e..4a72b97 100644 --- a/examples/images/classification/train.py +++ b/examples/images/classification/train.py @@ -1,23 +1,34 @@ import os -import torch -import wandb import hydra import omegaconf -from omegaconf import DictConfig, OmegaConf - import pytorch_lightning as pl +import torch +import wandb +from omegaconf import DictConfig, OmegaConf from pytorch_lightning.loggers import WandbLogger - from train_utils import get_model_data_and_callbacks, get_trainer, load_envs + def train_images(hyperparams: DictConfig): - hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type'] + hyperparams['canonicalization_type'] = hyperparams['canonicalization'][ + 'canonicalization_type' + ] hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' - hyperparams['dataset']['data_path'] = hyperparams['dataset']['data_path'] + "/" + hyperparams['dataset']['dataset_name'] - hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \ - hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \ - + "/" + hyperparams['prediction']['prediction_network_architecture'] + hyperparams['dataset']['data_path'] = ( + hyperparams['dataset']['data_path'] + + "/" + + hyperparams['dataset']['dataset_name'] + ) + hyperparams['checkpoint']['checkpoint_path'] = ( + hyperparams['checkpoint']['checkpoint_path'] + + "/" + + hyperparams['dataset']['dataset_name'] + + "/" + + hyperparams['canonicalization_type'] + + "/" + + hyperparams['prediction']['prediction_network_architecture'] + ) # set system environment variables for wandb if hyperparams['wandb']['use_wandb']: @@ -27,19 +38,30 @@ def train_images(hyperparams: DictConfig): print("Wandb disabled for logging...") os.environ["WANDB_MODE"] = "disabled" os.environ["WANDB_DIR"] = hyperparams['wandb']['wandb_dir'] - os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] - + os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] + # initialize wandb - wandb.init(config=OmegaConf.to_container(hyperparams, resolve=True), entity=hyperparams['wandb']['wandb_entity'], project=hyperparams['wandb']['wandb_project'], dir=hyperparams['wandb']['wandb_dir']) - wandb_logger = WandbLogger(project=hyperparams['wandb']['wandb_project'], log_model="all") + wandb.init( + config=OmegaConf.to_container(hyperparams, resolve=True), + entity=hyperparams['wandb']['wandb_entity'], + project=hyperparams['wandb']['wandb_project'], + dir=hyperparams['wandb']['wandb_dir'], + ) + wandb_logger = WandbLogger( + project=hyperparams['wandb']['wandb_project'], log_model="all" + ) # set seed pl.seed_everything(hyperparams.experiment.seed) - + # get model, callbacks, and image data model, image_data, callbacks = get_model_data_and_callbacks(hyperparams) - - if hyperparams.canonicalization_type in ("group_equivariant", "opt_equivariant", "steerable"): + + if hyperparams.canonicalization_type in ( + "group_equivariant", + "opt_equivariant", + "steerable", + ): wandb.watch(model.canonicalizer.canonicalization_network, log='all') # get trainer @@ -47,18 +69,21 @@ def train_images(hyperparams: DictConfig): if hyperparams.experiment.run_mode == "train": trainer.fit(model, datamodule=image_data) - + elif hyperparams.experiment.run_mode == "auto_tune": trainer.tune(model, datamodule=image_data) trainer.test(model, datamodule=image_data) + # load the variables from .env file load_envs() + @hydra.main(config_path=str("./configs/"), config_name="default") def main(cfg: omegaconf.DictConfig): train_images(cfg) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/images/classification/train_utils.py b/examples/images/classification/train_utils.py index 27f2ae3..25bb65c 100644 --- a/examples/images/classification/train_utils.py +++ b/examples/images/classification/train_utils.py @@ -1,77 +1,101 @@ -import dotenv -from omegaconf import DictConfig from typing import Dict, Optional +import dotenv import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping - from model import ImageClassifierPipeline -from prepare import RotatedMNISTDataModule, CIFAR10DataModule, CIFAR100DataModule, STL10DataModule, Flowers102DataModule, CelebADataModule, ImageNetDataModule - -def get_model_data_and_callbacks(hyperparams : DictConfig): - - # get image data +from omegaconf import DictConfig +from prepare import ( + CelebADataModule, + CIFAR10DataModule, + CIFAR100DataModule, + Flowers102DataModule, + ImageNetDataModule, + RotatedMNISTDataModule, + STL10DataModule, +) +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint + + +def get_model_data_and_callbacks(hyperparams: DictConfig): + + # get image data image_data = get_image_data(hyperparams.dataset) - + # checkpoint name hyperparams.checkpoint.checkpoint_name = get_checkpoint_name(hyperparams) - + # checkpoint callbacks callbacks = get_callbacks(hyperparams) - # get model pipeline + # get model pipeline model = get_model_pipeline(hyperparams) - - return model, image_data, callbacks + + return model, image_data, callbacks + def get_model_pipeline(hyperparams: DictConfig): if hyperparams.experiment.run_mode == "test": model = ImageClassifierPipeline.load_from_checkpoint( - checkpoint_path=hyperparams.checkpoint.checkpoint_path + "/" + \ - hyperparams.checkpoint.checkpoint_name + ".ckpt", - hyperparams=hyperparams + checkpoint_path=hyperparams.checkpoint.checkpoint_path + + "/" + + hyperparams.checkpoint.checkpoint_name + + ".ckpt", + hyperparams=hyperparams, ) model.freeze() model.eval() else: model = ImageClassifierPipeline(hyperparams) - + return model + def get_trainer( - hyperparams: DictConfig, - callbacks: list, - wandb_logger: pl.loggers.WandbLogger + hyperparams: DictConfig, callbacks: list, wandb_logger: pl.loggers.WandbLogger ): if hyperparams.experiment.run_mode == "auto_tune": trainer = pl.Trainer( - max_epochs=hyperparams.experiment.num_epochs, accelerator="auto", - auto_scale_batch_size=True, auto_lr_find=True, logger=wandb_logger, - callbacks=callbacks, deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, - strategy='ddp' + max_epochs=hyperparams.experiment.num_epochs, + accelerator="auto", + auto_scale_batch_size=True, + auto_lr_find=True, + logger=wandb_logger, + callbacks=callbacks, + deterministic=hyperparams.experiment.deterministic, + num_nodes=hyperparams.experiment.num_nodes, + devices=hyperparams.experiment.num_gpus, + strategy='ddp', ) - + elif hyperparams.experiment.run_mode == "dryrun": trainer = pl.Trainer( - fast_dev_run=5, max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", - limit_train_batches=5, limit_val_batches=5, logger=wandb_logger, - callbacks=callbacks, deterministic=hyperparams.experiment.deterministic + fast_dev_run=5, + max_epochs=hyperparams.experiment.training.num_epochs, + accelerator="auto", + limit_train_batches=5, + limit_val_batches=5, + logger=wandb_logger, + callbacks=callbacks, + deterministic=hyperparams.experiment.deterministic, ) else: trainer = pl.Trainer( - max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", - logger=wandb_logger, callbacks=callbacks, deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, - strategy='ddp' + max_epochs=hyperparams.experiment.training.num_epochs, + accelerator="auto", + logger=wandb_logger, + callbacks=callbacks, + deterministic=hyperparams.experiment.deterministic, + num_nodes=hyperparams.experiment.num_nodes, + devices=hyperparams.experiment.num_gpus, + strategy='ddp', ) return trainer - - + + def get_callbacks(hyperparams: DictConfig): - + checkpoint_callback = ModelCheckpoint( dirpath=hyperparams.checkpoint.checkpoint_path, filename=hyperparams.checkpoint.checkpoint_name, @@ -79,14 +103,17 @@ def get_callbacks(hyperparams: DictConfig): mode="max", save_on_train_epoch_end=False, ) - early_stop_metric_callback = EarlyStopping(monitor="val/acc", - min_delta=hyperparams.experiment.training.min_delta, - patience=hyperparams.experiment.training.patience, - verbose=True, - mode="max") - + early_stop_metric_callback = EarlyStopping( + monitor="val/acc", + min_delta=hyperparams.experiment.training.min_delta, + patience=hyperparams.experiment.training.patience, + verbose=True, + mode="max", + ) + return [checkpoint_callback, early_stop_metric_callback] + def get_recursive_hyperparams_identifier(hyperparams: Dict): # get the identifier for the canonicalization network hyperparameters # recursively go through the dictionary and get the values and concatenate them @@ -97,15 +124,21 @@ def get_recursive_hyperparams_identifier(hyperparams: Dict): else: identifier += f"_{key}_{value}_" return identifier - -def get_checkpoint_name(hyperparams : DictConfig): - - return f"{get_recursive_hyperparams_identifier(hyperparams.canonicalization)}".lstrip("_") + \ - f"__epochs_{hyperparams.experiment.training.num_epochs}_" + f"__seed_{hyperparams.experiment.seed}" - + + +def get_checkpoint_name(hyperparams: DictConfig): + + return ( + f"{get_recursive_hyperparams_identifier(hyperparams.canonicalization)}".lstrip( + "_" + ) + + f"__epochs_{hyperparams.experiment.training.num_epochs}_" + + f"__seed_{hyperparams.experiment.seed}" + ) + def get_image_data(dataset_hyperparams: DictConfig): - + dataset_classes = { "rotated_mnist": RotatedMNISTDataModule, "cifar10": CIFAR10DataModule, @@ -113,14 +146,15 @@ def get_image_data(dataset_hyperparams: DictConfig): "stl10": STL10DataModule, "celeba": CelebADataModule, "flowers102": Flowers102DataModule, - "imagenet": ImageNetDataModule + "imagenet": ImageNetDataModule, } - + if dataset_hyperparams.dataset_name not in dataset_classes: raise ValueError(f"{dataset_hyperparams.dataset_name} not implemented") - + return dataset_classes[dataset_hyperparams.dataset_name](dataset_hyperparams) + def load_envs(env_file: Optional[str] = None) -> None: """ Load all the environment variables defined in the `env_file`. @@ -131,4 +165,4 @@ def load_envs(env_file: Optional[str] = None) -> None: :param env_file: the file that defines the environment variables to use. If None it searches for a `.env` file in the project. """ - dotenv.load_dotenv(dotenv_path=env_file, override=True) \ No newline at end of file + dotenv.load_dotenv(dotenv_path=env_file, override=True) diff --git a/examples/images/common/utils.py b/examples/images/common/utils.py index 7fd7b6b..6ddf784 100644 --- a/examples/images/common/utils.py +++ b/examples/images/common/utils.py @@ -1,15 +1,27 @@ -import torch +import torch from omegaconf import DictConfig from equiadapt.common.basecanonicalization import IdentityCanonicalization -from equiadapt.images.canonicalization.discrete_group import GroupEquivariantImageCanonicalization, OptimizedGroupEquivariantImageCanonicalization -from equiadapt.images.canonicalization.continuous_group import SteerableImageCanonicalization, OptimizedSteerableImageCanonicalization -from equiadapt.images.canonicalization_networks import ESCNNEquivariantNetwork, ConvNetwork, CustomEquivariantNetwork, ESCNNSteerableNetwork +from equiadapt.images.canonicalization.continuous_group import ( + OptimizedSteerableImageCanonicalization, + SteerableImageCanonicalization, +) +from equiadapt.images.canonicalization.discrete_group import ( + GroupEquivariantImageCanonicalization, + OptimizedGroupEquivariantImageCanonicalization, +) +from equiadapt.images.canonicalization_networks import ( + ConvNetwork, + CustomEquivariantNetwork, + ESCNNEquivariantNetwork, + ESCNNSteerableNetwork, +) + def get_canonicalization_network( canonicalization_type: str, canonicalization_hyperparams: DictConfig, - in_shape: tuple + in_shape: tuple, ): """ The function returns the canonicalization network based on the canonicalization type @@ -20,7 +32,7 @@ def get_canonicalization_network( """ if canonicalization_type == 'identity': return torch.nn.Identity() - + canonicalization_network_dict = { 'group_equivariant': { 'escnn': ESCNNEquivariantNetwork, @@ -29,27 +41,28 @@ def get_canonicalization_network( 'steerable': { 'escnn': ESCNNSteerableNetwork, }, - 'opt_group_equivariant':{ + 'opt_group_equivariant': { 'cnn': ConvNetwork, }, 'opt_steerable': { 'cnn': ConvNetwork, - } + }, } - + if canonicalization_type not in canonicalization_network_dict: - raise ValueError(f'{canonicalization_type} is not implemented') - if canonicalization_hyperparams.network_type not in canonicalization_network_dict[canonicalization_type]: - raise ValueError(f'{canonicalization_hyperparams.network_type} is not implemented for {canonicalization_type}') - - canonicalization_network = \ - canonicalization_network_dict[canonicalization_type][ + raise ValueError(f'{canonicalization_type} is not implemented') + if ( canonicalization_hyperparams.network_type - ]( - in_shape = in_shape, - **canonicalization_hyperparams.network_hyperparams + not in canonicalization_network_dict[canonicalization_type] + ): + raise ValueError( + f'{canonicalization_hyperparams.network_type} is not implemented for {canonicalization_type}' ) - + + canonicalization_network = canonicalization_network_dict[canonicalization_type][ + canonicalization_hyperparams.network_type + ](in_shape=in_shape, **canonicalization_hyperparams.network_hyperparams) + return canonicalization_network @@ -57,7 +70,7 @@ def get_canonicalizer( canonicalization_type: str, canonicalization_network: torch.nn.Module, canonicalization_hyperparams: DictConfig, - in_shape: tuple + in_shape: tuple, ): """ The function returns the canonicalization network based on the canonicalization type @@ -68,21 +81,23 @@ def get_canonicalizer( """ if canonicalization_type == 'identity': return IdentityCanonicalization(canonicalization_network) - + canonicalizer_dict = { 'group_equivariant': GroupEquivariantImageCanonicalization, 'steerable': SteerableImageCanonicalization, 'opt_group_equivariant': OptimizedGroupEquivariantImageCanonicalization, - 'opt_steerable': OptimizedSteerableImageCanonicalization + 'opt_steerable': OptimizedSteerableImageCanonicalization, } - + if canonicalization_type not in canonicalizer_dict: - raise ValueError(f'{canonicalization_type} needs a canonicalization network implementation.') - + raise ValueError( + f'{canonicalization_type} needs a canonicalization network implementation.' + ) + canonicalizer = canonicalizer_dict[canonicalization_type]( canonicalization_network=canonicalization_network, canonicalization_hyperparams=canonicalization_hyperparams, - in_shape=in_shape + in_shape=in_shape, ) - - return canonicalizer \ No newline at end of file + + return canonicalizer diff --git a/examples/images/segmentation/README.md b/examples/images/segmentation/README.md index c49cd9c..ee45f86 100644 --- a/examples/images/segmentation/README.md +++ b/examples/images/segmentation/README.md @@ -3,14 +3,14 @@ ## For COCO ### For instance segmentation (without prior regularization) ``` -python train.py canonicalization=group_equivariant experiment.training.loss.prior_weight=0 +python train.py canonicalization=group_equivariant experiment.training.loss.prior_weight=0 ``` ### For instance segmentation (with prior regularization) -``` -python train.py canonicalization=group_equivariant +``` +python train.py canonicalization=group_equivariant ``` -**Note**: You can also run the `train.py` as follows from root directory of the project: +**Note**: You can also run the `train.py` as follows from root directory of the project: ``` python examples/images/segmentation/train.py canonicalization=group_equivariant ``` diff --git a/examples/images/segmentation/configs/canonicalization/group_equivariant.yaml b/examples/images/segmentation/configs/canonicalization/group_equivariant.yaml index 0bcda3b..731c152 100644 --- a/examples/images/segmentation/configs/canonicalization/group_equivariant.yaml +++ b/examples/images/segmentation/configs/canonicalization/group_equivariant.yaml @@ -8,4 +8,4 @@ network_hyperparams: num_rotations: 4 # Number of rotations for the canonization network beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization -resize_shape: 32 # Resize shape for the input \ No newline at end of file +resize_shape: 32 # Resize shape for the input diff --git a/examples/images/segmentation/configs/canonicalization/identity.yaml b/examples/images/segmentation/configs/canonicalization/identity.yaml index 1598d17..513e776 100644 --- a/examples/images/segmentation/configs/canonicalization/identity.yaml +++ b/examples/images/segmentation/configs/canonicalization/identity.yaml @@ -1 +1 @@ -canonicalization_type: identity \ No newline at end of file +canonicalization_type: identity diff --git a/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml b/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml index 12f60e4..93110c5 100644 --- a/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml +++ b/examples/images/segmentation/configs/canonicalization/opt_group_equivariant.yaml @@ -9,4 +9,4 @@ group_type: "rotation" # Type of group for the canonization network num_rotations: 4 # Number of rotations for the canonization network beta: 1.0 # Beta parameter for the canonization network input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization -resize_shape: 32 # Resize shape for the input \ No newline at end of file +resize_shape: 32 # Resize shape for the input diff --git a/examples/images/segmentation/configs/canonicalization/opt_steerable.yaml b/examples/images/segmentation/configs/canonicalization/opt_steerable.yaml index 086cf49..47722db 100644 --- a/examples/images/segmentation/configs/canonicalization/opt_steerable.yaml +++ b/examples/images/segmentation/configs/canonicalization/opt_steerable.yaml @@ -6,4 +6,4 @@ network_hyperparams: num_layers: 3 # Number of layers in the canonization network out_vector_size: 4 # Dimension of the output vector group_type: "rotation" # Type of group for the canonization network -input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization \ No newline at end of file +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization diff --git a/examples/images/segmentation/configs/canonicalization/steerable.yaml b/examples/images/segmentation/configs/canonicalization/steerable.yaml index 9a6f5f2..629d274 100644 --- a/examples/images/segmentation/configs/canonicalization/steerable.yaml +++ b/examples/images/segmentation/configs/canonicalization/steerable.yaml @@ -5,4 +5,4 @@ network_hyperparams: out_channels: 16 # Number of output channels for the canonization network num_layers: 3 # Number of layers in the canonization network group_type: "rotation" # Type of group for the canonization network -input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization \ No newline at end of file +input_crop_ratio: 0.8 # Ratio at which we crop the input to the canonicalization diff --git a/examples/images/segmentation/configs/checkpoint/default.yaml b/examples/images/segmentation/configs/checkpoint/default.yaml index 419f669..7398463 100644 --- a/examples/images/segmentation/configs/checkpoint/default.yaml +++ b/examples/images/segmentation/configs/checkpoint/default.yaml @@ -1,3 +1,3 @@ checkpoint_path: ${oc.env:CHECKPOINT_PATH} # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later -save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file +save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/dataset/default.yaml b/examples/images/segmentation/configs/dataset/default.yaml index 341c2ee..6e5774f 100644 --- a/examples/images/segmentation/configs/dataset/default.yaml +++ b/examples/images/segmentation/configs/dataset/default.yaml @@ -1,6 +1,6 @@ dataset_name: coco # Name of the dataset to use root_dir: ${oc.env:DATA_PATH}/${dataset_name} # Root directory of the dataset ann_dir: ${root_dir}/annotations # Path to annotations -augment: flip # Whether to train with flip augmentation +augment: flip # Whether to train with flip augmentation num_workers: 4 # Number of workers for data loading -batch_size: 128 # Number of samples per batch \ No newline at end of file +batch_size: 128 # Number of samples per batch diff --git a/examples/images/segmentation/configs/experiment/default.yaml b/examples/images/segmentation/configs/experiment/default.yaml index 44c2568..6c25ac3 100644 --- a/examples/images/segmentation/configs/experiment/default.yaml +++ b/examples/images/segmentation/configs/experiment/default.yaml @@ -1,5 +1,5 @@ run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune -seed: 0 # Seed for random number generation +seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -18,4 +18,4 @@ training: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference diff --git a/examples/images/segmentation/configs/original_configs/group_equivariant/cifar10.yaml b/examples/images/segmentation/configs/original_configs/group_equivariant/cifar10.yaml index d94f0bf..a209727 100644 --- a/examples/images/segmentation/configs/original_configs/group_equivariant/cifar10.yaml +++ b/examples/images/segmentation/configs/original_configs/group_equivariant/cifar10.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -53,4 +53,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/s/siba-smarak.panigrahi/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/original_configs/group_equivariant/rotmnist.yaml b/examples/images/segmentation/configs/original_configs/group_equivariant/rotmnist.yaml index 26b0414..afac33a 100644 --- a/examples/images/segmentation/configs/original_configs/group_equivariant/rotmnist.yaml +++ b/examples/images/segmentation/configs/original_configs/group_equivariant/rotmnist.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) device: "cuda" # Device, can be cuda or cpu num_nodes: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -53,4 +53,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/original_configs/opt_equivariant/cifar10.yaml b/examples/images/segmentation/configs/original_configs/opt_equivariant/cifar10.yaml index c611b84..9c7afe6 100644 --- a/examples/images/segmentation/configs/original_configs/opt_equivariant/cifar10.yaml +++ b/examples/images/segmentation/configs/original_configs/opt_equivariant/cifar10.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) num_nodes: 1 num_gpus: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -54,4 +54,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/original_configs/opt_equivariant/rotmnist.yaml b/examples/images/segmentation/configs/original_configs/opt_equivariant/rotmnist.yaml index 5168bd1..1dc06c3 100644 --- a/examples/images/segmentation/configs/original_configs/opt_equivariant/rotmnist.yaml +++ b/examples/images/segmentation/configs/original_configs/opt_equivariant/rotmnist.yaml @@ -9,7 +9,7 @@ dataset: experiment: run_mode: "train" # Mode to run the model in, different run modes 1)dryrun 2)train 3)test 4)auto_tune - seed: 0 # Seed for random number generation + seed: 0 # Seed for random number generation deterministic: false # Whether to set deterministic mode (true) or not (false) num_nodes: 1 num_gpus: 1 @@ -25,7 +25,7 @@ experiment: inference: method: "group" # Type of inference options 1) vanilla 2) group group_type: "rotation" # Type of group to test during inference 1) Rotation 2) Roto-reflection - num_rotations: 4 # Number of rotations to check robustness during inference + num_rotations: 4 # Number of rotations to check robustness during inference prediction: prediction_network_architecture: "resnet50" # Architecture of the prediction network @@ -54,4 +54,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints checkpoint_name: "" # Model checkpoint name, should be left empty and dynamically allocated later - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/original_configs/steerable/cifar10.yaml b/examples/images/segmentation/configs/original_configs/steerable/cifar10.yaml index 36db32a..be68cf1 100644 --- a/examples/images/segmentation/configs/original_configs/steerable/cifar10.yaml +++ b/examples/images/segmentation/configs/original_configs/steerable/cifar10.yaml @@ -19,7 +19,7 @@ prediction: freeze_pretrained_encoder: 0 # Whether to freeze the pretrained encoder (1) or not (0) canonicalization: - network_type: 'escnn' # Options o canonization method 1) escnn + network_type: 'escnn' # Options o canonization method 1) escnn network_hyperparams: kernel_size: 3 # Kernel size for the canonization network out_channels: 16 # Number of output channels for the canonization network @@ -42,4 +42,4 @@ wandb: checkpoint: checkpoint_path: "/home/mila/a/arnab.mondal/scratch/equiadapt/image/checkpoints" # Path to save checkpoints deterministic: false # Whether to set deterministic mode (true) or not (false) - save_canonized_images: 0 # Whether to save canonized images (1) or not (0) \ No newline at end of file + save_canonized_images: 0 # Whether to save canonized images (1) or not (0) diff --git a/examples/images/segmentation/configs/prediction/default.yaml b/examples/images/segmentation/configs/prediction/default.yaml index e456f88..6121c65 100644 --- a/examples/images/segmentation/configs/prediction/default.yaml +++ b/examples/images/segmentation/configs/prediction/default.yaml @@ -2,4 +2,3 @@ prediction_network_architecture: "sam" # Architecture of the prediction network prediction_network_class: "vit_h" # Class of the prediction network use_pretrained: 1 # Whether to use pretrained weights (1) or not (0) freeze_encoder: 1 # Whether to freeze encoder (1) or not (0) - diff --git a/examples/images/segmentation/configs/wandb_sweep.yaml b/examples/images/segmentation/configs/wandb_sweep.yaml index 56117e9..2e1273f 100644 --- a/examples/images/segmentation/configs/wandb_sweep.yaml +++ b/examples/images/segmentation/configs/wandb_sweep.yaml @@ -27,4 +27,4 @@ command: - ${env} - python3 - ${program} - - ${args_no_hyphens} \ No newline at end of file + - ${args_no_hyphens} diff --git a/examples/images/segmentation/inference_utils.py b/examples/images/segmentation/inference_utils.py index 911d8a1..8545212 100644 --- a/examples/images/segmentation/inference_utils.py +++ b/examples/images/segmentation/inference_utils.py @@ -1,105 +1,146 @@ import copy -import torch, math -import wandb - -from typing import Union, Dict +import math +from typing import Dict, Union -from torchvision import transforms +import torch +import wandb from torchmetrics.detection.mean_ap import MeanAveragePrecision +from torchvision import transforms from equiadapt.images.utils import flip_boxes, flip_masks, rotate_boxes, rotate_masks -def get_inference_method(canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, - num_classes: int, - inference_hyperparams: Union[Dict, wandb.Config], - in_shape: tuple = (3, 1024, 1024)): + +def get_inference_method( + canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, + num_classes: int, + inference_hyperparams: Union[Dict, wandb.Config], + in_shape: tuple = (3, 1024, 1024), +): if inference_hyperparams.method == 'vanilla': return VanillaInference(canonicalizer, prediction_network, num_classes) elif inference_hyperparams.method == 'group': return GroupInference( - canonicalizer, prediction_network, num_classes, - inference_hyperparams, in_shape + canonicalizer, + prediction_network, + num_classes, + inference_hyperparams, + in_shape, ) else: raise ValueError(f'{inference_hyperparams.method} is not implemented for now.') + class VanillaInference: - def __init__(self, - canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module) -> None: + def __init__( + self, canonicalizer: torch.nn.Module, prediction_network: torch.nn.Module + ) -> None: self.canonicalizer = canonicalizer self.prediction_network = prediction_network - + def forward(self, x, targets): # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized, targets_canonicalized = self.canonicalizer(x, targets) - + # Forward pass through the prediction network as you'll normally do # Finetuning maskrcnn model will return the losses which can be used to fine tune the model - # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions + # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions # For uniformity, we will ensure the prediction network returns both losses and predictions irrespective of the model return self.prediction_network(x_canonicalized, targets_canonicalized) - + def get_inference_metrics(self, x: torch.Tensor, targets: torch.Tensor): # Forward pass through the prediction network - _, _, _, outputs = self.forward(x) - + _, _, _, outputs = self.forward(x) + _map = MeanAveragePrecision(iou_type='segm') - targets = [dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) for target in targets] - outputs = [dict(boxes=output['boxes'], labels=output['labels'], scores=output['scores'], masks=output['masks']) for output in outputs] + targets = [ + dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) + for target in targets + ] + outputs = [ + dict( + boxes=output['boxes'], + labels=output['labels'], + scores=output['scores'], + masks=output['masks'], + ) + for output in outputs + ] _map.update(outputs, targets) _map_dict = _map.compute() - + metrics = {'test/map': _map_dict['map']} - + return metrics - + + class GroupInference(VanillaInference): - def __init__(self, - canonicalizer: torch.nn.Module, - prediction_network: torch.nn.Module, - inference_hyperparams: Union[Dict, wandb.Config], - in_shape: tuple = (3, 32, 32)): - + def __init__( + self, + canonicalizer: torch.nn.Module, + prediction_network: torch.nn.Module, + inference_hyperparams: Union[Dict, wandb.Config], + in_shape: tuple = (3, 32, 32), + ): + super().__init__(canonicalizer, prediction_network) self.group_type = inference_hyperparams.group_type self.num_rotations = inference_hyperparams.num_rotations - self.num_group_elements = self.num_rotations if self.group_type == 'rotation' else 2 * self.num_rotations - self.pad = transforms.Pad( - math.ceil(in_shape[-2] * 0.4), - padding_mode='edge' + self.num_group_elements = ( + self.num_rotations + if self.group_type == 'rotation' + else 2 * self.num_rotations ) + self.pad = transforms.Pad(math.ceil(in_shape[-2] * 0.4), padding_mode='edge') self.crop = transforms.CenterCrop((in_shape[-2], in_shape[-1])) def get_group_element_wise_maps(self, images: torch.Tensor, targets: torch.Tensor): map_dict = dict() image_width = images[0].shape[1] - + degrees = torch.linspace(0, 360, self.num_rotations + 1)[:-1] for rot, degree in enumerate(degrees): - + targets_transformed = copy.deepcopy(targets) - + # apply group element on images images_pad = self.pad(images) images_rot = transforms.functional.rotate(images_pad, int(degree)) images_rot = self.crop(images_rot) - + # apply group element on bounding boxes and masks for t in range(len(targets_transformed)): - targets_transformed[t]["boxes"] = rotate_boxes(targets_transformed[t]["boxes"], -degree, image_width) - targets_transformed[t]["masks"] = rotate_masks(targets_transformed[t]["masks"], degree) + targets_transformed[t]["boxes"] = rotate_boxes( + targets_transformed[t]["boxes"], -degree, image_width + ) + targets_transformed[t]["masks"] = rotate_masks( + targets_transformed[t]["masks"], degree + ) # get predictions for the transformed images _, _, _, outputs = self.forward(images_rot, targets_transformed) - + Map = MeanAveragePrecision(iou_type='segm') - targets = [dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) for target in targets] - outputs = [dict(boxes=output['boxes'], labels=output['labels'], scores=output['scores'], masks=output['masks']) for output in outputs] + targets = [ + dict( + boxes=target['boxes'], + labels=target['labels'], + masks=target['masks'], + ) + for target in targets + ] + outputs = [ + dict( + boxes=output['boxes'], + labels=output['labels'], + scores=output['scores'], + masks=output['masks'], + ) + for output in outputs + ] Map.update(outputs, targets) - + map_dict[rot] = Map.compute() if self.group_type == 'roto-reflection': @@ -108,38 +149,66 @@ def get_group_element_wise_maps(self, images: torch.Tensor, targets: torch.Tenso images_pad = self.pad(images) images_reflect = transforms.functional.hflip(images_pad) - images_rotoreflect = transforms.functional.rotate(images_reflect, int(degree)) + images_rotoreflect = transforms.functional.rotate( + images_reflect, int(degree) + ) images_rotoreflect = self.crop(images_rotoreflect) # apply group element on bounding boxes and masks for t in range(len(targets_transformed)): - targets_transformed[t]["boxes"] = rotate_boxes(targets_transformed[t]["boxes"], -degree, image_width) - targets_transformed[t]["boxes"] = flip_boxes(targets_transformed[t]["boxes"], image_width) - - targets_transformed[t]["masks"] = rotate_masks(targets_transformed[t]["masks"], degree) - targets_transformed[t]["masks"] = flip_masks(targets_transformed[t]["masks"]) + targets_transformed[t]["boxes"] = rotate_boxes( + targets_transformed[t]["boxes"], -degree, image_width + ) + targets_transformed[t]["boxes"] = flip_boxes( + targets_transformed[t]["boxes"], image_width + ) + + targets_transformed[t]["masks"] = rotate_masks( + targets_transformed[t]["masks"], degree + ) + targets_transformed[t]["masks"] = flip_masks( + targets_transformed[t]["masks"] + ) # get predictions for the transformed images _, _, _, outputs = self.forward(images_rotoreflect, targets_transformed) - + Map = MeanAveragePrecision(iou_type='segm') - targets = [dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) for target in targets] - outputs = [dict(boxes=output['boxes'], labels=output['labels'], scores=output['scores'], masks=output['masks']) for output in outputs] + targets = [ + dict( + boxes=target['boxes'], + labels=target['labels'], + masks=target['masks'], + ) + for target in targets + ] + outputs = [ + dict( + boxes=output['boxes'], + labels=output['labels'], + scores=output['scores'], + masks=output['masks'], + ) + for output in outputs + ] Map.update(outputs, targets) - + map_dict[rot + len(degrees)] = Map.compute() - + return map_dict - + def get_inference_metrics(self, images: torch.Tensor, targets: torch.Tensor): - + map_dict = self.get_group_element_wise_maps(images, targets) - + # Use list comprehension to calculate accuracy for each group element - metrics.update({ + metrics.update( + { f'test/map_group_element_{i}': max(map_dict[i]['map'], 0.0), f'test/map_small_group_element_{i}': max(map_dict[i]['map_small'], 0.0), - f'test/map_medium_group_element_{i}': max(map_dict[i]['map_medium'], 0.0), + f'test/map_medium_group_element_{i}': max( + map_dict[i]['map_medium'], 0.0 + ), f'test/map_large_group_element_{i}': max(map_dict[i]['map_large'], 0.0), f'test/map_50_group_element_{i}': max(map_dict[i]['map_50'], 0.0), f'test/map_75_group_element_{i}': max(map_dict[i]['map_75'], 0.0), @@ -147,17 +216,27 @@ def get_inference_metrics(self, images: torch.Tensor, targets: torch.Tensor): f'test/mar_10_group_element_{i}': max(map_dict[i]['mar_10'], 0.0), f'test/mar_100_group_element_{i}': max(map_dict[i]['mar_100'], 0.0), f'test/mar_small_group_element_{i}': max(map_dict[i]['mar_small'], 0.0), - f'test/mar_medium_group_element_{i}': max(map_dict[i]['mar_medium'], 0.0), + f'test/mar_medium_group_element_{i}': max( + map_dict[i]['mar_medium'], 0.0 + ), f'test/mar_large_group_element_{i}': max(map_dict[i]['mar_large'], 0.0), - } for i in range(self.num_group_elements)) - - map_per_group_element = torch.tensor([map_dict[i]['map'] for i in range(self.num_group_elements)]) + } + for i in range(self.num_group_elements) + ) + + map_per_group_element = torch.tensor( + [map_dict[i]['map'] for i in range(self.num_group_elements)] + ) metrics = {"test/group_map": torch.mean(map_per_group_element)} - metrics.update({f'test/map_group_element_{i}': max(map_per_group_element[i], 0.0) for i in range(self.num_group_elements)}) + metrics.update( + { + f'test/map_group_element_{i}': max(map_per_group_element[i], 0.0) + for i in range(self.num_group_elements) + } + ) # Calculate the overall map metrics.update({"test/map": max(map_dict[0]['map'], 0.0)}) return metrics - \ No newline at end of file diff --git a/examples/images/segmentation/model.py b/examples/images/segmentation/model.py index d3162b5..022c359 100644 --- a/examples/images/segmentation/model.py +++ b/examples/images/segmentation/model.py @@ -1,21 +1,22 @@ -import torch import pytorch_lightning as pl +import torch +from common.utils import get_canonicalization_network, get_canonicalizer +from inference_utils import get_inference_method +from model_utils import calc_iou, get_dataset_specific_info, get_prediction_network +from omegaconf import DictConfig from torch.optim.lr_scheduler import MultiStepLR from torchmetrics.detection.mean_ap import MeanAveragePrecision -from omegaconf import DictConfig - -from inference_utils import get_inference_method -from model_utils import get_dataset_specific_info, get_prediction_network, calc_iou -from common.utils import get_canonicalization_network, get_canonicalizer - # define the LightningModule class ImageSegmentationPipeline(pl.LightningModule): def __init__(self, hyperparams: DictConfig): super().__init__() - - self.loss, self.image_shape, self.num_classes = get_dataset_specific_info(hyperparams.dataset.dataset_name, hyperparams.prediction.prediction_network_architecture) + + self.loss, self.image_shape, self.num_classes = get_dataset_specific_info( + hyperparams.dataset.dataset_name, + hyperparams.prediction.prediction_network_architecture, + ) self.prediction_network = get_prediction_network( architecture=hyperparams.prediction.prediction_network_architecture, @@ -23,211 +24,261 @@ def __init__(self, hyperparams: DictConfig): dataset_name=hyperparams.dataset.dataset_name, use_pretrained=hyperparams.prediction.use_pretrained, freeze_encoder=hyperparams.prediction.freeze_encoder, - num_classes=self.num_classes + num_classes=self.num_classes, ) canonicalization_network = get_canonicalization_network( - hyperparams.canonicalization_type, + hyperparams.canonicalization_type, hyperparams.canonicalization, self.image_shape, ) - + self.canonicalizer = get_canonicalizer( hyperparams.canonicalization_type, canonicalization_network, hyperparams.canonicalization, - self.image_shape - ) - + self.image_shape, + ) + self.hyperparams = hyperparams - + self.inference_method = get_inference_method( self.canonicalizer, self.prediction_network, self.num_classes, hyperparams.experiment.inference, - self.image_shape + self.image_shape, ) - + self.max_epochs = hyperparams.experiment.training.num_epochs - + self.save_hyperparameters() - - def apply_loss(self, loss_dict: dict, pred_masks: torch.Tensor, targets_canonicalized: dict, iou_predictions: torch.Tensor = None): - assert self.loss or loss_dict, "Either pass a loss function or a dictionary of pre-computed losses for segmentation task loss" - + + def apply_loss( + self, + loss_dict: dict, + pred_masks: torch.Tensor, + targets_canonicalized: dict, + iou_predictions: torch.Tensor = None, + ): + assert ( + self.loss or loss_dict + ), "Either pass a loss function or a dictionary of pre-computed losses for segmentation task loss" + if loss_dict: # for maskrcnn model, the loss_dict will contain the losses - return sum(loss_dict.values()) - + return sum(loss_dict.values()) + num_masks = sum(len(pred_mask) for pred_mask in pred_masks) - - loss_focal = torch.tensor(0., device=self.hyperparams.device) - loss_dice = torch.tensor(0., device=self.hyperparams.device) - loss_iou = torch.tensor(0., device=self.hyperparams.device) - - for pred_mask, target, iou_prediction in zip(pred_masks, targets_canonicalized, iou_predictions): - + + loss_focal = torch.tensor(0.0, device=self.hyperparams.device) + loss_dice = torch.tensor(0.0, device=self.hyperparams.device) + loss_iou = torch.tensor(0.0, device=self.hyperparams.device) + + for pred_mask, target, iou_prediction in zip( + pred_masks, targets_canonicalized, iou_predictions + ): + # if gt_masks is larger then select the first len(pred_masks) masks - gt_mask = target['masks'][:len(pred_mask), :, :] - + gt_mask = target['masks'][: len(pred_mask), :, :] + for loss_func in self.loss: - assert hasattr(loss_func, 'forward'), "The loss function must have a forward method" - if loss_func.name == 'focal_loss': loss_focal += loss_func(pred_mask, gt_mask.float(), num_masks) - elif loss_func.name == 'dice_loss': loss_dice += loss_func(pred_mask, gt_mask, num_masks) - else: raise ValueError(f"Loss function {loss_func.name} is not supported") - - + assert hasattr( + loss_func, 'forward' + ), "The loss function must have a forward method" + if loss_func.name == 'focal_loss': + loss_focal += loss_func(pred_mask, gt_mask.float(), num_masks) + elif loss_func.name == 'dice_loss': + loss_dice += loss_func(pred_mask, gt_mask, num_masks) + else: + raise ValueError(f"Loss function {loss_func.name} is not supported") + if iou_predictions: batch_iou = calc_iou(pred_mask, gt_mask) - loss_iou += torch.nn.functional.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks - - - return 20. * loss_focal + loss_dice + loss_iou + loss_iou += ( + torch.nn.functional.mse_loss( + iou_prediction, batch_iou, reduction='sum' + ) + / num_masks + ) + return 20.0 * loss_focal + loss_dice + loss_iou def training_step(self, batch: torch.Tensor): x, targets = batch x = torch.stack(x) batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape training_metrics = {} loss = 0.0 - + # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized, targets_canonicalized = self.canonicalizer(x, targets) - + # add group contrast loss while using optmization based canonicalization method if 'opt' in self.hyperparams.canonicalization_type: group_contrast_loss = self.canonicalizer.get_optimization_specific_loss() - loss += group_contrast_loss * self.hyperparams.experiment.training.loss.group_contrast_weight - training_metrics.update({"train/optimization_specific_loss": group_contrast_loss}) - + loss += ( + group_contrast_loss + * self.hyperparams.experiment.training.loss.group_contrast_weight + ) + training_metrics.update( + {"train/optimization_specific_loss": group_contrast_loss} + ) + # calculate the task loss - # if finetuning is not required, set the weight for task loss to 0 + # if finetuning is not required, set the weight for task loss to 0 # it will avoid unnecessary forward pass through the prediction network - if self.hyperparams.experiment.training.loss.task_weight: - + if self.hyperparams.experiment.training.loss.task_weight: + # Forward pass through the prediction network as you'll normally do # Finetuning maskrcnn model will return the losses which can be used to fine tune the model - # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions + # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions # For uniformity, we will ensure the prediction network returns both losses and predictions irrespective of the model - loss_dict, pred_masks, iou_predictions, _ = self.prediction_network(x_canonicalized, targets_canonicalized) - + loss_dict, pred_masks, iou_predictions, _ = self.prediction_network( + x_canonicalized, targets_canonicalized + ) + # no requirement to invert canonicalization for the loss calculation # since we will compute the loss w.r.t canonicalized targets (to align with the loss computation in maskrcnn) - task_loss = self.apply_loss(loss_dict, pred_masks, targets_canonicalized, iou_predictions) + task_loss = self.apply_loss( + loss_dict, pred_masks, targets_canonicalized, iou_predictions + ) loss += self.hyperparams.experiment.training.loss.task_weight * task_loss - - training_metrics.update({ - "train/task_loss": task_loss, - }) - + + training_metrics.update( + { + "train/task_loss": task_loss, + } + ) + # Add prior regularization loss if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: prior_loss = self.canonicalizer.get_prior_regularization_loss() loss += prior_loss * self.hyperparams.experiment.training.loss.prior_weight metric_identity = self.canonicalizer.get_identity_metric() - training_metrics.update({ - "train/prior_loss": prior_loss, - "train/identity_metric": metric_identity - }) - - training_metrics.update({ + training_metrics.update( + { + "train/prior_loss": prior_loss, + "train/identity_metric": metric_identity, + } + ) + + training_metrics.update( + { "train/loss": loss, - }) - + } + ) + # Log the training metrics self.log_dict(training_metrics, prog_bar=True) - + return {'loss': loss} - + def validation_step(self, batch: torch.Tensor): x, targets = batch x = torch.stack(x) batch_size, num_channels, height, width = x.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape validation_metrics = {} - + # canonicalize the input data # For the vanilla model, the canonicalization is the identity transformation x_canonicalized, targets_canonicalized = self.canonicalizer(x, targets) - + # Forward pass through the prediction network as you'll normally do # Finetuning maskrcnn model will return the losses which can be used to fine tune the model - # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions + # Meanwhile, Segment-Anything (SAM) can return boxes, ious, masks predictions # For uniformity, we will ensure the prediction network returns both losses and predictions irrespective of the model - _, _, _, outputs = self.prediction_network(x_canonicalized, targets_canonicalized) - + _, _, _, outputs = self.prediction_network( + x_canonicalized, targets_canonicalized + ) + _map = MeanAveragePrecision(iou_type='segm') - targets = [dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) for target in targets] - outputs = [dict(boxes=output['boxes'], labels=output['labels'], scores=output['scores'], masks=output['masks']) for output in outputs] + targets = [ + dict(boxes=target['boxes'], labels=target['labels'], masks=target['masks']) + for target in targets + ] + outputs = [ + dict( + boxes=output['boxes'], + labels=output['labels'], + scores=output['scores'], + masks=output['masks'], + ) + for output in outputs + ] _map.update(outputs, targets) _map_dict = _map.compute() - - validation_metrics.update({ - 'val/map': _map_dict['map'], - 'val/map_small': _map_dict['map_small'], - 'val/map_medium': _map_dict['map_medium'], - 'val/map_large': _map_dict['map_large'], - 'val/map_50': _map_dict['map_50'], - 'val/map_75': _map_dict['map_75'], - 'val/mar_1': _map_dict['mar_1'], - 'val/mar_10': _map_dict['mar_10'], - 'val/mar_100': _map_dict['mar_100'], - 'val/mar_small': _map_dict['mar_small'], - 'val/mar_medium': _map_dict['mar_medium'], - 'val/mar_large': _map_dict['mar_large'], - - }) - + + validation_metrics.update( + { + 'val/map': _map_dict['map'], + 'val/map_small': _map_dict['map_small'], + 'val/map_medium': _map_dict['map_medium'], + 'val/map_large': _map_dict['map_large'], + 'val/map_50': _map_dict['map_50'], + 'val/map_75': _map_dict['map_75'], + 'val/mar_1': _map_dict['mar_1'], + 'val/mar_10': _map_dict['mar_10'], + 'val/mar_100': _map_dict['mar_100'], + 'val/mar_small': _map_dict['mar_small'], + 'val/mar_medium': _map_dict['mar_medium'], + 'val/mar_large': _map_dict['mar_large'], + } + ) + # Log the validation metrics self.log_dict(validation_metrics, prog_bar=True) - - # Log the identity metric if the prior weight is non-zero + + # Log the identity metric if the prior weight is non-zero if self.hyperparams.experiment.training.loss.prior_weight: metric_identity = self.canonicalizer.get_identity_metric() - validation_metrics.update({ - "val/identity_metric": metric_identity - }) + validation_metrics.update({"val/identity_metric": metric_identity}) self.log_dict(validation_metrics, prog_bar=True) return {'map': _map_dict['map']} - def test_step(self, batch: torch.Tensor): images, targets = batch batch_size, num_channels, height, width = images.shape - + # assert that the input is in the right shape assert (num_channels, height, width) == self.image_shape test_metrics = self.inference_method.get_inference_metrics(images, targets) - + # Log the test metrics self.log_dict(test_metrics, prog_bar=True) - - return test_metrics - + + return test_metrics + + def configure_optimizers(self): # using SGD optimizer and MultiStepLR scheduler optimizer = torch.optim.SGD( - [ - {'params': self.prediction_network.parameters(), 'lr': self.hyperparams.experiment.training.prediction_lr}, - {'params': self.canonicalizer.parameters(), 'lr': self.hyperparams.experiment.training.canonicalization_lr}, - ], - momentum=0.9, - weight_decay=5e-4, - ) - + [ + { + 'params': self.prediction_network.parameters(), + 'lr': self.hyperparams.experiment.training.prediction_lr, + }, + { + 'params': self.canonicalizer.parameters(), + 'lr': self.hyperparams.experiment.training.canonicalization_lr, + }, + ], + momentum=0.9, + weight_decay=5e-4, + ) + scheduler_dict = { "scheduler": MultiStepLR( optimizer, @@ -236,4 +287,4 @@ def configure_optimizers(self): ), "interval": "epoch", } - return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} \ No newline at end of file + return {"optimizer": optimizer, "lr_scheduler": scheduler_dict} diff --git a/examples/images/segmentation/model_utils.py b/examples/images/segmentation/model_utils.py index b9ad360..2596235 100644 --- a/examples/images/segmentation/model_utils.py +++ b/examples/images/segmentation/model_utils.py @@ -1,19 +1,20 @@ import torch import torch.nn as nn import torch.nn.functional as F - - from segment_anything import sam_model_registry from torchvision.models.detection import maskrcnn_resnet50_fpn_v2 ALPHA = 0.8 GAMMA = 2 + class MaskRCNNModel(nn.Module): def __init__(self, architecture_type, num_classes, weights='DEFAULT'): super().__init__() - - assert architecture_type in ['resnet50_fpn_v2'], NotImplementedError('Only `maskrcnn_resnet50_fpn_v2` is supported for now.') + + assert architecture_type in ['resnet50_fpn_v2'], NotImplementedError( + 'Only `maskrcnn_resnet50_fpn_v2` is supported for now.' + ) if architecture_type == 'resnet50_fpn_v2': self.model = maskrcnn_resnet50_fpn_v2(weights='DEFAULT') @@ -35,7 +36,15 @@ def forward(self, images, targets): output[0]['masks'] = target['masks'] output[0]['scores'] = torch.ones(len(target['masks'])) ious.append(torch.ones(len(target['masks']), dtype=torch.float32)) - pred_masks.append(torch.ones(len(target['masks']), image.shape[-2], image.shape[-1], dtype=torch.float32, device=self.hyperparams.device)) + pred_masks.append( + torch.ones( + len(target['masks']), + image.shape[-2], + image.shape[-1], + dtype=torch.float32, + device=self.hyperparams.device, + ) + ) else: masks = output[0]['masks'] @@ -43,22 +52,33 @@ def forward(self, images, targets): pred_masks.append(masks.squeeze(1)) ious.append(iou_predictions) - output[0]['masks'] = torch.as_tensor(output[0]['masks'].squeeze(1) > 0.5, dtype=torch.uint8).squeeze(1) - output[0]['scores'] = torch.as_tensor(output[0]['scores'], dtype=torch.float32) - output[0]['labels'] = torch.as_tensor(output[0]['labels'], dtype=torch.int64) - output[0]['boxes'] = torch.as_tensor(output[0]['boxes'], dtype=torch.float32) + output[0]['masks'] = torch.as_tensor( + output[0]['masks'].squeeze(1) > 0.5, dtype=torch.uint8 + ).squeeze(1) + output[0]['scores'] = torch.as_tensor( + output[0]['scores'], dtype=torch.float32 + ) + output[0]['labels'] = torch.as_tensor( + output[0]['labels'], dtype=torch.int64 + ) + output[0]['boxes'] = torch.as_tensor( + output[0]['boxes'], dtype=torch.float32 + ) outputs.append(output[0]) return None, pred_masks, ious, outputs + class SAMModel(nn.Module): - def __init__(self, - architecture_type: str, - sam_pretrained_ckpt_path: str): + def __init__(self, architecture_type: str, sam_pretrained_ckpt_path: str): super().__init__() - assert sam_pretrained_ckpt_path is not None, ValueError('SAM requires a pretrained checkpoint path.') - self.model = sam_model_registry[architecture_type](checkpoint=sam_pretrained_ckpt_path) + assert sam_pretrained_ckpt_path is not None, ValueError( + 'SAM requires a pretrained checkpoint path.' + ) + self.model = sam_model_registry[architecture_type]( + checkpoint=sam_pretrained_ckpt_path + ) def forward(self, images, targets): if type(images) == list: @@ -91,19 +111,20 @@ def forward(self, images, targets): mode="bilinear", align_corners=False, ) - pred_masks.append(masks.squeeze(1)) # bbox_length x H x W - ious.append(iou_predictions) # bbox_length x 1 + pred_masks.append(masks.squeeze(1)) # bbox_length x H x W + ious.append(iou_predictions) # bbox_length x 1 output = dict( - masks = torch.as_tensor(masks.squeeze(1) > 0.5, dtype=torch.uint8), - scores = torch.as_tensor(iou_predictions.squeeze(1), dtype=torch.float32), - labels = torch.as_tensor(target['labels'], dtype=torch.int64), - boxes = torch.as_tensor(target['boxes'], dtype=torch.float32) + masks=torch.as_tensor(masks.squeeze(1) > 0.5, dtype=torch.uint8), + scores=torch.as_tensor(iou_predictions.squeeze(1), dtype=torch.float32), + labels=torch.as_tensor(target['labels'], dtype=torch.int64), + boxes=torch.as_tensor(target['boxes'], dtype=torch.float32), ) outputs.append(output) return None, pred_masks, ious, outputs - + + class FocalLoss(nn.Module): def __init__(self, weight=None, size_average=True): @@ -113,17 +134,18 @@ def __init__(self, weight=None, size_average=True): def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1): inputs = F.sigmoid(inputs) - #flatten label and prediction tensors + # flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) - #first compute binary cross-entropy + # first compute binary cross-entropy BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') BCE_EXP = torch.exp(-BCE) - focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE + focal_loss = alpha * (1 - BCE_EXP) ** gamma * BCE return focal_loss + class DiceLoss(nn.Module): def __init__(self, weight=None, size_average=True): @@ -133,15 +155,16 @@ def __init__(self, weight=None, size_average=True): def forward(self, inputs, targets, smooth=1): inputs = F.sigmoid(inputs) - #flatten label and prediction tensors + # flatten label and prediction tensors inputs = inputs.view(-1) targets = targets.view(-1) intersection = (inputs * targets).sum() - dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) + dice = (2.0 * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) return 1 - dice - + + def get_dataset_specific_info(dataset_name, prediction_architecture_name): dataset_info = { 'coco': { @@ -157,6 +180,7 @@ def get_dataset_specific_info(dataset_name, prediction_architecture_name): return dataset_info[dataset_name][prediction_architecture_name] + def get_prediction_network( architecture: str = 'sam', architecture_type: str = 'vit_h', @@ -164,19 +188,21 @@ def get_prediction_network( use_pretrained: bool = False, freeze_encoder: bool = False, num_classes: int = 91, - sam_pretrained_ckpt_path=None + sam_pretrained_ckpt_path=None, ): weights = 'DEFAULT' if use_pretrained else None model_dict = { 'sam': SAMModel(architecture_type, sam_pretrained_ckpt_path), - 'maskrcnn': MaskRCNNModel(architecture_type, num_classes, weights) + 'maskrcnn': MaskRCNNModel(architecture_type, num_classes, weights), } if architecture not in model_dict: - raise ValueError(f'{architecture} is not implemented as prediction network for now.') + raise ValueError( + f'{architecture} is not implemented as prediction network for now.' + ) prediction_network = model_dict[architecture](weights=weights) - + if freeze_encoder: for param in prediction_network.parameters(): param.requires_grad = False @@ -188,10 +214,13 @@ def get_prediction_network( return prediction_network + def calc_iou(pred_mask, gt_mask): pred_mask = (pred_mask >= 0.5).float() intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2)) - union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection + union = ( + torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection + ) batch_iou = intersection / union batch_iou = batch_iou.unsqueeze(1) - return batch_iou \ No newline at end of file + return batch_iou diff --git a/examples/images/segmentation/prepare/__init__.py b/examples/images/segmentation/prepare/__init__.py index dc33087..ab3e630 100644 --- a/examples/images/segmentation/prepare/__init__.py +++ b/examples/images/segmentation/prepare/__init__.py @@ -1 +1 @@ -from .coco_data import COCODataModule \ No newline at end of file +from .coco_data import COCODataModule diff --git a/examples/images/segmentation/prepare/coco_data.py b/examples/images/segmentation/prepare/coco_data.py index dfbb5d9..cbfc886 100644 --- a/examples/images/segmentation/prepare/coco_data.py +++ b/examples/images/segmentation/prepare/coco_data.py @@ -1,15 +1,14 @@ import os -import torch import numpy as np -from PIL import Image import pytorch_lightning as pl +import torch import torchvision.transforms as transforms import vision_transforms as T -from torch.utils.data import DataLoader, Dataset - +from PIL import Image from pycocotools.coco import COCO from segment_anything.utils.transforms import ResizeLongestSide +from torch.utils.data import DataLoader, Dataset class ResizeAndPad: @@ -40,7 +39,10 @@ def __call__(self, image, target): # Adjust bounding boxes bboxes = self.transform.apply_boxes(bboxes, (og_h, og_w)) - bboxes = [[bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] for bbox in bboxes] + bboxes = [ + [bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] + for bbox in bboxes + ] target['masks'] = torch.stack(masks) target['boxes'] = torch.as_tensor(bboxes, dtype=torch.float32) @@ -51,7 +53,7 @@ class COCODataModule(pl.LightningDataModule): def __init__(self, hyperparams): super().__init__() self.hyperparams = hyperparams - + def get_transform(self, train=True): tr = [] tr.append(T.PILToTensor()) @@ -60,7 +62,7 @@ def get_transform(self, train=True): if train and self.hyperparams.augment == 'flip': tr.append(T.RandomHorizontalFlip(0.5)) return T.Compose(tr) - + def collate_fn(self, batch): images = [x[0] for x in batch] targets = [x[1] for x in batch] @@ -69,23 +71,29 @@ def collate_fn(self, batch): def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = COCODataset( - root_dir=os.path.join(self.hyperparams.root_dir, 'train2017'), - annotation_file=os.path.join(self.hyperparams.ann_dir, 'instances_train2017.json'), - transform=self.get_transform(train=True) - ) + root_dir=os.path.join(self.hyperparams.root_dir, 'train2017'), + annotation_file=os.path.join( + self.hyperparams.ann_dir, 'instances_train2017.json' + ), + transform=self.get_transform(train=True), + ) self.valid_dataset = COCODataset( - root_dir=os.path.join(self.hyperparams.root_dir, 'val2017'), - annotation_file=os.path.join(self.hyperparams.ann_dir, 'instances_val2017.json'), - transform=self.get_transform(train=False) - ) + root_dir=os.path.join(self.hyperparams.root_dir, 'val2017'), + annotation_file=os.path.join( + self.hyperparams.ann_dir, 'instances_val2017.json' + ), + transform=self.get_transform(train=False), + ) if stage == "test": self.test_dataset = COCODataset( - root_dir=os.path.join(self.hyperparams.root_dir, 'val2017'), - annotation_file=os.path.join(self.hyperparams.ann_dir, 'instances_val2017.json'), - transform=self.get_transform(train=False) - ) + root_dir=os.path.join(self.hyperparams.root_dir, 'val2017'), + annotation_file=os.path.join( + self.hyperparams.ann_dir, 'instances_val2017.json' + ), + transform=self.get_transform(train=False), + ) print('Test dataset size: ', len(self.test_dataset)) - + def train_dataloader(self): train_loader = DataLoader( self.train_dataset, @@ -116,6 +124,7 @@ def test_dataloader(self): ) return test_loader + class COCODataset(Dataset): def __init__(self, root_dir, annotation_file, transform=None, sam_transform=None): @@ -125,7 +134,11 @@ def __init__(self, root_dir, annotation_file, transform=None, sam_transform=None self.image_ids = list(self.coco.imgs.keys()) # Filter out image_ids without any annotations - self.image_ids = [image_id for image_id in self.image_ids if len(self.coco.getAnnIds(imgIds=image_id)) > 0] + self.image_ids = [ + image_id + for image_id in self.image_ids + if len(self.coco.getAnnIds(imgIds=image_id)) > 0 + ] def __len__(self): return len(self.image_ids) @@ -150,7 +163,7 @@ def __getitem__(self, idx): # there are degenerate boxes in the dataset, skip them if ann['area'] <= 0 or w < 1 or h < 1: continue - bboxes.append([x, y, x + w, y + h]) # NOTE: origin is left top corner + bboxes.append([x, y, x + w, y + h]) # NOTE: origin is left top corner labels.append(ann['category_id']) mask = self.coco.annToMask(ann) masks.append(mask) @@ -163,10 +176,12 @@ def __getitem__(self, idx): target['labels'] = torch.as_tensor(labels, dtype=torch.int64) target['masks'] = torch.as_tensor(np.array(masks), dtype=torch.uint8) target['image_id'] = torch.as_tensor(image_ids[0], dtype=torch.int64) - target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * (target['boxes'][:, 2] - target['boxes'][:, 0]) + target['area'] = (target['boxes'][:, 3] - target['boxes'][:, 1]) * ( + target['boxes'][:, 2] - target['boxes'][:, 0] + ) target['iscrowd'] = torch.as_tensor(iscrowds, dtype=torch.int64) if self.transform is not None: image, target = self.transform(image, target) - return image, target \ No newline at end of file + return image, target diff --git a/examples/images/segmentation/prepare/vision_transforms.py b/examples/images/segmentation/prepare/vision_transforms.py index 7e993e0..1949df9 100644 --- a/examples/images/segmentation/prepare/vision_transforms.py +++ b/examples/images/segmentation/prepare/vision_transforms.py @@ -1,9 +1,10 @@ from typing import Dict, Optional, Tuple import torch -from torch import nn, Tensor +from torch import Tensor, nn from torchvision import ops -from torchvision.transforms import functional as F, transforms as T +from torchvision.transforms import functional as F +from torchvision.transforms import transforms as T def _flip_coco_person_keypoints(kps, width): @@ -61,4 +62,4 @@ def forward( self, image: Tensor, target: Optional[Dict[str, Tensor]] = None ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: image = F.convert_image_dtype(image, self.dtype) - return image, target \ No newline at end of file + return image, target diff --git a/examples/images/segmentation/train.py b/examples/images/segmentation/train.py index 8834e1e..4a72b97 100644 --- a/examples/images/segmentation/train.py +++ b/examples/images/segmentation/train.py @@ -1,23 +1,34 @@ import os -import torch -import wandb import hydra import omegaconf -from omegaconf import DictConfig, OmegaConf - import pytorch_lightning as pl +import torch +import wandb +from omegaconf import DictConfig, OmegaConf from pytorch_lightning.loggers import WandbLogger - from train_utils import get_model_data_and_callbacks, get_trainer, load_envs + def train_images(hyperparams: DictConfig): - hyperparams['canonicalization_type'] = hyperparams['canonicalization']['canonicalization_type'] + hyperparams['canonicalization_type'] = hyperparams['canonicalization'][ + 'canonicalization_type' + ] hyperparams['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' - hyperparams['dataset']['data_path'] = hyperparams['dataset']['data_path'] + "/" + hyperparams['dataset']['dataset_name'] - hyperparams['checkpoint']['checkpoint_path'] = hyperparams['checkpoint']['checkpoint_path'] + "/" + \ - hyperparams['dataset']['dataset_name'] + "/" + hyperparams['canonicalization_type'] \ - + "/" + hyperparams['prediction']['prediction_network_architecture'] + hyperparams['dataset']['data_path'] = ( + hyperparams['dataset']['data_path'] + + "/" + + hyperparams['dataset']['dataset_name'] + ) + hyperparams['checkpoint']['checkpoint_path'] = ( + hyperparams['checkpoint']['checkpoint_path'] + + "/" + + hyperparams['dataset']['dataset_name'] + + "/" + + hyperparams['canonicalization_type'] + + "/" + + hyperparams['prediction']['prediction_network_architecture'] + ) # set system environment variables for wandb if hyperparams['wandb']['use_wandb']: @@ -27,19 +38,30 @@ def train_images(hyperparams: DictConfig): print("Wandb disabled for logging...") os.environ["WANDB_MODE"] = "disabled" os.environ["WANDB_DIR"] = hyperparams['wandb']['wandb_dir'] - os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] - + os.environ["WANDB_CACHE_DIR"] = hyperparams['wandb']['wandb_cache_dir'] + # initialize wandb - wandb.init(config=OmegaConf.to_container(hyperparams, resolve=True), entity=hyperparams['wandb']['wandb_entity'], project=hyperparams['wandb']['wandb_project'], dir=hyperparams['wandb']['wandb_dir']) - wandb_logger = WandbLogger(project=hyperparams['wandb']['wandb_project'], log_model="all") + wandb.init( + config=OmegaConf.to_container(hyperparams, resolve=True), + entity=hyperparams['wandb']['wandb_entity'], + project=hyperparams['wandb']['wandb_project'], + dir=hyperparams['wandb']['wandb_dir'], + ) + wandb_logger = WandbLogger( + project=hyperparams['wandb']['wandb_project'], log_model="all" + ) # set seed pl.seed_everything(hyperparams.experiment.seed) - + # get model, callbacks, and image data model, image_data, callbacks = get_model_data_and_callbacks(hyperparams) - - if hyperparams.canonicalization_type in ("group_equivariant", "opt_equivariant", "steerable"): + + if hyperparams.canonicalization_type in ( + "group_equivariant", + "opt_equivariant", + "steerable", + ): wandb.watch(model.canonicalizer.canonicalization_network, log='all') # get trainer @@ -47,18 +69,21 @@ def train_images(hyperparams: DictConfig): if hyperparams.experiment.run_mode == "train": trainer.fit(model, datamodule=image_data) - + elif hyperparams.experiment.run_mode == "auto_tune": trainer.tune(model, datamodule=image_data) trainer.test(model, datamodule=image_data) + # load the variables from .env file load_envs() + @hydra.main(config_path=str("./configs/"), config_name="default") def main(cfg: omegaconf.DictConfig): train_images(cfg) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/images/segmentation/train_utils.py b/examples/images/segmentation/train_utils.py index 954a296..75b0f1e 100644 --- a/examples/images/segmentation/train_utils.py +++ b/examples/images/segmentation/train_utils.py @@ -1,77 +1,93 @@ -import dotenv -from omegaconf import DictConfig from typing import Dict, Optional +import dotenv import pytorch_lightning as pl -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping - from model import ImageSegmentationPipeline +from omegaconf import DictConfig from prepare import COCODataModule - -def get_model_data_and_callbacks(hyperparams : DictConfig): - - # get image data +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint + + +def get_model_data_and_callbacks(hyperparams: DictConfig): + + # get image data image_data = get_image_data(hyperparams.dataset) - + # checkpoint name hyperparams.checkpoint.checkpoint_name = get_checkpoint_name(hyperparams) - + # checkpoint callbacks callbacks = get_callbacks(hyperparams) - # get model pipeline + # get model pipeline model = get_model_pipeline(hyperparams) - - return model, image_data, callbacks + + return model, image_data, callbacks + def get_model_pipeline(hyperparams: DictConfig): if hyperparams.experiment.run_mode == "test": model = ImageSegmentationPipeline.load_from_checkpoint( - checkpoint_path=hyperparams.checkpoint.checkpoint_path + "/" + \ - hyperparams.checkpoint.checkpoint_name + ".ckpt", - hyperparams=hyperparams + checkpoint_path=hyperparams.checkpoint.checkpoint_path + + "/" + + hyperparams.checkpoint.checkpoint_name + + ".ckpt", + hyperparams=hyperparams, ) model.freeze() model.eval() else: model = ImageSegmentationPipeline(hyperparams) - + return model + def get_trainer( - hyperparams: DictConfig, - callbacks: list, - wandb_logger: pl.loggers.WandbLogger + hyperparams: DictConfig, callbacks: list, wandb_logger: pl.loggers.WandbLogger ): if hyperparams.experiment.run_mode == "auto_tune": trainer = pl.Trainer( - max_epochs=hyperparams.experiment.num_epochs, accelerator="auto", - auto_scale_batch_size=True, auto_lr_find=True, logger=wandb_logger, - callbacks=callbacks, deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, - strategy='ddp' + max_epochs=hyperparams.experiment.num_epochs, + accelerator="auto", + auto_scale_batch_size=True, + auto_lr_find=True, + logger=wandb_logger, + callbacks=callbacks, + deterministic=hyperparams.experiment.deterministic, + num_nodes=hyperparams.experiment.num_nodes, + devices=hyperparams.experiment.num_gpus, + strategy='ddp', ) - + elif hyperparams.experiment.run_mode == "dryrun": trainer = pl.Trainer( - fast_dev_run=5, max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", - limit_train_batches=5, limit_val_batches=5, logger=wandb_logger, - callbacks=callbacks, deterministic=hyperparams.experiment.deterministic + fast_dev_run=5, + max_epochs=hyperparams.experiment.training.num_epochs, + accelerator="auto", + limit_train_batches=5, + limit_val_batches=5, + logger=wandb_logger, + callbacks=callbacks, + deterministic=hyperparams.experiment.deterministic, ) else: trainer = pl.Trainer( - max_epochs=hyperparams.experiment.training.num_epochs, accelerator="auto", - logger=wandb_logger, callbacks=callbacks, deterministic=hyperparams.experiment.deterministic, - num_nodes=hyperparams.experiment.num_nodes, devices=hyperparams.experiment.num_gpus, - strategy='ddp' + max_epochs=hyperparams.experiment.training.num_epochs, + accelerator="auto", + logger=wandb_logger, + callbacks=callbacks, + deterministic=hyperparams.experiment.deterministic, + num_nodes=hyperparams.experiment.num_nodes, + devices=hyperparams.experiment.num_gpus, + strategy='ddp', ) return trainer - - + + def get_callbacks(hyperparams: DictConfig): - + checkpoint_callback = ModelCheckpoint( dirpath=hyperparams.checkpoint.checkpoint_path, filename=hyperparams.checkpoint.checkpoint_name, @@ -79,14 +95,17 @@ def get_callbacks(hyperparams: DictConfig): mode="max", save_on_train_epoch_end=False, ) - early_stop_metric_callback = EarlyStopping(monitor="val/map", - min_delta=hyperparams.experiment.training.min_delta, - patience=hyperparams.experiment.training.patience, - verbose=True, - mode="max") - + early_stop_metric_callback = EarlyStopping( + monitor="val/map", + min_delta=hyperparams.experiment.training.min_delta, + patience=hyperparams.experiment.training.patience, + verbose=True, + mode="max", + ) + return [checkpoint_callback, early_stop_metric_callback] + def get_recursive_hyperparams_identifier(hyperparams: Dict): # get the identifier for the canonicalization network hyperparameters # recursively go through the dictionary and get the values and concatenate them @@ -97,24 +116,29 @@ def get_recursive_hyperparams_identifier(hyperparams: Dict): else: identifier += f"_{key}_{value}_" return identifier - -def get_checkpoint_name(hyperparams : DictConfig): - - return f"{get_recursive_hyperparams_identifier(hyperparams.canonicalization)}".lstrip("_") + \ - f"__epochs_{hyperparams.experiment.training.num_epochs}_" + f"__seed_{hyperparams.experiment.seed}" - + + +def get_checkpoint_name(hyperparams: DictConfig): + + return ( + f"{get_recursive_hyperparams_identifier(hyperparams.canonicalization)}".lstrip( + "_" + ) + + f"__epochs_{hyperparams.experiment.training.num_epochs}_" + + f"__seed_{hyperparams.experiment.seed}" + ) + def get_image_data(dataset_hyperparams: DictConfig): - - dataset_classes = { - "coco2017": COCODataModule - } - + + dataset_classes = {"coco2017": COCODataModule} + if dataset_hyperparams.dataset_name not in dataset_classes: raise ValueError(f"{dataset_hyperparams.dataset_name} not implemented") - + return dataset_classes[dataset_hyperparams.dataset_name](dataset_hyperparams) + def load_envs(env_file: Optional[str] = None) -> None: """ Load all the environment variables defined in the `env_file`. @@ -125,4 +149,4 @@ def load_envs(env_file: Optional[str] = None) -> None: :param env_file: the file that defines the environment variables to use. If None it searches for a `.env` file in the project. """ - dotenv.load_dotenv(dotenv_path=env_file, override=True) \ No newline at end of file + dotenv.load_dotenv(dotenv_path=env_file, override=True) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..a8d8739 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +# AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! +requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] +# For smarter version schemes and other configuration options, +# check out https://github.com/pypa/setuptools_scm +version_scheme = "no-guess-dev" + +[tool.black] +skip-string-normalization = true + +[tool.mypy] +exclude = ['docs'] diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..94d11bd --- /dev/null +++ b/setup.cfg @@ -0,0 +1,116 @@ +# This file is used to configure your project. +# Read more about the various options under: +# https://setuptools.pypa.io/en/latest/userguide/declarative_config.html +# https://setuptools.pypa.io/en/latest/references/keywords.html + +[metadata] +name = equiadapt +description = Library that provides metrics to assess representation quality +author = Arnab Mondal +author_email = arnab.mondal@mila.quebec +license = MIT +license_files = LICENSE +long_description = file: README.md +long_description_content_type = text/markdown; charset=UTF-8; variant=GFM +url = https://github.com/arnab39/EquivariantAdaptation/ +# Add here related links, for example: +project_urls = + Tracker = https://github.com/arnab39/EquivariantAdaptation/issues + Source = https://github.com/arnab39/EquivariantAdaptation/ + +# Change if running only on Windows, Mac or Linux (comma-separated) +platforms = Linux + +# Add here all kinds of additional classifiers as defined under +# https://pypi.org/classifiers/ +classifiers = + Programming Language :: Python :: 3 + License :: OSI Approved :: MIT License + Operating System :: Linux + +[options] +zip_safe = False +packages = find_namespace: +include_package_data = True + +# Require a min/specific Python version (comma-separated conditions) +python_requires = >=3.7 + +# Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. +# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in +# new major versions. This works if the required packages follow Semantic Versioning. +# For more information, check out https://semver.org/. +install_requires = + torch + numpy + torchvision + kornia + escnn @ git+https://github.com/danibene/escnn.git@remove/py3nj_dep + +[options.packages.find] +exclude = + tests + +[options.extras_require] +# Add here additional requirements for extra features, to install with: +# `pip install equiadapt[PDF]` like: +# PDF = ReportLab; RXP + +# Add here test requirements (semicolon/line-separated) +testing = + setuptools + pytest + pytest-cov + +[options.entry_points] +# Add here console scripts like: +# console_scripts = +# script_name = equiadapt.module:function + +[tool:pytest] +# Specify command line options as you would do when invoking pytest directly. +# e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml +# in order to write a coverage file that can be read by Jenkins. +# CAUTION: --cov flags may prohibit setting breakpoints while debugging. +# Comment those flags to avoid this pytest issue. +addopts = + --cov equiadapt --cov-report term-missing + --verbose +norecursedirs = + dist + build + .tox +testpaths = tests +# Use pytest markers to select/deselect specific tests +# markers = +# slow: mark tests as slow (deselect with '-m "not slow"') +# system: mark end-to-end system tests + +[devpi:upload] +# Options for the devpi: PyPI server and packaging tool +# VCS export must be deactivated since we are using setuptools-scm +no_vcs = 1 +formats = bdist_wheel + +[flake8] +# Some sane defaults for the code style checker flake8 +max_line_length = 88 +extend_ignore = E203, W503 +# ^ Black-compatible +# E203 and W503 have edge cases handled by black +exclude = + .tox + build + dist + .eggs + docs/conf.py + +[pyscaffold] +# PyScaffold's parameters when the project was created. +# This will be used when updating. Do not change! +version = 4.5 +package = equiadapt +extensions = + github_actions + markdown + pre_commit diff --git a/setup.py b/setup.py index 8c21c04..477bbff 100644 --- a/setup.py +++ b/setup.py @@ -1,26 +1,22 @@ -from setuptools import setup, find_packages +""" + Setup file for equiadapt. + Use setup.cfg to configure your project. -setup( - name='equiadapt', # Replace with your package's name - version='0.1.0', # Package version - author='Arnab Mondal', # Replace with your name - author_email='arnab.mondal@mila.quebec', # Replace with your email - description='Library to make any existing neural network architecture equivariant', # Package summary - long_description=open('README.md').read(), - long_description_content_type='text/markdown', - url='https://github.com/arnab39/EquivariantAdaptation', # Replace with your repository URL - packages=find_packages(), - install_requires=[ - 'torch', # Specify your project's dependencies here - 'numpy', - 'torchvision', - 'kornia', - 'escnn @ git+https://github.com/danibene/escnn.git@remove/py3nj_dep' - ], - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - ], - python_requires='>=3.7', # Minimum version requirement of Python -) + This file was generated with PyScaffold 4.5. + PyScaffold helps you to put up the scaffold of your new Python project. + Learn more under: https://pyscaffold.org/ +""" + +from setuptools import setup + +if __name__ == "__main__": + try: + setup(use_scm_version={"version_scheme": "no-guess-dev"}) + except: # noqa + print( + "\n\nAn error occurred while building the project, " + "please ensure you have the most updated version of setuptools, " + "setuptools_scm and wheel with:\n" + " pip install -U setuptools setuptools_scm wheel\n\n" + ) + raise diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..08c21bc --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,10 @@ +""" + Dummy conftest.py for equiadapt. + + If you don't know what this is for, just leave it empty. + Read more about conftest.py under: + - https://docs.pytest.org/en/stable/fixture.html + - https://docs.pytest.org/en/stable/writing_plugins.html +""" + +# import pytest diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..4913c70 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,12 @@ +import torch + +from equiadapt.common.utils import gram_schmidt + + +def test_gram_schmidt() -> None: + torch.manual_seed(0) + vectors = torch.randn(1, 3, 3) # batch of 1, 3 vectors of dimension 3 + + output = gram_schmidt(vectors) + + assert torch.allclose(output[0][0][0], torch.tensor(0.5740), atol=1e-4) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..69f8159 --- /dev/null +++ b/tox.ini @@ -0,0 +1,93 @@ +# Tox configuration file +# Read more under https://tox.wiki/ +# THIS SCRIPT IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! + +[tox] +minversion = 3.24 +envlist = default +isolated_build = True + + +[testenv] +description = Invoke pytest to run automated tests +setenv = + TOXINIDIR = {toxinidir} +passenv = + HOME + SETUPTOOLS_* +extras = + testing +commands = + pytest {posargs} + + +# # To run `tox -e lint` you need to make sure you have a +# # `.pre-commit-config.yaml` file. See https://pre-commit.com +# [testenv:lint] +# description = Perform static analysis and style checks +# skip_install = True +# deps = pre-commit +# passenv = +# HOMEPATH +# PROGRAMDATA +# SETUPTOOLS_* +# commands = +# pre-commit run --all-files {posargs:--show-diff-on-failure} + + +[testenv:{build,clean}] +description = + build: Build the package in isolation according to PEP517, see https://github.com/pypa/build + clean: Remove old distribution files and temporary build artifacts (./build and ./dist) +# https://setuptools.pypa.io/en/stable/build_meta.html#how-to-use-it +skip_install = True +changedir = {toxinidir} +deps = + build: build[virtualenv] +passenv = + SETUPTOOLS_* +commands = + clean: python -c 'import shutil; [shutil.rmtree(p, True) for p in ("build", "dist", "docs/_build")]' + clean: python -c 'import pathlib, shutil; [shutil.rmtree(p, True) for p in pathlib.Path("src").glob("*.egg-info")]' + build: python -m build {posargs} +# By default, both `sdist` and `wheel` are built. If your sdist is too big or you don't want +# to make it available, consider running: `tox -e build -- --wheel` + + +[testenv:{docs,doctests,linkcheck}] +description = + docs: Invoke sphinx-build to build the docs + doctests: Invoke sphinx-build to run doctests + linkcheck: Check for broken links in the documentation +passenv = + SETUPTOOLS_* +setenv = + DOCSDIR = {toxinidir}/docs + BUILDDIR = {toxinidir}/docs/_build + docs: BUILD = html + doctests: BUILD = doctest + linkcheck: BUILD = linkcheck +deps = + -r {toxinidir}/docs/requirements.txt + # ^ requirements.txt shared with Read The Docs +commands = + sphinx-build --color -b {env:BUILD} -d "{env:BUILDDIR}/doctrees" "{env:DOCSDIR}" "{env:BUILDDIR}/{env:BUILD}" {posargs} + + +[testenv:publish] +description = + Publish the package you have been developing to a package index server. + By default, it uses testpypi. If you really want to publish your package + to be publicly accessible in PyPI, use the `-- --repository pypi` option. +skip_install = True +changedir = {toxinidir} +passenv = + # See: https://twine.readthedocs.io/en/latest/ + TWINE_USERNAME + TWINE_PASSWORD + TWINE_REPOSITORY + TWINE_REPOSITORY_URL +deps = twine +commands = + python -m twine check dist/* + python -m twine upload {posargs:--repository {env:TWINE_REPOSITORY:testpypi}} dist/*