diff --git a/.devcontainer/.vscode/launch.json b/.devcontainer/.vscode/launch.json new file mode 100644 index 0000000000..f682b56388 --- /dev/null +++ b/.devcontainer/.vscode/launch.json @@ -0,0 +1,24 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File (just my code)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + }, + { + "name": "Python: Current File (all)", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000000..414f2d0292 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,5 @@ +FROM mcr.microsoft.com/devcontainers/python:1-3.10-bookworm +RUN apt-get update \ + && export DEBIAN_FRONTEND=noninteractive && apt-get install -y libboost-dev \ + && apt-get clean && rm -rf /var/cache/apt/* && rm -rf /var/lib/apt/lists/* && rm -rf /tmp/* +RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR="/bin" sh diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000000..7dc4b2f08c --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,49 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/python +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "build": { + "dockerfile": "Dockerfile" + }, + + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "bash .devcontainer/setup.sh", + + "containerEnv": { + "PRE_COMMIT_HOME": "/workspaces/gt4py/.caches/pre-commit" + }, + + // Configure tool-specific properties. + "customizations": { + // Configure properties specific to VS Code. + "vscode": { + // Set *default* container specific settings.json values on container create. + "settings": { + "python.formatting.provider": "ruff", + "python.testing.pytestEnabled": true, + "python.defaultInterpreterPath": "/workspaces/gt4py/.venv/bin/python", + "files.insertFinalNewline": true, + "python.terminal.activateEnvironment": true, + "cmake.ignoreCMakeListsMissing": true + }, + "extensions": [ + "charliermarsh.ruff", + "donjayamanne.githistory", + "github.vscode-github-actions", + "lextudio.restructuredtext", + "ms-python.python", + "ms-vsliveshare.vsliveshare", + "swyddfa.esbonio" + ] + } + } + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} diff --git a/.devcontainer/setup.sh b/.devcontainer/setup.sh new file mode 100755 index 0000000000..d23dda9dea --- /dev/null +++ b/.devcontainer/setup.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +ln -sfn /workspaces/gt4py/.devcontainer/.vscode /workspaces/gt4py/.vscode +uv venv .venv +source .venv/bin/activate +uv pip install -r requirements-dev.txt +uv pip install -e . +uv pip install -i https://test.pypi.org/simple/ atlas4py +pre-commit install --install-hooks +deactivate diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7284a7df04..83304a9c62 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -15,7 +15,7 @@ Delete this comment and add a proper description of the changes contained in thi - test: Adding missing tests or correcting existing tests : cartesian | eve | next | storage - # ONLY if changes are limited to a specific subsytem + # ONLY if changes are limited to a specific subsystem - PR Description: @@ -27,7 +27,7 @@ Delete this comment and add a proper description of the changes contained in thi ## Requirements - [ ] All fixes and/or new features come with corresponding tests. -- [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. +- [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. If this PR contains code authored by new contributors please make sure: diff --git a/.github/workflows/_disabled/gt4py-sphinx.yml b/.github/workflows/_disabled/gt4py-sphinx.yml index d862ab7321..2533b2a42d 100644 --- a/.github/workflows/_disabled/gt4py-sphinx.yml +++ b/.github/workflows/_disabled/gt4py-sphinx.yml @@ -4,11 +4,9 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 concurrency: group: ${{ github.workflow }}-${{ github.ref }} @@ -22,7 +20,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.10 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 2137cd871a..10bb537e3e 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -4,24 +4,26 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 jobs: code-quality: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 + with: + python-version-file: ".python-version" + + - name: Install uv + uses: astral-sh/setup-uv@v5 with: - python-version: "3.10" - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - uses: pre-commit/action@v3.0.0 + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: "Run pre-commit" + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 42f96659e0..a2a52ce1ff 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -5,7 +5,8 @@ on: - cron: '0 4 * * *' workflow_dispatch: - ## COMMENTED OUT: only for testing CI action changes + ## COMMENTED OUT: only for testing CI action changes. + ## It only works for PRs to `main` branch from branches in the upstream gt4py repo. # pull_request: # branches: # - main @@ -15,113 +16,87 @@ jobs: daily-ci: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - tox-module-factor: ["cartesian", "eve", "next", "storage"] - os: ["ubuntu-latest"] - requirements-file: ["requirements-dev.txt", "min-requirements-test.txt", "min-extra-requirements-test.txt"] + # dependencies-strategy -> The strategy that `uv lock` should use to select + # between the different compatible versions for a given package requirement + # [arg: --resolution, env: UV_RESOLUTION=] + dependencies-strategy: ["lowest-direct", "highest"] + gt4py-module: ["cartesian", "eve", "next", "storage"] + os: ["ubuntu-latest"] #, "macos-latest"] + python-version: ["3.10", "3.11"] fail-fast: false runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - name: Install C++ libraries if: ${{ matrix.os == 'macos-latest' }} shell: bash - run: | - brew install boost + run: brew install boost + - name: Install C++ libraries if: ${{ matrix.os == 'ubuntu-latest' }} shell: bash - run: | - sudo apt install libboost-dev - wget https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.gz - echo 7bd7ddceec1a1dfdcbdb3e609b60d01739c38390a5f956385a12f3122049f0ca boost_1_76_0.tar.gz > boost_hash.txt - sha256sum -c boost_hash.txt - tar xzf boost_1_76_0.tar.gz - mkdir -p boost/include - mv boost_1_76_0/boost boost/include/ - echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + run: sudo apt install libboost-dev + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install tox - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel tox - python -m pip list - - name: Update requirements - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e requirements-py${pyversion_no_dot} - # TODO(egparedes): add notification for dependencies updates - # - name: Check for updated requirements - # id: update-requirements - # continue-on-error: true - # if: ${{ matrix.python-version == '3.8' && matrix.tox-module-factor == 'cartesian' }} - # shell: bash - # run: | - # if diff -q constraints.txt CURRENT-constraints.txt; then - # echo "REQS_DIFF=''" >> $GITHUB_OUTPUT - # else - # diff --changed-group-format='%<' --unchanged-group-format='' constraints.txt CURRENT-constraints.txt | tr '\n' ' ' > constraints.txt.diff - # echo "REQS_DIFF='$(cat constraints.txt.diff)'" >> $GITHUB_OUTPUT - # fi - # echo "REQS_DIFF_TEST="FOOOOOOOO" >> $GITHUB_OUTPUT - # - name: Notify updated requirements (if any) - # if: ${{ steps.update-requirements.outputs.REQS_DIFF }} - # env: - # SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} - # uses: slackapi/slack-github-action@v1.23.0 - # with: - # channel-id: ${{ vars.SLACK_BOT_CHANNEL }} - # payload: | - # { - # "text": "TEXT", - # "blocks": [ - # { - # "type": "section", - # "text": { - # "type": "plain_text", - # "text": "@channel: AA/${{ steps.update-requirements.outputs.REQS_DIFF }}/BB/ ${{ steps.update-requirements.outputs.REQS_DIFF_TEST }} /CC" - # } - # }, - # { - # "type": "section", - # "text": { - # "type": "mrkdwn", - # "text": "@channel: AA/${{ steps.update-requirements.outputs.REQS_DIFF }}/BB/ ${{ steps.update-requirements.outputs.REQS_DIFF_TEST }} /CC" - # } - # } - # ] - # } - - name: Run tests + + - name: Run CPU tests for '${{ matrix.gt4py-module }}' with '${{ matrix.dependencies-strategy }}' resolution strategy env: NUM_PROCESSES: auto - ENV_REQUIREMENTS_FILE: ${{ matrix.requirements-file }} - run: | - tox run --skip-missing-interpreters -m test-${{ matrix.tox-module-factor }}-cpu + UV_RESOLUTION: ${{ matrix.dependencies-strategy }} + run: uv run nox -s 'test_${{ matrix.gt4py-module }}-${{ matrix.python-version }}' -t 'cpu' + - name: Notify slack if: ${{ failure() }} env: SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} uses: slackapi/slack-github-action@v1.23.0 with: - channel-id: ${{ vars.SLACK_BOT_CHANNEL }} + channel-id: ${{ vars.SLACK_BOT_CHANNEL }} # Use SLACK_BOT_CHANNEL_TEST for testing + payload: | + { + "text": "Failed tests for ${{ github.workflow }} (dependencies-strategy=${{ matrix.dependencies-strategy }}, python=${{ matrix.python-version }}, component=${{ matrix.gt4py-module }}) [https://github.com/GridTools/gt4py/actions/runs/${{ github.run_id }}].", + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "Failed tests: " + } + } + ] + } + + weekly-reminder: + runs-on: ubuntu-latest + steps: + - id: get_day_of_the_week + name: Get day of the week + run: echo "day_of_week=$(date +'%u')" >> $GITHUB_OUTPUT + + - name: Weekly notification + if: ${{ env.DAY_OF_WEEK == 1 }} + env: + DAY_OF_WEEK: ${{ steps.get_day_of_the_week.outputs.day_of_week }} + SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} + uses: slackapi/slack-github-action@v1.23.0 + with: + channel-id: ${{ vars.SLACK_BOT_CHANNEL }} # Use SLACK_BOT_CHANNEL_TEST for testing payload: | { - "text": "${{ github.workflow }}: `test-${{ matrix.tox-module-factor }}-cpu (python${{ matrix.python-version }})`>: *Failed tests!*", + "text": "Weekly reminder to check the latest runs of the GT4Py Daily CI workflow at the GitHub Actions dashboard [https://github.com/GridTools/gt4py/actions/workflows/daily-ci.yml].", "blocks": [ { "type": "section", "text": { "type": "mrkdwn", - "text": ": *Failed tests!*" + "text": "Weekly reminder to check the latest runs of the workflow at the GitHub Actions dashboard." } } ] diff --git a/.github/workflows/deploy-release.yml b/.github/workflows/deploy-release.yml index 048a6f73e1..7a7505caa5 100644 --- a/.github/workflows/deploy-release.yml +++ b/.github/workflows/deploy-release.yml @@ -14,9 +14,9 @@ jobs: name: Build Python distribution runs-on: ubuntu-latest steps: - - uses: actions/checkout@master + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install pypa/build @@ -26,7 +26,7 @@ jobs: run: | python -m build --sdist --wheel --outdir dist/ - name: Upload artifact - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: gt4py-dist path: ./dist/** @@ -42,7 +42,7 @@ jobs: id-token: write steps: - name: Download wheel - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: gt4py-dist path: dist @@ -60,7 +60,7 @@ jobs: id-token: write steps: - name: Download wheel - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: name: gt4py-dist path: dist diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml index 45bbdf271a..a846af2e7b 100644 --- a/.github/workflows/test-cartesian-fallback.yml +++ b/.github/workflows/test-cartesian-fallback.yml @@ -4,7 +4,6 @@ on: pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths: # Inverse of corresponding workflow - "src/gt4py/next/**" - "tests/next_tests/**" @@ -14,11 +13,11 @@ on: jobs: test-cartesian: - runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - tox-factor: [internal, dace] - + codegen-factor: [internal, dace] + os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] + runs-on: ${{ matrix.os }} steps: - run: 'echo "No build required"' diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index 5d23577bc9..ea6b7940a3 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -4,12 +4,10 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 - paths-ignore: # Skip if only gt4py.next and irrelevant doc files have been updated + paths-ignore: # Skip when only gt4py.next or doc files have been updated - "src/gt4py/next/**" - "tests/next_tests/**" - "examples/**" @@ -22,41 +20,36 @@ concurrency: jobs: test-cartesian: - runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - tox-factor: [internal, dace] + codegen-factor: [internal, dace] + os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] + fail-fast: false + + runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Install boost + - uses: actions/checkout@v4 + + - name: Install C++ libraries + if: ${{ matrix.os == 'macos-latest' }} + shell: bash + run: brew install boost + + - name: Install C++ libraries + if: ${{ matrix.os == 'ubuntu-latest' }} shell: bash - run: | - wget https://boostorg.jfrog.io/artifactory/main/release/1.76.0/source/boost_1_76_0.tar.gz - echo 7bd7ddceec1a1dfdcbdb3e609b60d01739c38390a5f956385a12f3122049f0ca boost_1_76_0.tar.gz > boost_hash.txt - sha256sum -c boost_hash.txt - tar xzf boost_1_76_0.tar.gz - mkdir -p boost/include - mv boost_1_76_0/boost boost/include/ - echo "BOOST_ROOT=${PWD}/boost" >> $GITHUB_ENV - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + run: sudo apt install libboost-dev + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Test with tox + + - name: Run CPU 'cartesian' tests with nox env: NUM_PROCESSES: auto shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e cartesian-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu + run: uv run nox -s 'test_cartesian-${{ matrix.python-version }}(${{ matrix.codegen-factor }}, cpu)' diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml index 661118e71d..f3dbb58acf 100644 --- a/.github/workflows/test-eve-fallback.yml +++ b/.github/workflows/test-eve-fallback.yml @@ -1,16 +1,17 @@ name: "Fallback: Test Eve" on: + push: + branches: + - main pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths-ignore: # Inverse of corresponding workflow - "src/gt4py/eve/**" - "tests/eve_tests/**" - - "workflows/**" - - "*.cfg" - - "*.ini" + - ".github/workflows/**" + - "*.lock" - "*.toml" - "*.yml" @@ -18,8 +19,8 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index 061f7cd484..aad3971ad0 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -4,17 +4,14 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths: # Run when gt4py.eve files (or package settings) are changed - "src/gt4py/eve/**" - "tests/eve_tests/**" - - "workflows/**" - - "*.cfg" - - "*.ini" + - ".github/workflows/**" + - "*.lock" - "*.toml" - "*.yml" @@ -22,51 +19,23 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] fail-fast: false runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Run tox tests + + - name: Run 'eve' tests with nox env: NUM_PROCESSES: auto shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e eve-py${pyversion_no_dot} - # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json - # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v3 - # with: - # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }} - # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json - # - name: Gather info - # run: | - # echo ${{ github.ref_type }} >> info.txt - # echo ${{ github.ref }} >> info.txt - # echo ${{ github.sha }} >> info.txt - # echo ${{ github.event.number }} >> info.txt - # echo ${{ github.event.pull_request.head.ref }} >> info.txt - # echo ${{ github.event.pull_request.head.sha }} >> info.txt - # echo ${{ github.run_id }} >> info.txt - # - name: Upload info artifact - # uses: actions/upload-artifact@v3 - # with: - # name: info-py${{ matrix.python-version }}-${{ matrix.os }} - # path: info.txt + run: uv run nox -s test_eve-${{ matrix.python-version }} diff --git a/.github/workflows/test-examples.yml b/.github/workflows/test-examples.yml new file mode 100644 index 0000000000..836af45dd1 --- /dev/null +++ b/.github/workflows/test-examples.yml @@ -0,0 +1,44 @@ +name: "Test examples in documentation" + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + test-notebooks: + strategy: + matrix: + os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] + fail-fast: false + + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + + - name: Install C++ libraries + if: ${{ matrix.os == 'macos-latest' }} + shell: bash + run: brew install boost + + - name: Install C++ libraries + if: ${{ matrix.os == 'ubuntu-latest' }} + shell: bash + run: sudo apt install libboost-dev + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + python-version: ${{ matrix.python-version }} + + - name: Run 'docs' nox session + env: + NUM_PROCESSES: auto + shell: bash + run: uv run nox -s 'test_examples-${{ matrix.python-version }}' diff --git a/.github/workflows/test-next-fallback.yml b/.github/workflows/test-next-fallback.yml index b8c39dc0e6..ef8be3df5f 100644 --- a/.github/workflows/test-next-fallback.yml +++ b/.github/workflows/test-next-fallback.yml @@ -4,7 +4,6 @@ on: pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths: # Inverse of corresponding workflow - "src/gt4py/cartesian/**" - "tests/cartesian_tests/**" @@ -16,9 +15,10 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10", "3.11"] - tox-factor: ["nomesh", "atlas"] + codegen-factor: [internal, dace] + mesh-factor: [nomesh, atlas] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 8e05bbc86a..068377c6c7 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -4,12 +4,10 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 - paths-ignore: # Skip if only gt4py.cartesian and irrelevant doc files have been updated + paths-ignore: # Skip when only gt4py.cartesian or doc files have been updated - "src/gt4py/cartesian/**" - "tests/cartesian_tests/**" - "examples/**" @@ -20,63 +18,35 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10", "3.11"] - tox-factor: ["nomesh", "atlas"] + codegen-factor: [internal, dace] + mesh-factor: [nomesh, atlas] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] fail-fast: false runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 + - name: Install C++ libraries if: ${{ matrix.os == 'macos-latest' }} shell: bash - run: | - brew install boost + run: brew install boost + - name: Install C++ libraries if: ${{ matrix.os == 'ubuntu-latest' }} shell: bash - run: | - sudo apt install libboost-dev - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + run: sudo apt install libboost-dev + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - shell: bash - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Run tox tests + + - name: Run CPU 'next' tests with nox env: NUM_PROCESSES: auto shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e next-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu - # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json - # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v3 - # with: - # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu - # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu.json - # - name: Gather info - # run: | - # echo ${{ github.ref_type }} >> info.txt - # echo ${{ github.ref }} >> info.txt - # echo ${{ github.sha }} >> info.txt - # echo ${{ github.event.number }} >> info.txt - # echo ${{ github.event.pull_request.head.ref }} >> info.txt - # echo ${{ github.event.pull_request.head.sha }} >> info.txt - # echo ${{ github.run_id }} >> info.txt - # - name: Upload info artifact - # uses: actions/upload-artifact@v3 - # with: - # name: info-py${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.tox-env-factor }}-cpu - # path: info.txt + run: uv run nox -s 'test_next-${{ matrix.python-version }}(${{ matrix.codegen-factor }}, cpu, ${{ matrix.mesh-factor }})' diff --git a/.github/workflows/test-notebooks.yml b/.github/workflows/test-notebooks.yml deleted file mode 100644 index 39298b5427..0000000000 --- a/.github/workflows/test-notebooks.yml +++ /dev/null @@ -1,44 +0,0 @@ -name: "Test Jupyter Notebooks" - -on: - push: - branches: - - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 - pull_request: - branches: - - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 - -jobs: - test-notebooks: - strategy: - matrix: - python-version: ["3.10", "3.11"] - os: ["ubuntu-latest"] - fail-fast: false - - runs-on: ${{ matrix.os }} - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Run tox tests - env: - NUM_PROCESSES: auto - shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e notebooks-py${pyversion_no_dot} diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml index df861c6468..c913529a1c 100644 --- a/.github/workflows/test-storage-fallback.yml +++ b/.github/workflows/test-storage-fallback.yml @@ -1,17 +1,18 @@ name: "Fallback: Test Storage (CPU)" on: + push: + branches: + - main pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths-ignore: # Inverse of corresponding workflow - "src/gt4py/storage/**" - "src/gt4py/cartesian/backend/**" # For DaCe storages - "tests/storage_tests/**" - - "workflows/**" - - "*.cfg" - - "*.ini" + - ".github/workflows/**" + - "*.lock" - "*.toml" - "*.yml" @@ -19,9 +20,8 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - tox-factor: [internal, dace] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index e76526c296..b2bb09dfcc 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -4,18 +4,15 @@ on: push: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 pull_request: branches: - main - - gtir # TODO(tehrengruber): remove after GTIR refactoring #1582 paths: # Run when gt4py.storage files (or package settings) are changed - "src/gt4py/storage/**" - "src/gt4py/cartesian/backend/**" # For DaCe storages - "tests/storage_tests/**" - - "workflows/**" - - "*.cfg" - - "*.ini" + - ".github/workflows/**" + - "*.lock" - "*.toml" - "*.yml" @@ -23,52 +20,23 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] - tox-factor: [internal, dace] os: ["ubuntu-latest"] + python-version: ["3.10", "3.11"] fail-fast: false runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + + - name: Install uv and set the python version + uses: astral-sh/setup-uv@v5 with: + enable-cache: true + cache-dependency-glob: "uv.lock" python-version: ${{ matrix.python-version }} - cache: 'pip' - cache-dependency-path: | - **/pyproject.toml - **/constraints.txt - **/requirements-dev.txt - - name: Install python dependencies - run: | - python -m pip install -c ./constraints.txt pip setuptools wheel - python -m pip install -r ./requirements-dev.txt - - name: Run tox tests + + - name: Run CPU 'storage' tests with nox env: NUM_PROCESSES: auto shell: bash - run: | - pyversion=${{ matrix.python-version }} - pyversion_no_dot=${pyversion//./} - tox run -e storage-py${pyversion_no_dot}-${{ matrix.tox-factor }}-cpu - # mv coverage.json coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json - # - name: Upload coverage.json artifact - # uses: actions/upload-artifact@v3 - # with: - # name: coverage-py${{ matrix.python-version }}-${{ matrix.os }} - # path: coverage-py${{ matrix.python-version }}-${{ matrix.os }}.json - # - name: Gather info - # run: | - # echo ${{ github.ref_type }} >> info.txt - # echo ${{ github.ref }} >> info.txt - # echo ${{ github.sha }} >> info.txt - # echo ${{ github.event.number }} >> info.txt - # echo ${{ github.event.pull_request.head.ref }} >> info.txt - # echo ${{ github.event.pull_request.head.sha }} >> info.txt - # echo ${{ github.run_id }} >> info.txt - # - name: Upload info artifact - # uses: actions/upload-artifact@v3 - # with: - # name: info-py${{ matrix.python-version }}-${{ matrix.os }} - # path: info.txt + run: uv run nox -s 'test_storage-${{ matrix.python-version }}(cpu)' diff --git a/.gitignore b/.gitignore index 5792b8a9b7..ebbbfaebeb 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ _local /src/__init__.py /tests/__init__.py .gt_cache/ +.gt4py_cache/ .gt_cache_pytest*/ # DaCe @@ -159,5 +160,5 @@ venv.bak/ ### Others ### .obsidian - coverage.json +.caches diff --git a/.gitpod.Dockerfile b/.gitpod.Dockerfile index 967ae36f2e..5d02a0f436 100644 --- a/.gitpod.Dockerfile +++ b/.gitpod.Dockerfile @@ -1,8 +1,6 @@ -FROM gitpod/workspace-python +FROM gitpod/workspace-python-3.11 USER root RUN apt-get update \ && apt-get install -y libboost-dev \ && apt-get clean && rm -rf /var/cache/apt/* && rm -rf /var/lib/apt/lists/* && rm -rf /tmp/* USER gitpod -RUN pyenv install 3.10.2 -RUN pyenv global 3.10.2 diff --git a/.gitpod/.vscode/launch.json b/.gitpod/.vscode/launch.json index f682b56388..b25a182648 100644 --- a/.gitpod/.vscode/launch.json +++ b/.gitpod/.vscode/launch.json @@ -6,7 +6,7 @@ "configurations": [ { "name": "Python: Current File (just my code)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", @@ -14,11 +14,20 @@ }, { "name": "Python: Current File (all)", - "type": "python", + "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", "justMyCode": false + }, + { + "name": "Python: Debug Tests", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "integratedTerminal", + "justMyCode": true } ] } diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e0314bca3..173997849a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,21 +1,23 @@ -# ----------------------------------------------------------------------- -# This file contains 'cog' snippets (https://nedbatchelder.com/code/cog/) -# to keep version numbers in sync with 'constraints.txt' -# ----------------------------------------------------------------------- - default_language_version: python: python3.10 +minimum_pre_commit_version: 3.8.0 repos: # - repo: meta # hooks: # - id: check-hooks-apply # - id: check-useless-excludes + +- repo: https://github.com/astral-sh/uv-pre-commit + # uv version. + rev: 0.5.25 + hooks: + - id: uv-lock + - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks rev: v2.6.0 hooks: - id: pretty-format-ini args: [--autofix] - exclude: tox.ini - id: pretty-format-toml args: [--autofix] exclude: tach.toml @@ -43,89 +45,27 @@ repos: - id: check-merge-conflict - id: check-toml - id: check-yaml - - id: debug-statements - repo: https://github.com/astral-sh/ruff-pre-commit - ##[[[cog - ## import re - ## version = re.search('ruff==([0-9\.]*)', open("constraints.txt").read())[1] - ## print(f"rev: v{version}") - ##]]] - rev: v0.6.4 - ##[[[end]]] + rev: v0.8.6 hooks: - # Run the linter. - # TODO: include tests here - id: ruff - files: ^src/ + files: ^src/ # TODO(egparedes): also add the `tests` folder here args: [--fix] - # Run the formatter. - id: ruff-format - repo: https://github.com/gauge-sh/tach-pre-commit - rev: v0.10.7 + rev: v0.23.0 hooks: - id: tach -- repo: https://github.com/pre-commit/mirrors-mypy - ##[[[cog - ## import re - ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] - ## print(f"#========= FROM constraints.txt: v{version} =========") - ##]]] - #========= FROM constraints.txt: v1.11.2 ========= - ##[[[end]]] - rev: v1.11.2 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) +- repo: local hooks: - id: mypy - additional_dependencies: # versions from constraints.txt - ##[[[cog - ## import re, sys - ## if sys.version_info >= (3, 11): - ## import tomllib - ## else: - ## import tomli as tomllib - ## constraints = open("constraints.txt").read() - ## project = tomllib.loads(open("pyproject.toml").read()) - ## packages = [re.match('^([\w-][\w\d-]*)', r)[1] for r in project["project"]["dependencies"] if r.strip()] - ## for pkg in packages: - ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) - ##]]] - - astunparse==1.6.3 - - attrs==24.2.0 - - black==24.8.0 - - boltons==24.0.0 - - cached-property==1.5.2 - - click==8.1.7 - - cmake==3.30.3 - - cytoolz==0.12.3 - - deepdiff==8.0.1 - - devtools==0.12.2 - - factory-boy==3.3.1 - - frozendict==2.4.4 - - gridtools-cpp==2.3.4 - - importlib-resources==6.4.5 - - jinja2==3.1.4 - - lark==1.2.2 - - mako==1.3.5 - - nanobind==2.1.0 - - ninja==1.11.1.1 - - numpy==1.24.4 - - packaging==24.1 - - pybind11==2.13.5 - - setuptools==74.1.2 - - tabulate==0.9.0 - - typing-extensions==4.12.2 - - xxhash==3.0.0 - ##[[[end]]] - - types-tabulate - - types-typed-ast - args: [--no-install-types] - exclude: | - (?x)^( - setup.py | - build/.* | - ci/.* | - docs/.* | - tests/.* - )$ + name: mypy static type checker + entry: uv run --frozen mypy --no-install-types src/ + language: system + types_or: [python, pyi] + pass_filenames: false + require_serial: true + stages: [pre-commit] diff --git a/.python-version b/.python-version new file mode 100644 index 0000000000..c8cfe39591 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.10 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 15e139a53e..28134a61b9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -55,7 +55,7 @@ Ready to start contributing? We use a [fork and pull request](https://www.atlass 3. Follow instructions in the [README.md](README.md) file to set up an environment for local development. For example: ```bash - $ tox --devenv .venv + $ uv sync --extra all $ source .venv/bin/activate ``` @@ -67,11 +67,11 @@ Ready to start contributing? We use a [fork and pull request](https://www.atlass Now you can make your changes locally. Make sure you follow the project code style documented in [CODING_GUIDELINES.md](CODING_GUIDELINES.md). -5. When you're done making changes, check that your code complies with the project code style and other quality assurance (QA) practices using `pre-commit`. Additionally, make sure that unit and regression tests pass for all supported Python versions by running `tox`: +5. When you're done making changes, check that your code complies with the project code style and other quality assurance (QA) practices using `pre-commit`. Additionally, make sure that unit and regression tests pass for all supported Python versions by running `nox`: ```bash $ pre-commit run - $ tox + $ nox ``` Read [Testing](#testing) section below for further details. @@ -143,21 +143,21 @@ pytest -v -l -s tests/ Check `pytest` documentation (`pytest --help`) for all the options to select and execute tests. -We recommended you to use `tox` for most development-related tasks, like running the complete test suite in different environments. `tox` runs the package installation script in properly isolated environments to run tests (or other tasks) in a reproducible way. A simple way to start with tox could be: +We recommended you to use `nox` for running the test suite in different environments. `nox` runs the package installation script in properly isolated environments to run tests in a reproducible way. A simple way to start with `nox` would be: ```bash -# List all the available task environments -tox list +# List all available sessions +nox --list -# Run a specific task environment -tox run -e cartesian-py38-internal-cpu +# Run a specific session +nox -s "test_cartesian-3.10(internal, cpu)" ``` -Check `tox` documentation (`tox --help`) for the complete reference. +Check `nox` documentation (`nox --help`) for the complete reference. +Additionally, `nox` is configured to generate HTML test coverage reports in `tests/_reports/coverage_html/` at the end. --> ## Pull Requests (PRs) and Merge Guidelines @@ -175,27 +175,29 @@ Before submitting a pull request, check that it meets the following criteria: As mentioned above, we use several tools to help us write high-quality code. New tools could be added in the future, especially if they do not add a large overhead to our workflow and they bring extra benefits to keep our codebase in shape. The most important ones which we currently rely on are: -- [ruff][ruff] for style enforcement and code linting. +- [nox][nox] for testing and task automation with different environments. - [pre-commit][pre-commit] for automating the execution of QA tools. - [pytest][pytest] for writing readable tests, extended with: - [Coverage.py][coverage] and [pytest-cov][pytest-cov] for test coverage reports. - [pytest-xdist][pytest-xdist] for running tests in parallel. -- [tox][tox] for testing and task automation with different environments. +- [ruff][ruff] for style enforcement and code linting. - [sphinx][sphinx] for generating documentation, extended with: - [sphinx-autodoc][sphinx-autodoc] and [sphinx-napoleon][sphinx-napoleon] for extracting API documentation from docstrings. - [jupytext][jupytext] for writing new user documentation with code examples. +- [uv][uv] for managing dependencies and environments. [conventional-commits]: https://www.conventionalcommits.org/en/v1.0.0/#summary [coverage]: https://coverage.readthedocs.io/ -[ruff]: https://astral.sh/ruff [jupytext]: https://jupytext.readthedocs.io/ +[nox]: https://nox.thea.codes/en/stable/ [pre-commit]: https://pre-commit.com/ [pytest]: https://docs.pytest.org/ [pytest-cov]: https://pypi.org/project/pytest-cov/ [pytest-xdist]: https://pytest-xdist.readthedocs.io/en/latest/ +[ruff]: https://astral.sh/ruff [sphinx]: https://www.sphinx-doc.org [sphinx-autodoc]: https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html [sphinx-napoleon]: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/index.html -[tox]: https://tox.wiki/en/latest/ +[uv]: https://docs.astral.sh/uv/ diff --git a/LICENSE.txt b/LICENSE similarity index 100% rename from LICENSE.txt rename to LICENSE diff --git a/README.md b/README.md index b782e20f63..f778c4f54b 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,14 @@ ![test-eve](https://github.com/GridTools/gt4py/actions/workflows/test-eve.yml/badge.svg?branch=main) ![qa](https://github.com/GridTools/gt4py/actions/workflows/code-quality.yml/badge.svg?branch=main) +[![uv](https://img.shields.io/badge/-uv-261230.svg?logo=uv)](https://github.com/astral-sh/uv) +[![Nox](https://img.shields.io/badge/%F0%9F%A6%8A-Nox-D85E00.svg)](https://github.com/wntrblm/nox) + # GT4Py: GridTools for Python GT4Py is a Python library for generating high performance implementations of stencil kernels from a high-level definition using regular Python functions. GT4Py is part of the GridTools framework, a set of libraries and utilities to develop performance portable applications in the area of weather and climate modeling. -**NOTE:** The `gt4py.next` subpackage contains a new version of GT4Py which is not compatible with the current _stable_ version defined in `gt4py.cartesian`. The new version is highly experimental, it only works with unstructured meshes and it requires `python >= 3.10`. +**NOTE:** The `gt4py.next` subpackage contains a new version of GT4Py which is not compatible with the current _stable_ version defined in `gt4py.cartesian`. The new version is still experimental. ## 📃 Description @@ -36,18 +39,18 @@ The following backends are supported: ## 🚜 Installation -GT4Py can be installed as a regular Python package using `pip` (or any other PEP-517 frontend). As usual, we strongly recommended to create a new virtual environment to work on this project. +GT4Py can be installed as a regular Python package using [uv](https://docs.astral.sh/uv/), [pip](https://pip.pypa.io/en/stable/) or any other PEP-517 compatible frontend. We strongly recommended to use`uv` to create and manage virtual environments for your own projects. The performance backends also require the [Boost](https://www.boost.org) library, a dependency of [GridTools C++](https://github.com/GridTools/gridtools), which needs to be installed by the user. ## ⚙ Configuration -If GridTools or Boost are not found in the compiler's standard include path, or a custom version is desired, then a couple configuration environment variables will allow the compiler to use them: +To explicitly set the [GridTools-C++](https://gridtools.github.io/gridtools) or [Boost](https://www.boost.org) versions used by the code generation backends, the following environment variables can be used: - `GT_INCLUDE_PATH`: Path to the GridTools installation. - `BOOST_ROOT`: Path to a boost installation. -Other commonly used environment variables are: +Other useful available environment variables are: - `CUDA_ARCH`: Set the compute capability of the NVIDIA GPU if it is not detected automatically by `cupy`. - `CXX`: Set the C++ compiler. @@ -56,67 +59,68 @@ Other commonly used environment variables are: More options and details are available in [`config.py`](https://github.com/GridTools/gt4py/blob/main/src/gt4py/cartesian/config.py). -## 📖 Documentation +## 🛠 Development Instructions -GT4Py uses Sphinx documentation. To build the documentation install the dependencies in `requirements-dev.txt` +Follow the installation instructions below to initialize a development virtual environment containing an _editable_ installation of the GT4Py package. Make sure you read the [CONTRIBUTING.md](CONTRIBUTING.md) and [CODING_GUIDELINES.md](CODING_GUIDELINES.md) documents before you start working on the project. -```bash -pip install -r ./gt4py/requirements-dev.txt -``` +### Development Environment Installation using `uv` -and then build the docs with +GT4Py uses the [`uv`](https://docs.astral.sh/uv/) project manager for the development workflow. `uv` is a versatile tool that consolidates functionality usually distributed across different applications into subcommands. -```bash -cd gt4py/docs/user/cartesian -make html # run 'make help' for a list of targets -``` +- The `uv pip` subcommand provides a _fast_ Python package manager, emulating [`pip`](https://pip.pypa.io/en/stable/). +- The `uv export | lock | sync` subcommands manage dependency versions in a manner similar to the [`pip-tools`](https://pip-tools.readthedocs.io/en/stable/) command suite. +- The `uv init | add | remove | build | publish | ...` subcommands facilitate project development workflows, akin to [`hatch`](https://hatch.pypa.io/latest/). +- The `uv tool` subcommand serves as a runner for Python applications in isolation, similar to [`pipx`](https://pipx.pypa.io/stable/). +- The `uv python` subcommands manage different Python installations and versions, much like [`pyenv`](https://github.com/pyenv/pyenv). -## 🛠 Development Instructions +`uv` can be installed in various ways (see its [installation instructions](https://docs.astral.sh/uv/getting-started/installation/)), with the recommended method being the standalone installer: -Follow the installation instructions below to initialize a development virtual environment containing an _editable_ installation of the GT4Py package. Make sure you read the [CONTRIBUTING.md](CONTRIBUTING.md) and [CODING_GUIDELINES.md](CODING_GUIDELINES.md) documents before you start working on the project. - -### Recommended Installation using `tox` +```bash +$ curl -LsSf https://astral.sh/uv/install.sh | sh +``` -If [tox](https://tox.wiki/en/latest/) is already installed in your system (`tox` is available in PyPI and many other package managers), the easiest way to create a virtual environment ready for development is: +Once `uv` is installed in your system, it is enough to clone this repository and let `uv` handling the installation of the development environment. ```bash # Clone the repository git clone https://github.com/gridtools/gt4py.git cd gt4py -# Create the development environment in any location (usually `.venv`) -# selecting one of the following templates: -# dev-py310 -> base environment -# dev-py310-atlas -> base environment + atlas4py bindings -tox devenv -e dev-py310 .venv +# Let uv create the development environment at `.venv`. +# The `--extra all` option tells uv to install all the optional +# dependencies of gt4py, and thus it is not strictly necessary. +# Note that if no dependency groups are provided as an option, +# uv uses `--group dev` by default so the development dependencies +# are installed. +uv sync --extra all -# Finally, activate the environment +# Finally, activate the virtual environment and start writing code! source .venv/bin/activate ``` -### Manual Installation +The newly created _venv_ is a standard Python virtual environment preconfigured with all necessary runtime and development dependencies. Additionally, the `gt4py` package is installed in editable mode, allowing for seamless development and testing. To install new packages in this environment, use the `uv pip` subcommand which emulates the `pip` interface and is generally much faster than the original `pip` tool (which is also available within the venv although its use is discouraged). -Alternatively, a development environment can be created from scratch installing the frozen dependencies packages : +The `pyproject.toml` file contains both the definition of the `gt4py` Python distribution package and the settings of the development tools used in this project, most notably `uv`, `ruff`, and `mypy`. It also contains _dependency groups_ (see [PEP 735](https://peps.python.org/pep-0735/) for further reference) with the development requirements listed in different groups (`build`, `docs`, `lint`, `test`, `typing`, ...) and collected together in the general `dev` group, which gets installed by default by `uv` as mentioned above. -```bash -# Clone the repository -git clone https://github.com/gridtools/gt4py.git -cd gt4py +### Development Tasks (`dev-tasks.py`) -# Create a (Python 3.10) virtual environment (usually at `.venv`) -python3.10 -m venv .venv +Recurrent development tasks like bumping versions of used development tools or required third party dependencies have been collected as different subcommands in the [`dev-tasks.py`](./dev-tasks.py) script. Read the tool help for a brief description of every task and always use this tool to update the versions and sync the version configuration accross different files (e.g. `pyproject.toml` and `.pre-commit-config.yaml`). -# Activate the virtual environment and update basic packages -source .venv/bin/activate -pip install --upgrade wheel setuptools pip +## 📖 Documentation + +GT4Py uses the Sphinx tool for the documentation. To build browseable HTML documentation, install the required tools provided in the `docs` dependency group: -# Install the required development tools -pip install -r requirements-dev.txt -# Install GT4Py project in editable mode -pip install -e . +```bash +uv install --group docs --extra all # or --group dev +``` + +(Note that most likely these tools are already installed in your development environment, since the `docs` group is included in the `dev` group, which installed by default by `uv sync` if no dependency groups are specified.) + +Once the requirements are already installed, then build the docs using: -# Optionally, install atlas4py bindings directly from the repo -# pip install git+https://github.com/GridTools/atlas4py#egg=atlas4py +```bash +cd gt4py/docs/user/cartesian +make html # run 'make help' for a list of targets ``` ## ⚖️ License diff --git a/ci/base.Dockerfile b/ci/base.Dockerfile index d20d9ca6ef..1ad9aefa03 100644 --- a/ci/base.Dockerfile +++ b/ci/base.Dockerfile @@ -1,5 +1,6 @@ -ARG CUDA_VERSION=12.5.0 -FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 +ARG CUDA_VERSION=12.6.2 +ARG UBUNTU_VERSION=22.04 +FROM docker.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 @@ -22,7 +23,7 @@ RUN apt-get update -qq && apt-get install -qq -y --no-install-recommends \ tk-dev \ libffi-dev \ liblzma-dev \ - python-openssl \ + $( [ "${UBUNTU_VERSION}" = "20.04" ] && echo "python-openssl" || echo "python3-openssl" ) \ libreadline-dev \ git \ rustc \ @@ -55,4 +56,5 @@ RUN pyenv update && \ ENV PATH="/root/.pyenv/shims:${PATH}" ARG CUPY_PACKAGE=cupy-cuda12x -RUN pip install --upgrade pip setuptools wheel tox ${CUPY_PACKAGE}==12.3.0 +ARG CUPY_VERSION=13.3.0 +RUN pip install --upgrade pip setuptools wheel uv nox ${CUPY_PACKAGE}==${CUPY_VERSION} diff --git a/ci/cscs-ci.yml b/ci/cscs-ci.yml index 7fcd65106d..712c2450d6 100644 --- a/ci/cscs-ci.yml +++ b/ci/cscs-ci.yml @@ -4,18 +4,10 @@ include: .py311: &py311 PYVERSION_PREFIX: py311 PYVERSION: 3.11.9 - .py310: &py310 PYVERSION_PREFIX: py310 PYVERSION: 3.10.9 -.py39: &py39 - PYVERSION_PREFIX: py39 - PYVERSION: 3.9.1 - -.py38: &py38 - PYVERSION_PREFIX: py38 - PYVERSION: 3.8.5 stages: - baseimage @@ -42,52 +34,40 @@ stages: DOCKERFILE: ci/base.Dockerfile # change to 'always' if you want to rebuild, even if target tag exists already (if-not-exists is the default, i.e. we could also skip the variable) CSCS_REBUILD_POLICY: if-not-exists - DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "PYVERSION=$PYVERSION", "CI_PROJECT_DIR=$CI_PROJECT_DIR"]' + DOCKER_BUILD_ARGS: '["CUDA_VERSION=$CUDA_VERSION", "CUPY_PACKAGE=$CUPY_PACKAGE", "CUPY_VERSION=$CUPY_VERSION", "UBUNTU_VERSION=$UBUNTU_VERSION", "PYVERSION=$PYVERSION"]' .build_baseimage_x86_64: extends: [.container-builder-cscs-zen2, .build_baseimage] variables: - CUDA_VERSION: 11.2.2 + CUDA_VERSION: 11.4.3 CUPY_PACKAGE: cupy-cuda11x + CUPY_VERSION: 12.3.0 # latest version that supports cuda 11 + UBUNTU_VERSION: 20.04 # 22.04 hangs on daint in some tests for unknown reasons. .build_baseimage_aarch64: extends: [.container-builder-cscs-gh200, .build_baseimage] variables: - CUDA_VERSION: 12.4.1 + CUDA_VERSION: 12.6.2 CUPY_PACKAGE: cupy-cuda12x - # TODO: enable CI job when Todi is back in operational state - when: manual + CUPY_VERSION: 13.3.0 + UBUNTU_VERSION: 22.04 -build_py311_baseimage_x86_64: - extends: .build_baseimage_x86_64 - variables: - <<: *py311 +# build_py311_baseimage_x86_64: +# extends: .build_baseimage_x86_64 +# variables: +# <<: *py311 build_py311_baseimage_aarch64: extends: .build_baseimage_aarch64 variables: <<: *py311 -build_py310_baseimage_x86_64: - extends: .build_baseimage_x86_64 - variables: - <<: *py310 +# build_py310_baseimage_x86_64: +# extends: .build_baseimage_x86_64 +# variables: +# <<: *py310 build_py310_baseimage_aarch64: extends: .build_baseimage_aarch64 variables: <<: *py310 -build_py39_baseimage_x86_64: - extends: .build_baseimage_x86_64 - variables: - <<: *py39 -build_py39_baseimage_aarch64: - extends: .build_baseimage_aarch64 - variables: - <<: *py39 - -build_py38_baseimage_x86_64: - extends: .build_baseimage_x86_64 - variables: - <<: *py38 - .build_image: stage: image @@ -102,81 +82,68 @@ build_py38_baseimage_x86_64: .build_image_aarch64: extends: [.container-builder-cscs-gh200, .build_image] -build_py311_image_x86_64: - extends: .build_image_x86_64 - needs: [build_py311_baseimage_x86_64] - variables: - <<: *py311 +# build_py311_image_x86_64: +# extends: .build_image_x86_64 +# needs: [build_py311_baseimage_x86_64] +# variables: +# <<: *py311 build_py311_image_aarch64: extends: .build_image_aarch64 needs: [build_py311_baseimage_aarch64] variables: <<: *py311 -build_py310_image_x86_64: - extends: .build_image_x86_64 - needs: [build_py310_baseimage_x86_64] - variables: - <<: *py310 +# build_py310_image_x86_64: +# extends: .build_image_x86_64 +# needs: [build_py310_baseimage_x86_64] +# variables: +# <<: *py310 build_py310_image_aarch64: extends: .build_image_aarch64 needs: [build_py310_baseimage_aarch64] variables: <<: *py310 -build_py39_image_x86_64: - extends: .build_image_x86_64 - needs: [build_py39_baseimage_x86_64] - variables: - <<: *py39 -build_py39_image_aarch64: - extends: .build_image_aarch64 - needs: [build_py39_baseimage_aarch64] - variables: - <<: *py39 - -build_py38_image_x86_64: - extends: .build_image_x86_64 - needs: [build_py38_baseimage_x86_64] - variables: - <<: *py38 - .test_helper: stage: test image: $CSCS_REGISTRY_PATH/public/$ARCH/gt4py/gt4py-ci:$CI_COMMIT_SHA-$PYVERSION script: - cd /gt4py.src - - python -c "import cupy" - - tox run -e $SUBPACKAGE-$PYVERSION_PREFIX$VARIANT$SUBVARIANT + - NOX_SESSION_ARGS="${VARIANT:+($VARIANT}${SUBVARIANT:+, $SUBVARIANT}${DETAIL:+, $DETAIL}${VARIANT:+)}" + - nox -e "test_$SUBPACKAGE-${PYVERSION:0:4}$NOX_SESSION_ARGS" variables: CRAY_CUDA_MPS: 1 SLURM_JOB_NUM_NODES: 1 SLURM_TIMELIMIT: 15 NUM_PROCESSES: auto + PYENV_VERSION: $PYVERSION VIRTUALENV_SYSTEM_SITE_PACKAGES: 1 -.test_helper_x86_64: - extends: [.container-runner-daint-gpu, .test_helper] - parallel: - matrix: - - SUBPACKAGE: [cartesian, storage] - VARIANT: [-internal, -dace] - SUBVARIANT: [-cuda11x, -cpu] - - SUBPACKAGE: eve - - SUBPACKAGE: next - VARIANT: [-nomesh, -atlas] - SUBVARIANT: [-cuda11x, -cpu] +# .test_helper_x86_64: +# extends: [.container-runner-daint-gpu, .test_helper] +# parallel: +# matrix: +# - SUBPACKAGE: [cartesian, storage] +# VARIANT: [-internal, -dace] +# SUBVARIANT: [-cuda11x, -cpu] +# - SUBPACKAGE: eve +# - SUBPACKAGE: next +# VARIANT: [-nomesh, -atlas] +# SUBVARIANT: [-cuda11x, -cpu] .test_helper_aarch64: - extends: [.container-runner-todi-gh200, .test_helper] + extends: [.container-runner-daint-gh200, .test_helper] parallel: matrix: - - SUBPACKAGE: [cartesian, storage] - VARIANT: [-internal, -dace] - SUBVARIANT: [-cuda12x, -cpu] + - SUBPACKAGE: [cartesian] + VARIANT: ['internal', 'dace'] + SUBVARIANT: ['cuda12', 'cpu'] - SUBPACKAGE: eve - SUBPACKAGE: next - VARIANT: [-nomesh, -atlas] - SUBVARIANT: [-cuda12x, -cpu] + VARIANT: ['internal', 'dace'] + SUBVARIANT: ['cuda12', 'cpu'] + DETAIL: ['nomesh', 'atlas'] + - SUBPACKAGE: [storage] + VARIANT: ['cuda12', 'cpu'] variables: # Grace-Hopper gpu architecture is not enabled by default in CUDA build CUDAARCHS: "90" @@ -185,41 +152,24 @@ build_py38_image_x86_64: # when high test parallelism is used. NUM_PROCESSES: 16 -test_py311_x86_64: - extends: [.test_helper_x86_64] - needs: [build_py311_image_x86_64] - variables: - <<: *py311 +# test_py311_x86_64: +# extends: [.test_helper_x86_64] +# needs: [build_py311_image_x86_64] +# variables: +# <<: *py311 test_py311_aarch64: extends: [.test_helper_aarch64] needs: [build_py311_image_aarch64] variables: <<: *py311 -test_py310_x86_64: - extends: [.test_helper_x86_64] - needs: [build_py310_image_x86_64] - variables: - <<: *py310 +# test_py310_x86_64: +# extends: [.test_helper_x86_64] +# needs: [build_py310_image_x86_64] +# variables: +# <<: *py310 test_py310_aarch64: extends: [.test_helper_aarch64] needs: [build_py310_image_aarch64] variables: <<: *py310 - -test_py39_x86_64: - extends: [.test_helper_x86_64] - needs: [build_py39_image_x86_64] - variables: - <<: *py39 -test_py39_aarch64: - extends: [.test_helper_aarch64] - needs: [build_py39_image_aarch64] - variables: - <<: *py39 - -test_py38_x86_64: - extends: [.test_helper_x86_64] - needs: [build_py38_image_x86_64] - variables: - <<: *py38 diff --git a/constraints.txt b/constraints.txt deleted file mode 100644 index 5df3f58c60..0000000000 --- a/constraints.txt +++ /dev/null @@ -1,183 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.8 -# by the following command: -# -# "tox run -e requirements-base" -# -aenum==3.1.15 # via dace -alabaster==0.7.13 # via sphinx -annotated-types==0.7.0 # via pydantic -asttokens==2.4.1 # via devtools, stack-data -astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) -attrs==24.2.0 # via gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.16.0 # via sphinx -backcall==0.2.0 # via ipython -black==24.8.0 # via gt4py (pyproject.toml) -boltons==24.0.0 # via gt4py (pyproject.toml) -bracex==2.5 # via wcmatch -build==1.2.2 # via pip-tools -bump-my-version==0.26.0 # via -r requirements-dev.in -cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.5.0 # via tox -certifi==2024.8.30 # via requests -cfgv==3.4.0 # via pre-commit -chardet==5.2.0 # via tox -charset-normalizer==3.3.2 # via requests -clang-format==18.1.8 # via -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.7 # via black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.3 # via gt4py (pyproject.toml) -cogapp==3.4.1 # via -r requirements-dev.in -colorama==0.4.6 # via tox -comm==0.2.2 # via ipykernel -contourpy==1.1.1 # via matplotlib -coverage==7.6.1 # via -r requirements-dev.in, pytest-cov -cycler==0.12.1 # via matplotlib -cytoolz==0.12.3 # via gt4py (pyproject.toml) -dace==0.16.1 # via gt4py (pyproject.toml) -darglint==1.8.1 # via -r requirements-dev.in -debugpy==1.8.5 # via ipykernel -decorator==5.1.1 # via ipython -deepdiff==8.0.1 # via gt4py (pyproject.toml) -devtools==0.12.2 # via gt4py (pyproject.toml) -dill==0.3.8 # via dace -distlib==0.3.8 # via virtualenv -docutils==0.20.1 # via sphinx, sphinx-rtd-theme -eval-type-backport==0.2.0 # via tach -exceptiongroup==1.2.2 # via hypothesis, pytest -execnet==2.1.1 # via pytest-cache, pytest-xdist -executing==2.1.0 # via devtools, stack-data -factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy -faker==28.4.1 # via factory-boy -fastjsonschema==2.20.0 # via nbformat -filelock==3.16.0 # via tox, virtualenv -fonttools==4.53.1 # via matplotlib -fparser==0.1.4 # via dace -frozendict==2.4.4 # via gt4py (pyproject.toml) -gitdb==4.0.11 # via gitpython -gitpython==3.1.43 # via tach -gridtools-cpp==2.3.4 # via gt4py (pyproject.toml) -hypothesis==6.112.0 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.6.0 # via pre-commit -idna==3.8 # via requests -imagesize==1.4.1 # via sphinx -importlib-metadata==8.5.0 # via build, jupyter-client, sphinx -importlib-resources==6.4.5 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib -inflection==0.5.1 # via pytest-factoryboy -iniconfig==2.0.0 # via pytest -ipykernel==6.29.5 # via nbmake -ipython==8.12.3 # via ipykernel -jedi==0.19.1 # via ipython -jinja2==3.1.4 # via dace, gt4py (pyproject.toml), sphinx -jsonschema==4.23.0 # via nbformat -jsonschema-specifications==2023.12.1 # via jsonschema -jupyter-client==8.6.2 # via ipykernel, nbclient -jupyter-core==5.7.2 # via ipykernel, jupyter-client, nbformat -jupytext==1.16.4 # via -r requirements-dev.in -kiwisolver==1.4.7 # via matplotlib -lark==1.2.2 # via gt4py (pyproject.toml) -mako==1.3.5 # via gt4py (pyproject.toml) -markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins, rich -markupsafe==2.1.5 # via jinja2, mako -matplotlib==3.7.5 # via -r requirements-dev.in -matplotlib-inline==0.1.7 # via ipykernel, ipython -mdit-py-plugins==0.4.2 # via jupytext -mdurl==0.1.2 # via markdown-it-py -mpmath==1.3.0 # via sympy -mypy==1.11.2 # via -r requirements-dev.in -mypy-extensions==1.0.0 # via black, mypy -nanobind==2.1.0 # via gt4py (pyproject.toml) -nbclient==0.6.8 # via nbmake -nbformat==5.10.4 # via jupytext, nbclient, nbmake -nbmake==1.5.4 # via -r requirements-dev.in -nest-asyncio==1.6.0 # via ipykernel, nbclient -networkx==3.1 # via dace, tach -ninja==1.11.1.1 # via gt4py (pyproject.toml) -nodeenv==1.9.1 # via pre-commit -numpy==1.24.4 # via contourpy, dace, gt4py (pyproject.toml), matplotlib, scipy -orderly-set==5.2.2 # via deepdiff -packaging==24.1 # via black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox -parso==0.8.4 # via jedi -pathspec==0.12.1 # via black -pexpect==4.9.0 # via ipython -pickleshare==0.7.5 # via ipython -pillow==10.4.0 # via matplotlib -pip-tools==7.4.1 # via -r requirements-dev.in -pipdeptree==2.23.3 # via -r requirements-dev.in -pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.3.2 # via black, jupyter-core, tox, virtualenv -pluggy==1.5.0 # via pytest, tox -ply==3.11 # via dace -pre-commit==3.5.0 # via -r requirements-dev.in -prompt-toolkit==3.0.36 # via ipython, questionary, tach -psutil==6.0.0 # via -r requirements-dev.in, ipykernel, pytest-xdist -ptyprocess==0.7.0 # via pexpect -pure-eval==0.2.3 # via stack-data -pybind11==2.13.5 # via gt4py (pyproject.toml) -pydantic==2.9.1 # via bump-my-version, pydantic-settings, tach -pydantic-core==2.23.3 # via pydantic -pydantic-settings==2.5.2 # via bump-my-version -pydot==2.0.0 # via tach -pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx -pyparsing==3.1.4 # via matplotlib, pydot -pyproject-api==1.7.1 # via tox -pyproject-hooks==1.1.0 # via build, pip-tools -pytest==8.3.3 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist -pytest-cache==1.0 # via -r requirements-dev.in -pytest-cov==5.0.0 # via -r requirements-dev.in -pytest-custom-exit-code==0.3.0 # via -r requirements-dev.in -pytest-factoryboy==2.7.0 # via -r requirements-dev.in -pytest-instafail==0.5.0 # via -r requirements-dev.in -pytest-xdist==3.6.1 # via -r requirements-dev.in -python-dateutil==2.9.0.post0 # via faker, jupyter-client, matplotlib -python-dotenv==1.0.1 # via pydantic-settings -pytz==2024.2 # via babel -pyyaml==6.0.2 # via dace, jupytext, pre-commit, tach -pyzmq==26.2.0 # via ipykernel, jupyter-client -questionary==2.0.1 # via bump-my-version -referencing==0.35.1 # via jsonschema, jsonschema-specifications -requests==2.32.3 # via sphinx -rich==13.8.1 # via bump-my-version, rich-click, tach -rich-click==1.8.3 # via bump-my-version -rpds-py==0.20.0 # via jsonschema, referencing -ruff==0.6.4 # via -r requirements-dev.in -scipy==1.10.1 # via gt4py (pyproject.toml) -setuptools-scm==8.1.0 # via fparser -six==1.16.0 # via asttokens, astunparse, python-dateutil -smmap==5.0.1 # via gitdb -snowballstemmer==2.2.0 # via sphinx -sortedcontainers==2.4.0 # via hypothesis -sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in -sphinxcontrib-applehelp==1.0.4 # via sphinx -sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==2.0.1 # via sphinx -sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme -sphinxcontrib-jsmath==1.0.1 # via sphinx -sphinxcontrib-qthelp==1.0.3 # via sphinx -sphinxcontrib-serializinghtml==1.1.5 # via sphinx -stack-data==0.6.3 # via ipython -stdlib-list==0.10.0 # via tach -sympy==1.12.1 # via dace, gt4py (pyproject.toml) -tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.10.7 # via -r requirements-dev.in -tomli==2.0.1 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tox -tomli-w==1.0.0 # via tach -tomlkit==0.13.2 # via bump-my-version -toolz==0.12.1 # via cytoolz -tornado==6.4.1 # via ipykernel, jupyter-client -tox==4.18.1 # via -r requirements-dev.in -traitlets==5.14.3 # via comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20240106 # via -r requirements-dev.in -typing-extensions==4.12.2 # via annotated-types, black, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm -urllib3==2.2.3 # via requests -virtualenv==20.26.4 # via pre-commit, tox -wcmatch==9.0 # via bump-my-version -wcwidth==0.2.13 # via prompt-toolkit -websockets==13.0.1 # via dace -wheel==0.44.0 # via astunparse, pip-tools -xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.20.1 # via importlib-metadata, importlib-resources - -# The following packages are considered to be unsafe in a requirements file: -pip==24.2 # via pip-tools, pipdeptree -setuptools==74.1.2 # via gt4py (pyproject.toml), pip-tools, setuptools-scm diff --git a/dev-tasks.py b/dev-tasks.py new file mode 100755 index 0000000000..437d107807 --- /dev/null +++ b/dev-tasks.py @@ -0,0 +1,97 @@ +#! /usr/bin/env -S uv run -q +# +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +# +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "typer>=0.12.3", +# ] +# [tool.uv] +# exclude-newer = "2025-01-31T00:00:00Z" +# /// + + +"""Script for running recurrent development tasks.""" + +from __future__ import annotations + +import pathlib +import subprocess +from typing import Final + +import typer + +ROOT_DIR: Final = pathlib.Path(__file__).parent + + +# -- Helpers -- +def gather_versions() -> dict[str, str]: + with subprocess.Popen( + [*"uv export --frozen --no-hashes --project".split(), ROOT_DIR], stdout=subprocess.PIPE + ) as proc: + return dict( + line.split("==") + for line in proc.stdout.read().decode().splitlines() + if not any(line.startswith(c) for c in ["-", "#"]) + ) + + +# -- CLI -- +app = typer.Typer(no_args_is_help=True) + + +@app.command() +def sync_precommit() -> None: + """Sync versions of tools used in pre-commit hooks with the project versions.""" + versions = gather_versions() + # Update ruff version in pre-commit config + subprocess.run( + f"""uvx -q --from 'yamlpath' yaml-set --mustexist --change='repos[.repo%https://github.com/astral-sh/ruff-pre-commit].rev' --value='v{versions["ruff"]}' .pre-commit-config.yaml""", + cwd=ROOT_DIR, + shell=True, + check=True, + ) + + # Update tach version in pre-commit config + subprocess.run( + f"""uvx -q --from 'yamlpath' yaml-set --mustexist --change='repos[.repo%https://github.com/gauge-sh/tach-pre-commit].rev' --value='v{versions["tach"]}' .pre-commit-config.yaml""", + cwd=ROOT_DIR, + shell=True, + check=True, + ) + + # Format yaml files + subprocess.run( + f"uv run --project {ROOT_DIR} pre-commit run pretty-format-yaml --all-files", shell=True + ) + + +@app.command() +def update_precommit() -> None: + """Update and sync pre-commit hooks with the latest compatible versions.""" + subprocess.run(f"uv run --project {ROOT_DIR} pre-commit autoupdate", shell=True) + sync_precommit() + + +@app.command() +def update_versions() -> None: + """Update all project dependencies to their latest compatible versions.""" + subprocess.run("uv lock --upgrade", cwd=ROOT_DIR, shell=True, check=True) + + +@app.command() +def update_all() -> None: + """Update all project dependencies and pre-commit hooks.""" + update_versions() + update_precommit() + + +if __name__ == "__main__": + app() diff --git a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md index a1ee8575d2..1ce83431ee 100644 --- a/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md +++ b/docs/development/ADRs/0008-Mapping_Domain_to_Cpp-Backend.md @@ -20,7 +20,7 @@ The Python embedded execution for Iterator IR keeps track of the current locatio ### Python side -On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` (aka `NeighborTableOffsetProvider` in the current implementation) describes the mapping between location types. +On the Python side, we label dimensions of fields with the location type, e.g. `Edge` or `Vertex`. The domain uses `named_ranges` that uses the same location types to express _where_ to iterate, e.g. `named_range(Vertex, range(0, 100))` is an iteration over the `Vertex` dimension, no order in the domain is required. Additionally, the `Connectivity` describes the mapping between location types. ### C++ side diff --git a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md index 18b9c1f878..69e09c7fae 100644 --- a/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md +++ b/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md @@ -7,6 +7,7 @@ tags: [backend, dace, optimization] - **Status**: valid - **Authors**: Philip Müller (@philip-paul-mueller) - **Created**: 2024-08-27 +- **Updated**: 2025-01-15 In the context of the implementation of the new DaCe fieldview we decided about a particular form of the SDFG. Their main intent is to reduce the complexity of the GT4Py specific transformations. @@ -22,6 +23,12 @@ In the pipeline we distinguish between: The current (GT4Py) pipeline mainly focus on intrastate optimization and relays on DaCe, especially its simplify pass, for interstate optimizations. +## Changelog + +#### 2025-01-15: + +- Made the rules clearer. Specifically, made a restriction on global memory more explicit. + ## Decision The canonical form is defined by several rules that affect different aspects of an SDFG and what a transformation can assume. @@ -38,20 +45,24 @@ The following rules especially affects transformations and how they operate: - [Note 2]: It is allowed for an _intrastate_ transformation to act in a way that allows state fusion by later intrastate transformations. - [Note 3]: The DaCe simplification pass violates this rule, for that reason this pass must always be called on its own, see also rule 2. -2. It is invalid to call the simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed. The only valid way to call _simplify()_ is to call the `gt_simplify()` function provided by GT4Py. +2. It is invalid to call DaCe's simplification pass directly, i.e. the usage of `SDFG.simplify()` is not allowed. The only valid way to call _simplify()_ is to call the `gt_simplify()` function provided by GT4Py. + - [Rationale]: It was observed that some sub-passes in _simplify()_ have a negative impact and that additional passes might be needed in the future. By using a single function later modifications to _simplify()_ are easy. - [Note]: One issue is that the remove redundant array transformation is not able to handle all cases. #### Global Memory -The only restriction we impose on global memory is: +Global memory has to adhere to the same rules as transient memory. +However, the following rule takes precedence, i.e. if this rule is fulfilled then rules 6 to 10 may be violated. + +3. The same global memory is allowed to be used as input and output at the same time, either in the SDFG or in a state, if and only if the output depends _elementwise_ on the input. -3. The same global memory is allowed to be used as input and output at the same time, if and only if the output depends _elementwise_ on the input. - [Rationale 1]: This allows the removal of double buffering, that DaCe may not remove. See also rule 2. - [Rationale 2]: This formulation allows writing expressions such as `a += 1`, with only memory for `a`. Phrased more technically, using global memory for input and output is allowed if and only if the two computations `tmp = computation(global_memory); global_memory = tmp;` and `global_memory = computation(global_memory);` are equivalent. - - [Note]: In the long term this rule will be changed to: Global memory (an array) is either used as input (only read from) or as output (only written to) but never for both. + - [Note 1]: This rule also forbids expressions such as `A[0:10] = A[1:11]`, where `A` refers to a global memory. + - [Note 2]: In the long term this rule will be changed to: Global memory (an array) is either used as input (only read from) or as output (only written to) but never for both. #### State Machine @@ -63,6 +74,7 @@ For the SDFG state machine we assume that: - [Note]: Running _simplify()_ might actually result in the violation of this rule, see note of rule 9. 5. The state graph does not contain any cycles, i.e. the implementation of a for/while loop using states is not allowed, the new loop construct or serial maps must be used in that case. + - [Rationale]: This is a simplification that makes it much simpler to define what "later in the computation" means, as we will never have a cycle. - [Note]: Currently the code generator does not support the `LoopRegion` construct and it is transformed to a state machine. @@ -93,7 +105,7 @@ It is important to note that these rules only have to be met after _simplify()_ 8. No two access nodes in a state can refer to the same array. - [Rationale]: Together with rule 5 this guarantees SSA style. - - [Note]: An SDFG can still be constructed using different access node for the same underlying data; _simplify()_ will combine them. + - [Note]: An SDFG can still be constructed using different access node for the same underlying data in the same state; _simplify()_ will combine them. 9. Every access node that reads from an array (having an outgoing edge) that was not written to in the same state must be a source node. @@ -103,6 +115,7 @@ It is important to note that these rules only have to be met after _simplify()_ Excess interstate transients, that will be kept alive that way, will be removed by later calls to _simplify()_. 10. Every AccessNode within a map scope must refer to a data descriptor whose lifetime must be `dace.dtypes.AllocationLifetime.Scope` and its storage class should either be `dace.dtypes.StorageType.Default` or _preferably_ `dace.dtypes.StorageType.Register`. + - [Rationale 1]: This makes optimizations operating inside maps/kernels simpler, as it guarantees that the AccessNode does not propagate outside. - [Rationale 2]: The storage type avoids the need to dynamically allocate memory inside a kernel. @@ -120,6 +133,7 @@ For maps we assume the following: - [Rationale]: Without this rule it is very hard to tell which map variable does what, this way we can transmit information from GT4Py to DaCe, see also rule 12. 12. Two map ranges, i.e. the pair map/iteration variable and range, can only be fused if they have the same name _and_ cover the same range. + - [Rationale 1]: Because of rule 11, we will only fuse maps that actually makes sense to fuse. - [Rationale 2]: This allows fusing maps without renaming the map variables. - [Note]: This rule might be dropped in the future. diff --git a/docs/development/ADRs/0019-Connectivities.md b/docs/development/ADRs/0019-Connectivities.md new file mode 100644 index 0000000000..76e85e49a6 --- /dev/null +++ b/docs/development/ADRs/0019-Connectivities.md @@ -0,0 +1,55 @@ +--- +tags: [] +--- + +# [Connectivities] + +- **Status**: valid +- **Authors**: Hannes Vogt (@havogt) +- **Created**: 2024-11-08 +- **Updated**: 2024-11-08 + +The representation of Connectivities (neighbor tables, `NeighborTableOffsetProvider`) and their identifier (offset tag, `FieldOffset`, etc.) was extended and modified based on the needs of different parts of the toolchain. Here we outline the ideas for consolidating the different closely-related concepts. + +## History + +In the early days of Iterator IR (ITIR), an `offset` was a literal in the IR. Its meaning was only provided at execution time by a mapping from `offset` tag to an entity that we labelled `OffsetProvider`. We had mainly 2 kinds of `OffsetProvider`: a `Dimension` representing a Cartesian shift and a `NeighborTableOffsetProvider` for unstructured shifts. Since the type of `offset` needs to be known for compilation (strided for Cartesian, lookup-table for unstructured), this prevents a clean interface for ahead-of-time compilation. +For the frontend type-checking we later introduce a `FieldOffset` which contained type information of the mapped dimensions. +For (field-view) embedded we introduced a `ConnectivityField` (now `Connectivity`) which could be generated from the OffsetProvider information. + +These different concepts had overlap but were not 1-to-1 replacements. + +## Decision + +We update and introduce the following concepts + +### Conceptual definitions + +**Connectivity** is a mapping from index (or product of indices) to index. It covers 1-to-1 mappings, e.g. Cartesian shifts, NeighborTables (2D mappings) and dynamic Cartesian shifts. + +**NeighborConnectivity** is a 2D mapping of the N neighbors of a Location A to a Location B. + +**NeighborTable** is a _NeighborConnectivity_ backed by a buffer. + +**ConnectivityType**, **NeighborConnectivityType** contains all information that is needed for compilation. + +### Full definitions + +See `next.common` module + +Note: Currently, the compiled backends supports only `NeighborConnectivity`s that are `NeighborTable`s. We do not yet encode this in the type and postpone discussion to the point where we support alternative implementations (e.g. `StridedNeighborConnectivity`). + +## Which parts of the toolchain use which concept? + +### Embedded + +Embedded execution of field-view supports any kind of `Connectivity`. +Embedded execution of iterator (local) view supports only `NeighborConnectivity`s. + +### IR transformations and compiled backends + +All transformations and code-generation should use `ConnectivityType`, not the `Connectivity` which contains the runtime mapping. + +Note, currently the `global_tmps` pass uses runtime information, therefore this is not strictly enforced. + +The only supported `Connectivity`s in compiled backends (currently) are `NeighborTable`s. diff --git a/docs/development/ADRs/Index.md b/docs/development/ADRs/README.md similarity index 100% rename from docs/development/ADRs/Index.md rename to docs/development/ADRs/README.md diff --git a/docs/development/tools/ci-infrastructure.md b/docs/development/tools/ci-infrastructure.md index 242bea50bd..e76cb7d608 100644 --- a/docs/development/tools/ci-infrastructure.md +++ b/docs/development/tools/ci-infrastructure.md @@ -1,6 +1,6 @@ # CI infrastructure -Any test job that runs on CI is encoded in automation tools like **tox** and **pre-commit** and can be run locally instead. +Any test job that runs on CI is encoded in automation tools like **nox** and **pre-commit** and can be run locally instead. ## GitHub Workflows diff --git a/docs/development/tools/requirements.md b/docs/development/tools/requirements.md deleted file mode 100644 index 010f317493..0000000000 --- a/docs/development/tools/requirements.md +++ /dev/null @@ -1,27 +0,0 @@ -# Requirements - -The specification of required third-party packages is scattered and partially duplicated across several configuration files used by several tools. Keeping all package requirements in sync manually is challenging and error-prone. Therefore, in this project we use [pip-tools](https://pip-tools.readthedocs.io/en/latest/) and the [cog](https://nedbatchelder.com/code/cog/) file generation tool to avoid inconsistencies. - -The following files in this repository contain information about required third-party packages: - -- `pyproject.toml`: GT4Py [package configuration](https://peps.python.org/pep-0621/) used by the build backend (`setuptools`). Install dependencies are specified in the _project.dependencies_ and _project.optional-dependencies_ tables. -- `requirements-dev.in`: [requirements file](https://pip.pypa.io/en/stable/reference/requirements-file-format/) used by **pip**. It contains a list of packages required only for the development of GT4Py. -- `requirements-dev.txt`: requirements file used by **pip**. It contains a completely frozen list of all packages required for installing and developing GT4Py. It is used by **pip** and **tox** to initialize the standard development and testing environments. It is automatically generated automatically from `requirements-dev.in` by **pip-compile**, when running the **tox** environment to update requirements. -- `constraints.txt`: [constraints file](https://pip.pypa.io/en/stable/user_guide/#constraints-files) used by **pip** and **tox** to initialize a subset of the standard development environment making sure that if other packages are installed, transitive dependencies are taken from the frozen package list. It is generated automatically from `requirements-dev.in` using **pip-compile**. -- `min-requirements-test.txt`: requirements file used by **pip**. It contains the minimum list of requirements to run GT4Py tests with the oldest compatible versions of all dependencies. It is generated automatically from `pyproject.toml` using **cog**. -- `min-extra-requirements-test.txt`: requirements file used by **pip**. It contains the minimum list of requirements to run GT4Py tests with the oldest compatible versions of all dependencies, additionally including all GT4Py extras. It is generated automatically from `pyproject.toml` using **cog**. -- `.pre-commit-config.yaml`: **pre-commit** configuration with settings for many linting and formatting tools. Part of its content is generated automatically from `pyproject.toml` using **cog**. - -The expected workflow to update GT4Py requirements is as follows: - -1. For changes in the GT4Py package dependencies, update the relevant table in `pyproject.toml`. When adding new tables to the _project.optional-dependencies_ section, make sure to add the new table as a dependency of the `all-` extra tables when possible. - -2. For changes in the development tools, update the `requirements-dev.in` file. Note that required project packages already appearing in `pyproject.toml` should not be duplicated here. - -3. Run the **tox** _requirements-base_ environment to update all files automatically with **pip-compile** and **cog**. Note that **pip-compile** will most likely update the versions of some unrelated tools if new versions are available in PyPI. - - ```bash - tox r -e requirements-base - ``` - -4. Check that the **mypy** mirror used by **pre-commit** (https://github.com/pre-commit/mirrors-mypy) in `.pre-commit-config.yaml` supports the same version as in `constraints.txt`, and manually update the `rev` version number. diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index 81604c7770..2cb6647519 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -155,8 +155,6 @@ This section approaches the pseudo-laplacian by introducing the required APIs pr - [Using reductions on connected mesh elements](#Using-reductions-on-connected-mesh-elements) - [Implementing the actual pseudo-laplacian](#Implementing-the-pseudo-laplacian) -+++ - #### Defining the mesh and its connectivities The examples related to unstructured meshes use the mesh below. The edges (in blue) and the cells (in red) are numbered with zero-based indices. @@ -237,7 +235,7 @@ E2C = gtx.FieldOffset("E2C", source=CellDim, target=(EdgeDim,E2CDim)) Note that the field offset does not contain the actual connectivity table, that's provided through an _offset provider_: ```{code-cell} ipython3 -E2C_offset_provider = gtx.NeighborTableOffsetProvider(edge_to_cell_table, EdgeDim, CellDim, 2) +E2C_offset_provider = gtx.as_connectivity([EdgeDim, E2CDim], codomain=CellDim, data=edge_to_cell_table, skip_value=-1) ``` The field operator `nearest_cell_to_edge` below shows an example of applying this transform. There is a little twist though: the subscript in `E2C[0]` means that only the value of the first connected cell is taken, the second (if exists) is ignored. @@ -385,7 +383,7 @@ As explained in the section outline, the pseudo-laplacian needs the cell-to-edge C2EDim = gtx.Dimension("C2E", kind=gtx.DimensionKind.LOCAL) C2E = gtx.FieldOffset("C2E", source=EdgeDim, target=(CellDim, C2EDim)) -C2E_offset_provider = gtx.NeighborTableOffsetProvider(cell_to_edge_table, CellDim, EdgeDim, 3) +C2E_offset_provider = gtx.as_connectivity([CellDim, C2EDim], codomain=EdgeDim, data=cell_to_edge_table, skip_value=-1) ``` **Weights of edge differences:** diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 029833cb7d..358f6e8d0d 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -15,7 +15,7 @@ from gt4py import eve ```python cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace( - past_to_itir=gtx.ffront.past_to_itir.past_to_itir_factory(cached=False) + past_to_itir=gtx.ffront.past_to_itir.past_to_gtir_factory(cached=False) ) ``` diff --git a/docs/user/next/advanced/ToolchainWalkthrough.md b/docs/user/next/advanced/ToolchainWalkthrough.md index b82dea1a2f..a5a63cb56c 100644 --- a/docs/user/next/advanced/ToolchainWalkthrough.md +++ b/docs/user/next/advanced/ToolchainWalkthrough.md @@ -247,7 +247,7 @@ pprint.pprint(jit_args) ``` ```python -gtx.program_processors.runners.roundtrip.executor(pitir)(*jit_args.args, **jit_args.kwargs) +gtx.program_processors.runners.roundtrip.Roundtrip()(pitir)(*jit_args.args, **jit_args.kwargs) ``` ```python @@ -290,7 +290,7 @@ assert pitir2 == pitir #### Pass The result to the compile workflow and execute ```python -example_compiled = gtx.program_processors.runners.roundtrip.executor(pitir2) +example_compiled = gtx.program_processors.runners.roundtrip.Roundtrip()(pitir2) ``` ```python diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb index 50349e52b0..b0a1980d0f 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise.ipynb @@ -81,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -113,7 +113,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb index 6baac2b8c0..573ee6a44e 100644 --- a/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/2_divergence_exercise_solution.ipynb @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "5dbd2f62", "metadata": {}, "outputs": [], @@ -118,7 +118,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " divergence_gt4py = gtx.zeros(cell_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb index c8914120d3..2b422b1823 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise.ipynb @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -110,7 +110,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb index 5e940a4b71..85044b989f 100644 --- a/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/3_gradient_exercise_solution.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "84b02762", "metadata": {}, "outputs": [], @@ -123,7 +123,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", "\n", " gradient_gt4py_x = gtx.zeros(cell_domain, allocator=backend)\n", " gradient_gt4py_y = gtx.zeros(cell_domain, allocator=backend)\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb index 4a6b37baf7..dc321f1bdd 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise.ipynb @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -134,7 +134,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb index 065cf02de7..251fe8239a 100644 --- a/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/4_curl_exercise_solution.ipynb @@ -107,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "5b6ffc9e", "metadata": {}, "outputs": [], @@ -139,7 +139,7 @@ " edge_orientation.asnumpy(),\n", " )\n", "\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " curl_gt4py = gtx.zeros(vertex_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb index 832375a86b..30f568de6f 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise.ipynb @@ -228,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -272,10 +272,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb index be846d199d..eaeb8c7b02 100644 --- a/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/5_vector_laplace_exercise_solution.ipynb @@ -249,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -293,10 +293,10 @@ " edge_orientation_cell.asnumpy(),\n", " )\n", "\n", - " c2e_connectivity = gtx.NeighborTableOffsetProvider(c2e_table, C, E, 3, has_skip_values=False)\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", - " e2v_connectivity = gtx.NeighborTableOffsetProvider(e2v_table, E, V, 2, has_skip_values=False)\n", - " e2c_connectivity = gtx.NeighborTableOffsetProvider(e2c_table, E, C, 2, has_skip_values=False)\n", + " c2e_connectivity = gtx.as_connectivity([C, C2EDim], codomain=E, data=c2e_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", + " e2v_connectivity = gtx.as_connectivity([E, E2VDim], codomain=V, data=e2v_table)\n", + " e2c_connectivity = gtx.as_connectivity([E, E2CDim], codomain=C, data=e2c_table)\n", "\n", " laplacian_gt4py = gtx.zeros(edge_domain, allocator=backend)\n", "\n", diff --git a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb index d4bcdb33d5..b278cee26d 100644 --- a/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb +++ b/docs/user/next/workshop/exercises/8_diffusion_exercise_solution.ipynb @@ -118,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "id": "f9cfc097", "metadata": {}, "outputs": [], @@ -156,10 +156,8 @@ " dt,\n", " )\n", "\n", - " e2c2v_connectivity = gtx.NeighborTableOffsetProvider(\n", - " e2c2v_table, E, V, 4, has_skip_values=False\n", - " )\n", - " v2e_connectivity = gtx.NeighborTableOffsetProvider(v2e_table, V, E, 6, has_skip_values=False)\n", + " e2c2v_connectivity = gtx.as_connectivity([E, E2C2VDim], codomain=V, data=e2c2v_table)\n", + " v2e_connectivity = gtx.as_connectivity([V, V2EDim], codomain=E, data=v2e_table)\n", "\n", " diffusion_step(\n", " u,\n", diff --git a/docs/user/next/workshop/slides/slides_2.ipynb b/docs/user/next/workshop/slides/slides_2.ipynb index 1e8925087f..c6967df4b2 100644 --- a/docs/user/next/workshop/slides/slides_2.ipynb +++ b/docs/user/next/workshop/slides/slides_2.ipynb @@ -281,17 +281,19 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "6d30a5e1", "metadata": {}, "outputs": [], "source": [ - "E2C_offset_provider = gtx.NeighborTableOffsetProvider(e2c_table, Edge, Cell, 2)" + "E2C_offset_provider = gtx.as_connectivity(\n", + " [Edge, E2CDim], codomain=Cell, data=e2c_table, skip_value=-1\n", + ")" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "d62f6c98", "metadata": {}, "outputs": [ @@ -311,7 +313,7 @@ " return cell_field(E2C[0]) # 0th index to isolate edge dimension\n", "\n", "\n", - "@gtx.program # uses skip_values, therefore we cannot use embedded\n", + "@gtx.program\n", "def run_nearest_cell_to_edge(\n", " cell_field: gtx.Field[Dims[Cell], float64], edge_field: gtx.Field[Dims[Edge], float64]\n", "):\n", diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt deleted file mode 100644 index 10d70397c6..0000000000 --- a/min-extra-requirements-test.txt +++ /dev/null @@ -1,111 +0,0 @@ -# -# Generated automatically by cog from pyproject.toml and requirements-dev.in -# Run: -# tox r -e requirements-common -# - -##[[[cog -## import copy, sys -## from packaging import requirements as reqs, specifiers as specs -## if sys.version_info >= (3, 11): -## import tomllib -## else: -## import tomli as tomllib -## -## def make_min_req(r: reqs.Requirement) -> reqs.Requirement: -## for s in r.specifier: -## if (ss := str(s)).startswith(">"): -## assert ss.startswith(">="), f"'{r!s}' requires a '>=' constraint" -## min_spec = specs.SpecifierSet(f"=={ss[2:]}") -## break -## min_r = copy.deepcopy(r) -## min_r.specifier = min_spec -## return min_r -## -## project = tomllib.loads(open("pyproject.toml").read()) -## all_cpu_extra = project["project"]["optional-dependencies"]["all-cpu"] -## assert len(all_cpu_extra) == 1 and all_cpu_extra[0].startswith("gt4py[") -## opt_req_versions = { -## reqs.Requirement(r).name: reqs.Requirement(r) -## for e in reqs.Requirement(all_cpu_extra[0]).extras -## for r in project["project"]["optional-dependencies"][e] -## } -## requirements = [ -## reqs.Requirement(rr) -## for r in (project["project"]["dependencies"] + open("requirements-dev.in").readlines()) -## if (rr := (r[: r.find("#")] if "#" in r else r)) -## ] -## processed = set() -## result = [] -## for r in requirements: -## assert r.name not in processed -## processed.add(r.name) -## if not r.specifier: -## assert r.name in opt_req_versions, f"Missing contraints for '{r.name}'" -## r = opt_req_versions[r.name] -## result.append(str(make_min_req(r))) -## for r_name, r in opt_req_versions.items(): -## if r_name not in processed: -## result.append(str(make_min_req(r))) -## print("\n".join(sorted(result))) -##]]] -astunparse==1.6.3; python_version < "3.9" -attrs==21.3 -black==22.3 -boltons==20.1 -bump-my-version==0.12.0 -cached-property==1.5.1 -clang-format==9.0 -click==8.0.0 -cmake==3.22 -cogapp==3.3 -coverage[toml]==5.0 -cytoolz==0.12.1 -dace==0.16.1 -darglint==1.6 -deepdiff==5.6.0 -devtools==0.6 -factory-boy==3.3.0 -frozendict==2.3 -gridtools-cpp==2.3.4 -hypothesis==6.0.0 -importlib-resources==5.0; python_version < "3.9" -jax[cpu]==0.4.18; python_version >= "3.10" -jinja2==3.0.0 -jupytext==1.14 -lark==1.1.2 -mako==1.1 -matplotlib==3.3 -mypy==1.0 -nanobind==1.4.0 -nbmake==1.4.6 -ninja==1.10 -numpy==1.23.3 -packaging==20.0 -pip-tools==6.10 -pipdeptree==2.3 -pre-commit==2.17 -psutil==5.0 -pybind11==2.10.1 -pygments==2.7.3 -pytest-cache==1.0 -pytest-cov==2.8 -pytest-custom-exit-code==0.3.0 -pytest-factoryboy==2.0.3 -pytest-instafail==0.5.0 -pytest-xdist[psutil]==2.4 -pytest==7.0 -ruff==0.2.0 -scipy==1.9.2 -setuptools==65.5.0 -sphinx==4.4 -sphinx_rtd_theme==1.0 -sympy==1.9 -tabulate==0.8.10 -tach==0.10.7 -tomli==2.0.1; python_version < "3.11" -tox==3.2.0 -types-tabulate==0.8.10 -typing-extensions==4.10.0 -xxhash==1.4.4 -##[[[end]]] diff --git a/min-requirements-test.txt b/min-requirements-test.txt deleted file mode 100644 index 01b21dc1f2..0000000000 --- a/min-requirements-test.txt +++ /dev/null @@ -1,104 +0,0 @@ -# -# Generated automatically by cog from pyproject.toml and requirements-dev.in -# Run: -# tox r -e requirements-common -# - -##[[[cog -## import copy, sys -## from packaging import requirements as reqs, specifiers as specs -## if sys.version_info >= (3, 11): -## import tomllib -## else: -## import tomli as tomllib -## -## def make_min_req(r: reqs.Requirement) -> reqs.Requirement: -## for s in r.specifier: -## if (ss := str(s)).startswith(">"): -## assert ss.startswith(">="), f"'{r!s}' requires a '>=' constraint" -## min_spec = specs.SpecifierSet(f"=={ss[2:]}") -## break -## min_r = copy.deepcopy(r) -## min_r.specifier = min_spec -## return min_r -## -## project = tomllib.loads(open("pyproject.toml").read()) -## all_cpu_extra = project["project"]["optional-dependencies"]["all-cpu"] -## assert len(all_cpu_extra) == 1 and all_cpu_extra[0].startswith("gt4py[") -## opt_req_versions = { -## reqs.Requirement(r).name: reqs.Requirement(r) -## for e in reqs.Requirement(all_cpu_extra[0]).extras -## for r in project["project"]["optional-dependencies"][e] -## } -## requirements = [ -## reqs.Requirement(rr) -## for r in (project["project"]["dependencies"] + open("requirements-dev.in").readlines()) -## if (rr := (r[: r.find("#")] if "#" in r else r)) -## ] -## processed = set() -## result = [] -## for r in requirements: -## assert r.name not in processed -## processed.add(r.name) -## if not r.specifier: -## assert r.name in opt_req_versions, f"Missing contraints for '{r.name}'" -## r = opt_req_versions[r.name] -## result.append(str(make_min_req(r))) -## print("\n".join(sorted(result))) -##]]] -astunparse==1.6.3; python_version < "3.9" -attrs==21.3 -black==22.3 -boltons==20.1 -bump-my-version==0.12.0 -cached-property==1.5.1 -clang-format==9.0 -click==8.0.0 -cmake==3.22 -cogapp==3.3 -coverage[toml]==5.0 -cytoolz==0.12.1 -darglint==1.6 -deepdiff==5.6.0 -devtools==0.6 -factory-boy==3.3.0 -frozendict==2.3 -gridtools-cpp==2.3.4 -hypothesis==6.0.0 -importlib-resources==5.0; python_version < "3.9" -jinja2==3.0.0 -jupytext==1.14 -lark==1.1.2 -mako==1.1 -matplotlib==3.3 -mypy==1.0 -nanobind==1.4.0 -nbmake==1.4.6 -ninja==1.10 -numpy==1.23.3 -packaging==20.0 -pip-tools==6.10 -pipdeptree==2.3 -pre-commit==2.17 -psutil==5.0 -pybind11==2.10.1 -pygments==2.7.3 -pytest-cache==1.0 -pytest-cov==2.8 -pytest-custom-exit-code==0.3.0 -pytest-factoryboy==2.0.3 -pytest-instafail==0.5.0 -pytest-xdist[psutil]==2.4 -pytest==7.0 -ruff==0.2.0 -setuptools==65.5.0 -sphinx==4.4 -sphinx_rtd_theme==1.0 -tabulate==0.8.10 -tach==0.10.7 -tomli==2.0.1; python_version < "3.11" -tox==3.2.0 -types-tabulate==0.8.10 -typing-extensions==4.10.0 -xxhash==1.4.4 -##[[[end]]] diff --git a/noxfile.py b/noxfile.py new file mode 100644 index 0000000000..3aad565837 --- /dev/null +++ b/noxfile.py @@ -0,0 +1,255 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import os +import pathlib +import types +from collections.abc import Sequence +from typing import Final, Literal, TypeAlias + +import nox + +#: This should just be `pytest.ExitCode.NO_TESTS_COLLECTED` but `pytest` +#: is not guaranteed to be available in the venv where `nox` is running. +NO_TESTS_COLLECTED_EXIT_CODE: Final = 5 + +# -- nox configuration -- +nox.options.default_venv_backend = "uv" +nox.options.sessions = [ + "test_cartesian-3.10(internal, cpu)", + "test_cartesian-3.10(dace, cpu)", + "test_cartesian-3.11(internal, cpu)", + "test_cartesian-3.11(dace, cpu)", + "test_eve-3.10", + "test_eve-3.11", + "test_next-3.10(internal, cpu, nomesh)", + "test_next-3.10(dace, cpu, nomesh)", + "test_next-3.11(internal, cpu, nomesh)", + "test_next-3.11(dace, cpu, nomesh)", + "test_storage-3.10(cpu)", + "test_storage-3.11(cpu)", +] + +# -- Parameter sets -- +DeviceOption: TypeAlias = Literal["cpu", "cuda11", "cuda12", "rocm4_3", "rocm5_0"] +DeviceNoxParam: Final = types.SimpleNamespace( + **{device: nox.param(device, id=device, tags=[device]) for device in DeviceOption.__args__} +) +DeviceTestSettings: Final[dict[str, dict[str, Sequence]]] = { + "cpu": {"extras": [], "markers": ["not requires_gpu"]}, + **{ + device: {"extras": [device], "markers": ["requires_gpu"]} + for device in ["cuda11", "cuda12", "rocm4_3", "rocm5_0"] + }, +} + +CodeGenOption: TypeAlias = Literal["internal", "dace"] +CodeGenNoxParam: Final = types.SimpleNamespace( + **{ + codegen: nox.param(codegen, id=codegen, tags=[codegen]) + for codegen in CodeGenOption.__args__ + } +) +CodeGenTestSettings: Final[dict[str, dict[str, Sequence]]] = { + "internal": {"extras": [], "markers": ["not requires_dace"]}, + "dace": {"extras": ["dace"], "markers": ["requires_dace"]}, +} +# Use dace-next for GT4Py-next, to install a different dace version than in cartesian +CodeGenNextTestSettings = CodeGenTestSettings | { + "dace": {"extras": ["dace-next"], "markers": ["requires_dace"]}, +} + + +# -- nox sessions -- +@nox.session(python=["3.10", "3.11"], tags=["cartesian"]) +@nox.parametrize("device", [DeviceNoxParam.cpu, DeviceNoxParam.cuda12]) +@nox.parametrize("codegen", [CodeGenNoxParam.internal, CodeGenNoxParam.dace]) +def test_cartesian( + session: nox.Session, + codegen: CodeGenOption, + device: DeviceOption, +) -> None: + """Run selected 'gt4py.cartesian' tests.""" + + codegen_settings = CodeGenTestSettings[codegen] + device_settings = DeviceTestSettings[device] + + _install_session_venv( + session, + extras=["performance", "testing", *codegen_settings["extras"], *device_settings["extras"]], + groups=["test"], + ) + + num_processes = os.environ.get("NUM_PROCESSES", "auto") + markers = " and ".join(codegen_settings["markers"] + device_settings["markers"]) + + session.run( + *f"pytest --cache-clear -sv -n {num_processes} --dist loadgroup".split(), + *("-m", f"{markers}"), + str(pathlib.Path("tests") / "cartesian_tests"), + *session.posargs, + ) + session.run( + *"pytest --doctest-modules --doctest-ignore-import-errors -sv".split(), + str(pathlib.Path("src") / "gt4py" / "cartesian"), + ) + + +@nox.session(python=["3.10", "3.11"]) +def test_examples(session: nox.Session) -> None: + """Run and test documentation workflows.""" + + _install_session_venv(session, extras=["testing"], groups=["docs", "test"]) + + session.run(*"jupytext docs/user/next/QuickstartGuide.md --to .ipynb".split()) + session.run(*"jupytext docs/user/next/advanced/*.md --to .ipynb".split()) + + num_processes = os.environ.get("NUM_PROCESSES", "auto") + for notebook, extra_args in [ + ("docs/user/next/workshop/slides", None), + ("docs/user/next/workshop/exercises", ["-k", "solutions"]), + ("docs/user/next/QuickstartGuide.ipynb", None), + ("docs/user/next/advanced", None), + ("examples", (None)), + ]: + session.run( + *f"pytest --nbmake {notebook} -sv -n {num_processes}".split(), + *(extra_args or []), + ) + + +@nox.session(python=["3.10", "3.11"], tags=["cartesian", "next", "cpu"]) +def test_eve(session: nox.Session) -> None: + """Run 'gt4py.eve' tests.""" + + _install_session_venv(session, groups=["test"]) + + num_processes = os.environ.get("NUM_PROCESSES", "auto") + + session.run( + *f"pytest --cache-clear -sv -n {num_processes}".split(), + str(pathlib.Path("tests") / "eve_tests"), + *session.posargs, + ) + session.run( + *"pytest --doctest-modules -sv".split(), + str(pathlib.Path("src") / "gt4py" / "eve"), + ) + + +@nox.session(python=["3.10", "3.11"], tags=["next"]) +@nox.parametrize( + "meshlib", + [ + nox.param("nomesh", id="nomesh", tags=["nomesh"]), + nox.param("atlas", id="atlas", tags=["atlas"]), + ], +) +@nox.parametrize("device", [DeviceNoxParam.cpu, DeviceNoxParam.cuda12]) +@nox.parametrize("codegen", [CodeGenNoxParam.internal, CodeGenNoxParam.dace]) +def test_next( + session: nox.Session, + codegen: CodeGenOption, + device: DeviceOption, + meshlib: Literal["nomesh", "atlas"], +) -> None: + """Run selected 'gt4py.next' tests.""" + + codegen_settings = CodeGenNextTestSettings[codegen] + device_settings = DeviceTestSettings[device] + groups: list[str] = ["test"] + mesh_markers: list[str] = [] + + match meshlib: + case "nomesh": + mesh_markers.append("not requires_atlas") + case "atlas": + mesh_markers.append("requires_atlas") + groups.append("frameworks") + + _install_session_venv( + session, + extras=["performance", "testing", *codegen_settings["extras"], *device_settings["extras"]], + groups=groups, + ) + + num_processes = os.environ.get("NUM_PROCESSES", "auto") + markers = " and ".join(codegen_settings["markers"] + device_settings["markers"] + mesh_markers) + + session.run( + *f"pytest --cache-clear -sv -n {num_processes}".split(), + *("-m", f"{markers}"), + str(pathlib.Path("tests") / "next_tests"), + *session.posargs, + success_codes=[0, NO_TESTS_COLLECTED_EXIT_CODE], + ) + session.run( + *"pytest --doctest-modules --doctest-ignore-import-errors -sv".split(), + str(pathlib.Path("src") / "gt4py" / "next"), + success_codes=[0, NO_TESTS_COLLECTED_EXIT_CODE], + ) + + +@nox.session(python=["3.10", "3.11"], tags=["cartesian", "next"]) +@nox.parametrize("device", [DeviceNoxParam.cpu, DeviceNoxParam.cuda12]) +def test_storage( + session: nox.Session, + device: DeviceOption, +) -> None: + """Run selected 'gt4py.storage' tests.""" + + device_settings = DeviceTestSettings[device] + + _install_session_venv( + session, extras=["performance", "testing", *device_settings["extras"]], groups=["test"] + ) + + num_processes = os.environ.get("NUM_PROCESSES", "auto") + markers = " and ".join(device_settings["markers"]) + + session.run( + *f"pytest --cache-clear -sv -n {num_processes}".split(), + *("-m", f"{markers}"), + str(pathlib.Path("tests") / "storage_tests"), + *session.posargs, + ) + session.run( + *"pytest --doctest-modules -sv".split(), + str(pathlib.Path("src") / "gt4py" / "storage"), + success_codes=[0, NO_TESTS_COLLECTED_EXIT_CODE], + ) + + +# -- utils -- +def _install_session_venv( + session: nox.Session, + *args: str | Sequence[str], + extras: Sequence[str] = (), + groups: Sequence[str] = (), +) -> None: + """Install session packages using uv.""" + session.run_install( + "uv", + "sync", + *("--python", session.python), + "--no-dev", + *(f"--extra={e}" for e in extras), + *(f"--group={g}" for g in groups), + env={key: value for key, value in os.environ.items()} + | {"UV_PROJECT_ENVIRONMENT": session.virtualenv.location}, + ) + for item in args: + session.run_install( + "uv", + "pip", + "install", + *((item,) if isinstance(item, str) else item), + env={"UV_PROJECT_ENVIRONMENT": session.virtualenv.location}, + ) diff --git a/pyproject.toml b/pyproject.toml index 1bb05c11c5..e182b23878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,68 @@ +# -- Build system requirements (PEP 518) -- + [build-system] build-backend = 'setuptools.build_meta' -requires = ['setuptools>=65.5.0', 'wheel>=0.33.6', 'cython>=0.29.13'] +requires = ['setuptools>=70.0.0', 'wheel>=0.33.6', 'cython>=3.0.0'] + +# -- Dependency groups -- +[dependency-groups] +build = [ + 'bump-my-version>=0.16.0', + 'cython>=3.0.0', + 'pip>=22.1.1', + 'setuptools>=70.0.0', + 'wheel>=0.33.6' +] +dev = [ + {include-group = 'build'}, + {include-group = 'docs'}, + {include-group = 'frameworks'}, + {include-group = 'lint'}, + {include-group = 'test'}, + {include-group = 'typing'} +] +docs = [ + 'esbonio>=0.16.0', + 'jupytext>=1.14', + 'matplotlib>=3.8.4', + 'myst-parser>=4.0.0', + 'pygments>=2.7.3', + 'sphinx>=7.3.7', + 'sphinx-rtd-theme>=3.0.1', + 'sphinx-toolbox>=3.8.1' +] +frameworks = [ + # 3rd party frameworks with some interoperability with gt4py + 'atlas4py>=0.35' +] +lint = [ + 'pre-commit>=4.0.1', + 'ruff>=0.8.0', + 'tach>=0.16.0' +] +test = [ + 'coverage[toml]>=7.5.0', + 'hypothesis>=6.0.0', + 'nbmake>=1.4.6', + 'nox>=2024.10.9', + 'pytest>=8.0.1', + 'pytest-benchmark>=5.0.0', + 'pytest-cache>=1.0', + 'pytest-cov>=5.0.0', + 'pytest-factoryboy>=2.6.1', + 'pytest-instafail>=0.5.0', + 'pytest-xdist[psutil]>=3.5.0' +] +typing = [ + 'mypy[faster-cache]>=1.13.0', + 'types-tabulate>=0.8.10', + 'types-PyYAML>=6.0.10', + 'types-decorator>=5.1.8', + 'types-docutils>=0.21.0', + 'types-pytz>=2024.2.0' +] -# ---- Project description ---- -# -- Standard options (PEP 621) -- +# -- Standard project description options (PEP 621) -- [project] authors = [{name = 'ETH Zurich', email = 'gridtools@cscs.ch'}] classifiers = [ @@ -14,8 +73,6 @@ classifiers = [ 'License :: OSI Approved :: BSD License', 'Operating System :: POSIX', 'Programming Language :: Python', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: Implementation :: CPython', @@ -24,7 +81,6 @@ classifiers = [ 'Topic :: Scientific/Engineering :: Physics' ] dependencies = [ - "astunparse>=1.6.3;python_version<'3.9'", 'attrs>=21.3', 'black>=22.3', 'boltons>=20.1', @@ -34,10 +90,11 @@ dependencies = [ 'cytoolz>=0.12.1', 'deepdiff>=5.6.0', 'devtools>=0.6', + 'diskcache>=5.6.3', 'factory-boy>=3.3.0', + 'filelock>=3.16.1', 'frozendict>=2.3', - 'gridtools-cpp>=2.3.4,==2.*', - "importlib-resources>=5.0;python_version<'3.9'", + 'gridtools-cpp>=2.3.8,==2.*', 'jinja2>=3.0.0', 'lark>=1.1.2', 'mako>=1.1', @@ -46,9 +103,10 @@ dependencies = [ 'numpy>=1.23.3', 'packaging>=20.0', 'pybind11>=2.10.1', - 'setuptools>=65.5.0', + 'setuptools>=70.0.0', 'tabulate>=0.8.10', - 'typing-extensions>=4.10.0', + 'toolz>=0.12.1', + 'typing-extensions>=4.11.0', 'xxhash>=1.4.4,<3.1.0' ] description = 'Python library for generating high-performance implementations of stencil kernels for weather and climate modeling from a domain-specific language (DSL)' @@ -62,27 +120,26 @@ keywords = [ 'portable', 'hpc' ] -license = {file = 'LICENSE'} +license = {text = 'BSD-3 License'} # TODO: waiting for PEP 639 being implemented by setuptools (https://github.com/codecov/codecov-cli/issues/605) name = 'gt4py' readme = 'README.md' -requires-python = '>=3.8' +requires-python = '>=3.10, <3.12' [project.optional-dependencies] -# Bundles -all-cpu = ['gt4py[dace,formatting,jax-cpu,performance,testing]'] -all-cuda11 = ['gt4py[cuda11,dace,formatting,jax-cuda11,performance,testing]'] -all-cuda12 = ['gt4py[cuda12,dace,formatting,jax-cuda12,performance,testing]'] -# Other extras +# bundles +all = ['gt4py[dace,formatting,jax,performance,testing]'] +# device-specific extras cuda11 = ['cupy-cuda11x>=12.0'] cuda12 = ['cupy-cuda12x>=12.0'] -dace = ['dace>=0.16.1', 'sympy>=1.9,<1.13'] # see https://github.com/spcl/dace/pull/1620 +# features +dace = ['dace>=1.0.1,<1.1.0'] # v1.x will contain breaking changes, see https://github.com/spcl/dace/milestone/4 +dace-next = ['dace'] # pull dace latest version from the git repository formatting = ['clang-format>=9.0'] -gpu = ['cupy>=12.0'] -jax-cpu = ['jax[cpu]>=0.4.18; python_version>="3.10"'] -jax-cuda11 = ['jax[cuda11_pip]>=0.4.18; python_version>="3.10"'] -jax-cuda12 = ['jax[cuda12_pip]>=0.4.18; python_version>="3.10"'] +jax = ['jax>=0.4.26'] +jax-cuda12 = ['jax[cuda12_local]>=0.4.26', 'gt4py[cuda12]'] performance = ['scipy>=1.9.2'] -rocm-43 = ['cupy-rocm-4-3'] +rocm4_3 = ['cupy-rocm-4-3>=13.3.0'] +rocm5_0 = ['cupy-rocm-5-0>=13.3.0'] testing = ['hypothesis>=6.0.0', 'pytest>=7.0'] [project.scripts] @@ -91,7 +148,7 @@ gtpyc = 'gt4py.cartesian.cli:gtpyc' [project.urls] Documentation = 'https://gridtools.github.io/gt4py' Homepage = 'https://gridtools.github.io/' -Source = 'https://github.com/GridTools/gt4py' +Repository = 'https://github.com/GridTools/gt4py' # ---- Other tools ---- # -- bump-my-version -- @@ -99,7 +156,7 @@ Source = 'https://github.com/GridTools/gt4py' allow_dirty = false commit = false commit_args = '' -current_version = "1.0.4" +current_version = '1.0.4' ignore_missing_version = false message = 'Bump version: {current_version} → {new_version}' parse = '(?P\d+)\.(?P\d+)(\.(?P\d+))?' @@ -113,7 +170,7 @@ tag_message = 'Bump version: {current_version} → {new_version}' tag_name = 'v{new_version}' [[tool.bumpversion.files]] -filename = "src/gt4py/__about__.py" +filename = 'src/gt4py/__about__.py' # -- coverage -- [tool.coverage] @@ -122,7 +179,7 @@ filename = "src/gt4py/__about__.py" directory = 'tests/_reports/coverage_html' [tool.coverage.paths] -source = ['src/', '.tox/py*/lib/python3.*/site-packages/'] +source = ['src/', '.nox/py*/lib/python3.*/site-packages/'] [tool.coverage.report] # Regexes for lines to exclude from consideration @@ -174,7 +231,6 @@ allow_incomplete_defs = true allow_untyped_defs = true follow_imports = 'silent' module = 'gt4py.cartesian.*' -warn_unused_ignores = false [[tool.mypy.overrides]] ignore_errors = true @@ -238,20 +294,24 @@ markers = [ 'requires_atlas: tests that require `atlas4py` bindings package', 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', - 'starts_from_gtir_program: tests that require backend to start lowering from GTIR program', 'uses_applied_shifts: tests that require backend support for applied-shifts', + 'uses_can_deref: tests that require backend support for can_deref builtin function', + 'uses_composite_shifts: tests that use composite shifts in unstructured domain', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_floordiv: tests that require backend support for floor division', 'uses_if_stmts: tests that require backend support for if-statements', 'uses_index_fields: tests that require backend support for index fields', - 'uses_lift_expressions: tests that require backend support for lift expressions', + 'uses_ir_if_stmts', + 'uses_lift: tests that require backend support for lift builtin function', 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', 'uses_origin: tests that require backend support for domain origin', - 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', + 'uses_reduce_with_lambda: tests that use lambdas as reduce functions', 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', + 'uses_scalar_in_domain_and_fo', 'uses_scan: tests that uses scan', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', + 'uses_scan_in_stencil: tests that require backend support for scan in stencil', 'uses_scan_without_field_args: tests that require calls to scan that do not have any fields as arguments', 'uses_scan_nested: tests that use nested scans', 'uses_scan_requiring_projector: tests need a projector implementation in gtfn', @@ -259,6 +319,8 @@ markers = [ 'uses_sparse_fields_as_output: tests that require backend support for writing sparse fields', 'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset', 'uses_tuple_args: tests that require backend support for tuple arguments', + 'uses_tuple_args_with_different_but_promotable_dims: test that requires backend support for tuple args with different but promotable dims', + 'uses_tuple_iterator: tests that require backend support to deref tuple iterators', 'uses_tuple_returns: tests that require backend support for tuple results', 'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields', 'uses_cartesian_shift: tests that use a Cartesian connectivity', @@ -269,6 +331,7 @@ markers = [ ] norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*'] testpaths = 'tests' +xfail_strict = true # -- ruff -- [tool.ruff] @@ -276,7 +339,7 @@ line-length = 100 # It should be the same as in `tool.black.line-length` above respect-gitignore = true show-fixes = true # show-source = true -target-version = 'py38' +target-version = 'py310' [tool.ruff.format] docstring-code-format = true @@ -293,12 +356,16 @@ docstring-code-format = true # NPY: NumPy-specific rules # RUF: Ruff-specific rules ignore = [ - 'E501' # [line-too-long] + 'E501', # [line-too-long] + 'B905' # [zip-without-explicit-strict] # TODO(egparedes): Reevaluate this rule ] select = ['E', 'F', 'I', 'B', 'A', 'T10', 'ERA', 'NPY', 'RUF'] typing-modules = ['gt4py.eve.extended_typing'] unfixable = [] +[tool.ruff.lint.flake8-builtins] +builtins-allowed-modules = ['builtins'] + [tool.ruff.lint.isort] combine-as-imports = true # force-wrap-aliases = true @@ -364,3 +431,31 @@ version = {attr = 'gt4py.__about__.__version__'} [tool.setuptools.packages] find = {namespaces = false, where = ['src']} + +# -- uv: packages & workspace -- +[tool.uv] +conflicts = [ + [ + {extra = 'cuda11'}, + {extra = 'jax-cuda12'}, + {extra = 'rocm4_3'}, + {extra = 'rocm5_0'} + ], + [ + {extra = 'dace'}, + {extra = 'dace-next'} + ], + [ + {extra = 'all'}, + {extra = 'dace-next'} + ] +] + +[[tool.uv.index]] +explicit = true +name = 'test.pypi' +url = 'https://test.pypi.org/simple/' + +[tool.uv.sources] +atlas4py = {index = "test.pypi"} +dace = {git = "https://github.com/spcl/dace", branch = "main", extra = "dace-next"} diff --git a/requirements-dev.in b/requirements-dev.in deleted file mode 100644 index 1697051d25..0000000000 --- a/requirements-dev.in +++ /dev/null @@ -1,36 +0,0 @@ -# -# Constraints should specify the minimum required version (>=). -# -# Packages also required in the extra `gt4py['all-cpu']` configuration -# should be added here without constraints, so they will use the -# constraints defined in `pyproject.toml`. -# -bump-my-version>=0.12.0 -clang-format>=9.0 -cogapp>=3.3 -coverage[toml]>=5.0 -darglint>=1.6 -hypothesis # constraints in gt4py['testing'] -jupytext>=1.14 -mypy>=1.0 -matplotlib>=3.3 -nbmake>=1.4.6 -pipdeptree>=2.3 -pip-tools>=6.10 -pre-commit>=2.17 -psutil>=5.0 -pygments>=2.7.3 -pytest # constraints in gt4py['testing'] -pytest-cache>=1.0 -pytest-cov>=2.8 -pytest-custom-exit-code>=0.3.0 -pytest-factoryboy>=2.0.3 -pytest-xdist[psutil]>=2.4 -pytest-instafail>=0.5.0 -ruff>=0.2.0 -sphinx>=4.4 -sphinx_rtd_theme>=1.0 -tach>=0.10.7 -tomli>=2.0.1;python_version<'3.11' -tox>=3.2.0 -types-tabulate>=0.8.10 diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 0b7baec1bc..0000000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,182 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.8 -# by the following command: -# -# "tox run -e requirements-base" -# -aenum==3.1.15 # via -c constraints.txt, dace -alabaster==0.7.13 # via -c constraints.txt, sphinx -annotated-types==0.7.0 # via -c constraints.txt, pydantic -asttokens==2.4.1 # via -c constraints.txt, devtools, stack-data -astunparse==1.6.3 ; python_version < "3.9" # via -c constraints.txt, dace, gt4py (pyproject.toml) -attrs==24.2.0 # via -c constraints.txt, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.16.0 # via -c constraints.txt, sphinx -backcall==0.2.0 # via -c constraints.txt, ipython -black==24.8.0 # via -c constraints.txt, gt4py (pyproject.toml) -boltons==24.0.0 # via -c constraints.txt, gt4py (pyproject.toml) -bracex==2.5 # via -c constraints.txt, wcmatch -build==1.2.2 # via -c constraints.txt, pip-tools -bump-my-version==0.26.0 # via -c constraints.txt, -r requirements-dev.in -cached-property==1.5.2 # via -c constraints.txt, gt4py (pyproject.toml) -cachetools==5.5.0 # via -c constraints.txt, tox -certifi==2024.8.30 # via -c constraints.txt, requests -cfgv==3.4.0 # via -c constraints.txt, pre-commit -chardet==5.2.0 # via -c constraints.txt, tox -charset-normalizer==3.3.2 # via -c constraints.txt, requests -clang-format==18.1.8 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) -click==8.1.7 # via -c constraints.txt, black, bump-my-version, gt4py (pyproject.toml), pip-tools, rich-click -cmake==3.30.3 # via -c constraints.txt, gt4py (pyproject.toml) -cogapp==3.4.1 # via -c constraints.txt, -r requirements-dev.in -colorama==0.4.6 # via -c constraints.txt, tox -comm==0.2.2 # via -c constraints.txt, ipykernel -contourpy==1.1.1 # via -c constraints.txt, matplotlib -coverage[toml]==7.6.1 # via -c constraints.txt, -r requirements-dev.in, pytest-cov -cycler==0.12.1 # via -c constraints.txt, matplotlib -cytoolz==0.12.3 # via -c constraints.txt, gt4py (pyproject.toml) -dace==0.16.1 # via -c constraints.txt, gt4py (pyproject.toml) -darglint==1.8.1 # via -c constraints.txt, -r requirements-dev.in -debugpy==1.8.5 # via -c constraints.txt, ipykernel -decorator==5.1.1 # via -c constraints.txt, ipython -deepdiff==8.0.1 # via -c constraints.txt, gt4py (pyproject.toml) -devtools==0.12.2 # via -c constraints.txt, gt4py (pyproject.toml) -dill==0.3.8 # via -c constraints.txt, dace -distlib==0.3.8 # via -c constraints.txt, virtualenv -docutils==0.20.1 # via -c constraints.txt, sphinx, sphinx-rtd-theme -eval-type-backport==0.2.0 # via -c constraints.txt, tach -exceptiongroup==1.2.2 # via -c constraints.txt, hypothesis, pytest -execnet==2.1.1 # via -c constraints.txt, pytest-cache, pytest-xdist -executing==2.1.0 # via -c constraints.txt, devtools, stack-data -factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy -faker==28.4.1 # via -c constraints.txt, factory-boy -fastjsonschema==2.20.0 # via -c constraints.txt, nbformat -filelock==3.16.0 # via -c constraints.txt, tox, virtualenv -fonttools==4.53.1 # via -c constraints.txt, matplotlib -fparser==0.1.4 # via -c constraints.txt, dace -frozendict==2.4.4 # via -c constraints.txt, gt4py (pyproject.toml) -gitdb==4.0.11 # via -c constraints.txt, gitpython -gitpython==3.1.43 # via -c constraints.txt, tach -gridtools-cpp==2.3.4 # via -c constraints.txt, gt4py (pyproject.toml) -hypothesis==6.112.0 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.6.0 # via -c constraints.txt, pre-commit -idna==3.8 # via -c constraints.txt, requests -imagesize==1.4.1 # via -c constraints.txt, sphinx -importlib-metadata==8.5.0 # via -c constraints.txt, build, jupyter-client, sphinx -importlib-resources==6.4.5 ; python_version < "3.9" # via -c constraints.txt, gt4py (pyproject.toml), jsonschema, jsonschema-specifications, matplotlib -inflection==0.5.1 # via -c constraints.txt, pytest-factoryboy -iniconfig==2.0.0 # via -c constraints.txt, pytest -ipykernel==6.29.5 # via -c constraints.txt, nbmake -ipython==8.12.3 # via -c constraints.txt, ipykernel -jedi==0.19.1 # via -c constraints.txt, ipython -jinja2==3.1.4 # via -c constraints.txt, dace, gt4py (pyproject.toml), sphinx -jsonschema==4.23.0 # via -c constraints.txt, nbformat -jsonschema-specifications==2023.12.1 # via -c constraints.txt, jsonschema -jupyter-client==8.6.2 # via -c constraints.txt, ipykernel, nbclient -jupyter-core==5.7.2 # via -c constraints.txt, ipykernel, jupyter-client, nbformat -jupytext==1.16.4 # via -c constraints.txt, -r requirements-dev.in -kiwisolver==1.4.7 # via -c constraints.txt, matplotlib -lark==1.2.2 # via -c constraints.txt, gt4py (pyproject.toml) -mako==1.3.5 # via -c constraints.txt, gt4py (pyproject.toml) -markdown-it-py==3.0.0 # via -c constraints.txt, jupytext, mdit-py-plugins, rich -markupsafe==2.1.5 # via -c constraints.txt, jinja2, mako -matplotlib==3.7.5 # via -c constraints.txt, -r requirements-dev.in -matplotlib-inline==0.1.7 # via -c constraints.txt, ipykernel, ipython -mdit-py-plugins==0.4.2 # via -c constraints.txt, jupytext -mdurl==0.1.2 # via -c constraints.txt, markdown-it-py -mpmath==1.3.0 # via -c constraints.txt, sympy -mypy==1.11.2 # via -c constraints.txt, -r requirements-dev.in -mypy-extensions==1.0.0 # via -c constraints.txt, black, mypy -nanobind==2.1.0 # via -c constraints.txt, gt4py (pyproject.toml) -nbclient==0.6.8 # via -c constraints.txt, nbmake -nbformat==5.10.4 # via -c constraints.txt, jupytext, nbclient, nbmake -nbmake==1.5.4 # via -c constraints.txt, -r requirements-dev.in -nest-asyncio==1.6.0 # via -c constraints.txt, ipykernel, nbclient -networkx==3.1 # via -c constraints.txt, dace, tach -ninja==1.11.1.1 # via -c constraints.txt, gt4py (pyproject.toml) -nodeenv==1.9.1 # via -c constraints.txt, pre-commit -numpy==1.24.4 # via -c constraints.txt, contourpy, dace, gt4py (pyproject.toml), matplotlib -orderly-set==5.2.2 # via -c constraints.txt, deepdiff -packaging==24.1 # via -c constraints.txt, black, build, gt4py (pyproject.toml), ipykernel, jupytext, matplotlib, pipdeptree, pyproject-api, pytest, pytest-factoryboy, setuptools-scm, sphinx, tox -parso==0.8.4 # via -c constraints.txt, jedi -pathspec==0.12.1 # via -c constraints.txt, black -pexpect==4.9.0 # via -c constraints.txt, ipython -pickleshare==0.7.5 # via -c constraints.txt, ipython -pillow==10.4.0 # via -c constraints.txt, matplotlib -pip-tools==7.4.1 # via -c constraints.txt, -r requirements-dev.in -pipdeptree==2.23.3 # via -c constraints.txt, -r requirements-dev.in -pkgutil-resolve-name==1.3.10 # via -c constraints.txt, jsonschema -platformdirs==4.3.2 # via -c constraints.txt, black, jupyter-core, tox, virtualenv -pluggy==1.5.0 # via -c constraints.txt, pytest, tox -ply==3.11 # via -c constraints.txt, dace -pre-commit==3.5.0 # via -c constraints.txt, -r requirements-dev.in -prompt-toolkit==3.0.36 # via -c constraints.txt, ipython, questionary, tach -psutil==6.0.0 # via -c constraints.txt, -r requirements-dev.in, ipykernel, pytest-xdist -ptyprocess==0.7.0 # via -c constraints.txt, pexpect -pure-eval==0.2.3 # via -c constraints.txt, stack-data -pybind11==2.13.5 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.9.1 # via -c constraints.txt, bump-my-version, pydantic-settings, tach -pydantic-core==2.23.3 # via -c constraints.txt, pydantic -pydantic-settings==2.5.2 # via -c constraints.txt, bump-my-version -pydot==2.0.0 # via -c constraints.txt, tach -pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx -pyparsing==3.1.4 # via -c constraints.txt, matplotlib, pydot -pyproject-api==1.7.1 # via -c constraints.txt, tox -pyproject-hooks==1.1.0 # via -c constraints.txt, build, pip-tools -pytest==8.3.3 # via -c constraints.txt, -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-custom-exit-code, pytest-factoryboy, pytest-instafail, pytest-xdist -pytest-cache==1.0 # via -c constraints.txt, -r requirements-dev.in -pytest-cov==5.0.0 # via -c constraints.txt, -r requirements-dev.in -pytest-custom-exit-code==0.3.0 # via -c constraints.txt, -r requirements-dev.in -pytest-factoryboy==2.7.0 # via -c constraints.txt, -r requirements-dev.in -pytest-instafail==0.5.0 # via -c constraints.txt, -r requirements-dev.in -pytest-xdist[psutil]==3.6.1 # via -c constraints.txt, -r requirements-dev.in -python-dateutil==2.9.0.post0 # via -c constraints.txt, faker, jupyter-client, matplotlib -python-dotenv==1.0.1 # via -c constraints.txt, pydantic-settings -pytz==2024.2 # via -c constraints.txt, babel -pyyaml==6.0.2 # via -c constraints.txt, dace, jupytext, pre-commit, tach -pyzmq==26.2.0 # via -c constraints.txt, ipykernel, jupyter-client -questionary==2.0.1 # via -c constraints.txt, bump-my-version -referencing==0.35.1 # via -c constraints.txt, jsonschema, jsonschema-specifications -requests==2.32.3 # via -c constraints.txt, sphinx -rich==13.8.1 # via -c constraints.txt, bump-my-version, rich-click, tach -rich-click==1.8.3 # via -c constraints.txt, bump-my-version -rpds-py==0.20.0 # via -c constraints.txt, jsonschema, referencing -ruff==0.6.4 # via -c constraints.txt, -r requirements-dev.in -setuptools-scm==8.1.0 # via -c constraints.txt, fparser -six==1.16.0 # via -c constraints.txt, asttokens, astunparse, python-dateutil -smmap==5.0.1 # via -c constraints.txt, gitdb -snowballstemmer==2.2.0 # via -c constraints.txt, sphinx -sortedcontainers==2.4.0 # via -c constraints.txt, hypothesis -sphinx==7.1.2 # via -c constraints.txt, -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==2.0.0 # via -c constraints.txt, -r requirements-dev.in -sphinxcontrib-applehelp==1.0.4 # via -c constraints.txt, sphinx -sphinxcontrib-devhelp==1.0.2 # via -c constraints.txt, sphinx -sphinxcontrib-htmlhelp==2.0.1 # via -c constraints.txt, sphinx -sphinxcontrib-jquery==4.1 # via -c constraints.txt, sphinx-rtd-theme -sphinxcontrib-jsmath==1.0.1 # via -c constraints.txt, sphinx -sphinxcontrib-qthelp==1.0.3 # via -c constraints.txt, sphinx -sphinxcontrib-serializinghtml==1.1.5 # via -c constraints.txt, sphinx -stack-data==0.6.3 # via -c constraints.txt, ipython -stdlib-list==0.10.0 # via -c constraints.txt, tach -sympy==1.12.1 # via -c constraints.txt, dace, gt4py (pyproject.toml) -tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.10.7 # via -c constraints.txt, -r requirements-dev.in -tomli==2.0.1 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tox -tomli-w==1.0.0 # via -c constraints.txt, tach -tomlkit==0.13.2 # via -c constraints.txt, bump-my-version -toolz==0.12.1 # via -c constraints.txt, cytoolz -tornado==6.4.1 # via -c constraints.txt, ipykernel, jupyter-client -tox==4.18.1 # via -c constraints.txt, -r requirements-dev.in -traitlets==5.14.3 # via -c constraints.txt, comm, ipykernel, ipython, jupyter-client, jupyter-core, matplotlib-inline, nbclient, nbformat -types-tabulate==0.9.0.20240106 # via -c constraints.txt, -r requirements-dev.in -typing-extensions==4.12.2 # via -c constraints.txt, annotated-types, black, gt4py (pyproject.toml), ipython, mypy, pydantic, pydantic-core, pytest-factoryboy, rich, rich-click, setuptools-scm -urllib3==2.2.3 # via -c constraints.txt, requests -virtualenv==20.26.4 # via -c constraints.txt, pre-commit, tox -wcmatch==9.0 # via -c constraints.txt, bump-my-version -wcwidth==0.2.13 # via -c constraints.txt, prompt-toolkit -websockets==13.0.1 # via -c constraints.txt, dace -wheel==0.44.0 # via -c constraints.txt, astunparse, pip-tools -xxhash==3.0.0 # via -c constraints.txt, gt4py (pyproject.toml) -zipp==3.20.1 # via -c constraints.txt, importlib-metadata, importlib-resources - -# The following packages are considered to be unsafe in a requirements file: -pip==24.2 # via -c constraints.txt, pip-tools, pipdeptree -setuptools==74.1.2 # via -c constraints.txt, gt4py (pyproject.toml), pip-tools, setuptools-scm diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 1b88285475..c0bf9580b3 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -27,6 +27,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 shadowing a Python builtin __all__ += ["next"] diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 9d07b2eb79..41a592c3d4 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -39,15 +39,21 @@ ) -if TYPE_CHECKING: +try: import cupy as cp +except ImportError: + cp = None +if TYPE_CHECKING: CuPyNDArray: TypeAlias = cp.ndarray import jax.numpy as jnp JaxNDArray: TypeAlias = jnp.ndarray +# The actual assignment happens after the definition of `DeviceType` enum. +CUPY_DEVICE_TYPE: Literal[None, DeviceType.CUDA, DeviceType.ROCM] +"""Type of the GPU accelerator device, if present.""" # -- Scalar types supported by GT4Py -- bool_ = np.bool_ @@ -373,41 +379,34 @@ class DeviceType(enum.IntEnum): CPU = 1 CUDA = 2 - CPU_PINNED = 3 - OPENCL = 4 - VULKAN = 7 - METAL = 8 - VPI = 9 + # CPU_PINNED = 3 # noqa: ERA001 + # OPENCL = 4 # noqa: ERA001 + # VULKAN = 7 # noqa: ERA001 + # METAL = 8 # noqa: ERA001 + # VPI = 9 # noqa: ERA001 ROCM = 10 - CUDA_MANAGED = 13 - ONE_API = 14 + # CUDA_MANAGED = 13 # noqa: ERA001 + # ONE_API = 14 # noqa: ERA001 CPUDeviceTyping: TypeAlias = Literal[DeviceType.CPU] CUDADeviceTyping: TypeAlias = Literal[DeviceType.CUDA] -CPUPinnedDeviceTyping: TypeAlias = Literal[DeviceType.CPU_PINNED] -OpenCLDeviceTyping: TypeAlias = Literal[DeviceType.OPENCL] -VulkanDeviceTyping: TypeAlias = Literal[DeviceType.VULKAN] -MetalDeviceTyping: TypeAlias = Literal[DeviceType.METAL] -VPIDeviceTyping: TypeAlias = Literal[DeviceType.VPI] ROCMDeviceTyping: TypeAlias = Literal[DeviceType.ROCM] -CUDAManagedDeviceTyping: TypeAlias = Literal[DeviceType.CUDA_MANAGED] -OneApiDeviceTyping: TypeAlias = Literal[DeviceType.ONE_API] DeviceTypeT = TypeVar( "DeviceTypeT", CPUDeviceTyping, CUDADeviceTyping, - CPUPinnedDeviceTyping, - OpenCLDeviceTyping, - VulkanDeviceTyping, - MetalDeviceTyping, - VPIDeviceTyping, ROCMDeviceTyping, ) +CUPY_DEVICE_TYPE = ( + None if not cp else (DeviceType.ROCM if cp.cuda.runtime.is_hip else DeviceType.CUDA) +) + + @dataclasses.dataclass(frozen=True) class Device(Generic[DeviceTypeT]): """ @@ -439,13 +438,21 @@ def ndim(self) -> int: ... @property def shape(self) -> tuple[int, ...]: ... + @property + def strides(self) -> tuple[int, ...]: ... + @property def dtype(self) -> Any: ... + @property + def itemsize(self) -> int: ... + def item(self) -> Any: ... def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... + def any(self) -> bool: ... + def __getitem__(self, item: Any) -> NDArrayObject: ... def __abs__(self) -> NDArrayObject: ... @@ -496,4 +503,4 @@ def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... + def __xor__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... diff --git a/src/gt4py/cartesian/__init__.py b/src/gt4py/cartesian/__init__.py index c03ef15105..90df315d5c 100644 --- a/src/gt4py/cartesian/__init__.py +++ b/src/gt4py/cartesian/__init__.py @@ -27,7 +27,7 @@ __all__ = [ - "typing", + "StencilObject", "caching", "cli", "config", @@ -39,5 +39,5 @@ "stencil_builder", "stencil_object", "type_hints", - "StencilObject", + "typing", ] diff --git a/src/gt4py/cartesian/backend/__init__.py b/src/gt4py/cartesian/backend/__init__.py index 7a6f877295..4296e3b389 100644 --- a/src/gt4py/cartesian/backend/__init__.py +++ b/src/gt4py/cartesian/backend/__init__.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from warnings import warn + from .base import ( REGISTRY, Backend, @@ -16,13 +18,6 @@ from_name, register, ) - - -try: - from .dace_backend import DaceCPUBackend, DaceGPUBackend -except ImportError: - pass - from .cuda_backend import CudaBackend from .gtcpp_backend import GTCpuIfirstBackend, GTCpuKfirstBackend, GTGpuBackend from .module_generator import BaseModuleGenerator @@ -37,9 +32,9 @@ "BasePyExtBackend", "CLIBackendMixin", "CudaBackend", - "GTGpuBackend", "GTCpuIfirstBackend", "GTCpuKfirstBackend", + "GTGpuBackend", "NumpyBackend", "PurePythonBackendCLIMixin", "from_name", @@ -47,5 +42,12 @@ ] -if "DaceCPUBackend" in globals(): +try: + from .dace_backend import DaceCPUBackend, DaceGPUBackend + __all__ += ["DaceCPUBackend", "DaceGPUBackend"] +except ImportError: + warn( + "GT4Py was unable to load DaCe. DaCe backends (`dace:cpu` and `dace:gpu`) will not be available.", + stacklevel=2, + ) diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 5bab0453a9..571f86b527 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -172,9 +172,9 @@ def generate_computation(self) -> Dict[str, Union[str, Dict]]: Returns ------- Dict[str, str | Dict] of source file names / directories -> contents: - If a key's value is a string it is interpreted as a file name and the value as the - source code of that file - If a key's value is a Dict, it is interpreted as a directory name and it's + If a key's value is a string, it is interpreted as a file name and its value as the + source code of that file. + If a key's value is a Dict, it is interpreted as a directory name and its value as a nested file hierarchy to which the same rules are applied recursively. The root path is relative to the build directory. @@ -222,7 +222,7 @@ def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: Returns ------- - Analog to :py:meth:`generate_computation` but containing bindings source code, The + Analog to :py:meth:`generate_computation` but containing bindings source code. The dictionary contains a tree of directories with leaves being a mapping from filename to source code pairs, relative to the build directory. diff --git a/src/gt4py/cartesian/backend/cuda_backend.py b/src/gt4py/cartesian/backend/cuda_backend.py index f0238e309b..9646383c0f 100644 --- a/src/gt4py/cartesian/backend/cuda_backend.py +++ b/src/gt4py/cartesian/backend/cuda_backend.py @@ -136,12 +136,12 @@ class CudaBackend(BaseGTBackend, CLIBackendMixin): } languages = {"computation": "cuda", "bindings": ["python"]} storage_info = gt_storage.layout.CUDALayout - PYEXT_GENERATOR_CLASS = CudaExtGenerator # type: ignore + PYEXT_GENERATOR_CLASS = CudaExtGenerator MODULE_GENERATOR_CLASS = CUDAPyExtModuleGenerator GT_BACKEND_T = "gpu" def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: - return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True) + return self.make_extension(uses_cuda=True) def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index f49895a435..81775ade1e 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -12,7 +12,6 @@ import os import pathlib import re -import textwrap from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import dace @@ -20,6 +19,7 @@ from dace.sdfg.utils import inline_sdfgs from gt4py import storage as gt_storage +from gt4py._core import definitions as core_defs from gt4py.cartesian import config as gt_config from gt4py.cartesian.backend.base import CLIBackendMixin, register from gt4py.cartesian.backend.gtc_common import ( @@ -32,10 +32,10 @@ ) from gt4py.cartesian.backend.module_generator import make_args_data_from_gtir from gt4py.cartesian.gtc import common, gtir +from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder from gt4py.cartesian.gtc.dace.transformations import ( - InlineThreadLocalTransients, NoEmptyEdgeTrivialMapElimination, nest_sequential_map_scopes, ) @@ -56,17 +56,17 @@ def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): - repldict = replace_strides( + replacement_dictionary = replace_strides( [array for array in sdfg.arrays.values() if array.transient], layout_map ) - sdfg.replace_dict(repldict) + sdfg.replace_dict(replacement_dictionary) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.NestedSDFG): - for k, v in repldict.items(): + for k, v in replacement_dictionary.items(): if k in node.symbol_mapping: node.symbol_mapping[k] = v - for k in repldict.keys(): + for k in replacement_dictionary.keys(): if k in sdfg.symbols: sdfg.remove_symbol(k) @@ -120,8 +120,6 @@ def _set_expansion_orders(sdfg: dace.SDFG): def _set_tile_sizes(sdfg: dace.SDFG): - import gt4py.cartesian.gtc.dace.daceir as dcir # avoid circular import - for node, _ in filter( lambda n: isinstance(n[0], StencilComputation), sdfg.all_nodes_recursive() ): @@ -143,7 +141,7 @@ def _to_device(sdfg: dace.SDFG, device: str) -> None: node.device = dace.DeviceType.GPU -def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): +def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): args_data = make_args_data_from_gtir(gtir_pipeline) # stencils without effect @@ -152,10 +150,6 @@ def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map) sdfg.add_state(gtir_pipeline.gtir.name) return sdfg - for array in sdfg.arrays.values(): - if array.transient: - array.lifetime = dace.AllocationLifetime.Persistent - sdfg.simplify(validate=False) _set_expansion_orders(sdfg) @@ -164,7 +158,7 @@ def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map) return sdfg -def _post_expand_trafos(sdfg: dace.SDFG): +def _post_expand_transformations(sdfg: dace.SDFG): # DaCe "standard" clean-up transformations sdfg.simplify(validate=False) @@ -179,7 +173,8 @@ def _post_expand_trafos(sdfg: dace.SDFG): if node.schedule == dace.ScheduleType.CPU_Multicore and len(node.range) <= 1: node.schedule = dace.ScheduleType.Sequential - sdfg.apply_transformations_repeated(InlineThreadLocalTransients, validate=False) + # To be re-evaluated with https://github.com/GridTools/gt4py/issues/1896 + # sdfg.apply_transformations_repeated(InlineThreadLocalTransients, validate=False) # noqa: ERA001 sdfg.simplify(validate=False) nest_sequential_map_scopes(sdfg) for sd in sdfg.all_sdfgs_recursive(): @@ -355,7 +350,7 @@ def _unexpanded_sdfg(self): sdfg = OirSDFGBuilder().visit(oir_node) _to_device(sdfg, self.builder.backend.storage_info["device"]) - _pre_expand_trafos( + _pre_expand_transformations( self.builder.gtir_pipeline, sdfg, self.builder.backend.storage_info["layout_map"], @@ -371,7 +366,7 @@ def unexpanded_sdfg(self): def _expanded_sdfg(self): sdfg = self._unexpanded_sdfg() sdfg.expand_library_nodes() - _post_expand_trafos(sdfg) + _post_expand_transformations(sdfg) return sdfg def expanded_sdfg(self): @@ -432,20 +427,20 @@ def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: class DaCeComputationCodegen: template = as_mako( - """ - auto ${name}(const std::array& domain) { - return [domain](${",".join(functor_args)}) { - const int __I = domain[0]; - const int __J = domain[1]; - const int __K = domain[2]; - ${name}${state_suffix} dace_handle; - ${backend_specifics} - auto allocator = gt::sid::cached_allocator(&${allocator}); - ${"\\n".join(tmp_allocs)} - __program_${name}(${",".join(["&dace_handle", *dace_args])}); - }; - } - """ + """\ +auto ${name}(const std::array& domain) { + return [domain](${",".join(functor_args)}) { + const int __I = domain[0]; + const int __J = domain[1]; + const int __K = domain[2]; + ${name}${state_suffix} dace_handle; + ${backend_specifics} + auto allocator = gt::sid::cached_allocator(&${allocator}); + ${"\\n".join(tmp_allocs)} + __program_${name}(${",".join(["&dace_handle", *dace_args])}); + }; +} +""" ) def generate_tmp_allocs(self, sdfg): @@ -511,7 +506,7 @@ def _postprocess_dace_code(code_objects, is_gpu, builder): lines = lines[0:i] + cuda_code.split("\n") + lines[i + 1 :] break - def keep_line(line): + def keep_line(line: str) -> bool: line = line.strip() if line == '#include "../../include/hash.h"': return False @@ -521,11 +516,7 @@ def keep_line(line): return False return True - lines = filter(keep_line, lines) - generated_code = "\n".join(lines) - if builder.options.format_source: - generated_code = codegen.format_source("cpp", generated_code, style="LLVM") - return generated_code + return "\n".join(filter(keep_line, lines)) @classmethod def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDFG): @@ -533,7 +524,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDF with dace.config.temporary_config(): # To prevent conflict with 3rd party usage of DaCe config always make sure that any # changes be under the temporary_config manager - if gt_config.GT4PY_USE_HIP: + if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: dace.config.Config.set("compiler", "cuda", "backend", value="hip") dace.config.Config.set("compiler", "cuda", "max_concurrent_streams", value=-1) dace.config.Config.set( @@ -563,17 +554,18 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: StencilBuilder, sdfg: dace.SDF allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"), ) - generated_code = textwrap.dedent( - f"""#include - #include - #include - {"#include " if is_gpu else omp_header} - namespace gt = gridtools; - {computations} - - {interface} - """ - ) + generated_code = f"""\ +#include +#include +#include +{"#include " if is_gpu else omp_header} +namespace gt = gridtools; + +{computations} + +{interface} +""" + if builder.options.format_source: generated_code = codegen.format_source("cpp", generated_code, style="LLVM") @@ -760,7 +752,7 @@ class DaCeCUDAPyExtModuleGenerator(DaCePyExtModuleGenerator, CUDAPyExtModuleGene class BaseDaceBackend(BaseGTBackend, CLIBackendMixin): GT_BACKEND_T = "dace" - PYEXT_GENERATOR_CLASS = DaCeExtGenerator # type: ignore + PYEXT_GENERATOR_CLASS = DaCeExtGenerator def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) @@ -794,7 +786,7 @@ class DaceCPUBackend(BaseDaceBackend): options = BaseGTBackend.GT_BACKEND_OPTS def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: - return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=False) + return self.make_extension(uses_cuda=False) @register @@ -815,4 +807,4 @@ class DaceGPUBackend(BaseDaceBackend): options = {**BaseGTBackend.GT_BACKEND_OPTS, "device_sync": {"versioning": True, "type": bool}} def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: - return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True) + return self.make_extension(uses_cuda=True) diff --git a/src/gt4py/cartesian/backend/gtc_common.py b/src/gt4py/cartesian/backend/gtc_common.py index abc4baede1..348e85de92 100644 --- a/src/gt4py/cartesian/backend/gtc_common.py +++ b/src/gt4py/cartesian/backend/gtc_common.py @@ -236,19 +236,15 @@ def generate(self) -> Type[StencilObject]: def generate_computation(self) -> Dict[str, Union[str, Dict]]: dir_name = f"{self.builder.options.name}_src" - src_files = self.make_extension_sources(stencil_ir=self.builder.gtir) + src_files = self._make_extension_sources() return {dir_name: src_files["computation"]} - def generate_bindings( - self, language_name: str, *, stencil_ir: Optional[gtir.Stencil] = None - ) -> Dict[str, Union[str, Dict]]: - if not stencil_ir: - stencil_ir = self.builder.gtir - assert stencil_ir is not None + def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]: if language_name != "python": return super().generate_bindings(language_name) + dir_name = f"{self.builder.options.name}_src" - src_files = self.make_extension_sources(stencil_ir=stencil_ir) + src_files = self._make_extension_sources() return {dir_name: src_files["bindings"]} @abc.abstractmethod @@ -260,32 +256,26 @@ def generate_extension(self, **kwargs: Any) -> Tuple[str, str]: """ pass - def make_extension( - self, *, stencil_ir: Optional[gtir.Stencil] = None, uses_cuda: bool = False - ) -> Tuple[str, str]: + def make_extension(self, *, uses_cuda: bool = False) -> Tuple[str, str]: build_info = self.builder.options.build_info if build_info is not None: start_time = time.perf_counter() - if not stencil_ir: - stencil_ir = self.builder.gtir - assert stencil_ir is not None - # Generate source gt_pyext_files: Dict[str, Any] gt_pyext_sources: Dict[str, Any] - if not self.builder.options._impl_opts.get("disable-code-generation", False): - gt_pyext_files = self.make_extension_sources(stencil_ir=stencil_ir) - gt_pyext_sources = { - **gt_pyext_files["computation"], - **gt_pyext_files["bindings"], - } - else: + if self.builder.options._impl_opts.get("disable-code-generation", False): # Pass NOTHING to the self.builder means try to reuse the source code files gt_pyext_files = {} gt_pyext_sources = { key: gt_utils.NOTHING for key in self.PYEXT_GENERATOR_CLASS.TEMPLATE_FILES.keys() } + else: + gt_pyext_files = self._make_extension_sources() + gt_pyext_sources = { + **gt_pyext_files["computation"], + **gt_pyext_files["bindings"], + } if build_info is not None: next_time = time.perf_counter() @@ -317,10 +307,11 @@ def make_extension( return result - def make_extension_sources(self, *, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]: + def _make_extension_sources(self) -> Dict[str, Dict[str, str]]: """Generate the source for the stencil independently from use case.""" if "computation_src" in self.builder.backend_data: return self.builder.backend_data["computation_src"] + class_name = self.pyext_class_name if self.builder.stencil_id else self.builder.options.name module_name = ( self.pyext_module_name @@ -328,7 +319,7 @@ def make_extension_sources(self, *, stencil_ir: gtir.Stencil) -> Dict[str, Dict[ else f"{self.builder.options.name}_pyext" ) gt_pyext_generator = self.PYEXT_GENERATOR_CLASS(class_name, module_name, self) - gt_pyext_sources = gt_pyext_generator(stencil_ir) + gt_pyext_sources = gt_pyext_generator(self.builder.gtir) final_ext = ".cu" if self.languages and self.languages["computation"] == "cuda" else ".cpp" comp_src = gt_pyext_sources["computation"] for key in [k for k in comp_src.keys() if k.endswith(".src")]: diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index 5d3fd623d9..96f5672ae4 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -126,10 +126,10 @@ def apply(cls, root, *, module_name="stencil", **kwargs) -> str: class GTBaseBackend(BaseGTBackend, CLIBackendMixin): options = BaseGTBackend.GT_BACKEND_OPTS - PYEXT_GENERATOR_CLASS = GTExtGenerator # type: ignore + PYEXT_GENERATOR_CLASS = GTExtGenerator def _generate_extension(self, uses_cuda: bool) -> Tuple[str, str]: - return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=uses_cuda) + return self.make_extension(uses_cuda=uses_cuda) def generate(self) -> Type[StencilObject]: self.check_options(self.builder.options) diff --git a/src/gt4py/cartesian/backend/module_generator.py b/src/gt4py/cartesian/backend/module_generator.py index e2266b709c..8cc63ae34e 100644 --- a/src/gt4py/cartesian/backend/module_generator.py +++ b/src/gt4py/cartesian/backend/module_generator.py @@ -62,8 +62,6 @@ def parameter_names(self) -> Set[str]: def make_args_data_from_gtir(pipeline: GtirPipeline) -> ModuleData: """ Compute module data containing information about stencil arguments from gtir. - - This is no longer compatible with the legacy backends. """ if pipeline.stencil_id in _args_data_cache: return _args_data_cache[pipeline.stencil_id] @@ -142,7 +140,7 @@ def __call__( """ Generate source code for a Python module containing a StencilObject. - A possible reaosn for extending is processing additional kwargs, + A possible reason for extending is processing additional kwargs, using a different template might require completely overriding. """ if builder: diff --git a/src/gt4py/cartesian/backend/pyext_builder.py b/src/gt4py/cartesian/backend/pyext_builder.py index 8f49ce6f22..8875e3e3af 100644 --- a/src/gt4py/cartesian/backend/pyext_builder.py +++ b/src/gt4py/cartesian/backend/pyext_builder.py @@ -18,6 +18,7 @@ from setuptools import distutils from setuptools.command.build_ext import build_ext +from gt4py._core import definitions as core_defs from gt4py.cartesian import config as gt_config @@ -51,6 +52,7 @@ def get_gt_pyext_build_opts( ) -> Dict[str, Union[str, List[str], Dict[str, Any]]]: include_dirs = [gt_config.build_settings["boost_include_path"]] extra_compile_args_from_config = gt_config.build_settings["extra_compile_args"] + is_rocm_gpu = core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM if uses_cuda: compute_capability = get_cuda_compute_capability() @@ -68,8 +70,6 @@ def get_gt_pyext_build_opts( gt_include_path = gt_config.build_settings["gt_include_path"] - import os - extra_compile_args = dict( cxx=[ "-std=c++17", @@ -93,7 +93,7 @@ def get_gt_pyext_build_opts( "-DBOOST_OPTIONAL_USE_OLD_DEFINITION_OF_NONE", *extra_compile_args_from_config["cuda"], ] - if gt_config.GT4PY_USE_HIP: + if is_rocm_gpu: extra_compile_args["cuda"] += [ "-isystem{}".format(gt_include_path), "-isystem{}".format(gt_config.build_settings["boost_include_path"]), @@ -125,7 +125,7 @@ def get_gt_pyext_build_opts( extra_compile_args["cxx"].append( "-isystem{}".format(os.path.join(dace_path, "runtime/include")) ) - if gt_config.GT4PY_USE_HIP: + if is_rocm_gpu: extra_compile_args["cuda"].append( "-isystem{}".format(os.path.join(dace_path, "runtime/include")) ) @@ -158,7 +158,7 @@ def get_gt_pyext_build_opts( if uses_cuda: cuda_flags = [] for cpp_flag in cpp_flags: - if gt_config.GT4PY_USE_HIP: + if is_rocm_gpu: cuda_flags.extend([cpp_flag]) else: cuda_flags.extend(["--compiler-options", cpp_flag]) @@ -309,7 +309,7 @@ def build_pybind_cuda_ext( library_dirs = library_dirs or [] library_dirs = [*library_dirs, gt_config.build_settings["cuda_library_path"]] libraries = libraries or [] - if gt_config.GT4PY_USE_HIP: + if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: libraries = [*libraries, "hiprtc"] else: libraries = [*libraries, "cudart"] @@ -363,7 +363,7 @@ def cuda_compile(obj, src, ext, cc_args, extra_postargs, pp_opts): cflags = copy.deepcopy(extra_postargs) try: if os.path.splitext(src)[-1] == ".cu": - if gt_config.GT4PY_USE_HIP: + if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: cuda_exec = os.path.join(gt_config.build_settings["cuda_bin_path"], "hipcc") else: cuda_exec = os.path.join(gt_config.build_settings["cuda_bin_path"], "nvcc") diff --git a/src/gt4py/cartesian/caching.py b/src/gt4py/cartesian/caching.py index 20c0b49fae..2df2589ded 100644 --- a/src/gt4py/cartesian/caching.py +++ b/src/gt4py/cartesian/caching.py @@ -61,7 +61,7 @@ def generate_cache_info(self) -> Dict[str, Any]: """ Generate the cache info dict. - Backend specific additions can be added via a hook propery on the backend instance. + Backend specific additions can be added via a hook properly on the backend instance. Override :py:meth:`gt4py.backend.base.Backend.extra_cache_info` to store extra info. """ diff --git a/src/gt4py/cartesian/cli.py b/src/gt4py/cartesian/cli.py index 23f8791ca7..4ea5e44074 100644 --- a/src/gt4py/cartesian/cli.py +++ b/src/gt4py/cartesian/cli.py @@ -90,7 +90,7 @@ def backend_table(cls) -> str: ", ".join(backend.languages["bindings"]) if backend and backend.languages else "?" for backend in backends ] - enabled = [backend is not None and "Yes" or "No" for backend in backends] + enabled = [(backend is not None and "Yes") or "No" for backend in backends] data = zip(names, comp_langs, binding_langs, enabled) return tabulate.tabulate(data, headers=headers) @@ -138,6 +138,8 @@ def convert( self, value: str, param: Optional[click.Parameter], ctx: Optional[click.Context] ) -> Tuple[str, Any]: backend = ctx.params["backend"] if ctx else gt4pyc.backend.from_name("numpy") + assert isinstance(backend, type) + assert issubclass(backend, gt4pyc.backend.Backend) name, value = self._try_split(value) if name.strip() not in backend.options: self.fail(f"Backend {backend.name} received unknown option: {name}!") diff --git a/src/gt4py/cartesian/config.py b/src/gt4py/cartesian/config.py index 5aa32506b7..a48f612c84 100644 --- a/src/gt4py/cartesian/config.py +++ b/src/gt4py/cartesian/config.py @@ -12,6 +12,8 @@ import gridtools_cpp +from gt4py._core import definitions as core_defs + GT4PY_INSTALLATION_PATH: str = os.path.dirname(os.path.abspath(__file__)) @@ -26,18 +28,6 @@ CUDA_HOST_CXX: Optional[str] = os.environ.get("CUDA_HOST_CXX", None) -if "GT4PY_USE_HIP" in os.environ: - GT4PY_USE_HIP: bool = bool(int(os.environ["GT4PY_USE_HIP"])) -else: - # Autodetect cupy with ROCm/HIP support - try: - import cupy as _cp - - GT4PY_USE_HIP = _cp.cuda.get_hipcc_path() is not None - del _cp - except Exception: - GT4PY_USE_HIP = False - GT_INCLUDE_PATH: str = os.path.abspath(gridtools_cpp.get_include_dir()) GT_CPP_TEMPLATE_DEPTH: int = 1024 @@ -66,7 +56,7 @@ "parallel_jobs": multiprocessing.cpu_count(), "cpp_template_depth": os.environ.get("GT_CPP_TEMPLATE_DEPTH", GT_CPP_TEMPLATE_DEPTH), } -if GT4PY_USE_HIP: +if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: build_settings["cuda_library_path"] = os.path.join(CUDA_ROOT, "lib") else: build_settings["cuda_library_path"] = os.path.join(CUDA_ROOT, "lib64") diff --git a/src/gt4py/cartesian/frontend/__init__.py b/src/gt4py/cartesian/frontend/__init__.py index 6988fb6aab..f1e0f9a775 100644 --- a/src/gt4py/cartesian/frontend/__init__.py +++ b/src/gt4py/cartesian/frontend/__init__.py @@ -10,4 +10,4 @@ from .base import REGISTRY, Frontend, from_name, register -__all__ = ["gtscript_frontend", "REGISTRY", "Frontend", "from_name", "register"] +__all__ = ["REGISTRY", "Frontend", "from_name", "gtscript_frontend", "register"] diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index e2aa98f3cf..4d8ac98529 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -459,7 +459,8 @@ def visit_Call(self, node: ast.Call, *, target_node=None): # Cyclomatic complex call_args[name] = ast.Constant(value=arg_infos[name]) except Exception as ex: raise GTScriptSyntaxError( - message="Invalid call signature", loc=nodes.Location.from_ast_node(node) + message=f"Invalid call signature when calling {call_name}", + loc=nodes.Location.from_ast_node(node), ) from ex # Rename local names in subroutine to avoid conflicts with caller context names @@ -1450,14 +1451,6 @@ def visit_Assign(self, node: ast.Assign) -> list: message="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.", loc=nodes.Location.from_ast_node(t), ) - if self.backend_name in ["gt:gpu", "dace:gpu"]: - import cupy as cp - - if cp.cuda.runtime.runtimeGetVersion() < 12000: - raise GTScriptSyntaxError( - message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} for CUDA<12. Please update CUDA.", - loc=nodes.Location.from_ast_node(t), - ) if not self._is_known(name): if name in self.temp_decls: diff --git a/src/gt4py/cartesian/frontend/nodes.py b/src/gt4py/cartesian/frontend/nodes.py index f84577e7b5..2ca9e8fe1f 100644 --- a/src/gt4py/cartesian/frontend/nodes.py +++ b/src/gt4py/cartesian/frontend/nodes.py @@ -130,7 +130,6 @@ parameters: List[VarDecl], computations: List[ComputationBlock], [externals: Dict[str, Any], sources: Dict[str, str]]) - """ from __future__ import annotations diff --git a/src/gt4py/cartesian/gtc/common.py b/src/gt4py/cartesian/gtc/common.py index bfe434e7f3..60236a3e97 100644 --- a/src/gt4py/cartesian/gtc/common.py +++ b/src/gt4py/cartesian/gtc/common.py @@ -38,14 +38,14 @@ class GTCPreconditionError(eve.exceptions.EveError, RuntimeError): message_template = "GTC pass precondition error: [{info}]" def __init__(self, *, expected: str, **kwargs: Any) -> None: - super().__init__(expected=expected, **kwargs) # type: ignore + super().__init__(expected=expected, **kwargs) class GTCPostconditionError(eve.exceptions.EveError, RuntimeError): message_template = "GTC pass postcondition error: [{info}]" def __init__(self, *, expected: str, **kwargs: Any) -> None: - super().__init__(expected=expected, **kwargs) # type: ignore + super().__init__(expected=expected, **kwargs) class AssignmentKind(eve.StrEnum): @@ -60,7 +60,7 @@ class AssignmentKind(eve.StrEnum): @enum.unique class UnaryOperator(eve.StrEnum): - """Unary operator indentifier.""" + """Unary operator identifier.""" POS = "+" NEG = "-" @@ -118,7 +118,7 @@ def isbool(self): return self == self.BOOL def isinteger(self): - return self in (self.INT8, self.INT32, self.INT64) + return self in (self.INT8, self.INT16, self.INT32, self.INT64) def isfloat(self): return self in (self.FLOAT32, self.FLOAT64) @@ -229,8 +229,8 @@ class LevelMarker(eve.StrEnum): @enum.unique class ExprKind(eve.IntEnum): - SCALAR: ExprKind = typing.cast("ExprKind", enum.auto()) - FIELD: ExprKind = typing.cast("ExprKind", enum.auto()) + SCALAR = typing.cast("ExprKind", enum.auto()) + FIELD = typing.cast("ExprKind", enum.auto()) class LocNode(eve.Node): @@ -267,7 +267,7 @@ def verify_and_get_common_dtype( ) -> Optional[DataType]: assert len(exprs) > 0 if all(e.dtype is not DataType.AUTO for e in exprs): - dtypes: List[DataType] = [e.dtype for e in exprs] # type: ignore # guaranteed to be not None + dtypes: List[DataType] = [e.dtype for e in exprs] # guaranteed to be not None dtype = dtypes[0] if strict: if all(dt == dtype for dt in dtypes): @@ -311,7 +311,7 @@ class CartesianOffset(eve.Node): k: int @classmethod - def zero(cls) -> "CartesianOffset": + def zero(cls) -> CartesianOffset: return cls(i=0, j=0, k=0) def to_dict(self) -> Dict[str, int]: @@ -908,7 +908,7 @@ def op_to_ufunc( @functools.lru_cache(maxsize=None) def typestr_to_data_type(typestr: str) -> DataType: if not isinstance(typestr, str) or len(typestr) < 3 or not typestr[2:].isnumeric(): - return DataType.INVALID # type: ignore + return DataType.INVALID table = { ("b", 1): DataType.BOOL, ("i", 1): DataType.INT8, @@ -919,4 +919,4 @@ def typestr_to_data_type(typestr: str) -> DataType: ("f", 8): DataType.FLOAT64, } key = (typestr[1], int(typestr[2:])) - return table.get(key, DataType.INVALID) # type: ignore + return table.get(key, DataType.INVALID) diff --git a/src/gt4py/cartesian/gtc/cuir/cuir.py b/src/gt4py/cartesian/gtc/cuir/cuir.py index 62c3c520ac..fb6d28d071 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir.py @@ -32,11 +32,11 @@ class Stmt(common.Stmt): pass -class Literal(common.Literal, Expr): # type: ignore +class Literal(common.Literal, Expr): pass -class ScalarAccess(common.ScalarAccess, Expr): # type: ignore +class ScalarAccess(common.ScalarAccess, Expr): pass @@ -44,7 +44,7 @@ class VariableKOffset(common.VariableKOffset[Expr]): pass -class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore +class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): pass @@ -113,7 +113,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr): _dtype_propagation = common.ternary_op_dtype_propagation(strict=True) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py index 76f076874a..96149a1723 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py @@ -592,7 +592,7 @@ def ctype(symbol: str) -> str: @classmethod def apply(cls, root: LeafNode, **kwargs: Any) -> str: if not isinstance(root, cuir.Program): - raise ValueError("apply() requires gtcpp.Progam root node") + raise ValueError("apply() requires gtcpp.Program root node") generated_code = super().apply(root, **kwargs) if kwargs.get("format_source", True): generated_code = codegen.format_source("cpp", generated_code, style="LLVM") diff --git a/src/gt4py/cartesian/gtc/dace/daceir.py b/src/gt4py/cartesian/gtc/dace/daceir.py index 0ecb02b50f..e07ae9e52b 100644 --- a/src/gt4py/cartesian/gtc/dace/daceir.py +++ b/src/gt4py/cartesian/gtc/dace/daceir.py @@ -17,6 +17,7 @@ from gt4py import eve from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.common import LocNode +from gt4py.cartesian.gtc.dace import prefix from gt4py.cartesian.gtc.dace.symbol_utils import ( get_axis_bound_dace_symbol, get_axis_bound_diff_str, @@ -51,11 +52,11 @@ def tile_symbol(self) -> eve.SymbolRef: return eve.SymbolRef("__tile_" + self.lower()) @staticmethod - def dims_3d() -> Generator["Axis", None, None]: + def dims_3d() -> Generator[Axis, None, None]: yield from [Axis.I, Axis.J, Axis.K] @staticmethod - def dims_horizontal() -> Generator["Axis", None, None]: + def dims_horizontal() -> Generator[Axis, None, None]: yield from [Axis.I, Axis.J] def to_idx(self) -> int: @@ -357,7 +358,7 @@ def free_symbols(self) -> Set[eve.SymbolRef]: class GridSubset(eve.Node): - intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]] + intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] def __iter__(self): for axis in Axis.dims_3d(): @@ -429,10 +430,10 @@ def from_gt4py_extent(cls, extent: gt4py.cartesian.gtc.definitions.Extent): @classmethod def from_interval( cls, - interval: Union[oir.Interval, TileInterval, DomainInterval, IndexWithExtent], + interval: Union[DomainInterval, IndexWithExtent, oir.Interval, TileInterval], axis: Axis, ): - res_interval: Union[IndexWithExtent, TileInterval, DomainInterval] + res_interval: Union[DomainInterval, IndexWithExtent, TileInterval] if isinstance(interval, (DomainInterval, oir.Interval)): res_interval = DomainInterval( start=AxisBound( @@ -441,7 +442,7 @@ def from_interval( end=AxisBound(level=interval.end.level, offset=interval.end.offset, axis=Axis.K), ) else: - assert isinstance(interval, (TileInterval, IndexWithExtent)) + assert isinstance(interval, (IndexWithExtent, TileInterval)) res_interval = interval return cls(intervals={axis: res_interval}) @@ -464,7 +465,7 @@ def full_domain(cls, axes=None): return GridSubset(intervals=res_subsets) def tile(self, tile_sizes: Dict[Axis, int]): - res_intervals: Dict[Axis, Union[DomainInterval, TileInterval, IndexWithExtent]] = {} + res_intervals: Dict[Axis, Union[DomainInterval, IndexWithExtent, TileInterval]] = {} for axis, interval in self.intervals.items(): if isinstance(interval, DomainInterval) and axis in tile_sizes: if axis == Axis.K: @@ -505,15 +506,15 @@ def union(self, other): intervals[axis] = interval1.union(interval2) else: assert ( - isinstance(interval2, (TileInterval, DomainInterval)) - and isinstance(interval1, (IndexWithExtent, DomainInterval)) + isinstance(interval2, (DomainInterval, TileInterval)) + and isinstance(interval1, (DomainInterval, IndexWithExtent)) ) or ( - isinstance(interval1, (TileInterval, DomainInterval)) + isinstance(interval1, (DomainInterval, TileInterval)) and isinstance(interval2, IndexWithExtent) ) intervals[axis] = ( interval1 - if isinstance(interval1, (TileInterval, DomainInterval)) + if isinstance(interval1, (DomainInterval, TileInterval)) else interval2 ) return GridSubset(intervals=intervals) @@ -525,10 +526,6 @@ class FieldAccessInfo(eve.Node): dynamic_access: bool = False variable_offset_axes: List[Axis] = eve.field(default_factory=list) - @property - def is_dynamic(self) -> bool: - return self.dynamic_access or len(self.variable_offset_axes) > 0 - def axes(self): yield from self.grid_subset.axes() @@ -713,7 +710,7 @@ def axes(self): @property def is_dynamic(self) -> bool: - return self.access_info.is_dynamic + return self.access_info.dynamic_access def with_set_access_info(self, access_info: FieldAccessInfo) -> FieldDecl: return FieldDecl( @@ -730,18 +727,30 @@ class Literal(common.Literal, Expr): class ScalarAccess(common.ScalarAccess, Expr): - name: eve.Coerced[eve.SymbolRef] + is_target: bool + original_name: Optional[str] = None class VariableKOffset(common.VariableKOffset[Expr]): - pass + @datamodels.validator("k") + def no_casts_in_offset_expression(self, _: datamodels.Attribute, expression: Expr) -> None: + for part in expression.walk_values(): + if isinstance(part, Cast): + raise ValueError( + "DaCe backends are currently missing support for casts in variable k offsets. See issue https://github.com/GridTools/gt4py/issues/1881." + ) class IndexAccess(common.FieldAccess, Expr): - offset: Optional[Union[common.CartesianOffset, VariableKOffset]] + # ScalarAccess used for indirect addressing + offset: Optional[common.CartesianOffset | Literal | ScalarAccess | VariableKOffset] + is_target: bool + + explicit_indices: Optional[list[Literal | ScalarAccess | VariableKOffset]] = None + """Used to access as a full field with explicit indices""" -class AssignStmt(common.AssignStmt[Union[ScalarAccess, IndexAccess], Expr], Stmt): +class AssignStmt(common.AssignStmt[Union[IndexAccess, ScalarAccess], Expr], Stmt): _dtype_validation = common.assign_stmt_dtype_validation(strict=True) @@ -771,7 +780,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr): _dtype_propagation = common.ternary_op_dtype_propagation(strict=True) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass @@ -836,11 +845,103 @@ class IterationNode(eve.Node): grid_subset: GridSubset +class Condition(eve.Node): + condition: Tasklet + true_states: list[ComputationState | Condition | WhileLoop] + + # Currently unused due to how `if` statements are parsed in `gtir_to_oir`, see + # https://github.com/GridTools/gt4py/issues/1898 + false_states: list[ComputationState | Condition | WhileLoop] = eve.field(default_factory=list) + + @datamodels.validator("condition") + def condition_has_boolean_expression( + self, attribute: datamodels.Attribute, tasklet: Tasklet + ) -> None: + assert isinstance(tasklet, Tasklet) + assert len(tasklet.stmts) == 1 + assert isinstance(tasklet.stmts[0], AssignStmt) + assert isinstance(tasklet.stmts[0].left, ScalarAccess) + if tasklet.stmts[0].left.original_name is None: + raise ValueError( + f"Original node name not found for {tasklet.stmts[0].left.name}. DaCe IR error." + ) + assert isinstance(tasklet.stmts[0].right, Expr) + if tasklet.stmts[0].right.dtype != common.DataType.BOOL: + raise ValueError("Condition must be a boolean expression.") + + class Tasklet(ComputationNode, IterationNode, eve.SymbolTableTrait): - decls: List[LocalScalarDecl] + label: str stmts: List[Stmt] grid_subset: GridSubset = GridSubset.single_gridpoint() + @datamodels.validator("stmts") + def non_empty_list(self, attribute: datamodels.Attribute, v: list[Stmt]) -> None: + if len(v) < 1: + raise ValueError("Tasklet must contain at least one statement.") + + @datamodels.validator("stmts") + def read_after_write(self, attribute: datamodels.Attribute, statements: list[Stmt]) -> None: + def _remove_prefix(name: eve.SymbolRef) -> str: + return name.removeprefix(prefix.TASKLET_OUT).removeprefix(prefix.TASKLET_IN) + + class ReadAfterWriteChecker(eve.NodeVisitor): + def visit_IndexAccess(self, node: IndexAccess, writes: set[str]) -> None: + if node.is_target: + # Keep track of writes + writes.add(_remove_prefix(node.name)) + return + + # Check reads + if ( + node.name.startswith(prefix.TASKLET_OUT) + and _remove_prefix(node.name) not in writes + ): + raise ValueError(f"Reading undefined '{node.name}'. DaCe IR error.") + + if _remove_prefix(node.name) in writes and not node.name.startswith( + prefix.TASKLET_OUT + ): + raise ValueError( + f"Read after write of '{node.name}' not connected to out connector. DaCe IR error." + ) + + def visit_ScalarAccess(self, node: ScalarAccess, writes: set[str]) -> None: + # Handle stencil parameters differently because they are always available + if not node.name.startswith(prefix.TASKLET_IN) and not node.name.startswith( + prefix.TASKLET_OUT + ): + return + + # Keep track of writes + if node.is_target: + writes.add(_remove_prefix(node.name)) + return + + # Make sure we don't read uninitialized memory + if ( + node.name.startswith(prefix.TASKLET_OUT) + and _remove_prefix(node.name) not in writes + ): + raise ValueError(f"Reading undefined '{node.name}'. DaCe IR error.") + + if _remove_prefix(node.name) in writes and not node.name.startswith( + prefix.TASKLET_OUT + ): + raise ValueError( + f"Read after write of '{node.name}' not connected to out connector. DaCe IR error." + ) + + def visit_AssignStmt(self, node: AssignStmt, writes: Set[eve.SymbolRef]) -> None: + # Visiting order matters because `writes` must not contain the symbols from the left visit + self.visit(node.right, writes=writes) + self.visit(node.left, writes=writes) + + writes: set[str] = set() + checker = ReadAfterWriteChecker() + for statement in statements: + checker.visit(statement, writes=writes) + class DomainMap(ComputationNode, IterationNode): index_ranges: List[Range] @@ -852,17 +953,38 @@ class ComputationState(IterationNode): computations: List[Union[Tasklet, DomainMap]] -class DomainLoop(IterationNode, ComputationNode): +class DomainLoop(ComputationNode, IterationNode): axis: Axis index_range: Range - loop_states: List[Union[ComputationState, DomainLoop]] + loop_states: list[ComputationState | Condition | DomainLoop | WhileLoop] + + +class WhileLoop(eve.Node): + condition: Tasklet + body: list[ComputationState | Condition | WhileLoop] + + @datamodels.validator("condition") + def condition_has_boolean_expression( + self, attribute: datamodels.Attribute, tasklet: Tasklet + ) -> None: + assert isinstance(tasklet, Tasklet) + assert len(tasklet.stmts) == 1 + assert isinstance(tasklet.stmts[0], AssignStmt) + assert isinstance(tasklet.stmts[0].left, ScalarAccess) + if tasklet.stmts[0].left.original_name is None: + raise ValueError( + f"Original node name not found for {tasklet.stmts[0].left.name}. DaCe IR error." + ) + assert isinstance(tasklet.stmts[0].right, Expr) + if tasklet.stmts[0].right.dtype != common.DataType.BOOL: + raise ValueError("Condition must be a boolean expression.") class NestedSDFG(ComputationNode, eve.SymbolTableTrait): label: eve.Coerced[eve.SymbolRef] field_decls: List[FieldDecl] symbol_decls: List[SymbolDecl] - states: List[Union[DomainLoop, ComputationState]] + states: list[ComputationState | Condition | DomainLoop | WhileLoop] # There are circular type references with string placeholders. These statements let datamodels resolve those. diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index a8a3a3cb54..f05a89c5fa 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -10,7 +10,8 @@ import dataclasses from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Union, cast +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union, cast import dace import dace.data @@ -18,11 +19,11 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc import common, oir +from gt4py.cartesian.gtc import common, oir, utils as gtc_utils from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion_specification import Loop, Map, Sections, Stages from gt4py.cartesian.gtc.dace.utils import ( - compute_dcir_access_infos, + compute_tasklet_access_infos, flatten_list, get_tasklet_symbol, make_dace_subset, @@ -39,54 +40,96 @@ from gt4py.cartesian.gtc.dace.nodes import StencilComputation -def _access_iter(node: oir.HorizontalExecution, get_outputs: bool): - if get_outputs: - iterator = filter( - lambda node: isinstance(node, oir.FieldAccess), - node.walk_values().if_isinstance(oir.AssignStmt).getattr("left"), +class AccessType(Enum): + READ = 0 + WRITE = 1 + + +def _field_access_iterator( + code_block: oir.CodeBlock | oir.MaskStmt | oir.While, access_type: AccessType +): + if access_type == AccessType.WRITE: + return ( + code_block.walk_values() + .if_isinstance(oir.AssignStmt) + .getattr("left") + .if_isinstance(oir.FieldAccess) ) - else: - def _iterator(): - for n in node.walk_values(): - if isinstance(n, oir.AssignStmt): - yield from n.right.walk_values().if_isinstance(oir.FieldAccess) - elif isinstance(n, oir.While): - yield from n.cond.walk_values().if_isinstance(oir.FieldAccess) - elif isinstance(n, oir.MaskStmt): - yield from n.mask.walk_values().if_isinstance(oir.FieldAccess) + def read_access_iterator(): + for node in code_block.walk_values(): + if isinstance(node, oir.AssignStmt): + yield from node.right.walk_values().if_isinstance(oir.FieldAccess) + elif isinstance(node, oir.While): + yield from node.cond.walk_values().if_isinstance(oir.FieldAccess) + elif isinstance(node, oir.MaskStmt): + yield from node.mask.walk_values().if_isinstance(oir.FieldAccess) - iterator = _iterator() + return read_access_iterator() + + +def _mapped_access_iterator( + node: oir.CodeBlock | oir.MaskStmt | oir.While, access_type: AccessType +): + iterator = _field_access_iterator(node, access_type) + write_access = access_type == AccessType.WRITE yield from ( eve.utils.xiter(iterator).map( lambda acc: ( acc.name, acc.offset, - get_tasklet_symbol(acc.name, acc.offset, is_target=get_outputs), + get_tasklet_symbol(acc.name, offset=acc.offset, is_target=write_access), ) ) ).unique(key=lambda x: x[2]) def _get_tasklet_inout_memlets( - node: oir.HorizontalExecution, + node: oir.CodeBlock | oir.MaskStmt | oir.While, + access_type: AccessType, *, - get_outputs: bool, global_ctx: DaCeIRBuilder.GlobalContext, - **kwargs, -): - access_infos = compute_dcir_access_infos( + horizontal_extent, + k_interval, + grid_subset: dcir.GridSubset, + dcir_statements: list[dcir.Stmt], +) -> list[dcir.Memlet]: + access_infos = compute_tasklet_access_infos( node, - block_extents=global_ctx.library_node.get_extents, - oir_decls=global_ctx.library_node.declarations, - collect_read=not get_outputs, - collect_write=get_outputs, - **kwargs, + collect_read=access_type == AccessType.READ, + collect_write=access_type == AccessType.WRITE, + declarations=global_ctx.library_node.declarations, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=grid_subset, ) - res = list() - for name, offset, tasklet_symbol in _access_iter(node, get_outputs=get_outputs): + names = [ + access.name + for statement in dcir_statements + for access in statement.walk_values().if_isinstance(dcir.ScalarAccess, dcir.IndexAccess) + ] + + memlets: list[dcir.Memlet] = [] + for name, offset, tasklet_symbol in _mapped_access_iterator(node, access_type): + # Avoid adding extra inputs/outputs to the tasklet + if name not in access_infos: + continue + + # Find `tasklet_symbol` in dcir_statements because we can't know (from the oir statements) + # where the tasklet boundaries will be. Consider + # + # with computation(PARALLEL), interval(...): + # statement1 + # if condition: + # statement2 + # statement3 + # + # statements 1 and 3 will end up in the same CodeBlock but aren't in the same tasklet. + if tasklet_symbol not in names: + continue + access_info = access_infos[name] if not access_info.variable_offset_axes: offset_dict = offset.to_dict() @@ -95,26 +138,27 @@ def _get_tasklet_inout_memlets( axis, extent=(offset_dict[axis.lower()], offset_dict[axis.lower()]) ) - memlet = dcir.Memlet( - field=name, - connector=tasklet_symbol, - access_info=access_info, - is_read=not get_outputs, - is_write=get_outputs, + memlets.append( + dcir.Memlet( + field=name, + connector=tasklet_symbol, + access_info=access_info, + is_read=access_type == AccessType.READ, + is_write=access_type == AccessType.WRITE, + ) ) - res.append(memlet) - return res + return memlets -def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval): - def all_statements_in_region(scope_nodes): +def _all_stmts_same_region(scope_nodes, axis: dcir.Axis, interval: Any) -> bool: + def all_statements_in_region(scope_nodes: List[eve.Node]) -> bool: return all( isinstance(stmt, dcir.HorizontalRestriction) for tasklet in eve.walk_values(scope_nodes).if_isinstance(dcir.Tasklet) for stmt in tasklet.stmts ) - def all_regions_same(scope_nodes): + def all_regions_same(scope_nodes: List[eve.Node]) -> bool: return ( len( set( @@ -179,11 +223,11 @@ def _get_dcir_decl( oir_decl: oir.Decl = self.library_node.declarations[field] assert isinstance(oir_decl, oir.FieldDecl) dace_array = self.arrays[field] - for s in dace_array.strides: - for sym in dace.symbolic.symlist(s).values(): - symbol_collector.add_symbol(str(sym)) - for sym in access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym) + for stride in dace_array.strides: + for symbol in dace.symbolic.symlist(stride).values(): + symbol_collector.add_symbol(str(symbol)) + for symbol in access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(symbol) return dcir.FieldDecl( name=field, @@ -236,11 +280,7 @@ def push_expansion_item(self, item: Union[Map, Loop]) -> DaCeIRBuilder.Iteration if not isinstance(item, (Map, Loop)): raise ValueError - if isinstance(item, Map): - iterations = item.iterations - else: - iterations = [item] - + iterations = item.iterations if isinstance(item, Map) else [item] grid_subset = self.grid_subset for it in iterations: axis = it.axis @@ -267,13 +307,13 @@ def pop(self) -> DaCeIRBuilder.IterationContext: class SymbolCollector: symbol_decls: Dict[str, dcir.SymbolDecl] = dataclasses.field(default_factory=dict) - def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32): + def add_symbol(self, name: str, dtype: common.DataType = common.DataType.INT32) -> None: if name not in self.symbol_decls: self.symbol_decls[name] = dcir.SymbolDecl(name=name, dtype=dtype) else: assert self.symbol_decls[name].dtype == dtype - def remove_symbol(self, name: eve.SymbolRef): + def remove_symbol(self, name: eve.SymbolRef) -> None: if name in self.symbol_decls: del self.symbol_decls[name] @@ -304,11 +344,20 @@ def visit_HorizontalRestriction( symbol_collector.add_symbol(axis.iteration_symbol()) if bound.level == common.LevelMarker.END: symbol_collector.add_symbol(axis.domain_symbol()) + return dcir.HorizontalRestriction( - mask=node.mask, body=self.visit(node.body, symbol_collector=symbol_collector, **kwargs) + mask=node.mask, + body=self.visit( + node.body, + symbol_collector=symbol_collector, + inside_horizontal_region=True, + **kwargs, + ), ) - def visit_VariableKOffset(self, node: oir.VariableKOffset, **kwargs): + def visit_VariableKOffset( + self, node: oir.VariableKOffset, **kwargs: Any + ) -> dcir.VariableKOffset: return dcir.VariableKOffset(k=self.visit(node.k, **kwargs)) def visit_LocalScalar(self, node: oir.LocalScalar, **kwargs: Any) -> dcir.LocalScalarDecl: @@ -319,168 +368,351 @@ def visit_FieldAccess( node: oir.FieldAccess, *, is_target: bool, - targets: Set[eve.SymbolRef], - var_offset_fields: Set[eve.SymbolRef], - K_write_with_offset: Set[eve.SymbolRef], + targets: list[oir.FieldAccess | oir.ScalarAccess], + var_offset_fields: set[eve.SymbolRef], + K_write_with_offset: set[eve.SymbolRef], **kwargs: Any, - ) -> Union[dcir.IndexAccess, dcir.ScalarAccess]: + ) -> dcir.IndexAccess | dcir.ScalarAccess: """Generate the relevant accessor to match the memlet that was previously setup. - When a Field is written in K, we force the usage of the OUT memlet throughout the stencil - to make sure all side effects are being properly resolved. Frontend checks ensure that no - parallel code issues sips here. + Args: + is_target (bool): true if we write to this FieldAccess """ - res: Union[dcir.IndexAccess, dcir.ScalarAccess] + # Distinguish between writing to a variable and reading a previously written variable. + # In the latter case (read after write), we need to read from the "gtOUT__" symbol. + is_write = is_target + is_target = is_target or ( + # read after write (within a code block) + any( + isinstance(t, oir.FieldAccess) and t.name == node.name and t.offset == node.offset + for t in targets + ) + ) + name = get_tasklet_symbol(node.name, offset=node.offset, is_target=is_target) + + access_node: dcir.IndexAccess | dcir.ScalarAccess if node.name in var_offset_fields.union(K_write_with_offset): - # If write in K, we consider the variable to always be a target - is_target = is_target or node.name in targets or node.name in K_write_with_offset - name = get_tasklet_symbol(node.name, node.offset, is_target=is_target) - res = dcir.IndexAccess( + access_node = dcir.IndexAccess( name=name, + is_target=is_target, offset=self.visit( node.offset, - is_target=is_target, + is_target=False, + targets=targets, + var_offset_fields=var_offset_fields, + K_write_with_offset=K_write_with_offset, + **kwargs, + ), + data_index=self.visit( + node.data_index, + is_target=False, targets=targets, var_offset_fields=var_offset_fields, K_write_with_offset=K_write_with_offset, **kwargs, ), - data_index=node.data_index, dtype=node.dtype, ) - else: - is_target = is_target or ( - node.name in targets and node.offset == common.CartesianOffset.zero() + elif node.data_index: + access_node = dcir.IndexAccess( + name=name, + offset=None, + is_target=is_target, + data_index=self.visit( + node.data_index, + is_target=False, + targets=targets, + var_offset_fields=var_offset_fields, + K_write_with_offset=K_write_with_offset, + **kwargs, + ), + dtype=node.dtype, ) - name = get_tasklet_symbol(node.name, node.offset, is_target=is_target) - if node.data_index: - res = dcir.IndexAccess( - name=name, offset=None, data_index=node.data_index, dtype=node.dtype - ) - else: - res = dcir.ScalarAccess(name=name, dtype=node.dtype) - if is_target: - targets.add(node.name) - return res + else: + access_node = dcir.ScalarAccess(name=name, dtype=node.dtype, is_target=is_write) + + if is_write and not any( + isinstance(t, oir.FieldAccess) and t.name == node.name and t.offset == node.offset + for t in targets + ): + targets.append(node) + return access_node def visit_ScalarAccess( self, node: oir.ScalarAccess, *, + is_target: bool, + targets: list[oir.FieldAccess | oir.ScalarAccess], global_ctx: DaCeIRBuilder.GlobalContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs: Any, + **_: Any, ) -> dcir.ScalarAccess: if node.name in global_ctx.library_node.declarations: + # Handle stencil parameters differently because they are always available symbol_collector.add_symbol(node.name, dtype=node.dtype) - return dcir.ScalarAccess(name=node.name, dtype=node.dtype) - - def visit_AssignStmt(self, node: oir.AssignStmt, *, targets, **kwargs: Any) -> dcir.AssignStmt: - # the visiting order matters here, since targets must not contain the target symbols from the left visit - right = self.visit(node.right, is_target=False, targets=targets, **kwargs) - left = self.visit(node.left, is_target=True, targets=targets, **kwargs) - return dcir.AssignStmt(left=left, right=right) - - def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs: Any) -> dcir.MaskStmt: - return dcir.MaskStmt( - mask=self.visit(node.mask, is_target=False, **kwargs), - body=self.visit(node.body, **kwargs), - ) - - def visit_While(self, node: oir.While, **kwargs: Any) -> dcir.While: - return dcir.While( - cond=self.visit(node.cond, is_target=False, **kwargs), - body=self.visit(node.body, **kwargs), + return dcir.ScalarAccess(name=node.name, dtype=node.dtype, is_target=is_target) + + # Distinguish between writing to a variable and reading a previously written variable. + # In the latter case (read after write), we need to read from the "gtOUT__" symbol. + is_write = is_target + is_target = is_target or ( + # read after write (within a code block) + any(isinstance(t, oir.ScalarAccess) and t.name == node.name for t in targets) ) - def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> dcir.Cast: - return dcir.Cast(dtype=node.dtype, expr=self.visit(node.expr, **kwargs)) + if is_write and not any( + isinstance(t, oir.ScalarAccess) and t.name == node.name for t in targets + ): + targets.append(node) - def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> dcir.NativeFuncCall: - return dcir.NativeFuncCall( - func=node.func, args=self.visit(node.args, **kwargs), dtype=node.dtype + # Rename local scalars inside tasklets such that we can pass them from one state + # to another (same as we do for index access). + tasklet_name = get_tasklet_symbol(node.name, is_target=is_target) + return dcir.ScalarAccess( + name=tasklet_name, original_name=node.name, dtype=node.dtype, is_target=is_write ) - def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> dcir.TernaryOp: - return dcir.TernaryOp( - cond=self.visit(node.cond, **kwargs), - true_expr=self.visit(node.true_expr, **kwargs), - false_expr=self.visit(node.false_expr, **kwargs), - dtype=node.dtype, - ) + def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs: Any) -> dcir.AssignStmt: + # Visiting order matters because targets must not contain the target symbols from the left visit + right = self.visit(node.right, is_target=False, **kwargs) + left = self.visit(node.left, is_target=True, **kwargs) + return dcir.AssignStmt(left=left, right=right, loc=node.loc) - def visit_HorizontalExecution( + def _condition_tasklet( self, - node: oir.HorizontalExecution, + node: oir.MaskStmt | oir.While, *, global_ctx: DaCeIRBuilder.GlobalContext, - iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - loop_order, + horizontal_extent, k_interval, - **kwargs, - ): - # skip type checking due to https://github.com/python/mypy/issues/5485 - extent = global_ctx.library_node.get_extents(node) # type: ignore - decls = [self.visit(decl, **kwargs) for decl in node.declarations] - targets: Set[str] = set() - stmts = [ - self.visit( - stmt, + iteration_ctx: DaCeIRBuilder.IterationContext, + targets: list[oir.FieldAccess | oir.ScalarAccess], + **kwargs: Any, + ) -> dcir.Tasklet: + condition_expression = node.mask if isinstance(node, oir.MaskStmt) else node.cond + prefix = "if" if isinstance(node, oir.MaskStmt) else "while" + tmp_name = f"{prefix}_expression_{id(node)}" + + # Reset the set of targets (used for detecting read after write inside a tasklet) + targets.clear() + + statement = dcir.AssignStmt( + right=self.visit( + condition_expression, + is_target=False, targets=targets, global_ctx=global_ctx, symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + iteration_ctx=iteration_ctx, **kwargs, - ) - for stmt in node.body - ] + ), + left=dcir.ScalarAccess( + name=get_tasklet_symbol(tmp_name, is_target=True), + original_name=tmp_name, + dtype=common.DataType.BOOL, + loc=node.loc, + is_target=True, + ), + loc=node.loc, + ) - stages_idx = next( - idx - for idx, item in enumerate(global_ctx.library_node.expansion_specification) - if isinstance(item, Stages) + read_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( + node, + AccessType.READ, + global_ctx=global_ctx, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=iteration_ctx.grid_subset, + dcir_statements=[statement], ) - expansion_items = global_ctx.library_node.expansion_specification[stages_idx + 1 :] - iteration_ctx = iteration_ctx.push_axes_extents( - {k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent)} + tasklet = dcir.Tasklet( + label=f"eval_{prefix}_{id(node)}", + stmts=[statement], + read_memlets=read_memlets, + write_memlets=[], + ) + # See notes inside the function + self._fix_memlet_array_access( + tasklet=tasklet, + memlets=read_memlets, + global_context=global_ctx, + symbol_collector=symbol_collector, + targets=targets, + **kwargs, ) - iteration_ctx = iteration_ctx.push_expansion_items(expansion_items) - assert iteration_ctx.grid_subset == dcir.GridSubset.single_gridpoint() + return tasklet + + def visit_MaskStmt( + self, + node: oir.MaskStmt, + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + horizontal_extent, + k_interval, + targets: list[oir.FieldAccess | oir.ScalarAccess], + inside_horizontal_region: bool = False, + **kwargs: Any, + ) -> dcir.MaskStmt | dcir.Condition: + if inside_horizontal_region: + # inside horizontal regions, we use old-style mask statements that + # might translate to if statements inside the tasklet + return dcir.MaskStmt( + mask=self.visit( + node.mask, + is_target=False, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + inside_horizontal_region=inside_horizontal_region, + targets=targets, + **kwargs, + ), + body=self.visit( + node.body, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + inside_horizontal_region=inside_horizontal_region, + targets=targets, + **kwargs, + ), + ) - read_memlets = _get_tasklet_inout_memlets( + tasklet = self._condition_tasklet( node, - get_outputs=False, global_ctx=global_ctx, - grid_subset=iteration_ctx.grid_subset, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, k_interval=k_interval, + iteration_ctx=iteration_ctx, + targets=targets, + **kwargs, + ) + code_block = self.visit( + oir.CodeBlock(body=node.body, loc=node.loc, label=f"condition_{id(node)}"), + global_ctx=global_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + iteration_ctx=iteration_ctx, + targets=targets, + **kwargs, ) + targets.clear() + return dcir.Condition(condition=tasklet, true_states=gtc_utils.listify(code_block)) - write_memlets = _get_tasklet_inout_memlets( + def visit_While( + self, + node: oir.While, + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + horizontal_extent, + k_interval, + targets: list[oir.FieldAccess | oir.ScalarAccess], + inside_horizontal_region: bool = False, + **kwargs: Any, + ) -> dcir.While | dcir.WhileLoop: + if inside_horizontal_region: + # inside horizontal regions, we use old-style while statements that + # might translate to while statements inside the tasklet + return dcir.While( + cond=self.visit( + node.cond, + is_target=False, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + inside_horizontal_region=inside_horizontal_region, + targets=targets, + **kwargs, + ), + body=self.visit( + node.body, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + inside_horizontal_region=inside_horizontal_region, + targets=targets, + **kwargs, + ), + ) + + tasklet = self._condition_tasklet( node, - get_outputs=True, global_ctx=global_ctx, - grid_subset=iteration_ctx.grid_subset, + symbol_collector=symbol_collector, + iteration_ctx=iteration_ctx, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + targets=targets, + **kwargs, + ) + code_block = self.visit( + oir.CodeBlock(body=node.body, loc=node.loc, label=f"while_{id(node)}"), + global_ctx=global_ctx, + symbol_collector=symbol_collector, + iteration_ctx=iteration_ctx, + horizontal_extent=horizontal_extent, k_interval=k_interval, + targets=targets, + **kwargs, + ) + targets.clear() + return dcir.WhileLoop(condition=tasklet, body=code_block) + + def visit_Cast(self, node: oir.Cast, **kwargs: Any) -> dcir.Cast: + return dcir.Cast(dtype=node.dtype, expr=self.visit(node.expr, **kwargs)) + + def visit_NativeFuncCall(self, node: oir.NativeFuncCall, **kwargs: Any) -> dcir.NativeFuncCall: + return dcir.NativeFuncCall( + func=node.func, args=self.visit(node.args, **kwargs), dtype=node.dtype ) - dcir_node = dcir.Tasklet( - decls=decls, stmts=stmts, read_memlets=read_memlets, write_memlets=write_memlets + def visit_TernaryOp(self, node: oir.TernaryOp, **kwargs: Any) -> dcir.TernaryOp: + return dcir.TernaryOp( + cond=self.visit(node.cond, **kwargs), + true_expr=self.visit(node.true_expr, **kwargs), + false_expr=self.visit(node.false_expr, **kwargs), + dtype=node.dtype, ) - for memlet in [*read_memlets, *write_memlets]: + def _fix_memlet_array_access( + self, + *, + tasklet: dcir.Tasklet, + memlets: list[dcir.Memlet], + global_context: DaCeIRBuilder.GlobalContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + **kwargs: Any, + ) -> None: + for memlet in memlets: """ This loop handles the special case of a tasklet performing array access. The memlet should pass the full array shape (no tiling) and the tasklet expression for array access should use all explicit indexes. """ - array_ndims = len(global_ctx.arrays[memlet.field].shape) - field_decl = global_ctx.library_node.field_decls[memlet.field] + array_ndims = len(global_context.arrays[memlet.field].shape) + field_decl = global_context.library_node.field_decls[memlet.field] # calculate array subset on original memlet memlet_subset = make_dace_subset( - global_ctx.library_node.access_infos[memlet.field], + global_context.library_node.access_infos[memlet.field], memlet.access_info, field_decl.data_dims, ) @@ -492,22 +724,171 @@ def visit_HorizontalExecution( ] if len(memlet_data_index) < array_ndims: reshape_memlet = False - for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess): + for access_node in tasklet.walk_values().if_isinstance(dcir.IndexAccess): if access_node.data_index and access_node.name == memlet.connector: - access_node.data_index = memlet_data_index + access_node.data_index - assert len(access_node.data_index) == array_ndims + # Order matters! + # Resolve first the cartesian dimensions packed in memlet_data_index + access_node.explicit_indices = [] + for data_index in memlet_data_index: + access_node.explicit_indices.append( + self.visit( + data_index, + symbol_collector=symbol_collector, + global_ctx=global_context, + **kwargs, + ) + ) + # Separate between case where K is offset or absolute and + # where it's a regular offset (should be dealt with the above memlet_data_index) + if access_node.offset: + access_node.explicit_indices.append(access_node.offset) + # Add any remaining data dimensions indexing + for data_index in access_node.data_index: + access_node.explicit_indices.append( + self.visit( + data_index, + symbol_collector=symbol_collector, + global_ctx=global_context, + is_target=False, + **kwargs, + ) + ) + assert len(access_node.explicit_indices) == array_ndims reshape_memlet = True if reshape_memlet: # ensure that memlet symbols used for array indexing are defined in context for sym in memlet.access_info.grid_subset.free_symbols: symbol_collector.add_symbol(sym) # set full shape on memlet - memlet.access_info = global_ctx.library_node.access_infos[memlet.field] + memlet.access_info = global_context.library_node.access_infos[memlet.field] + + def visit_CodeBlock( + self, + node: oir.CodeBlock, + *, + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + horizontal_extent, + k_interval, + targets: list[oir.FieldAccess | oir.ScalarAccess], + **kwargs: Any, + ): + # Reset the set of targets (used for detecting read after write inside a tasklet) + targets.clear() + statements = [ + self.visit( + statement, + targets=targets, + global_ctx=global_ctx, + symbol_collector=symbol_collector, + iteration_ctx=iteration_ctx, + k_interval=k_interval, + horizontal_extent=horizontal_extent, + **kwargs, + ) + for statement in node.body + ] + + # Gather all statements that aren't control flow (e.g. everything except Condition and WhileLoop), + # put them in a tasklet, and call "to_state" on it. + # Then, return a new list with types that are either ComputationState, Condition, or WhileLoop. + dace_nodes: list[dcir.ComputationState | dcir.Condition | dcir.WhileLoop] = [] + current_block: list[dcir.Stmt] = [] + for index, statement in enumerate(statements): + is_control_flow = isinstance(statement, (dcir.Condition, dcir.WhileLoop)) + if not is_control_flow: + current_block.append(statement) + + last_statement = index == len(statements) - 1 + if (is_control_flow or last_statement) and len(current_block) > 0: + read_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( + node, + AccessType.READ, + global_ctx=global_ctx, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=iteration_ctx.grid_subset, + dcir_statements=current_block, + ) + write_memlets: list[dcir.Memlet] = _get_tasklet_inout_memlets( + node, + AccessType.WRITE, + global_ctx=global_ctx, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=iteration_ctx.grid_subset, + dcir_statements=current_block, + ) + tasklet = dcir.Tasklet( + label=node.label, + stmts=current_block, + read_memlets=read_memlets, + write_memlets=write_memlets, + ) + # See notes inside the function + self._fix_memlet_array_access( + tasklet=tasklet, + memlets=[*read_memlets, *write_memlets], + global_context=global_ctx, + symbol_collector=symbol_collector, + targets=targets, + **kwargs, + ) + + dace_nodes.append(*self.to_state(tasklet, grid_subset=iteration_ctx.grid_subset)) + + # reset block scope + current_block = [] + + # append control flow statement after new tasklet (if applicable) + if is_control_flow: + dace_nodes.append(statement) + + return dace_nodes + + def visit_HorizontalExecution( + self, + node: oir.HorizontalExecution, + *, + global_ctx: DaCeIRBuilder.GlobalContext, + iteration_ctx: DaCeIRBuilder.IterationContext, + symbol_collector: DaCeIRBuilder.SymbolCollector, + k_interval, + **kwargs: Any, + ): + extent = global_ctx.library_node.get_extents(node) + + stages_idx = next( + idx + for idx, item in enumerate(global_ctx.library_node.expansion_specification) + if isinstance(item, Stages) + ) + expansion_items = global_ctx.library_node.expansion_specification[stages_idx + 1 :] + + iteration_ctx = iteration_ctx.push_axes_extents( + {k: v for k, v in zip(dcir.Axis.dims_horizontal(), extent)} + ) + iteration_ctx = iteration_ctx.push_expansion_items(expansion_items) + assert iteration_ctx.grid_subset == dcir.GridSubset.single_gridpoint() + + code_block = oir.CodeBlock(body=node.body, loc=node.loc, label=f"he_{id(node)}") + targets: list[oir.FieldAccess | oir.ScalarAccess] = [] + dcir_nodes = self.visit( + code_block, + global_ctx=global_ctx, + iteration_ctx=iteration_ctx, + symbol_collector=symbol_collector, + horizontal_extent=global_ctx.library_node.get_extents(node), + k_interval=k_interval, + targets=targets, + **kwargs, + ) for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop() - dcir_node = self._process_iteration_item( - [dcir_node], + dcir_nodes = self._process_iteration_item( + dcir_nodes, item, global_ctx=global_ctx, iteration_ctx=iteration_ctx, @@ -516,13 +897,13 @@ def visit_HorizontalExecution( ) # pop stages context (pushed with push_grid_subset) iteration_ctx.pop() - return dcir_node + + return dcir_nodes def visit_VerticalLoopSection( self, node: oir.VerticalLoopSection, *, - loop_order, iteration_ctx: DaCeIRBuilder.IterationContext, global_ctx: DaCeIRBuilder.GlobalContext, symbol_collector: DaCeIRBuilder.SymbolCollector, @@ -546,7 +927,6 @@ def visit_VerticalLoopSection( iteration_ctx=iteration_ctx, global_ctx=global_ctx, symbol_collector=symbol_collector, - loop_order=loop_order, k_interval=node.interval, **kwargs, ) @@ -581,7 +961,10 @@ def to_dataflow( nodes = flatten_list(nodes) if all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): return nodes - elif not all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): + if not all( + isinstance(n, (dcir.ComputationState, dcir.Condition, dcir.DomainLoop, dcir.WhileLoop)) + for n in nodes + ): raise ValueError("Can't mix dataflow and state nodes on same level.") read_memlets, write_memlets, field_memlets = union_inout_memlets(nodes) @@ -598,6 +981,7 @@ def to_dataflow( write_memlets = [ memlet.remove_read() for memlet in field_memlets if memlet.field in write_fields ] + return [ dcir.NestedSDFG( label=global_ctx.library_node.label, @@ -613,12 +997,15 @@ def to_dataflow( def to_state(self, nodes, *, grid_subset: dcir.GridSubset): nodes = flatten_list(nodes) - if all(isinstance(n, (dcir.ComputationState, dcir.DomainLoop)) for n in nodes): + if all( + isinstance(n, (dcir.ComputationState, dcir.Condition, dcir.DomainLoop, dcir.WhileLoop)) + for n in nodes + ): return nodes - elif all(isinstance(n, (dcir.NestedSDFG, dcir.DomainMap, dcir.Tasklet)) for n in nodes): + if all(isinstance(n, (dcir.DomainMap, dcir.NestedSDFG, dcir.Tasklet)) for n in nodes): return [dcir.ComputationState(computations=nodes, grid_subset=grid_subset)] - else: - raise ValueError("Can't mix dataflow and state nodes on same level.") + + raise ValueError("Can't mix dataflow and state nodes on same level.") def _process_map_item( self, @@ -628,8 +1015,8 @@ def _process_map_item( global_ctx: DaCeIRBuilder.GlobalContext, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs, - ): + **kwargs: Any, + ) -> List[dcir.DomainMap]: grid_subset = iteration_ctx.grid_subset read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes)) scope_nodes = self.to_dataflow( @@ -723,11 +1110,10 @@ def _process_loop_item( scope_nodes, item: Loop, *, - global_ctx, iteration_ctx: DaCeIRBuilder.IterationContext, symbol_collector: DaCeIRBuilder.SymbolCollector, - **kwargs, - ): + **kwargs: Any, + ) -> List[dcir.DomainLoop]: grid_subset = union_node_grid_subsets(list(scope_nodes)) read_memlets, write_memlets, _ = union_inout_memlets(list(scope_nodes)) scope_nodes = self.to_state(scope_nodes, grid_subset=grid_subset) @@ -793,14 +1179,14 @@ def _process_loop_item( def _process_iteration_item(self, scope, item, **kwargs): if isinstance(item, Map): return self._process_map_item(scope, item, **kwargs) - elif isinstance(item, Loop): + if isinstance(item, Loop): return self._process_loop_item(scope, item, **kwargs) - else: - raise ValueError("Invalid expansion specification set.") + + raise ValueError("Invalid expansion specification set.") def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs - ): + self, node: oir.VerticalLoop, *, global_ctx: DaCeIRBuilder.GlobalContext, **kwargs: Any + ) -> dcir.NestedSDFG: overall_extent = Extent.zeros(2) for he in node.walk_values().if_isinstance(oir.HorizontalExecution): overall_extent = overall_extent.union(global_ctx.library_node.get_extents(he)) @@ -840,7 +1226,6 @@ def visit_VerticalLoop( sections = flatten_list( self.generic_visit( node.sections, - loop_order=node.loop_order, global_ctx=global_ctx, iteration_ctx=iteration_ctx, symbol_collector=symbol_collector, @@ -870,6 +1255,7 @@ def visit_VerticalLoop( read_fields = set(memlet.field for memlet in read_memlets) write_fields = set(memlet.field for memlet in write_memlets) + return dcir.NestedSDFG( label=global_ctx.library_node.label, states=self.to_state(computations, grid_subset=iteration_ctx.grid_subset), diff --git a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py index 055bf64015..06ef69dcf4 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/expansion.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/expansion.py @@ -17,7 +17,7 @@ import dace.subsets import sympy -from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir, prefix from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder @@ -74,15 +74,14 @@ def _fix_context( * change connector names to match inner array name (before expansion prefixed to satisfy uniqueness) * change in- and out-edges' subsets so that they have the same shape as the corresponding array inside * determine the domain size based on edges to StencilComputation - """ # change connector names for in_edge in parent_state.in_edges(node): - assert in_edge.dst_conn.startswith("__in_") - in_edge.dst_conn = in_edge.dst_conn[len("__in_") :] + assert in_edge.dst_conn.startswith(prefix.CONNECTOR_IN) + in_edge.dst_conn = in_edge.dst_conn.removeprefix(prefix.CONNECTOR_IN) for out_edge in parent_state.out_edges(node): - assert out_edge.src_conn.startswith("__out_") - out_edge.src_conn = out_edge.src_conn[len("__out_") :] + assert out_edge.src_conn.startswith(prefix.CONNECTOR_OUT) + out_edge.src_conn = out_edge.src_conn.removeprefix(prefix.CONNECTOR_OUT) # union input and output subsets subsets = {} @@ -120,15 +119,27 @@ def _fix_context( if key in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[key] + for edge in parent_state.in_edges(node): + if edge.dst_conn not in nsdfg.in_connectors: + # Drop connection if connector is not found in the expansion of the library node + parent_state.remove_edge(edge) + if parent_state.in_degree(edge.src) + parent_state.out_degree(edge.src) == 0: + # Remove node if it is now isolated + parent_state.remove_node(edge.src) + @staticmethod def _get_parent_arrays( node: StencilComputation, parent_state: dace.SDFGState, parent_sdfg: dace.SDFG ) -> Dict[str, dace.data.Data]: parent_arrays: Dict[str, dace.data.Data] = {} for edge in (e for e in parent_state.in_edges(node) if e.dst_conn is not None): - parent_arrays[edge.dst_conn[len("__in_") :]] = parent_sdfg.arrays[edge.data.data] + parent_arrays[edge.dst_conn.removeprefix(prefix.CONNECTOR_IN)] = parent_sdfg.arrays[ + edge.data.data + ] for edge in (e for e in parent_state.out_edges(node) if e.src_conn is not None): - parent_arrays[edge.src_conn[len("__out_") :]] = parent_sdfg.arrays[edge.data.data] + parent_arrays[edge.src_conn.removeprefix(prefix.CONNECTOR_OUT)] = parent_sdfg.arrays[ + edge.data.data + ] return parent_arrays @staticmethod diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 9d64464377..c199891c13 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -18,11 +18,14 @@ import dace.subsets from gt4py import eve -from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir, prefix from gt4py.cartesian.gtc.dace.expansion.tasklet_codegen import TaskletCodegen -from gt4py.cartesian.gtc.dace.expansion.utils import get_dace_debuginfo from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass -from gt4py.cartesian.gtc.dace.utils import make_dace_subset +from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo, make_dace_subset + + +def node_name_from_connector(connector: str) -> str: + return connector.removeprefix(prefix.TASKLET_OUT).removeprefix(prefix.TASKLET_IN) class StencilComputationSDFGBuilder(eve.VisitorWithSymbolTableTrait): @@ -37,18 +40,17 @@ class SDFGContext: state: dace.SDFGState state_stack: List[dace.SDFGState] = dataclasses.field(default_factory=list) - def add_state(self): - new_state = self.sdfg.add_state() + def add_state(self, label: Optional[str] = None) -> None: + new_state = self.sdfg.add_state(label=label) for edge in self.sdfg.out_edges(self.state): self.sdfg.remove_edge(edge) self.sdfg.add_edge(new_state, edge.dst, edge.data) self.sdfg.add_edge(self.state, new_state, dace.InterstateEdge()) self.state = new_state - return self - def add_loop(self, index_range: dcir.Range): - loop_state = self.sdfg.add_state() - after_state = self.sdfg.add_state() + def add_loop(self, index_range: dcir.Range) -> None: + loop_state = self.sdfg.add_state("loop_state") + after_state = self.sdfg.add_state("loop_after") for edge in self.sdfg.out_edges(self.state): self.sdfg.remove_edge(edge) self.sdfg.add_edge(after_state, edge.dst, edge.data) @@ -76,9 +78,126 @@ def add_loop(self, index_range: dcir.Range): self.state_stack.append(after_state) self.state = loop_state - return self - def pop_loop(self): + def pop_loop(self) -> None: + self._pop_last("loop_after") + + def add_condition(self, node: dcir.Condition) -> None: + """Inserts a condition after the current self.state. + + The condition consists of an initial state connected to a guard state, which branches + to a true_state and a false_state based on the given condition. Both states then merge + into a merge_state. + + self.state is set to init_state and the other states are pushed on the stack to be + popped with `pop_condition_*()` methods. + """ + # Data model validators enforce this to exist + assert isinstance(node.condition.stmts[0], dcir.AssignStmt) + assert isinstance(node.condition.stmts[0].left, dcir.ScalarAccess) + condition_name = node.condition.stmts[0].left.original_name + + merge_state = self.sdfg.add_state("condition_after") + for edge in self.sdfg.out_edges(self.state): + self.sdfg.remove_edge(edge) + self.sdfg.add_edge(merge_state, edge.dst, edge.data) + + # Evaluate node condition + init_state = self.sdfg.add_state("condition_init") + self.sdfg.add_edge(self.state, init_state, dace.InterstateEdge()) + + # Promote condition (from init_state) to symbol + condition_state = self.sdfg.add_state("condition_guard") + self.sdfg.add_edge( + init_state, + condition_state, + dace.InterstateEdge(assignments=dict(if_condition=condition_name)), + ) + + true_state = self.sdfg.add_state("condition_true") + self.sdfg.add_edge( + condition_state, true_state, dace.InterstateEdge(condition="if_condition") + ) + self.sdfg.add_edge(true_state, merge_state, dace.InterstateEdge()) + + false_state = self.sdfg.add_state("condition_false") + self.sdfg.add_edge( + condition_state, false_state, dace.InterstateEdge(condition="not if_condition") + ) + self.sdfg.add_edge(false_state, merge_state, dace.InterstateEdge()) + + self.state_stack.append(merge_state) + self.state_stack.append(false_state) + self.state_stack.append(true_state) + self.state_stack.append(condition_state) + self.state = init_state + + def pop_condition_guard(self) -> None: + self._pop_last("condition_guard") + + def pop_condition_true(self) -> None: + self._pop_last("condition_true") + + def pop_condition_false(self) -> None: + self._pop_last("condition_false") + + def pop_condition_after(self) -> None: + self._pop_last("condition_after") + + def add_while(self, node: dcir.WhileLoop) -> None: + """Inserts a while loop after the current state.""" + # Data model validators enforce this to exist + assert isinstance(node.condition.stmts[0], dcir.AssignStmt) + assert isinstance(node.condition.stmts[0].left, dcir.ScalarAccess) + condition_name = node.condition.stmts[0].left.original_name + + after_state = self.sdfg.add_state("while_after") + for edge in self.sdfg.out_edges(self.state): + self.sdfg.remove_edge(edge) + self.sdfg.add_edge(after_state, edge.dst, edge.data) + + # Evaluate loop condition + init_state = self.sdfg.add_state("while_init") + self.sdfg.add_edge(self.state, init_state, dace.InterstateEdge()) + + # Promote condition (from init_state) to symbol + guard_state = self.sdfg.add_state("while_guard") + self.sdfg.add_edge( + init_state, + guard_state, + dace.InterstateEdge(assignments=dict(loop_condition=condition_name)), + ) + + loop_state = self.sdfg.add_state("while_loop") + self.sdfg.add_edge( + guard_state, loop_state, dace.InterstateEdge(condition="loop_condition") + ) + # Loop back to init_state to re-evaluate the loop condition + self.sdfg.add_edge(loop_state, init_state, dace.InterstateEdge()) + + # Exit the loop + self.sdfg.add_edge( + guard_state, after_state, dace.InterstateEdge(condition="not loop_condition") + ) + + self.state_stack.append(after_state) + self.state_stack.append(loop_state) + self.state_stack.append(guard_state) + self.state = init_state + + def pop_while_guard(self) -> None: + self._pop_last("while_guard") + + def pop_while_loop(self) -> None: + self._pop_last("while_loop") + + def pop_while_after(self) -> None: + self._pop_last("while_after") + + def _pop_last(self, node_label: str | None = None) -> None: + if node_label is not None: + assert self.state_stack[-1].label.startswith(node_label) + self.state = self.state_stack[-1] del self.state_stack[-1] @@ -89,7 +208,7 @@ def visit_Memlet( scope_node: dcir.ComputationNode, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, node_ctx: StencilComputationSDFGBuilder.NodeContext, - connector_prefix="", + connector_prefix: str = "", symtable: ChainMap[eve.SymbolRef, dcir.Decl], ) -> None: field_decl = symtable[node.field] @@ -132,6 +251,91 @@ def _add_empty_edges( exit_node, None, *node_ctx.output_node_and_conns[None], dace.Memlet() ) + def visit_WhileLoop( + self, + node: dcir.WhileLoop, + *, + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, + node_ctx: StencilComputationSDFGBuilder.NodeContext, + symtable: ChainMap[eve.SymbolRef, dcir.Decl], + **kwargs: Any, + ) -> None: + sdfg_ctx.add_while(node) + assert sdfg_ctx.state.label.startswith("while_init") + + read_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} + write_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} + for memlet in node.condition.read_memlets: + if memlet.field not in read_acc_and_conn: + read_acc_and_conn[memlet.field] = ( + sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + None, + ) + for memlet in node.condition.write_memlets: + if memlet.field not in write_acc_and_conn: + write_acc_and_conn[memlet.field] = ( + sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + None, + ) + eval_node_ctx = StencilComputationSDFGBuilder.NodeContext( + input_node_and_conns=read_acc_and_conn, output_node_and_conns=write_acc_and_conn + ) + self.visit( + node.condition, sdfg_ctx=sdfg_ctx, node_ctx=eval_node_ctx, symtable=symtable, **kwargs + ) + + sdfg_ctx.pop_while_guard() + sdfg_ctx.pop_while_loop() + + for state in node.body: + self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, symtable=symtable, **kwargs) + + sdfg_ctx.pop_while_after() + + def visit_Condition( + self, + node: dcir.Condition, + *, + sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, + node_ctx: StencilComputationSDFGBuilder.NodeContext, + symtable: ChainMap[eve.SymbolRef, dcir.Decl], + **kwargs: Any, + ) -> None: + sdfg_ctx.add_condition(node) + assert sdfg_ctx.state.label.startswith("condition_init") + + read_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} + write_acc_and_conn: dict[Optional[str], tuple[dace.nodes.Node, Optional[str]]] = {} + for memlet in node.condition.read_memlets: + if memlet.field not in read_acc_and_conn: + read_acc_and_conn[memlet.field] = ( + sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + None, + ) + for memlet in node.condition.write_memlets: + if memlet.field not in write_acc_and_conn: + write_acc_and_conn[memlet.field] = ( + sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + None, + ) + eval_node_ctx = StencilComputationSDFGBuilder.NodeContext( + input_node_and_conns=read_acc_and_conn, output_node_and_conns=write_acc_and_conn + ) + self.visit( + node.condition, sdfg_ctx=sdfg_ctx, node_ctx=eval_node_ctx, symtable=symtable, **kwargs + ) + + sdfg_ctx.pop_condition_guard() + sdfg_ctx.pop_condition_true() + for state in node.true_states: + self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, symtable=symtable, **kwargs) + + sdfg_ctx.pop_condition_false() + for state in node.false_states: + self.visit(state, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, symtable=symtable, **kwargs) + + sdfg_ctx.pop_condition_after() + def visit_Tasklet( self, node: dcir.Tasklet, @@ -139,24 +343,107 @@ def visit_Tasklet( sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, node_ctx: StencilComputationSDFGBuilder.NodeContext, symtable: ChainMap[eve.SymbolRef, dcir.Decl], - **kwargs, + **kwargs: Any, ) -> None: code = TaskletCodegen.apply_codegen( node, read_memlets=node.read_memlets, write_memlets=node.write_memlets, - sdfg_ctx=sdfg_ctx, symtable=symtable, + sdfg=sdfg_ctx.sdfg, ) + # We are breaking up vertical loops inside stencils in multiple Tasklets + # It might thus happen that we write a "local" scalar in one Tasklet and + # read it in another Tasklet (downstream). + # We thus create output connectors for all writes to scalar variables + # inside Tasklets. And input connectors for all scalar reads unless + # previously written in the same Tasklet. DaCe's simplify pipeline will get + # rid of any dead dataflow introduced with this general approach. + scalar_inputs: set[str] = set() + scalar_outputs: set[str] = set() + + # Gather scalar writes in this Tasklet + for access_node in node.walk_values().if_isinstance(dcir.AssignStmt): + target_name = access_node.left.name + + field_access = ( + len( + set( + memlet.connector + for memlet in [*node.write_memlets] + if memlet.connector == target_name + ) + ) + > 0 + ) + if field_access or target_name in scalar_outputs: + continue + + assert isinstance(access_node.left, dcir.ScalarAccess) + assert ( + access_node.left.original_name is not None + ), "Original name not found for '{access_nodes.left.name}'. DaCeIR error." + + original_name = access_node.left.original_name + scalar_outputs.add(target_name) + if original_name not in sdfg_ctx.sdfg.arrays: + sdfg_ctx.sdfg.add_scalar( + original_name, + dtype=data_type_to_dace_typeclass(access_node.left.dtype), + transient=True, + ) + + # Gather scalar reads in this Tasklet + for access_node in node.walk_values().if_isinstance(dcir.ScalarAccess): + read_name = access_node.name + field_access = ( + len( + set( + memlet.connector + for memlet in [*node.read_memlets, *node.write_memlets] + if memlet.connector == read_name + ) + ) + > 0 + ) + defined_symbol = any(read_name in symbol_map for symbol_map in symtable.maps) + + if ( + not field_access + and not defined_symbol + and not access_node.is_target + and read_name.startswith(prefix.TASKLET_IN) + and read_name not in scalar_inputs + ): + scalar_inputs.add(read_name) + + inputs = set(memlet.connector for memlet in node.read_memlets).union(scalar_inputs) + outputs = set(memlet.connector for memlet in node.write_memlets).union(scalar_outputs) + tasklet = sdfg_ctx.state.add_tasklet( - name=f"{sdfg_ctx.sdfg.label}_Tasklet", + name=node.label, code=code, - inputs=set(memlet.connector for memlet in node.read_memlets), - outputs=set(memlet.connector for memlet in node.write_memlets), + inputs=inputs, + outputs=outputs, debuginfo=get_dace_debuginfo(node), ) + # Add memlets for scalars access (read/write) + for connector in scalar_outputs: + original_name = node_name_from_connector(connector) + access_node = sdfg_ctx.state.add_write(original_name) + sdfg_ctx.state.add_memlet_path( + tasklet, access_node, src_conn=connector, memlet=dace.Memlet(data=original_name) + ) + for connector in scalar_inputs: + original_name = node_name_from_connector(connector) + access_node = sdfg_ctx.state.add_read(original_name) + sdfg_ctx.state.add_memlet_path( + access_node, tasklet, dst_conn=connector, memlet=dace.Memlet(data=original_name) + ) + + # Add memlets for field access (read/write) self.visit( node.read_memlets, scope_node=tasklet, @@ -173,11 +460,8 @@ def visit_Tasklet( symtable=symtable, **kwargs, ) - StencilComputationSDFGBuilder._add_empty_edges( - tasklet, tasklet, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx - ) - def visit_Range(self, node: dcir.Range, **kwargs) -> Dict[str, str]: + def visit_Range(self, node: dcir.Range, **kwargs: Any) -> Dict[str, str]: start, end = node.interval.to_dace_symbolic() return {node.var: str(dace.subsets.Range([(start, end - 1, node.stride)]))} @@ -187,7 +471,7 @@ def visit_DomainMap( *, node_ctx: StencilComputationSDFGBuilder.NodeContext, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: ndranges = { k: v @@ -206,13 +490,13 @@ def visit_DomainMap( input_node_and_conns: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} output_node_and_conns: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} for field in set(memlet.field for memlet in scope_node.read_memlets): - map_entry.add_in_connector("IN_" + field) - map_entry.add_out_connector("OUT_" + field) - input_node_and_conns[field] = (map_entry, "OUT_" + field) + map_entry.add_in_connector(f"{prefix.PASSTHROUGH_IN}{field}") + map_entry.add_out_connector(f"{prefix.PASSTHROUGH_OUT}{field}") + input_node_and_conns[field] = (map_entry, f"{prefix.PASSTHROUGH_OUT}{field}") for field in set(memlet.field for memlet in scope_node.write_memlets): - map_exit.add_in_connector("IN_" + field) - map_exit.add_out_connector("OUT_" + field) - output_node_and_conns[field] = (map_exit, "IN_" + field) + map_exit.add_in_connector(f"{prefix.PASSTHROUGH_IN}{field}") + map_exit.add_out_connector(f"{prefix.PASSTHROUGH_OUT}{field}") + output_node_and_conns[field] = (map_exit, f"{prefix.PASSTHROUGH_IN}{field}") if not input_node_and_conns: input_node_and_conns[None] = (map_entry, None) if not output_node_and_conns: @@ -228,7 +512,7 @@ def visit_DomainMap( scope_node=map_entry, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, - connector_prefix="IN_", + connector_prefix=prefix.PASSTHROUGH_IN, **kwargs, ) self.visit( @@ -236,7 +520,7 @@ def visit_DomainMap( scope_node=map_exit, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx, - connector_prefix="OUT_", + connector_prefix=prefix.PASSTHROUGH_OUT, **kwargs, ) StencilComputationSDFGBuilder._add_empty_edges( @@ -248,9 +532,9 @@ def visit_DomainLoop( node: dcir.DomainLoop, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: - sdfg_ctx = sdfg_ctx.add_loop(node.index_range) + sdfg_ctx.add_loop(node.index_range) self.visit(node.loop_states, sdfg_ctx=sdfg_ctx, **kwargs) sdfg_ctx.pop_loop() @@ -259,9 +543,16 @@ def visit_ComputationState( node: dcir.ComputationState, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: sdfg_ctx.add_state() + + # node_ctx is used to keep track of memlets per ComputationState. Conditions and WhileLoops + # will (recursively) introduce more than one compute state per vertical loop. We thus drop + # any node_ctx that is potentially passed down and instead create a new one for each + # ComputationState that we encounter. + kwargs.pop("node_ctx", None) + read_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} write_acc_and_conn: Dict[Optional[str], Tuple[dace.nodes.Node, Optional[str]]] = {} for computation in node.computations: @@ -269,13 +560,13 @@ def visit_ComputationState( for memlet in computation.read_memlets: if memlet.field not in read_acc_and_conn: read_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + sdfg_ctx.state.add_access(memlet.field), None, ) for memlet in computation.write_memlets: if memlet.field not in write_acc_and_conn: write_acc_and_conn[memlet.field] = ( - sdfg_ctx.state.add_access(memlet.field, debuginfo=dace.DebugInfo(0)), + sdfg_ctx.state.add_access(memlet.field), None, ) node_ctx = StencilComputationSDFGBuilder.NodeContext( @@ -289,7 +580,7 @@ def visit_FieldDecl( *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, non_transients: Set[eve.SymbolRef], - **kwargs, + **kwargs: Any, ) -> None: assert len(node.strides) == len(node.shape) sdfg_ctx.sdfg.add_array( @@ -299,7 +590,7 @@ def visit_FieldDecl( dtype=data_type_to_dace_typeclass(node.dtype), storage=node.storage.to_dace_storage(), transient=node.name not in non_transients, - debuginfo=dace.DebugInfo(0), + debuginfo=get_dace_debuginfo(node), ) def visit_SymbolDecl( @@ -307,7 +598,7 @@ def visit_SymbolDecl( node: dcir.SymbolDecl, *, sdfg_ctx: StencilComputationSDFGBuilder.SDFGContext, - **kwargs, + **kwargs: Any, ) -> None: if node.name not in sdfg_ctx.sdfg.symbols: sdfg_ctx.sdfg.add_symbol(node.name, stype=data_type_to_dace_typeclass(node.dtype)) @@ -319,11 +610,11 @@ def visit_NestedSDFG( sdfg_ctx: Optional[StencilComputationSDFGBuilder.SDFGContext] = None, node_ctx: Optional[StencilComputationSDFGBuilder.NodeContext] = None, symtable: ChainMap[eve.SymbolRef, Any], - **kwargs, + **kwargs: Any, ) -> dace.nodes.NestedSDFG: sdfg = dace.SDFG(node.label) inner_sdfg_ctx = StencilComputationSDFGBuilder.SDFGContext( - sdfg=sdfg, state=sdfg.add_state(is_start_state=True) + sdfg=sdfg, state=sdfg.add_state(is_start_block=True) ) self.visit( node.field_decls, @@ -335,7 +626,13 @@ def visit_NestedSDFG( symbol_mapping = {decl.name: decl.to_dace_symbol() for decl in node.symbol_decls} for computation_state in node.states: - self.visit(computation_state, sdfg_ctx=inner_sdfg_ctx, symtable=symtable, **kwargs) + self.visit( + computation_state, + sdfg_ctx=inner_sdfg_ctx, + node_ctx=node_ctx, + symtable=symtable, + **kwargs, + ) if sdfg_ctx is not None and node_ctx is not None: nsdfg = sdfg_ctx.state.add_nested_sdfg( @@ -344,7 +641,6 @@ def visit_NestedSDFG( inputs=node.input_connectors, outputs=node.output_connectors, symbol_mapping=symbol_mapping, - debuginfo=dace.DebugInfo(0), ) self.visit( node.read_memlets, @@ -365,13 +661,12 @@ def visit_NestedSDFG( StencilComputationSDFGBuilder._add_empty_edges( nsdfg, nsdfg, sdfg_ctx=sdfg_ctx, node_ctx=node_ctx ) - else: - nsdfg = dace.nodes.NestedSDFG( - label=sdfg.label, - sdfg=sdfg, - inputs={memlet.connector for memlet in node.read_memlets}, - outputs={memlet.connector for memlet in node.write_memlets}, - symbol_mapping=symbol_mapping, - ) + return nsdfg - return nsdfg + return dace.nodes.NestedSDFG( + label=sdfg.label, + sdfg=sdfg, + inputs={memlet.connector for memlet in node.read_memlets}, + outputs={memlet.connector for memlet in node.write_memlets}, + symbol_mapping=symbol_mapping, + ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index 696dc27387..29104b2a6e 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -30,21 +30,17 @@ def _visit_offset( node: Union[dcir.VariableKOffset, common.CartesianOffset], *, access_info: dcir.FieldAccessInfo, - decl: dcir.FieldDecl, - **kwargs, + **kwargs: Any, ) -> str: int_sizes: List[Optional[int]] = [] for i, axis in enumerate(access_info.axes()): memlet_shape = access_info.shape - if ( - str(memlet_shape[i]).isnumeric() - and axis not in decl.access_info.variable_offset_axes - ): + if str(memlet_shape[i]).isnumeric() and axis not in access_info.variable_offset_axes: int_sizes.append(int(memlet_shape[i])) else: int_sizes.append(None) sym_offsets = [ - dace.symbolic.pystr_to_symbolic(self.visit(off, **kwargs)) + dace.symbolic.pystr_to_symbolic(self.visit(off, access_info=access_info, **kwargs)) for off in (node.to_dict()["i"], node.to_dict()["j"], node.k) ] for axis in access_info.variable_offset_axes: @@ -60,27 +56,44 @@ def _visit_offset( res = dace.subsets.Range([r for i, r in enumerate(ranges.ranges) if int_sizes[i] != 1]) return str(res) - def visit_CartesianOffset(self, node: common.CartesianOffset, **kwargs): + def _explicit_indexing( + self, node: common.CartesianOffset | dcir.VariableKOffset, **kwargs: Any + ) -> str: + """If called from the explicit pass we need to be add manually the relative indexing""" + return f"__k+{self.visit(node.k, **kwargs)}" + + def visit_CartesianOffset( + self, node: common.CartesianOffset, explicit=False, **kwargs: Any + ) -> str: + if explicit: + return self._explicit_indexing(node, **kwargs) + return self._visit_offset(node, **kwargs) - def visit_VariableKOffset(self, node: common.CartesianOffset, **kwargs): + def visit_VariableKOffset( + self, node: dcir.VariableKOffset, explicit=False, **kwargs: Any + ) -> str: + if explicit: + return self._explicit_indexing(node, **kwargs) + return self._visit_offset(node, **kwargs) def visit_IndexAccess( self, node: dcir.IndexAccess, *, - is_target, - sdfg_ctx, + is_target: bool, + sdfg: dace.SDFG, symtable: ChainMap[eve.SymbolRef, dcir.Decl], - **kwargs, - ): + **kwargs: Any, + ) -> str: if is_target: memlets = kwargs["write_memlets"] else: # if this node is not a target, it will still use the symbol of the write memlet if the # field was previously written in the same memlet. memlets = kwargs["read_memlets"] + kwargs["write_memlets"] + try: memlet = next(mem for mem in memlets if mem.connector == node.name) except StopIteration: @@ -88,25 +101,49 @@ def visit_IndexAccess( "Memlet connector and tasklet variable mismatch, DaCe IR error." ) from None - index_strs = [] - if node.offset is not None: - index_strs.append( - self.visit( - node.offset, - decl=symtable[memlet.field], - access_info=memlet.access_info, - symtable=symtable, - in_idx=True, - **kwargs, + index_strs: list[str] = [] + if node.explicit_indices: + # Full array access with every dimensions accessed in full. + # Everything was packed in `explicit_indices` in `DaCeIRBuilder._fix_memlet_array_access()` + # along the `reshape_memlet=True` code path. + assert len(node.explicit_indices) == len(sdfg.arrays[memlet.field].shape) + for idx in node.explicit_indices: + index_strs.append( + self.visit( + idx, + symtable=symtable, + in_idx=True, + explicit=True, + **kwargs, + ) + ) + else: + # Grid-point access, I & J are unitary, K can be offsetted with variable + # Resolve K offset (also resolves I & J) + if node.offset is not None: + index_strs.append( + self.visit( + node.offset, + access_info=memlet.access_info, + symtable=symtable, + in_idx=True, + **kwargs, + ) ) + # Add any data dimensions + index_strs.extend( + self.visit(idx, symtable=symtable, in_idx=True, **kwargs) for idx in node.data_index ) - index_strs.extend( - self.visit(idx, sdfg_ctx=sdfg_ctx, symtable=symtable, in_idx=True, **kwargs) - for idx in node.data_index + # Filter empty strings + non_empty_indices = list(filter(None, index_strs)) + return ( + f"{node.name}[{','.join(non_empty_indices)}]" + if len(non_empty_indices) > 0 + else node.name ) - return f"{node.name}[{','.join(index_strs)}]" - def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs): + def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs: Any) -> str: + # Visiting order matters because targets must not contain the target symbols from the left visit right = self.visit(node.right, is_target=False, **kwargs) left = self.visit(node.left, is_target=True, **kwargs) return f"{left} = {right}" @@ -120,18 +157,18 @@ def visit_AssignStmt(self, node: dcir.AssignStmt, **kwargs): def visit_BuiltInLiteral(self, builtin: common.BuiltInLiteral, **kwargs: Any) -> str: if builtin == common.BuiltInLiteral.TRUE: return "True" - elif builtin == common.BuiltInLiteral.FALSE: + if builtin == common.BuiltInLiteral.FALSE: return "False" raise NotImplementedError("Not implemented BuiltInLiteral encountered.") - def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs): + def visit_Literal(self, literal: dcir.Literal, *, in_idx=False, **kwargs: Any) -> str: value = self.visit(literal.value, in_idx=in_idx, **kwargs) if in_idx: return str(value) - else: - return "{dtype}({value})".format( - dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value - ) + + return "{dtype}({value})".format( + dtype=self.visit(literal.dtype, in_idx=in_idx, **kwargs), value=value + ) Cast = as_fmt("{dtype}({expr})") @@ -178,26 +215,26 @@ def visit_NativeFuncCall(self, call: common.NativeFuncCall, **kwargs: Any) -> st def visit_DataType(self, dtype: common.DataType, **kwargs: Any) -> str: if dtype == common.DataType.BOOL: return "dace.bool_" - elif dtype == common.DataType.INT8: + if dtype == common.DataType.INT8: return "dace.int8" - elif dtype == common.DataType.INT16: + if dtype == common.DataType.INT16: return "dace.int16" - elif dtype == common.DataType.INT32: + if dtype == common.DataType.INT32: return "dace.int32" - elif dtype == common.DataType.INT64: + if dtype == common.DataType.INT64: return "dace.int64" - elif dtype == common.DataType.FLOAT32: + if dtype == common.DataType.FLOAT32: return "dace.float32" - elif dtype == common.DataType.FLOAT64: + if dtype == common.DataType.FLOAT64: return "dace.float64" raise NotImplementedError("Not implemented DataType encountered.") def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: if op == common.UnaryOperator.NOT: return " not " - elif op == common.UnaryOperator.NEG: + if op == common.UnaryOperator.NEG: return "-" - elif op == common.UnaryOperator.POS: + if op == common.UnaryOperator.POS: return "+" raise NotImplementedError("Not implemented UnaryOperator encountered.") @@ -205,18 +242,16 @@ def visit_UnaryOperator(self, op: common.UnaryOperator, **kwargs: Any) -> str: Param = as_fmt("{name}") - LocalScalarDecl = as_fmt("{name}: {dtype}") - - def visit_Tasklet(self, node: dcir.Tasklet, **kwargs): - return "\n".join(self.visit(node.decls, **kwargs) + self.visit(node.stmts, **kwargs)) + def visit_Tasklet(self, node: dcir.Tasklet, **kwargs: Any) -> str: + return "\n".join(self.visit(node.stmts, **kwargs)) def _visit_conditional( self, cond: Optional[Union[dcir.Expr, common.HorizontalMask]], body: List[dcir.Stmt], - keyword, - **kwargs, - ): + keyword: str, + **kwargs: Any, + ) -> str: mask_str = "" indent = "" if cond is not None and (cond_str := self.visit(cond, is_target=False, **kwargs)): @@ -226,16 +261,16 @@ def _visit_conditional( body_code = [indent + b for b in body_code] return "\n".join([mask_str, *body_code]) - def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs): + def visit_MaskStmt(self, node: dcir.MaskStmt, **kwargs: Any) -> str: return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs) - def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs): + def visit_HorizontalRestriction(self, node: dcir.HorizontalRestriction, **kwargs: Any) -> str: return self._visit_conditional(cond=node.mask, body=node.body, keyword="if", **kwargs) - def visit_While(self, node: dcir.While, **kwargs): + def visit_While(self, node: dcir.While, **kwargs: Any) -> Any: return self._visit_conditional(cond=node.cond, body=node.body, keyword="while", **kwargs) - def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs): + def visit_HorizontalMask(self, node: common.HorizontalMask, **kwargs: Any) -> str: clauses: List[str] = [] for axis, interval in zip(dcir.Axis.dims_horizontal(), node.intervals): diff --git a/src/gt4py/cartesian/gtc/dace/expansion/utils.py b/src/gt4py/cartesian/gtc/dace/expansion/utils.py index 919ec02996..637b348a03 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/utils.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/utils.py @@ -10,11 +10,6 @@ from typing import TYPE_CHECKING, List -import dace -import dace.data -import dace.library -import dace.subsets - from gt4py import eve from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.dace import daceir as dcir @@ -25,15 +20,6 @@ from gt4py.cartesian.gtc.dace.nodes import StencilComputation -def get_dace_debuginfo(node: common.LocNode): - if node.loc is not None: - return dace.dtypes.DebugInfo( - node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename - ) - else: - return dace.dtypes.DebugInfo(0) - - class HorizontalIntervalRemover(eve.NodeTranslator): def visit_HorizontalMask(self, node: common.HorizontalMask, *, axis: dcir.Axis): mask_attrs = dict(i=node.i, j=node.j) @@ -54,8 +40,8 @@ def visit_Tasklet(self, node: dcir.Tasklet): else: res_body.append(newstmt) return dcir.Tasklet( + label=f"he_remover_{id(node)}", stmts=res_body, - decls=node.decls, read_memlets=node.read_memlets, write_memlets=node.write_memlets, ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index c716f1a103..af9a814843 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -107,7 +107,8 @@ def get_expansion_order_index(expansion_order, axis): for idx, item in enumerate(expansion_order): if isinstance(item, Iteration) and item.axis == axis: return idx - elif isinstance(item, Map): + + if isinstance(item, Map): for it in item.iterations: if it.kind == "contiguous" and it.axis == axis: return idx @@ -136,7 +137,9 @@ def _choose_loop_or_map(node, eo): return eo -def _order_as_spec(computation_node, expansion_order): +def _order_as_spec( + computation_node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]] +) -> List[ExpansionItem]: expansion_order = list(_choose_loop_or_map(computation_node, eo) for eo in expansion_order) expansion_specification = [] for item in expansion_order: @@ -170,7 +173,7 @@ def _order_as_spec(computation_node, expansion_order): return expansion_specification -def _populate_strides(node, expansion_specification): +def _populate_strides(node: StencilComputation, expansion_specification: List[ExpansionItem]): """Fill in `stride` attribute of `Iteration` and `Loop` dataclasses. For loops, stride is set to either -1 or 1, based on iteration order. @@ -185,10 +188,7 @@ def _populate_strides(node, expansion_specification): for it in iterations: if isinstance(it, Loop): if it.stride is None: - if node.oir_node.loop_order == common.LoopOrder.BACKWARD: - it.stride = -1 - else: - it.stride = 1 + it.stride = -1 if node.oir_node.loop_order == common.LoopOrder.BACKWARD else 1 else: if it.stride is None: if it.kind == "tiling": @@ -204,7 +204,7 @@ def _populate_strides(node, expansion_specification): it.stride = 1 -def _populate_storages(self, expansion_specification): +def _populate_storages(expansion_specification: List[ExpansionItem]): assert all(isinstance(es, ExpansionItem) for es in expansion_specification) innermost_axes = set(dcir.Axis.dims_3d()) tiled_axes = set() @@ -222,7 +222,7 @@ def _populate_storages(self, expansion_specification): tiled_axes.remove(it.axis) -def _populate_cpu_schedules(self, expansion_specification): +def _populate_cpu_schedules(expansion_specification: List[ExpansionItem]): is_outermost = True for es in expansion_specification: if isinstance(es, Map): @@ -234,7 +234,7 @@ def _populate_cpu_schedules(self, expansion_specification): es.schedule = dace.ScheduleType.Default -def _populate_gpu_schedules(self, expansion_specification): +def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]): # On GPU if any dimension is tiled and has a contiguous map in the same axis further in # pick those two maps as Device/ThreadBlock maps. If not, Make just device map with # default blocksizes @@ -267,16 +267,16 @@ def _populate_gpu_schedules(self, expansion_specification): es.schedule = dace.ScheduleType.Default -def _populate_schedules(self, expansion_specification): +def _populate_schedules(node: StencilComputation, expansion_specification: List[ExpansionItem]): assert all(isinstance(es, ExpansionItem) for es in expansion_specification) - assert hasattr(self, "_device") - if self.device == dace.DeviceType.GPU: - _populate_gpu_schedules(self, expansion_specification) + assert hasattr(node, "_device") + if node.device == dace.DeviceType.GPU: + _populate_gpu_schedules(expansion_specification) else: - _populate_cpu_schedules(self, expansion_specification) + _populate_cpu_schedules(expansion_specification) -def _collapse_maps_gpu(self, expansion_specification): +def _collapse_maps_gpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: def _union_map_items(last_item, next_item): if last_item.schedule == next_item.schedule: return ( @@ -307,7 +307,7 @@ def _union_map_items(last_item, next_item): ), ) - res_items = [] + res_items: List[ExpansionItem] = [] for item in expansion_specification: if isinstance(item, Map): if not res_items or not isinstance(res_items[-1], Map): @@ -324,8 +324,8 @@ def _union_map_items(last_item, next_item): return res_items -def _collapse_maps_cpu(self, expansion_specification): - res_items = [] +def _collapse_maps_cpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: + res_items: List[ExpansionItem] = [] for item in expansion_specification: if isinstance(item, Map): if ( @@ -360,12 +360,12 @@ def _collapse_maps_cpu(self, expansion_specification): return res_items -def _collapse_maps(self, expansion_specification): - assert hasattr(self, "_device") - if self.device == dace.DeviceType.GPU: - res_items = _collapse_maps_gpu(self, expansion_specification) +def _collapse_maps(node: StencilComputation, expansion_specification: List[ExpansionItem]): + assert hasattr(node, "_device") + if node.device == dace.DeviceType.GPU: + res_items = _collapse_maps_gpu(expansion_specification) else: - res_items = _collapse_maps_cpu(self, expansion_specification) + res_items = _collapse_maps_cpu(expansion_specification) expansion_specification.clear() expansion_specification.extend(res_items) @@ -387,7 +387,7 @@ def make_expansion_order( _populate_strides(node, expansion_specification) _populate_schedules(node, expansion_specification) _collapse_maps(node, expansion_specification) - _populate_storages(node, expansion_specification) + _populate_storages(expansion_specification) return expansion_specification diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index 34401e18b9..13fb6ecc6e 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -23,12 +23,12 @@ from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.dace import daceir as dcir from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion +from gt4py.cartesian.gtc.dace.expansion.utils import HorizontalExecutionSplitter +from gt4py.cartesian.gtc.dace.expansion_specification import ExpansionItem, make_expansion_order +from gt4py.cartesian.gtc.dace.utils import get_dace_debuginfo from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.oir import Decl, FieldDecl, VerticalLoop, VerticalLoopSection -from .expansion.utils import HorizontalExecutionSplitter, get_dace_debuginfo -from .expansion_specification import ExpansionItem, make_expansion_order - def _set_expansion_order( node: StencilComputation, expansion_order: Union[List[ExpansionItem], List[str]] diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index 3555d555f9..d80e14296b 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -17,10 +17,14 @@ import gt4py.cartesian.gtc.oir as oir from gt4py import eve -from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir, prefix from gt4py.cartesian.gtc.dace.nodes import StencilComputation from gt4py.cartesian.gtc.dace.symbol_utils import data_type_to_dace_typeclass -from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos, make_dace_subset +from gt4py.cartesian.gtc.dace.utils import ( + compute_dcir_access_infos, + get_dace_debuginfo, + make_dace_subset, +) from gt4py.cartesian.gtc.definitions import Extent from gt4py.cartesian.gtc.passes.oir_optimizations.utils import ( AccessCollector, @@ -36,10 +40,11 @@ class SDFGContext: decls: Dict[str, oir.Decl] block_extents: Dict[int, Extent] access_infos: Dict[str, dcir.FieldAccessInfo] + loop_counter: int = 0 def __init__(self, stencil: oir.Stencil): self.sdfg = dace.SDFG(stencil.name) - self.last_state = self.sdfg.add_state(is_start_state=True) + self.last_state = self.sdfg.add_state(is_start_block=True) self.decls = {decl.name: decl for decl in stencil.params + stencil.declarations} self.block_extents = compute_horizontal_block_extents(stencil) @@ -93,16 +98,21 @@ def _make_dace_subset(self, local_access_info, field): global_access_info, local_access_info, self.decls[field].data_dims ) - def visit_VerticalLoop( - self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext, **kwargs - ): + def _vloop_name(self, node: oir.VerticalLoop, ctx: OirSDFGBuilder.SDFGContext) -> str: + sdfg_name = ctx.sdfg.name + counter = ctx.loop_counter + ctx.loop_counter += 1 + + return f"{sdfg_name}_vloop_{counter}_{node.loop_order}_{id(node)}" + + def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext): declarations = { acc.name: ctx.decls[acc.name] for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess) if acc.name in ctx.decls } library_node = StencilComputation( - name=f"{ctx.sdfg.name}_computation_{id(node)}", + name=self._vloop_name(node, ctx), extents=ctx.block_extents, declarations=declarations, oir_node=node, @@ -117,23 +127,24 @@ def visit_VerticalLoop( access_collection = AccessCollector.apply(node) for field in access_collection.read_fields(): - access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) - library_node.add_in_connector("__in_" + field) + access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) + connector_name = f"{prefix.CONNECTOR_IN}{field}" + library_node.add_in_connector(connector_name) subset = ctx.make_input_dace_subset(node, field) state.add_edge( - access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset) + access_node, None, library_node, connector_name, dace.Memlet(field, subset=subset) ) + for field in access_collection.write_fields(): - access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) - library_node.add_out_connector("__out_" + field) + access_node = state.add_access(field, debuginfo=get_dace_debuginfo(declarations[field])) + connector_name = f"{prefix.CONNECTOR_OUT}{field}" + library_node.add_out_connector(connector_name) subset = ctx.make_output_dace_subset(node, field) state.add_edge( - library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset) + library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset) ) - return - - def visit_Stencil(self, node: oir.Stencil, **kwargs): + def visit_Stencil(self, node: oir.Stencil): ctx = OirSDFGBuilder.SDFGContext(stencil=node) for param in node.params: if isinstance(param, oir.FieldDecl): @@ -149,7 +160,7 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs): ], dtype=data_type_to_dace_typeclass(param.dtype), transient=False, - debuginfo=dace.DebugInfo(0), + debuginfo=get_dace_debuginfo(param), ) else: ctx.sdfg.add_symbol(param.name, stype=data_type_to_dace_typeclass(param.dtype)) @@ -167,8 +178,9 @@ def visit_Stencil(self, node: oir.Stencil, **kwargs): ], dtype=data_type_to_dace_typeclass(decl.dtype), transient=True, - debuginfo=dace.DebugInfo(0), + lifetime=dace.AllocationLifetime.Persistent, + debuginfo=get_dace_debuginfo(decl), ) - self.generic_visit(node, ctx=ctx) + self.visit(node.vertical_loops, ctx=ctx) ctx.sdfg.validate() return ctx.sdfg diff --git a/src/gt4py/cartesian/gtc/dace/prefix.py b/src/gt4py/cartesian/gtc/dace/prefix.py new file mode 100644 index 0000000000..1da9eb95f3 --- /dev/null +++ b/src/gt4py/cartesian/gtc/dace/prefix.py @@ -0,0 +1,23 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from typing import Final + + +# DaCe passthrough prefixes +PASSTHROUGH_IN: Final[str] = "IN_" +PASSTHROUGH_OUT: Final[str] = "OUT_" + +# StencilComputation in/out connector prefixes +CONNECTOR_IN: Final[str] = "__in_" +CONNECTOR_OUT: Final[str] = "__out_" + +# Tasklet in/out connector prefixes +TASKLET_IN: Final[str] = "gtIN__" +TASKLET_OUT: Final[str] = "gtOUT__" diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index b5c23d2735..4ef48ebcd9 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -19,10 +19,19 @@ from gt4py import eve from gt4py.cartesian.gtc import common, oir from gt4py.cartesian.gtc.common import CartesianOffset, VariableKOffset -from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace import daceir as dcir, prefix from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_horizontal_block_extents +def get_dace_debuginfo(node: common.LocNode) -> dace.dtypes.DebugInfo: + if node.loc is None: + return dace.dtypes.DebugInfo(0) + + return dace.dtypes.DebugInfo( + node.loc.line, node.loc.column, node.loc.line, node.loc.column, node.loc.filename + ) + + def array_dimensions(array: dace.data.Array): dims = [ any( @@ -40,7 +49,7 @@ def array_dimensions(array: dace.data.Array): return dims -def replace_strides(arrays, get_layout_map): +def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, str]: symbol_mapping = {} for array in arrays: dims = array_dimensions(array) @@ -58,22 +67,25 @@ def replace_strides(arrays, get_layout_map): def get_tasklet_symbol( - name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool + name: str, + *, + offset: Optional[CartesianOffset | VariableKOffset] = None, + is_target: bool, ): - if is_target: - return f"gtOUT__{name}" - - acc_name = f"gtIN__{name}" - if offset is not None: - offset_strs = [] - for axis in dcir.Axis.dims_3d(): - off = offset.to_dict()[axis.lower()] - if off is not None and off != 0: - offset_strs.append(axis.lower() + ("m" if off < 0 else "p") + f"{abs(off):d}") - suffix = "_".join(offset_strs) - if suffix != "": - acc_name += suffix - return acc_name + access_name = f"{prefix.TASKLET_OUT}{name}" if is_target else f"{prefix.TASKLET_IN}{name}" + if offset is None: + return access_name + + # add (per axis) offset markers, e.g. gtIN__A_km1 for A[0, 0, -1] + offset_strings = [] + for axis in dcir.Axis.dims_3d(): + axis_offset = offset.to_dict()[axis.lower()] + if axis_offset is not None and axis_offset != 0: + offset_strings.append( + axis.lower() + ("m" if axis_offset < 0 else "p") + f"{abs(axis_offset):d}" + ) + + return access_name + "_".join(offset_strings) def axes_list_from_flags(flags): @@ -187,7 +199,8 @@ def visit_MaskStmt(self, node: oir.MaskStmt, *, is_conditional=False, **kwargs): self.visit(node.body, is_conditional=True, **kwargs) def visit_While(self, node: oir.While, *, is_conditional=False, **kwargs): - self.generic_visit(node, is_conditional=True, **kwargs) + self.visit(node.cond, is_conditional=is_conditional, **kwargs) + self.visit(node.body, is_conditional=True, **kwargs) @staticmethod def _global_grid_subset( @@ -233,12 +246,8 @@ def _make_access_info( is_write, ) -> dcir.FieldAccessInfo: # Check we have expression offsets in K - # OR write offsets in K offset = [offset_node.to_dict()[k] for k in "ijk"] - if isinstance(offset_node, oir.VariableKOffset) or (offset[2] != 0 and is_write): - variable_offset_axes = [dcir.Axis.K] - else: - variable_offset_axes = [] + variable_offset_axes = [dcir.Axis.K] if isinstance(offset_node, oir.VariableKOffset) else [] global_subset = self._global_grid_subset(region, he_grid, offset) intervals = {} @@ -257,7 +266,6 @@ def _make_access_info( return dcir.FieldAccessInfo( grid_subset=grid_subset, global_grid_subset=global_subset, - dynamic_access=len(variable_offset_axes) > 0 or is_conditional or region is not None, variable_offset_axes=variable_offset_axes, ) @@ -333,10 +341,173 @@ def compute_dcir_access_infos( global_grid_subset=access_info.global_grid_subset, ) ) + return res + + return ctx.access_infos + + +class TaskletAccessInfoCollector(eve.NodeVisitor): + @dataclass + class Context: + axes: dict[str, list[dcir.Axis]] + access_infos: dict[str, dcir.FieldAccessInfo] = field(default_factory=dict) + + def __init__( + self, collect_read: bool, collect_write: bool, *, horizontal_extent, k_interval, grid_subset + ): + self.collect_read: bool = collect_read + self.collect_write: bool = collect_write + + self.ij_grid = dcir.GridSubset.from_gt4py_extent(horizontal_extent) + self.he_grid = self.ij_grid.set_interval(dcir.Axis.K, k_interval) + self.grid_subset = grid_subset + + def visit_CodeBlock(self, _node: oir.CodeBlock, **_kwargs): + raise RuntimeError("We shouldn't reach code blocks anymore") + + def visit_AssignStmt(self, node: oir.AssignStmt, **kwargs): + self.visit(node.right, is_write=False, **kwargs) + self.visit(node.left, is_write=True, **kwargs) + + def visit_MaskStmt(self, node: oir.MaskStmt, **kwargs): + self.visit(node.mask, is_write=False, **kwargs) + self.visit(node.body, **kwargs) + + def visit_While(self, node: oir.While, **kwargs): + self.visit(node.cond, is_write=False, **kwargs) + self.visit(node.body, **kwargs) + + def visit_HorizontalRestriction(self, node: oir.HorizontalRestriction, **kwargs): + self.visit(node.mask, is_write=False, **kwargs) + self.visit(node.body, region=node.mask, **kwargs) + + def _global_grid_subset( + self, + region: Optional[common.HorizontalMask], + offset: list[Optional[int]], + ): + res: dict[dcir.Axis, dcir.DomainInterval | dcir.IndexWithExtent | dcir.TileInterval] = {} + if region is not None: + for axis, oir_interval in zip(dcir.Axis.dims_horizontal(), region.intervals): + he_grid_interval = self.he_grid.intervals[axis] + assert isinstance(he_grid_interval, dcir.DomainInterval) + start = ( + oir_interval.start if oir_interval.start is not None else he_grid_interval.start + ) + end = oir_interval.end if oir_interval.end is not None else he_grid_interval.end + dcir_interval = dcir.DomainInterval( + start=dcir.AxisBound.from_common(axis, start), + end=dcir.AxisBound.from_common(axis, end), + ) + res[axis] = dcir.DomainInterval.union(dcir_interval, res.get(axis, dcir_interval)) + if dcir.Axis.K in self.he_grid.intervals: + off = offset[dcir.Axis.K.to_idx()] or 0 + he_grid_k_interval = self.he_grid.intervals[dcir.Axis.K] + assert not isinstance(he_grid_k_interval, dcir.TileInterval) + res[dcir.Axis.K] = he_grid_k_interval.shifted(off) + for axis in dcir.Axis.dims_horizontal(): + iteration_interval = self.he_grid.intervals[axis] + mask_interval = res.get(axis, iteration_interval) + res[axis] = dcir.DomainInterval.intersection( + axis, iteration_interval, mask_interval + ).shifted(offset[axis.to_idx()]) + return dcir.GridSubset(intervals=res) + + def _make_access_info( + self, + offset_node: CartesianOffset | VariableKOffset, + axes, + region: Optional[common.HorizontalMask], + ) -> dcir.FieldAccessInfo: + # Check we have expression offsets in K + offset = [offset_node.to_dict()[k] for k in "ijk"] + variable_offset_axes = [dcir.Axis.K] if isinstance(offset_node, VariableKOffset) else [] + + global_subset = self._global_grid_subset(region, offset) + intervals = {} + for axis in axes: + extent = ( + (0, 0) + if axis in variable_offset_axes + else (offset[axis.to_idx()], offset[axis.to_idx()]) + ) + intervals[axis] = dcir.IndexWithExtent( + axis=axis, value=axis.iteration_symbol(), extent=extent + ) + + return dcir.FieldAccessInfo( + grid_subset=dcir.GridSubset(intervals=intervals), + global_grid_subset=global_subset, + # Field access inside horizontal regions might or might not happen + dynamic_access=region is not None, + variable_offset_axes=variable_offset_axes, + ) + + def visit_FieldAccess( + self, + node: oir.FieldAccess, + *, + is_write: bool, + region: Optional[common.HorizontalMask] = None, + ctx: TaskletAccessInfoCollector.Context, + **kwargs, + ): + self.visit(node.offset, ctx=ctx, is_write=False, region=region, **kwargs) + + if (is_write and not self.collect_write) or (not is_write and not self.collect_read): + return + + access_info = self._make_access_info( + node.offset, + axes=ctx.axes[node.name], + region=region, + ) + ctx.access_infos[node.name] = access_info.union( + ctx.access_infos.get(node.name, access_info) + ) + + +def compute_tasklet_access_infos( + node: oir.CodeBlock | oir.MaskStmt | oir.While, + *, + collect_read: bool = True, + collect_write: bool = True, + declarations: dict[str, oir.Decl], + horizontal_extent, + k_interval, + grid_subset, +): + """ + Compute access information needed to build Memlets for the Tasklet + associated with the given `node`. + """ + axes = { + name: axes_list_from_flags(declaration.dimensions) + for name, declaration in declarations.items() + if isinstance(declaration, oir.FieldDecl) + } + ctx = TaskletAccessInfoCollector.Context(axes=axes, access_infos=dict()) + collector = TaskletAccessInfoCollector( + collect_read=collect_read, + collect_write=collect_write, + horizontal_extent=horizontal_extent, + k_interval=k_interval, + grid_subset=grid_subset, + ) + if isinstance(node, oir.CodeBlock): + collector.visit(node.body, ctx=ctx) + elif isinstance(node, oir.MaskStmt): + # node.mask is a simple expression. + # Pass `is_write` explicitly since we don't automatically set it in `visit_AssignStmt()` + collector.visit(node.mask, ctx=ctx, is_write=False) + elif isinstance(node, oir.While): + # node.cond is a simple expression. + # Pass `is_write` explicitly since we don't automatically set it in `visit_AssignStmt()` + collector.visit(node.cond, ctx=ctx, is_write=False) else: - res = ctx.access_infos + raise ValueError("Unexpected node type.") - return res + return ctx.access_infos def make_dace_subset( @@ -349,7 +520,7 @@ def make_dace_subset( for axis in access_info.axes(): if axis in access_info.variable_offset_axes: clamped_access_info = clamped_access_info.clamp_full_axis(axis) - if axis in clamped_context_info.variable_offset_axes: + if axis in context_info.variable_offset_axes: clamped_context_info = clamped_context_info.clamp_full_axis(axis) res_ranges = [] diff --git a/src/gt4py/cartesian/gtc/definitions.py b/src/gt4py/cartesian/gtc/definitions.py index 16c7fbc46a..467ba04f99 100644 --- a/src/gt4py/cartesian/gtc/definitions.py +++ b/src/gt4py/cartesian/gtc/definitions.py @@ -118,7 +118,7 @@ def __add__(self, other): return self._apply(self._broadcast(other), operator.add) def __sub__(self, other): - """Element-wise substraction.""" + """Element-wise subtraction.""" return self._apply(self._broadcast(other), operator.sub) def __mul__(self, other): @@ -335,7 +335,7 @@ def __add__(self, other): return self._apply(self._broadcast(other), lambda a, b: a + b) def __sub__(self, other): - """Element-wise substraction.""" + """Element-wise subtraction.""" return self._apply(self._broadcast(other), lambda a, b: a - b) def __and__(self, other): diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py index 0d19814b9c..5ca766c272 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp.py @@ -31,11 +31,11 @@ class Offset(common.CartesianOffset): pass -class Literal(common.Literal, Expr): # type: ignore +class Literal(common.Literal, Expr): pass -class LocalAccess(common.ScalarAccess, Expr): # type: ignore +class LocalAccess(common.ScalarAccess, Expr): pass @@ -43,7 +43,7 @@ class VariableKOffset(common.VariableKOffset[Expr]): pass -class AccessorRef(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore +class AccessorRef(common.FieldAccess[Expr, VariableKOffset], Expr): pass @@ -88,7 +88,7 @@ class NativeFuncCall(common.NativeFuncCall[Expr], Expr): _dtype_propagation = common.native_func_call_dtype_propagation(strict=True) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/gtir.py b/src/gt4py/cartesian/gtc/gtir.py index c9f58de2da..0ee4f7ebe1 100644 --- a/src/gt4py/cartesian/gtc/gtir.py +++ b/src/gt4py/cartesian/gtc/gtir.py @@ -43,7 +43,7 @@ class BlockStmt(common.BlockStmt[Stmt], Stmt): pass -class Literal(common.Literal, Expr): # type: ignore +class Literal(common.Literal, Expr): pass @@ -51,11 +51,11 @@ class VariableKOffset(common.VariableKOffset[Expr]): pass -class ScalarAccess(common.ScalarAccess, Expr): # type: ignore +class ScalarAccess(common.ScalarAccess, Expr): pass -class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore +class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): pass @@ -163,7 +163,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr): _dtype_propagation = common.ternary_op_dtype_propagation(strict=False) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/gtir_to_oir.py b/src/gt4py/cartesian/gtc/gtir_to_oir.py index 560cbf96cf..96f8077ec4 100644 --- a/src/gt4py/cartesian/gtc/gtir_to_oir.py +++ b/src/gt4py/cartesian/gtc/gtir_to_oir.py @@ -7,11 +7,11 @@ # SPDX-License-Identifier: BSD-3-Clause from dataclasses import dataclass, field -from typing import Any, List, Optional, Set, Union +from typing import Any, List, Set, Union from gt4py import eve -from gt4py.cartesian.gtc import common, gtir, oir, utils -from gt4py.cartesian.gtc.common import CartesianOffset, DataType, LogicalOperator, UnaryOperator +from gt4py.cartesian.gtc import gtir, oir, utils +from gt4py.cartesian.gtc.common import CartesianOffset, DataType, UnaryOperator from gt4py.cartesian.gtc.passes.oir_optimizations.utils import compute_fields_extents @@ -22,7 +22,6 @@ def validate_stencil_memory_accesses(node: oir.Stencil) -> oir.Stencil: at the OIR level. This is similar to the check at the gtir level for read-with-offset and writes, but more complete because it involves extent analysis, so it catches indirect read-with-offset through temporaries. - """ def _writes(node: oir.Stencil) -> Set[str]: @@ -118,15 +117,8 @@ def visit_NativeFuncCall(self, node: gtir.NativeFuncCall) -> oir.NativeFuncCall: ) # --- Statements --- - def visit_ParAssignStmt( - self, node: gtir.ParAssignStmt, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> Union[oir.AssignStmt, oir.MaskStmt]: - statement = oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) - if mask is None: - return statement - - # Wrap inside MaskStmt - return oir.MaskStmt(body=[statement], mask=mask, loc=node.loc) + def visit_ParAssignStmt(self, node: gtir.ParAssignStmt, **kwargs: Any) -> oir.AssignStmt: + return oir.AssignStmt(left=self.visit(node.left), right=self.visit(node.right)) def visit_HorizontalRestriction( self, node: gtir.HorizontalRestriction, **kwargs: Any @@ -138,24 +130,19 @@ def visit_HorizontalRestriction( return oir.HorizontalRestriction(mask=node.mask, body=body) - def visit_While( - self, node: gtir.While, *, mask: Optional[oir.Expr] = None, **kwargs: Any - ) -> oir.While: + def visit_While(self, node: gtir.While, **kwargs: Any) -> oir.While: body: List[oir.Stmt] = [] for statement in node.body: oir_statement = self.visit(statement, **kwargs) body.extend(utils.flatten_list(utils.listify(oir_statement))) condition: oir.Expr = self.visit(node.cond) - if mask: - condition = oir.BinaryOp(op=common.LogicalOperator.AND, left=mask, right=condition) return oir.While(cond=condition, body=body, loc=node.loc) def visit_FieldIfStmt( self, node: gtir.FieldIfStmt, *, - mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, ) -> List[Union[oir.AssignStmt, oir.MaskStmt]]: @@ -182,26 +169,17 @@ def visit_FieldIfStmt( loc=node.loc, ) - combined_mask: oir.Expr = condition - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc - ) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=condition, loc=node.loc)) if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition) - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=combined_mask, loc=node.loc - ) + negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc)) return statements @@ -211,31 +189,21 @@ def visit_ScalarIfStmt( self, node: gtir.ScalarIfStmt, *, - mask: Optional[oir.Expr] = None, ctx: Context, **kwargs: Any, ) -> List[oir.MaskStmt]: condition = self.visit(node.cond) - combined_mask = condition - if mask: - combined_mask = oir.BinaryOp( - op=LogicalOperator.AND, left=mask, right=condition, loc=node.loc - ) - body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.true_branch.body] ) statements = [oir.MaskStmt(body=body, mask=condition, loc=node.loc)] if node.false_branch: - combined_mask = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) - if mask: - combined_mask = oir.BinaryOp(op=LogicalOperator.AND, left=mask, right=combined_mask) - + negated_condition = oir.UnaryOp(op=UnaryOperator.NOT, expr=condition, loc=node.loc) body = utils.flatten_list( [self.visit(statement, ctx=ctx, **kwargs) for statement in node.false_branch.body] ) - statements.append(oir.MaskStmt(body=body, mask=combined_mask, loc=node.loc)) + statements.append(oir.MaskStmt(body=body, mask=negated_condition, loc=node.loc)) return statements diff --git a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py index ed573ebfff..b6aeb49823 100644 --- a/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py +++ b/src/gt4py/cartesian/gtc/numpy/oir_to_npir.py @@ -157,13 +157,12 @@ def visit_AssignStmt( def visit_While( self, node: oir.While, *, mask: Optional[npir.Expr] = None, **kwargs: Any ) -> npir.While: - cond = self.visit(node.cond, mask=mask, **kwargs) + cond_expr = self.visit(node.cond, **kwargs) if mask: - mask = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond) - else: - mask = cond + cond_expr = npir.VectorLogic(op=common.LogicalOperator.AND, left=mask, right=cond_expr) + return npir.While( - cond=cond, body=utils.flatten_list(self.visit(node.body, mask=mask, **kwargs)) + cond=cond_expr, body=utils.flatten_list(self.visit(node.body, mask=cond_expr, **kwargs)) ) def visit_HorizontalRestriction( diff --git a/src/gt4py/cartesian/gtc/oir.py b/src/gt4py/cartesian/gtc/oir.py index df71ef26cf..1ba36b5077 100644 --- a/src/gt4py/cartesian/gtc/oir.py +++ b/src/gt4py/cartesian/gtc/oir.py @@ -33,11 +33,15 @@ class Stmt(common.Stmt): pass -class Literal(common.Literal, Expr): # type: ignore +class CodeBlock(common.BlockStmt[Stmt], Stmt): + label: str + + +class Literal(common.Literal, Expr): pass -class ScalarAccess(common.ScalarAccess, Expr): # type: ignore +class ScalarAccess(common.ScalarAccess, Expr): pass @@ -45,7 +49,7 @@ class VariableKOffset(common.VariableKOffset[Expr]): pass -class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): # type: ignore +class FieldAccess(common.FieldAccess[Expr, VariableKOffset], Expr): pass @@ -88,7 +92,7 @@ class TernaryOp(common.TernaryOp[Expr], Expr): _dtype_propagation = common.ternary_op_dtype_propagation(strict=True) -class Cast(common.Cast[Expr], Expr): # type: ignore +class Cast(common.Cast[Expr], Expr): pass diff --git a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py index 96cec5b6d4..40c31dca53 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_k_boundary.py @@ -41,20 +41,21 @@ def visit_FieldAccess( node: gtir.FieldAccess, vloop: gtir.VerticalLoop, field_boundaries: Dict[str, Tuple[Union[float, int], Union[float, int]]], - include_center_interval: bool, - **kwargs: Any, + **_: Any, ): boundary = field_boundaries[node.name] interval = vloop.interval if not isinstance(node.offset, gtir.VariableKOffset): - if interval.start.level == LevelMarker.START and ( - include_center_interval or interval.end.level == LevelMarker.START - ): - boundary = (max(-interval.start.offset - node.offset.k, boundary[0]), boundary[1]) - if ( - include_center_interval or interval.start.level == LevelMarker.END - ) and interval.end.level == LevelMarker.END: - boundary = (boundary[0], max(interval.end.offset + node.offset.k, boundary[1])) + if interval.start.level == LevelMarker.START: + boundary = ( + max(-interval.start.offset - node.offset.k, boundary[0]), + boundary[1], + ) + if interval.end.level == LevelMarker.END: + boundary = ( + boundary[0], + max(interval.end.offset + node.offset.k, boundary[1]), + ) if node.name in [decl.name for decl in vloop.temporaries] and ( boundary[0] > 0 or boundary[1] > 0 ): @@ -63,24 +64,35 @@ def visit_FieldAccess( field_boundaries[node.name] = boundary -def compute_k_boundary( - node: gtir.Stencil, include_center_interval=True -) -> Dict[str, Tuple[int, int]]: +def compute_k_boundary(node: gtir.Stencil) -> Dict[str, Tuple[int, int]]: # loop from START to END is not considered as it might be empty. additional check possible in the future - return KBoundaryVisitor().visit(node, include_center_interval=include_center_interval) + return KBoundaryVisitor().visit(node) -def compute_min_k_size(node: gtir.Stencil, include_center_interval=True) -> int: +def compute_min_k_size(node: gtir.Stencil) -> int: """Compute the required number of k levels to run a stencil.""" + min_size_start = 0 min_size_end = 0 + biggest_offset = 0 for vloop in node.vertical_loops: - if vloop.interval.start.level == LevelMarker.START and ( - include_center_interval or vloop.interval.end.level == LevelMarker.START + if ( + vloop.interval.start.level == LevelMarker.START + and vloop.interval.end.level == LevelMarker.END ): - min_size_start = max(min_size_start, vloop.interval.end.offset) + if not (vloop.interval.start.offset == 0 and vloop.interval.end.offset == 0): + biggest_offset = max( + biggest_offset, + vloop.interval.start.offset - vloop.interval.end.offset + 1, + ) elif ( - include_center_interval or vloop.interval.start.level == LevelMarker.END - ) and vloop.interval.end.level == LevelMarker.END: + vloop.interval.start.level == LevelMarker.START + and vloop.interval.end.level == LevelMarker.START + ): + min_size_start = max(min_size_start, vloop.interval.end.offset) + biggest_offset = max(biggest_offset, vloop.interval.end.offset) + else: min_size_end = max(min_size_end, -vloop.interval.start.offset) - return min_size_start + min_size_end + biggest_offset = max(biggest_offset, -vloop.interval.start.offset) + minimal_size = max(min_size_start + min_size_end, biggest_offset) + return minimal_size diff --git a/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py b/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py index 41fa127d6d..94c3d6cd78 100644 --- a/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py +++ b/src/gt4py/cartesian/gtc/passes/gtir_upcaster.py @@ -24,7 +24,7 @@ def _upcast_node(target_dtype: DataType, node: Expr) -> Expr: def _upcast_nodes(*exprs: Expr, upcasting_rule: Callable) -> Iterator[Expr]: assert all(e.dtype for e in exprs) - dtypes: List[DataType] = [e.dtype for e in exprs] # type: ignore # guaranteed to be not None + dtypes: List[DataType] = [e.dtype for e in exprs] # guaranteed to be not None target_dtypes = upcasting_rule(*dtypes) return iter(_upcast_node(target_dtype, arg) for target_dtype, arg in zip(target_dtypes, exprs)) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py index f6c864aaba..fd09017720 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/caches.py @@ -39,7 +39,6 @@ Note that filling and flushing k-caches can always be replaced by a local (non-filling or flushing) k-cache plus additional filling and flushing statements. - """ @@ -261,7 +260,7 @@ class FillFlushToLocalKCaches(eve.NodeTranslator, eve.VisitorWithSymbolTableTrai For each cached field, the following actions are performed: 1. A new locally-k-cached temporary is introduced. 2. All accesses to the original field are replaced by accesses to this temporary. - 3. Loop sections are split where necessary to allow single-level loads whereever possible. + 3. Loop sections are split where necessary to allow single-level loads wherever possible. 3. Fill statements from the original field to the temporary are introduced. 4. Flush statements from the temporary to the original field are introduced. """ diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 643ecba010..59f3ef37c2 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -657,10 +657,8 @@ def __str__(self) -> str: class _FieldDescriptorMaker: @staticmethod def _is_axes_spec(spec) -> bool: - return ( - isinstance(spec, Axis) - or isinstance(spec, collections.abc.Collection) - and all(isinstance(i, Axis) for i in spec) + return isinstance(spec, Axis) or ( + isinstance(spec, collections.abc.Collection) and all(isinstance(i, Axis) for i in spec) ) def __getitem__(self, field_spec): diff --git a/src/gt4py/cartesian/gtscript_imports.py b/src/gt4py/cartesian/gtscript_imports.py index 109f19759e..6fe49f18dd 100644 --- a/src/gt4py/cartesian/gtscript_imports.py +++ b/src/gt4py/cartesian/gtscript_imports.py @@ -23,13 +23,12 @@ gtscript_imports.enable( search_path=[, , ...], # for allowing only in search_path generate_path=, # for generating python modules in a specific dir - in_source=False, # set True to generate python modules next to gtscfipt files + in_source=False, # set True to generate python modules next to gtscript files ) # scoped usage with gtscript_imports.enabled(): import ... - """ import importlib diff --git a/src/gt4py/cartesian/stencil_builder.py b/src/gt4py/cartesian/stencil_builder.py index c0f58c0bc9..6ca2c673a1 100644 --- a/src/gt4py/cartesian/stencil_builder.py +++ b/src/gt4py/cartesian/stencil_builder.py @@ -58,10 +58,7 @@ def __init__( frontend: Optional[Type[FrontendType]] = None, ): self._definition = definition_func - # type ignore explanation: Attribclass generated init not recognized by mypy - self.options = options or BuildOptions( # type: ignore - **self.default_options_dict(definition_func) - ) + self.options = options or BuildOptions(**self.default_options_dict(definition_func)) backend = backend or "numpy" backend = gt4pyc.backend.from_name(backend) if isinstance(backend, str) else backend if backend is None: diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index b76415e17f..5e5976e3e5 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -513,7 +513,7 @@ def _normalize_origins( *((0,) * len(field_info.data_dims)), ) elif (info_origin := getattr(array_infos.get(name), "origin", None)) is not None: - origin[name] = info_origin # type: ignore + origin[name] = info_origin else: origin[name] = (0,) * field_info.ndim diff --git a/src/gt4py/cartesian/testing/__init__.py b/src/gt4py/cartesian/testing/__init__.py index 288d7b1d2d..0753b4175e 100644 --- a/src/gt4py/cartesian/testing/__init__.py +++ b/src/gt4py/cartesian/testing/__init__.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -__all__ = ["field", "global_name", "none", "parameter", "StencilTestSuite"] +__all__ = ["StencilTestSuite", "field", "global_name", "none", "parameter"] try: from .input_strategies import field, global_name, none, parameter from .suites import StencilTestSuite diff --git a/src/gt4py/cartesian/testing/suites.py b/src/gt4py/cartesian/testing/suites.py index 48bead86e2..423f834f51 100644 --- a/src/gt4py/cartesian/testing/suites.py +++ b/src/gt4py/cartesian/testing/suites.py @@ -167,7 +167,7 @@ def get_globals_combinations(dtypes): generation_strategy=composite_strategy_factory( d, generation_strategy_factories ), - implementations=[], + implementation=None, test_id=len(cls_dict["tests"]), definition=annotate_function( function=cls_dict["definition"], @@ -199,14 +199,19 @@ def hyp_wrapper(test_hyp, hypothesis_data): for test in cls_dict["tests"]: if test["suite"] == cls_name: - marks = test["marks"] - if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": - marks.append(pytest.mark.requires_gpu) name = test["backend"] name += "".join(f"_{key}_{value}" for key, value in test["constants"].items()) name += "".join( "_{}_{}".format(key, value.name) for key, value in test["dtypes"].items() ) + + marks = test["marks"].copy() + if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": + marks.append(pytest.mark.requires_gpu) + # Run generation and implementation tests in the same group to ensure + # (thread-) safe parallelization of stencil tests. + marks.append(pytest.mark.xdist_group(name=f"{cls_name}_{name}")) + param = pytest.param(test, marks=marks, id=name) pytest_params.append(param) @@ -228,14 +233,19 @@ def hyp_wrapper(test_hyp, hypothesis_data): runtime_pytest_params = [] for test in cls_dict["tests"]: if test["suite"] == cls_name: - marks = test["marks"] - if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": - marks.append(pytest.mark.requires_gpu) name = test["backend"] name += "".join(f"_{key}_{value}" for key, value in test["constants"].items()) name += "".join( "_{}_{}".format(key, value.name) for key, value in test["dtypes"].items() ) + + marks = test["marks"].copy() + if gt4pyc.backend.from_name(test["backend"]).storage_info["device"] == "gpu": + marks.append(pytest.mark.requires_gpu) + # Run generation and implementation tests in the same group to ensure + # (thread-) safe parallelization of stencil tests. + marks.append(pytest.mark.xdist_group(name=f"{cls_name}_{name}")) + runtime_pytest_params.append( pytest.param( test, @@ -434,8 +444,11 @@ class StencilTestSuite(metaclass=SuiteMeta): def _test_generation(cls, test, externals_dict): """Test source code generation for all *backends* and *stencil suites*. - The generated implementations are cached in a :class:`utils.ImplementationsDB` - instance, to avoid duplication of (potentially expensive) compilations. + The generated implementation is cached in the test context, to avoid duplication + of (potentially expensive) compilation. + Note: This caching introduces a dependency between tests, which is captured by an + `xdist_group` marker in combination with `--dist loadgroup` to ensure safe parallel + test execution. """ backend_slug = gt_utils.slugify(test["backend"], valid_symbols="") implementation = gtscript.stencil( @@ -461,7 +474,8 @@ def _test_generation(cls, test, externals_dict): or ax == "K" or field_info.boundary[i] >= cls.global_boundaries[name][i] ) - test["implementations"].append(implementation) + assert test["implementation"] is None + test["implementation"] = implementation @classmethod def _run_test_implementation(cls, parameters_dict, implementation): # too complex @@ -534,7 +548,7 @@ def _run_test_implementation(cls, parameters_dict, implementation): # too compl # call implementation implementation(**test_values, origin=origin, domain=domain, exec_info=exec_info) - # for validation data, data is cropped to actually touched domain, so that origin offseting + # for validation data, data is cropped to actually touched domain, so that origin offsetting # does not have to be implemented for every test suite. This is done based on info # specified in test suite cropped_validation_values = {} @@ -585,16 +599,16 @@ def _run_test_implementation(cls, parameters_dict, implementation): # too compl def _test_implementation(cls, test, parameters_dict): """Test computed values for implementations generated for all *backends* and *stencil suites*. - The generated implementations are reused from previous tests by means of a - :class:`utils.ImplementationsDB` instance shared at module scope. + The generated implementation was cached in the test context, to avoid duplication + of (potentially expensive) compilation. + Note: This caching introduces a dependency between tests, which is captured by an + `xdist_group` marker in combination with `--dist loadgroup` to ensure safe parallel + test execution. """ - implementation_list = test["implementations"] - if not implementation_list: - pytest.skip( - "Cannot perform validation tests, since there are no valid implementations." - ) - for implementation in implementation_list: - if not isinstance(implementation, StencilObject): - raise RuntimeError("Wrong function got from implementations_db cache!") + implementation = test["implementation"] + assert ( + implementation is not None + ), "Stencil implementation not found. This usually means code generation failed." + assert isinstance(implementation, StencilObject) - cls._run_test_implementation(parameters_dict, implementation) + cls._run_test_implementation(parameters_dict, implementation) diff --git a/src/gt4py/cartesian/utils/__init__.py b/src/gt4py/cartesian/utils/__init__.py index 3c0bdb3fc3..626d29b167 100644 --- a/src/gt4py/cartesian/utils/__init__.py +++ b/src/gt4py/cartesian/utils/__init__.py @@ -37,7 +37,7 @@ ) -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # Modules "attrib", "meta", diff --git a/src/gt4py/cartesian/utils/base.py b/src/gt4py/cartesian/utils/base.py index d5d43a4103..35184a3f7b 100644 --- a/src/gt4py/cartesian/utils/base.py +++ b/src/gt4py/cartesian/utils/base.py @@ -63,10 +63,8 @@ def flatten_iter(nested_iterables, filter_none=False, *, skip_types=(str, bytes) def get_member(instance, item_name): try: - if ( - isinstance(instance, collections.abc.Mapping) - or isinstance(instance, collections.abc.Sequence) - and isinstance(item_name, int) + if isinstance(instance, collections.abc.Mapping) or ( + isinstance(instance, collections.abc.Sequence) and isinstance(item_name, int) ): return instance[item_name] else: diff --git a/src/gt4py/eve/.gitignore b/src/gt4py/eve/.gitignore deleted file mode 100644 index 050cda3ca5..0000000000 --- a/src/gt4py/eve/.gitignore +++ /dev/null @@ -1 +0,0 @@ -_version.py diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 0b8cfa7d62..e294108011 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -21,11 +21,9 @@ 7. visitors 8. traits 9. codegen - """ -from __future__ import annotations # isort:skip - +from __future__ import annotations from .concepts import ( AnnexManager, @@ -72,7 +70,7 @@ from .visitors import NodeTranslator, NodeVisitor -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # version "__version__", "__version_info__", @@ -89,15 +87,6 @@ "SymbolRef", "VType", "register_annex_user", - "# datamodels" "Coerced", - "DataModel", - "FrozenModel", - "GenericDataModel", - "Unchecked", - "concretize", - "datamodel", - "field", - "frozenmodel", # datamodels "Coerced", "DataModel", @@ -122,7 +111,7 @@ "pre_walk_values", "walk_items", "walk_values", - "# type_definition", + # type_definitions "NOTHING", "ConstrainedStr", "Enum", diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 15fda4f3b4..3869ff313b 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -347,7 +347,7 @@ def __str__(self) -> str: class Template(Protocol): """Protocol (abstract base class) defining the Template interface. - Direct subclassess of this base class only need to implement the + Direct subclasses of this base class only need to implement the abstract methods to adapt different template engines to this interface. @@ -654,8 +654,8 @@ def apply( # redefinition of symbol Args: root: An IR node. - node_templates (optiona): see :class:`NodeDumper`. - dump_function (optiona): see :class:`NodeDumper`. + node_templates (optional): see :class:`NodeDumper`. + dump_function (optional): see :class:`NodeDumper`. ``**kwargs`` (optional): custom extra parameters forwarded to `visit_NODE_TYPE_NAME()`. Returns: diff --git a/src/gt4py/eve/datamodels/__init__.py b/src/gt4py/eve/datamodels/__init__.py index 68ddea2510..5f6806c5dd 100644 --- a/src/gt4py/eve/datamodels/__init__.py +++ b/src/gt4py/eve/datamodels/__init__.py @@ -11,7 +11,7 @@ Data Models can be considered as enhanced `attrs `_ / `dataclasses `_ providing additional features like automatic run-time type validation. Values assigned to fields -at initialization can be validated with automatic type checkings using the +at initialization can be validated with automatic type checking using the field type definition. Custom field validation methods can also be added with the :func:`validator` decorator, and global instance validation methods with :func:`root_validator`. @@ -33,7 +33,7 @@ 1. ``__init__()``. a. If a custom ``__init__`` already exists in the class, it will not be overwritten. - It is your responsability to call ``__auto_init__`` from there to obtain + It is your responsibility to call ``__auto_init__`` from there to obtain the described behavior. b. If there is not custom ``__init__``, the one generated by datamodels will be called first. @@ -104,7 +104,6 @@ >>> CustomModel(3, 2) Instance 1 == 1.5 CustomModel(value=1.5) - """ from . import core as core, validators as validators # imported but unused diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index d596f59cfb..31e63bdf9f 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -16,6 +16,7 @@ import dataclasses import functools import sys +import types import typing import warnings @@ -24,7 +25,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz @@ -270,7 +271,7 @@ def datamodel( @overload -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Type[_T], /, *, @@ -289,7 +290,7 @@ def datamodel( # redefinion of unused symbol # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Optional[Type[_T]] = None, /, *, @@ -867,7 +868,7 @@ def _substitute_typevars( def _make_counting_attr_from_attribute( field_attrib: Attribute, *, include_type: bool = False, **kwargs: Any -) -> Any: # attr.s lies a bit in some typing definitons +) -> Any: # attr.s lies a bit in some typing definitions args = [ "default", "validator", @@ -965,7 +966,7 @@ def _type_converter(value: Any) -> _T: return value if isinstance(value, type_annotation) else type_annotation(value) except Exception as error: raise TypeError( - f"Error during coertion of given value '{value}' for field '{name}'." + f"Error during coercion of given value '{value}' for field '{name}'." ) from error return _type_converter @@ -996,7 +997,7 @@ def _type_converter(value: Any) -> _T: return _make_type_converter(origin_type, name) raise exceptions.EveTypeError( - f"Automatic type coertion for {type_annotation} types is not supported." + f"Automatic type coercion for {type_annotation} types is not supported." ) @@ -1085,7 +1086,7 @@ def _make_datamodel( ) else: - # Create field converter if automatic coertion is enabled + # Create field converter if automatic coercion is enabled converter: TypeConverter = cast( TypeConverter, _make_type_converter(type_hint, qualified_field_name) if coerce_field else None, @@ -1099,7 +1100,7 @@ def _make_datamodel( if isinstance(attr_value_in_cls, _KNOWN_MUTABLE_TYPES): warnings.warn( f"'{attr_value_in_cls.__class__.__name__}' value used as default in '{cls.__name__}.{key}'.\n" - "Mutable types should not defbe normally used as field defaults (use 'default_factory' instead).", + "Mutable types should not be used as field defaults (use 'default_factory' instead).", stacklevel=_stacklevel_offset + 2, ) setattr( @@ -1254,8 +1255,11 @@ def _make_concrete_with_cache( if not is_generic_datamodel_class(datamodel_cls): raise TypeError(f"'{datamodel_cls.__name__}' is not a generic model class.") for t in type_args: + _accepted_types: tuple[type, ...] = (type, type(None), xtyping.StdGenericAliasType) + if sys.version_info >= (3, 10): + _accepted_types = (*_accepted_types, types.UnionType) if not ( - isinstance(t, (type, type(None), xtyping.StdGenericAliasType)) + isinstance(t, _accepted_types) or (getattr(type(t), "__module__", None) in ("typing", "typing_extensions")) ): raise TypeError( diff --git a/src/gt4py/eve/datamodels/validators.py b/src/gt4py/eve/datamodels/validators.py index 119410460c..4ce6f94c5e 100644 --- a/src/gt4py/eve/datamodels/validators.py +++ b/src/gt4py/eve/datamodels/validators.py @@ -42,7 +42,7 @@ from .core import DataModelTP, FieldValidator -__all__ = [ +__all__ = [ # noqa: RUF022 `__all__` is not sorted # reexported from attrs "and_", "deep_iterable", diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index e276f3bccf..bf44824b49 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -14,12 +14,8 @@ from __future__ import annotations -import abc as _abc import array as _array -import collections.abc as _collections_abc -import ctypes as _ctypes import dataclasses as _dataclasses -import enum as _enum import functools as _functools import inspect as _inspect import mmap as _mmap diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index 27f19d2670..8a3cc30f4b 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -31,14 +31,6 @@ from .type_definitions import Enum -try: - # For perfomance reasons, try to use cytoolz when possible (using cython) - import cytoolz as toolz -except ModuleNotFoundError: - # Fall back to pure Python toolz - import toolz # noqa: F401 [unused-import] - - TreeKey = Union[int, str] diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 613eca40b2..695ab69dc3 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -14,6 +14,8 @@ import collections.abc import dataclasses import functools +import sys +import types import typing from . import exceptions, extended_typing as xtyping, utils @@ -193,6 +195,12 @@ def __call__( if type_annotation is None: type_annotation = type(None) + if sys.version_info >= (3, 10): + if isinstance( + type_annotation, types.UnionType + ): # see https://github.com/python/cpython/issues/105499 + type_annotation = typing.Union[type_annotation.__args__] + # Non-generic types if xtyping.is_actual_type(type_annotation): assert not xtyping.get_args(type_annotation) @@ -277,6 +285,7 @@ def __call__( if issubclass(origin_type, (collections.abc.Sequence, collections.abc.Set)): assert len(type_args) == 1 + make_recursive(type_args[0]) if (member_validator := make_recursive(type_args[0])) is None: raise exceptions.EveValueError( f"{type_args[0]} type annotation is not supported." @@ -311,7 +320,7 @@ def __call__( # ... # # Since this can be an arbitrary type (not something regular like a collection) there is - # no way to check if the type parameter is verifed in the actual instance. + # no way to check if the type parameter is verified in the actual instance. # The only check can be done at run-time is to verify that the value is an instance of # the original type, completely ignoring the annotation. Ideally, the static type checker # can do a better job to try figure out if the type parameter is ok ... diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8cb68845d7..96e41a7bd8 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -69,7 +69,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz @@ -440,8 +440,8 @@ def content_hash(*args: Any, hash_algorithm: str | xtyping.HashlibAlgorithm | No return result -ddiff = deepdiff.DeepDiff -"""Shortcut for deepdiff.DeepDiff. +ddiff = deepdiff.diff.DeepDiff +"""Shortcut for deepdiff.diff.DeepDiff. Check https://zepworks.com/deepdiff/current/diff.html for more info. """ @@ -458,13 +458,13 @@ def dhash(obj: Any, **kwargs: Any) -> str: def pprint_ddiff( old: Any, new: Any, *, pprint_opts: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> None: - """Pretty printing of deepdiff.DeepDiff objects. + """Pretty printing of deepdiff.diff.DeepDiff objects. Keyword Arguments: pprint_opts: kwargs dict with options for pprint.pprint. """ pprint_opts = pprint_opts or {"indent": 2} - pprint.pprint(deepdiff.DeepDiff(old, new, **kwargs), **pprint_opts) + pprint.pprint(deepdiff.diff.DeepDiff(old, new, **kwargs), **pprint_opts) AnyWordsIterable = Union[str, Iterable[str]] diff --git a/src/gt4py/eve/visitors.py b/src/gt4py/eve/visitors.py index 28d1e2acf6..59b4ef0881 100644 --- a/src/gt4py/eve/visitors.py +++ b/src/gt4py/eve/visitors.py @@ -45,7 +45,7 @@ class NodeVisitor: 3. ``self.generic_visit()``. This dispatching mechanism is implemented in the main :meth:`visit` - method and can be overriden in subclasses. Therefore, a simple way to extend + method and can be overridden in subclasses. Therefore, a simple way to extend the behavior of a visitor is by inheriting from lightweight `trait` classes with a custom ``visit()`` method, which wraps the call to the superclass' ``visit()`` and adds extra pre and post visit logic. Check :mod:`eve.traits` @@ -82,7 +82,7 @@ def apply(cls, tree, init_var, foo, bar=5, **kwargs): Notes: If you want to apply changes to nodes during the traversal, - use the :class:`NodeMutator` subclass, which handles correctly + use the :class:`NodeTranslator` subclass, which handles correctly structural modifications of the visited tree. """ diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 80bb276c70..4fa9215706 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -20,6 +20,7 @@ from . import common, ffront, iterator, program_processors from .common import ( + Connectivity, Dimension, DimensionKind, Dims, @@ -39,8 +40,7 @@ from .ffront.fbuiltins import * # noqa: F403 [undefined-local-with-import-star] explicitly reexport all from fbuiltins.__all__ from .ffront.fbuiltins import FieldOffset from .iterator.embedded import ( - NeighborTableOffsetProvider, - StridedNeighborOffsetProvider, + NeighborTableOffsetProvider, # TODO(havogt): deprecated index_field, np_as_located_field, ) @@ -61,6 +61,7 @@ "Dimension", "DimensionKind", "Field", + "Connectivity", "GridType", "domain", "Domain", @@ -75,7 +76,6 @@ "as_connectivity", # from iterator "NeighborTableOffsetProvider", - "StridedNeighborOffsetProvider", "index_field", "np_as_located_field", # from ffront diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 864f8c1b09..dae2e9d021 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -18,7 +18,6 @@ Any, Callable, Final, - Literal, Optional, Protocol, Sequence, @@ -28,19 +27,6 @@ ) -try: - import cupy as cp -except ImportError: - cp = None - - -CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = ( - None - if not cp - else (core_defs.DeviceType.ROCM if cp.cuda.runtime.is_hip else core_defs.DeviceType.CUDA) -) - - FieldLayoutMapper: TypeAlias = Callable[ [Sequence[common.Dimension]], core_allocators.BufferLayoutMap ] @@ -180,7 +166,7 @@ def __gt_allocate__( def horizontal_first_layout_mapper( dims: Sequence[common.Dimension], ) -> core_allocators.BufferLayoutMap: - """Map dimensions to a buffer layout making horizonal dims change the slowest (i.e. larger strides).""" + """Map dimensions to a buffer layout making horizontal dims change the slowest (i.e. larger strides).""" def pos_of_kind(kind: common.DimensionKind) -> list[int]: return [i for i, dim in enumerate(dims) if dim.kind == kind] @@ -246,11 +232,11 @@ def __gt_allocate__( raise self.exception -if CUPY_DEVICE is not None: +if core_defs.CUPY_DEVICE_TYPE is not None: assert isinstance(core_allocators.cupy_array_utils, core_allocators.ArrayUtils) cupy_array_utils = core_allocators.cupy_array_utils - if CUPY_DEVICE is core_defs.DeviceType.CUDA: + if core_defs.CUPY_DEVICE_TYPE is core_defs.DeviceType.CUDA: class CUDAFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CUDADeviceTyping]): def __init__(self) -> None: @@ -278,7 +264,7 @@ def __init__(self) -> None: else: - class InvalidGPUFielBufferAllocator(InvalidFieldBufferAllocator[core_defs.CUDADeviceTyping]): + class InvalidGPUFieldBufferAllocator(InvalidFieldBufferAllocator[core_defs.CUDADeviceTyping]): def __init__(self) -> None: super().__init__( device_type=core_defs.DeviceType.CUDA, @@ -288,7 +274,9 @@ def __init__(self) -> None: StandardGPUFieldBufferAllocator: Final[type[FieldBufferAllocatorProtocol]] = cast( type[FieldBufferAllocatorProtocol], - type(device_allocators[CUPY_DEVICE]) if CUPY_DEVICE else InvalidGPUFielBufferAllocator, + type(device_allocators[core_defs.CUPY_DEVICE_TYPE]) + if core_defs.CUPY_DEVICE_TYPE + else InvalidGPUFieldBufferAllocator, ) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 0340d61f89..e075422ca3 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( - foast_to_itir, + foast_to_gtir, foast_to_past, func_to_foast, func_to_past, @@ -40,7 +40,7 @@ ARGS: typing.TypeAlias = arguments.JITArgs CARG: typing.TypeAlias = arguments.CompileTimeArgs -IT_PRG: typing.TypeAlias = itir.FencilDefinition | itir.Program +IT_PRG: typing.TypeAlias = itir.Program INPUT_DATA: typing.TypeAlias = DSL_FOP | FOP | DSL_PRG | PRG | IT_PRG @@ -76,7 +76,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) foast_to_itir: workflow.Workflow[AOT_FOP, itir.Expr] = dataclasses.field( - default_factory=foast_to_itir.adapted_foast_to_itir_factory + default_factory=foast_to_gtir.adapted_foast_to_gtir_factory ) field_view_op_to_prog: workflow.Workflow[AOT_FOP, AOT_PRG] = dataclasses.field( @@ -92,7 +92,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) past_to_itir: workflow.Workflow[AOT_PRG, stages.CompilableProgram] = dataclasses.field( - default_factory=past_to_itir.past_to_itir_factory + default_factory=past_to_itir.past_to_gtir_factory ) def step_order(self, inp: INPUT_PAIR) -> list[str]: @@ -125,7 +125,7 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: ) case PRG(): steps.extend(["past_lint", "field_view_prog_args_transform", "past_to_itir"]) - case itir.FencilDefinition() | itir.Program(): + case itir.Program(): pass case _: raise ValueError("Unexpected input.") diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 4aa0dd03aa..f615833045 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -18,7 +18,6 @@ from collections.abc import Mapping, Sequence import numpy as np -import numpy.typing as npt from gt4py._core import definitions as core_defs from gt4py.eve import utils @@ -70,6 +69,9 @@ def __str__(self) -> str: return self.value +_DIM_KIND_ORDER = {DimensionKind.HORIZONTAL: 1, DimensionKind.VERTICAL: 2, DimensionKind.LOCAL: 3} + + def dimension_to_implicit_offset(dim: str) -> str: """ Return name of offset implicitly defined by a dimension. @@ -95,7 +97,7 @@ def __str__(self) -> str: def __call__(self, val: int) -> NamedIndex: return NamedIndex(self, val) - def __add__(self, offset: int) -> ConnectivityField: + def __add__(self, offset: int) -> Connectivity: # TODO(sf-n): just to avoid circular import. Move or refactor the FieldOffset to avoid this. from gt4py.next.ffront import fbuiltins @@ -104,7 +106,7 @@ def __add__(self, offset: int) -> ConnectivityField: dimension_to_implicit_offset(self.value), source=self, target=(self,) )[offset] - def __sub__(self, offset: int) -> ConnectivityField: + def __sub__(self, offset: int) -> Connectivity: return self + (-offset) @@ -575,6 +577,12 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return Domain(dims=dims, ranges=ranges) + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + # remove cached property + state.pop("slice_at", None) + return state + FiniteDomain: TypeAlias = Domain[FiniteUnitRange] @@ -678,6 +686,9 @@ def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @property def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... + # TODO(havogt) + # This property is wrong, because for a function field we would not know to which NDArrayObject we want to convert + # at the very least, we need to take an allocator and rename this to `as_ndarray`. @property def ndarray(self) -> core_defs.NDArrayObject: ... @@ -688,7 +699,7 @@ def __str__(self) -> str: def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def premap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... + def premap(self, index_field: Connectivity | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod def restrict(self, item: AnyIndexSpec) -> Self: ... @@ -700,8 +711,8 @@ def as_scalar(self) -> core_defs.ScalarT: ... @abc.abstractmethod def __call__( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, ) -> Field: ... @abc.abstractmethod @@ -811,12 +822,64 @@ def remapping(cls) -> ConnectivityKind: return cls.ALTER_DIMS | cls.ALTER_STRUCT +@dataclasses.dataclass(frozen=True) +class ConnectivityType: # TODO(havogt): would better live in type_specifications but would have to solve a circular import + domain: tuple[Dimension, ...] + codomain: Dimension + skip_value: Optional[core_defs.IntegralScalar] + dtype: core_defs.DType + + @property + def has_skip_values(self) -> bool: + return self.skip_value is not None + + +@dataclasses.dataclass(frozen=True) +class NeighborConnectivityType(ConnectivityType): + # TODO(havogt): refactor towards encoding this information in the local dimensions of the ConnectivityType.domain + max_neighbors: int + + @property + def source_dim(self) -> Dimension: + return self.domain[0] + + @property + def neighbor_dim(self) -> Dimension: + return self.domain[1] + + @runtime_checkable # type: ignore[misc] # DimT should be covariant, but then it breaks in other places -class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): +class Connectivity(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod - def codomain(self) -> DimT: ... + def codomain(self) -> DimT: + """ + The `codomain` is the set of all indices in a certain `Dimension`. + + We use the `Dimension` itself to describe the (infinite) set of all indices. + + Note: + We could restrict the infinite codomain to only the indices that are actually contained in the mapping. + Currently, this would just complicate implementation as we do not use this information. + """ + + def __gt_type__(self) -> ConnectivityType: + if is_neighbor_connectivity(self): + return NeighborConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + max_neighbors=self.ndarray.shape[1], + ) + else: + return ConnectivityType( + domain=self.domain.dims, + codomain=self.codomain, + dtype=self.dtype, + skip_value=self.skip_value, + ) @property def kind(self) -> ConnectivityKind: @@ -831,61 +894,61 @@ def skip_value(self) -> Optional[core_defs.IntegralScalar]: ... # Operators def __abs__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __neg__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __invert__(self) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __eq__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __ne__(self, other: Any) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __add__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __radd__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __sub__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rsub__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __mul__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rmul__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __truediv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rtruediv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __floordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __rfloordiv__(self, other: Field | core_defs.IntegralScalar) -> Never: # type: ignore[misc] # Forward operator not callalbe - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __pow__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __and__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __or__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") def __xor__(self, other: Field | core_defs.IntegralScalar) -> Never: - raise TypeError("'ConnectivityField' does not support this operation.") + raise TypeError("'Connectivity' does not support this operation.") # Utility function to construct a `Field` from different buffer representations. @@ -911,38 +974,58 @@ def _connectivity( domain: Optional[DomainLike] = None, dtype: Optional[core_defs.DType] = None, skip_value: Optional[core_defs.IntegralScalar] = None, -) -> ConnectivityField: +) -> Connectivity: raise NotImplementedError -@runtime_checkable -class Connectivity(Protocol): - max_neighbors: int - has_skip_values: bool - origin_axis: Dimension - neighbor_axis: Dimension - index_type: type[int] | type[np.int32] | type[np.int64] +class NeighborConnectivity(Connectivity, Protocol): + # TODO(havogt): work towards encoding this properly in the type + def __gt_type__(self) -> NeighborConnectivityType: ... + - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - """Return neighbor index.""" +def is_neighbor_connectivity(obj: Any) -> TypeGuard[NeighborConnectivity]: + if not isinstance(obj, Connectivity): + return False + domain_dims = obj.domain.dims + return ( + len(domain_dims) == 2 + and domain_dims[0].kind is DimensionKind.HORIZONTAL + and domain_dims[1].kind is DimensionKind.LOCAL + ) -@runtime_checkable -class NeighborTable(Connectivity, Protocol): - table: npt.NDArray +class NeighborTable( + NeighborConnectivity, Protocol +): # TODO(havogt): try to express by inheriting from NdArrayConnectivityField (but this would require a protocol to move it out of `embedded.nd_array_field`) + @property + def ndarray(self) -> core_defs.NDArrayObject: + # Note that this property is currently already there from inheriting from `Field`, + # however this seems wrong, therefore we explicitly introduce it here (or it should come + # implicitly from the `NdArrayConnectivityField` protocol). + ... + +def is_neighbor_table(obj: Any) -> TypeGuard[NeighborTable]: + return is_neighbor_connectivity(obj) and hasattr(obj, "ndarray") -OffsetProviderElem: TypeAlias = Dimension | Connectivity + +OffsetProviderElem: TypeAlias = Dimension | NeighborConnectivity +OffsetProviderTypeElem: TypeAlias = Dimension | NeighborConnectivityType OffsetProvider: TypeAlias = Mapping[Tag, OffsetProviderElem] +OffsetProviderType: TypeAlias = Mapping[Tag, OffsetProviderTypeElem] + + +def offset_provider_to_type(offset_provider: OffsetProvider) -> OffsetProviderType: + return { + k: v.__gt_type__() if isinstance(v, Connectivity) else v for k, v in offset_provider.items() + } DomainDimT = TypeVar("DomainDimT", bound="Dimension") @dataclasses.dataclass(frozen=True, eq=False) -class CartesianConnectivity(ConnectivityField[Dims[DomainDimT], DimT]): +class CartesianConnectivity(Connectivity[Dims[DomainDimT], DimT]): domain_dim: DomainDimT codomain: DimT offset: int = 0 @@ -981,7 +1064,7 @@ def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: return core_defs.Int32DType() # type: ignore[return-value] # This is a workaround to make this class concrete, since `codomain` is an - # abstract property of the `ConnectivityField` Protocol. + # abstract property of the `Connectivity` Protocol. if not TYPE_CHECKING: @functools.cached_property @@ -1024,9 +1107,9 @@ def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRa def premap( self, - index_field: ConnectivityField | fbuiltins.FieldOffset, - *args: ConnectivityField | fbuiltins.FieldOffset, - ) -> ConnectivityField: + index_field: Connectivity | fbuiltins.FieldOffset, + *args: Connectivity | fbuiltins.FieldOffset, + ) -> Connectivity: raise NotImplementedError() __call__ = premap @@ -1043,84 +1126,56 @@ class GridType(StrEnum): UNSTRUCTURED = "unstructured" +def _ordered_dims(dims: list[Dimension] | set[Dimension]) -> list[Dimension]: + return sorted(dims, key=lambda dim: (_DIM_KIND_ORDER[dim.kind], dim.value)) + + +def check_dims(dims: list[Dimension]) -> None: + if sum(1 for dim in dims if dim.kind == DimensionKind.LOCAL) > 1: + raise ValueError("There are more than one dimension with DimensionKind 'LOCAL'.") + + if dims != _ordered_dims(dims): + raise ValueError( + f"Dimensions '{', '.join(map(str, dims))}' are not ordered correctly, expected '{', '.join(map(str, _ordered_dims(dims)))}'." + ) + + def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: """ - Find a unique ordering of multiple (individually ordered) lists of dimensions. - - The resulting list of dimensions contains all dimensions of the arguments - in the order they originally appear. If no unique order exists or a - contradicting order is found an exception is raised. + Find an ordering of multiple lists of dimensions. - A modified version (ensuring uniqueness of the order) of - `Kahn's algorithm `_ - is used to topologically sort the arguments. + The resulting list contains all unique dimensions from the input lists, + sorted first by dims_kind_order, i.e., `Dimension.kind` (`HORIZONTAL` < `VERTICAL` < `LOCAL`) and then + lexicographically by `Dimension.value`. Examples: >>> from gt4py.next.common import Dimension - >>> I, J, K = (Dimension(value=dim) for dim in ["I", "J", "K"]) - >>> promote_dims([I, J], [I, J, K]) == [I, J, K] + >>> I = Dimension("I", DimensionKind.HORIZONTAL) + >>> J = Dimension("J", DimensionKind.HORIZONTAL) + >>> K = Dimension("K", DimensionKind.VERTICAL) + >>> E2V = Dimension("E2V", kind=DimensionKind.LOCAL) + >>> E2C = Dimension("E2C", kind=DimensionKind.LOCAL) + >>> promote_dims([J, K], [I, K]) == [I, J, K] True - - >>> promote_dims([I, J], [K]) # doctest: +ELLIPSIS + >>> promote_dims([K, J], [I, K]) Traceback (most recent call last): ... - ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. - - >>> promote_dims([I, J], [J, I]) # doctest: +ELLIPSIS + ValueError: Dimensions 'K[vertical], J[horizontal]' are not ordered correctly, expected 'J[horizontal], K[vertical]'. + >>> promote_dims([I, K], [J, E2V]) == [I, J, K, E2V] + True + >>> promote_dims([I, E2C], [K, E2V]) Traceback (most recent call last): ... - ValueError: Dimensions can not be promoted. The following dimensions appear in contradicting order: I, J. + ValueError: There are more than one dimension with DimensionKind 'LOCAL'. """ - # build a graph with the vertices being dimensions and edges representing - # the order between two dimensions. The graph is encoded as a dictionary - # mapping dimensions to their predecessors, i.e. a dictionary containing - # adjacency lists. Since graphlib.TopologicalSorter uses predecessors - # (contrary to successors) we also use this directionality here. - graph: dict[Dimension, set[Dimension]] = {} + for dims in dims_list: - if len(dims) == 0: - continue - # create a vertex for each dimension - for dim in dims: - graph.setdefault(dim, set()) - # add edges - predecessor = dims[0] - for dim in dims[1:]: - graph[dim].add(predecessor) - predecessor = dim - - # modified version of Kahn's algorithm - topologically_sorted_list: list[Dimension] = [] - - # compute in-degree for each vertex - in_degree = {v: 0 for v in graph.keys()} - for v1 in graph: - for v2 in graph[v1]: - in_degree[v2] += 1 - - # process vertices with in-degree == 0 - # TODO(tehrengruber): avoid recomputation of zero_in_degree_vertex_list - while zero_in_degree_vertex_list := [v for v, d in in_degree.items() if d == 0]: - if len(zero_in_degree_vertex_list) != 1: - raise ValueError( - f"Dimensions can not be promoted. Could not determine " - f"order of the following dimensions: " - f"{', '.join((dim.value for dim in zero_in_degree_vertex_list))}." - ) - v = zero_in_degree_vertex_list[0] - del in_degree[v] - topologically_sorted_list.insert(0, v) - # update in-degree - for predecessor in graph[v]: - in_degree[predecessor] -= 1 - - if len(in_degree.items()) > 0: - raise ValueError( - f"Dimensions can not be promoted. The following dimensions " - f"appear in contradicting order: {', '.join((dim.value for dim in in_degree.keys()))}." - ) + check_dims(list(dims)) + unique_dims = {dim for dims in dims_list for dim in dims} - return topologically_sorted_list + promoted_dims = _ordered_dims(unique_dims) + check_dims(promoted_dims) + return promoted_dims class FieldBuiltinFuncRegistry: diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index ed244c2932..7a19f3eb9d 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -11,7 +11,6 @@ import enum import os import pathlib -import tempfile from typing import Final @@ -51,25 +50,22 @@ def env_flag_to_bool(name: str, default: bool) -> bool: ) -_PREFIX: Final[str] = "GT4PY" - #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) #: Verbose flag for DSL compilation errors VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False + "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False ) #: Where generated code projects should be persisted. #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) - / "gt4py_cache" + pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" ) @@ -77,11 +73,11 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: - SESSION: generated code projects get destroyed when the interpreter shuts down #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() + os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() ] #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() + os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() ] diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index dd52559e85..7b39511674 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -290,22 +290,24 @@ def as_connectivity( *, allocator: Optional[next_allocators.FieldBufferAllocatorProtocol] = None, device: Optional[core_defs.Device] = None, - skip_value: Optional[core_defs.IntegralScalar] = None, + skip_value: core_defs.IntegralScalar | eve.NothingType | None = eve.NOTHING, # TODO: copy=False -) -> common.ConnectivityField: +) -> common.Connectivity: """ - Construct a connectivity field from the given domain, codomain, and data. + Construct a `Connectivity` from the given domain, codomain, and data. Arguments: - domain: The domain of the connectivity field. It can be either a `common.DomainLike` object or a + domain: The domain of the connectivity. It can be either a `common.DomainLike` object or a sequence of `common.Dimension` objects. - codomain: The codomain dimension of the connectivity field. + codomain: The codomain dimension of the connectivity. data: The data used to construct the connectivity field. - dtype: The data type of the connectivity field. If not provided, it will be inferred from the data. - allocator: The allocator used to allocate the buffer for the connectivity field. If not provided, + dtype: The data type of the connectivity. If not provided, it will be inferred from the data. + allocator: The allocator used to allocate the buffer for the connectivity. If not provided, a default allocator will be used. - device: The device on which the connectivity field will be allocated. If not provided, the default + device: The device on which the connectivity will be allocated. If not provided, the default device will be used. + skip_value: The value that signals missing entries in the neighbor table. Defaults to the default + skip value if it is found in data, otherwise to `None` (= no skip value). Returns: The constructed connectivity field. @@ -313,9 +315,15 @@ def as_connectivity( Raises: ValueError: If the domain or codomain is invalid, or if the shape of the data does not match the domain shape. """ + if skip_value is eve.NOTHING: + skip_value = ( + common._DEFAULT_SKIP_VALUE if (data == common._DEFAULT_SKIP_VALUE).any() else None + ) + assert ( skip_value is None or skip_value == common._DEFAULT_SKIP_VALUE ) # TODO(havogt): not yet configurable + skip_value = cast(Optional[core_defs.IntegralScalar], skip_value) if isinstance(domain, Sequence) and all(isinstance(dim, common.Dimension) for dim in domain): domain = cast(Sequence[common.Dimension], domain) if len(domain) != data.ndim: diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 655a1137e8..537482508b 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -36,7 +36,6 @@ exceptions as embedded_exceptions, ) from gt4py.next.ffront import experimental, fbuiltins -from gt4py.next.iterator import embedded as itir_embedded try: @@ -148,7 +147,8 @@ def as_scalar(self) -> core_defs.ScalarT: raise ValueError( f"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'." ) - return self.ndarray.item() + # note: `.item()` will return a Python type, therefore we use indexing with an empty tuple + return self.asnumpy()[()] # type: ignore[return-value] # should be ensured by the 0-d check @property def codomain(self) -> type[core_defs.ScalarT]: @@ -188,10 +188,10 @@ def from_array( def premap( self: NdArrayField, - *connectivities: common.ConnectivityField | fbuiltins.FieldOffset, + *connectivities: common.Connectivity | fbuiltins.FieldOffset, ) -> NdArrayField: """ - Rearrange the field content using the provided connectivity fields as index mappings. + Rearrange the field content using the provided connectivities (index mappings). This operation is conceptually equivalent to a regular composition of mappings `f∘c`, being `c` the `connectivity` argument and `f` the `self` data field. @@ -205,7 +205,7 @@ def premap( argument used in the right hand side of the operator should therefore have the same product of dimensions `c: S × T → A × B`. Such a mapping can also be expressed as a pair of mappings `c1: S × T → A` and `c2: S × T → B`, and this - is actually the only supported form in GT4Py because `ConnectivityField` instances + is actually the only supported form in GT4Py because `Connectivity` instances can only deal with a single dimension in its codomain. This approach makes connectivities reusable for any combination of dimensions in a field domain and matches the NumPy advanced indexing API, which basically is a @@ -260,15 +260,15 @@ def premap( """ # noqa: RUF002 # TODO(egparedes): move docstring to the `premap` builtin function when it exists - conn_fields: list[common.ConnectivityField] = [] + conn_fields: list[common.Connectivity] = [] codomains_counter: collections.Counter[common.Dimension] = collections.Counter() for connectivity in connectivities: - # For neighbor reductions, a FieldOffset is passed instead of an actual ConnectivityField - if not isinstance(connectivity, common.ConnectivityField): + # For neighbor reductions, a FieldOffset is passed instead of an actual Connectivity + if not isinstance(connectivity, common.Connectivity): assert isinstance(connectivity, fbuiltins.FieldOffset) connectivity = connectivity.as_connectivity_field() - assert isinstance(connectivity, common.ConnectivityField) + assert isinstance(connectivity, common.Connectivity) # Current implementation relies on skip_value == -1: # if we assume the indexed array has at least one element, @@ -317,8 +317,8 @@ def premap( def __call__( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: return functools.reduce( lambda field, current_index_field: field.premap(current_index_field), @@ -459,7 +459,7 @@ def _dace_descriptor(self) -> Any: @dataclasses.dataclass(frozen=True) class NdArrayConnectivityField( # type: ignore[misc] # for __ne__, __eq__ - common.ConnectivityField[common.DimsT, common.DimT], + common.Connectivity[common.DimsT, common.DimT], NdArrayField[common.DimsT, core_defs.IntegralScalar], ): _codomain: common.DimT @@ -578,7 +578,7 @@ def restrict(self, index: common.AnyIndexSpec) -> NdArrayConnectivityField: __getitem__ = restrict -def _domain_premap(data: NdArrayField, *connectivities: common.ConnectivityField) -> NdArrayField: +def _domain_premap(data: NdArrayField, *connectivities: common.Connectivity) -> NdArrayField: """`premap` implementation transforming only the field domain not the data (i.e. translation and relocation).""" new_domain = data.domain for connectivity in connectivities: @@ -657,7 +657,10 @@ def _reshuffling_premap( conn_map[dim] = _identity_connectivity(new_domain, dim, cls=type(connectivity)) # Take data - take_indices = tuple(conn_map[dim].ndarray for dim in data.domain.dims) + take_indices = tuple( + conn_map[dim].ndarray - data.domain[dim].unit_range.start # shift to 0-based indexing + for dim in data.domain.dims + ) new_buffer = data._ndarray.__getitem__(take_indices) return data.__class__.from_array( @@ -667,7 +670,7 @@ def _reshuffling_premap( ) -def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField) -> NdArrayField: +def _remapping_premap(data: NdArrayField, connectivity: common.Connectivity) -> NdArrayField: new_dims = {*connectivity.domain.dims} - {connectivity.codomain} if repeated_dims := (new_dims & {*data.domain.dims}): raise ValueError(f"Remapped field will contain repeated dimensions '{repeated_dims}'.") @@ -692,7 +695,7 @@ def _remapping_premap(data: NdArrayField, connectivity: common.ConnectivityField if restricted_connectivity_domain != connectivity.domain else connectivity ) - assert isinstance(restricted_connectivity, common.ConnectivityField) + assert isinstance(restricted_connectivity, common.Connectivity) # 2- then compute the index array new_idx_array = xp.asarray(restricted_connectivity.ndarray) - current_range.start @@ -970,7 +973,7 @@ def _concat_where( return cls_.from_array(result_array, domain=result_domain) -NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[has-type] +NdArrayField.register_builtin_func(experimental.concat_where, _concat_where) # type: ignore[arg-type] def _make_reduction( @@ -995,15 +998,15 @@ def _builtin_op( offset_definition = current_offset_provider[ axis.value ] # assumes offset and local dimension have same name - assert isinstance(offset_definition, itir_embedded.NeighborTableOffsetProvider) + assert common.is_neighbor_table(offset_definition) new_domain = common.Domain(*[nr for nr in field.domain if nr.dim != axis]) broadcast_slice = tuple( - slice(None) if d in [axis, offset_definition.origin_axis] else xp.newaxis + slice(None) if d in [axis, offset_definition.domain.dims[0]] else xp.newaxis for d in field.domain.dims ) masked_array = xp.where( - xp.asarray(offset_definition.table[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, + xp.asarray(offset_definition.ndarray[broadcast_slice]) != common._DEFAULT_SKIP_VALUE, field.ndarray, initial_value_op(field), ) diff --git a/src/gt4py/next/errors/__init__.py b/src/gt4py/next/errors/__init__.py index 89f78a45e4..9febe098a4 100644 --- a/src/gt4py/next/errors/__init__.py +++ b/src/gt4py/next/errors/__init__.py @@ -23,9 +23,9 @@ __all__ = [ "DSLError", "InvalidParameterAnnotationError", + "MissingArgumentError", "MissingAttributeError", "MissingParameterAnnotationError", - "MissingArgumentError", "UndefinedSymbolError", "UnsupportedPythonFeatureError", ] diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 52fe8d8116..ecaf1a76b4 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -30,10 +30,10 @@ embedded as next_embedded, errors, ) -from gt4py.next.common import Connectivity, Dimension, GridType from gt4py.next.embedded import operators as embedded_operators from gt4py.next.ffront import ( field_operator_ast as foast, + foast_to_gtir, past_process_args, signature, stages as ffront_stages, @@ -80,15 +80,17 @@ class Program: definition_stage: ffront_stages.ProgramDefinition backend: Optional[next_backend.Backend] - connectivities: Optional[dict[str, Connectivity]] + connectivities: Optional[common.OffsetProvider] = ( + None # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information + ) @classmethod def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend], - grid_type: Optional[GridType] = None, - connectivities: Optional[dict[str, Connectivity]] = None, + grid_type: Optional[common.GridType] = None, + connectivities: Optional[common.OffsetProviderType] = None, ) -> Program: program_def = ffront_stages.ProgramDefinition(definition=definition, grid_type=grid_type) return cls(definition_stage=program_def, backend=backend, connectivities=connectivities) @@ -138,10 +140,10 @@ def _frontend_transforms(self) -> next_backend.Transforms: def with_backend(self, backend: next_backend.Backend) -> Program: return dataclasses.replace(self, backend=backend) - def with_connectivities(self, connectivities: dict[str, Connectivity]) -> Program: + def with_connectivities(self, connectivities: common.OffsetProviderType) -> Program: return dataclasses.replace(self, connectivities=connectivities) - def with_grid_type(self, grid_type: GridType) -> Program: + def with_grid_type(self, grid_type: common.GridType) -> Program: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) @@ -185,7 +187,7 @@ def _all_closure_vars(self) -> dict[str, Any]: return transform_utils._get_closure_vars_recursively(self.past_stage.closure_vars) @functools.cached_property - def itir(self) -> itir.FencilDefinition: + def gtir(self) -> itir.Program: no_args_past = toolchain.CompilableProgram( data=ffront_stages.PastProgramDefinition( past_node=self.past_stage.past_node, @@ -197,7 +199,7 @@ def itir(self) -> itir.FencilDefinition: return self._frontend_transforms.past_to_itir(no_args_past).data @functools.cached_property - def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderElem]: + def _implicit_offset_provider(self) -> dict[str, common.Dimension]: """ Add all implicit offset providers. @@ -224,14 +226,12 @@ def _implicit_offset_provider(self) -> dict[common.Tag, common.OffsetProviderEle ) return implicit_offset_provider - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: offset_provider = offset_provider | self._implicit_offset_provider if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a perfomance backend." + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." ), stacklevel=2, ) @@ -285,19 +285,17 @@ def definition(self) -> str: def with_backend(self, backend: next_backend.Backend) -> FrozenProgram: return self.__class__(program=self.program, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FrozenProgram: + def with_grid_type(self, grid_type: common.GridType) -> FrozenProgram: return self.__class__( program=dataclasses.replace(self.program, grid_type=grid_type), backend=self.backend ) def jit( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any + self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any ) -> stages.CompiledProgram: return self.backend.jit(self.program, *args, offset_provider=offset_provider, **kwargs) - def __call__( - self, *args: Any, offset_provider: dict[str, Dimension | Connectivity], **kwargs: Any - ) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: args, kwargs = signature.convert_to_positional(self.program, *args, **kwargs) if not self._compiled_program: @@ -308,7 +306,7 @@ def __call__( try: - from gt4py.next.program_processors.runners.dace_iterator import Program + from gt4py.next.program_processors.runners.dace.program import Program except ImportError: pass @@ -326,7 +324,7 @@ class ProgramFromPast(Program): past_stage: ffront_stages.PastProgramDefinition - def __call__(self, *args: Any, offset_provider: dict[str, Dimension], **kwargs: Any) -> None: + def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: Any) -> None: if self.backend is None: raise NotImplementedError( "Programs created from a PAST node (without a function definition) can not be executed in embedded mode" @@ -348,7 +346,7 @@ def __post_init__(self): class ProgramWithBoundArgs(Program): bound_args: dict[str, typing.Union[float, int, bool]] = None - def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs): + def __call__(self, *args, offset_provider: common.OffsetProvider, **kwargs): type_ = self.past_stage.past_node.type new_type = ts_ffront.ProgramType( definition=ts.FunctionType( @@ -434,7 +432,7 @@ def program( *, # `NOTHING` -> default backend, `None` -> no backend (embedded execution) backend: next_backend.Backend | eve.NOTHING = eve.NOTHING, - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, frozen: bool = False, ) -> Program | FrozenProgram | Callable[[types.FunctionType], Program | FrozenProgram]: """ @@ -504,7 +502,7 @@ def from_function( cls, definition: types.FunctionType, backend: Optional[next_backend.Backend], - grid_type: Optional[GridType] = None, + grid_type: Optional[common.GridType] = None, *, operator_node_cls: type[OperatorNodeT] = foast.FieldOperator, operator_attributes: Optional[dict[str, Any]] = None, @@ -555,15 +553,20 @@ def __gt_type__(self) -> ts.CallableType: def with_backend(self, backend: next_backend.Backend) -> FieldOperator: return dataclasses.replace(self, backend=backend) - def with_grid_type(self, grid_type: GridType) -> FieldOperator: + def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: return dataclasses.replace( self, definition_stage=dataclasses.replace(self.definition_stage, grid_type=grid_type) ) + # TODO(tehrengruber): We can not use transforms from `self.backend` since this can be + # a different backend than the one of the program that calls this field operator. Just use + # the hard-coded lowering until this is cleaned up. def __gt_itir__(self) -> itir.FunctionDefinition: - return self._frontend_transforms.foast_to_itir( - toolchain.CompilableProgram(self.foast_stage, arguments.CompileTimeArgs.empty()) - ) + return foast_to_gtir.foast_to_gtir(self.foast_stage) + + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + def __gt_gtir__(self) -> itir.FunctionDefinition: + return foast_to_gtir.foast_to_gtir(self.foast_stage) def __gt_closure_vars__(self) -> dict[str, Any]: return self.foast_stage.closure_vars @@ -591,6 +594,10 @@ def __call__(self, *args, **kwargs) -> None: if "out" not in kwargs: raise errors.MissingArgumentError(None, "out", True) out = kwargs.pop("out") + if "domain" in kwargs: + domain = common.domain(kwargs.pop("domain")) + out = out[domain] + args, kwargs = type_info.canonicalize_arguments( self.foast_stage.foast_node.type, args, kwargs ) @@ -681,33 +688,33 @@ def field_operator_inner(definition: types.FunctionType) -> FieldOperator[foast. def scan_operator( definition: types.FunctionType, *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> FieldOperator[foast.ScanOperator]: ... @typing.overload def scan_operator( *, - axis: Dimension, + axis: common.Dimension, forward: bool, init: core_defs.Scalar, backend: Optional[str], - grid_type: GridType, + grid_type: common.GridType, ) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... def scan_operator( definition: Optional[types.FunctionType] = None, *, - axis: Dimension, + axis: common.Dimension, forward: bool = True, init: core_defs.Scalar = 0.0, backend=eve.NOTHING, - grid_type: GridType = None, + grid_type: common.GridType = None, ) -> ( FieldOperator[foast.ScanOperator] | Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]] diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index 23b719abb7..79d188cdf2 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -106,11 +106,11 @@ def get_location(self, node: ast.AST) -> SourceLocation: # `FixMissingLocations` ensures that all nodes have the location attributes assert hasattr(node, "lineno") - line = node.lineno + line_offset if node.lineno is not None else None + line = node.lineno + line_offset assert hasattr(node, "end_lineno") end_line = node.end_lineno + line_offset if node.end_lineno is not None else None assert hasattr(node, "col_offset") - column = 1 + node.col_offset + col_offset if node.col_offset is not None else None + column = 1 + node.col_offset + col_offset assert hasattr(node, "end_col_offset") end_column = ( 1 + node.end_col_offset + col_offset if node.end_col_offset is not None else None diff --git a/src/gt4py/next/ffront/experimental.py b/src/gt4py/next/ffront/experimental.py index 8a94c20832..bd22aebe57 100644 --- a/src/gt4py/next/ffront/experimental.py +++ b/src/gt4py/next/ffront/experimental.py @@ -14,7 +14,7 @@ @BuiltInFunction -def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.ConnectivityField: +def as_offset(offset_: FieldOffset, field: common.Field, /) -> common.Connectivity: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 3b711212a3..ee14006b22 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -10,13 +10,13 @@ import functools import inspect import math -from builtins import bool, float, int, tuple +import operator +from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np -from numpy import float32, float64, int32, int64 +from numpy import float32, float64, int8, int16, int32, int64, uint8, uint16, uint32, uint64 -import gt4py.next as gtx from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.common import Dimension, Field # noqa: F401 [unused-import] for TYPE_BUILTINS @@ -30,12 +30,19 @@ TYPE_BUILTINS = [ common.Field, common.Dimension, + int8, + uint8, + int16, + uint16, int32, + uint32, int64, + uint64, float32, float64, *PYTHON_TYPE_BUILTINS, -] +] # TODO(tehrengruber): validate matches iterator.builtins.TYPE_BUILTINS? + TYPE_BUILTIN_NAMES = [t.__name__ for t in TYPE_BUILTINS] # Be aware: Type aliases are not fully supported in the frontend yet, e.g. `IndexType(1)` will not @@ -55,7 +62,7 @@ def _type_conversion_helper(t: type) -> type[ts.TypeSpec] | tuple[type[ts.TypeSp return ts.DimensionType elif t is FieldOffset: return ts.OffsetType - elif t is common.ConnectivityField: + elif t is common.Connectivity: return ts.OffsetType elif t is core_defs.ScalarT: return ts.ScalarType @@ -197,7 +204,7 @@ def astype( return core_defs.dtype(type_).scalar_type(value) -_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs} +_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs, "neg": operator.neg} UNARY_MATH_NUMBER_BUILTIN_NAMES: Final = [*_UNARY_MATH_NUMBER_BUILTIN_IMPL.keys()] _UNARY_MATH_FP_BUILTIN_IMPL: Final = { @@ -245,7 +252,7 @@ def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs value ) # default implementation for scalars, Fields are handled via dispatch - return _math_builtin(value) + return cast(common.Field | core_defs.ScalarT, _math_builtin(value)) # type: ignore[operator] # calling a function of unknown type impl.__name__ = name globals()[name] = BuiltInFunction(impl) @@ -321,7 +328,7 @@ def __post_init__(self) -> None: def __gt_type__(self) -> ts.OffsetType: return ts.OffsetType(source=self.source, target=self.target) - def __getitem__(self, offset: int) -> common.ConnectivityField: + def __getitem__(self, offset: int) -> common.Connectivity: """Serve as a connectivity factory.""" from gt4py.next import embedded # avoid circular import @@ -330,22 +337,19 @@ def __getitem__(self, offset: int) -> common.ConnectivityField: assert current_offset_provider is not None offset_definition = current_offset_provider[self.value] - connectivity: common.ConnectivityField + connectivity: common.Connectivity if isinstance(offset_definition, common.Dimension): connectivity = common.CartesianConnectivity(offset_definition, offset) - elif isinstance( - offset_definition, (gtx.NeighborTableOffsetProvider, common.ConnectivityField) - ): - unrestricted_connectivity = self.as_connectivity_field() - assert unrestricted_connectivity.domain.ndim > 1 + elif isinstance(offset_definition, common.Connectivity): + assert common.is_neighbor_connectivity(offset_definition) named_index = common.NamedIndex(self.target[-1], offset) - connectivity = unrestricted_connectivity[named_index] + connectivity = offset_definition[named_index] else: raise NotImplementedError() return connectivity - def as_connectivity_field(self) -> common.ConnectivityField: + def as_connectivity_field(self) -> common.Connectivity: """Convert to connectivity field using the offset providers in current embedded execution context.""" from gt4py.next import embedded # avoid circular import @@ -356,18 +360,8 @@ def as_connectivity_field(self) -> common.ConnectivityField: cache_key = id(offset_definition) if (connectivity := self._cache.get(cache_key, None)) is None: - if isinstance(offset_definition, common.ConnectivityField): + if isinstance(offset_definition, common.Connectivity): connectivity = offset_definition - elif isinstance(offset_definition, gtx.NeighborTableOffsetProvider): - connectivity = gtx.as_connectivity( - domain=self.target, - codomain=self.source, - data=offset_definition.table, - dtype=offset_definition.index_type, - skip_value=( - common._DEFAULT_SKIP_VALUE if offset_definition.has_skip_values else None - ), - ) else: raise NotImplementedError() diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 4693fed1a0..4f547aae14 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -180,7 +180,7 @@ class IfStmt(Stmt): @datamodels.root_validator @classmethod def _collect_common_symbols(cls: type[IfStmt], instance: IfStmt) -> None: - common_symbol_names = ( + common_symbol_names = sorted( # sort is required to get stable results across runs instance.true_branch.annex.symtable.keys() & instance.false_branch.annex.symtable.keys() ) instance.annex.propagated_symbols = { diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index d334487ae1..26bcadaef1 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Optional, TypeVar, cast +from typing import Any, Optional, TypeAlias, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -48,7 +48,7 @@ def with_altered_scalar_kind( if isinstance(type_spec, ts.FieldType): return ts.FieldType( dims=type_spec.dims, - dtype=ts.ScalarType(kind=new_scalar_kind, shape=type_spec.dtype.shape), + dtype=with_altered_scalar_kind(type_spec.dtype, new_scalar_kind), ) elif isinstance(type_spec, ts.ScalarType): return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape) @@ -68,13 +68,18 @@ def construct_tuple_type( >>> mask_type = ts.FieldType( ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) ... ) - >>> true_branch_types = [ts.ScalarType(kind=ts.ScalarKind), ts.ScalarType(kind=ts.ScalarKind)] + >>> true_branch_types = [ + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), + ... ] >>> false_branch_types = [ - ... ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)), - ... ts.ScalarType(kind=ts.ScalarKind), + ... ts.FieldType( + ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ), + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ... ] >>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type)) - [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] + [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] """ element_types_new = true_branch_types for i, element in enumerate(true_branch_types): @@ -105,16 +110,16 @@ def promote_to_mask_type( >>> I, J = (Dimension(value=dim) for dim in ["I", "J"]) >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) >>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype)) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=ScalarType(kind=, shape=None), shape=None)) + >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), dtype) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype) ... ) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I], dtype=bool_type), ts.FieldType(dims=[I, J], dtype=dtype) ... ) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) """ if isinstance(input_type, ts.ScalarType) or not all( item in input_type.dims for item in mask_type.dims @@ -360,7 +365,7 @@ def visit_Assign(self, node: foast.Assign, **kwargs: Any) -> foast.Assign: def visit_TupleTargetAssign( self, node: foast.TupleTargetAssign, **kwargs: Any ) -> foast.TupleTargetAssign: - TargetType = list[foast.Starred | foast.Symbol] + TargetType: TypeAlias = list[foast.Starred | foast.Symbol] values = self.visit(node.value, **kwargs) if isinstance(values.type, ts.TupleType): @@ -374,7 +379,7 @@ def visit_TupleTargetAssign( ) new_targets: TargetType = [] - new_type: ts.TupleType | ts.DataType + new_type: ts.DataType for i, index in enumerate(indices): old_target = targets[i] @@ -391,7 +396,8 @@ def visit_TupleTargetAssign( location=old_target.location, ) else: - new_type = values.type.types[index] + new_type = values.type.types[index] # type: ignore[assignment] # see check in next line + assert isinstance(new_type, ts.DataType) new_target = self.visit( old_target, refine_type=new_type, location=old_target.location, **kwargs ) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 948a8481d7..f884ec555d 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -53,8 +53,8 @@ def adapted_foast_to_gtir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, i return toolchain.StripArgsAdapter(foast_to_gtir_factory(**kwargs)) -def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node.type): +def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: + if not type_info.contains_local_field(node_type): return lambda x: im.op_as_fieldop("make_const_list")(x) return lambda x: x @@ -116,7 +116,31 @@ def visit_FieldOperator( def visit_ScanOperator( self, node: foast.ScanOperator, **kwargs: Any ) -> itir.FunctionDefinition: - raise NotImplementedError("TODO") + # note: we don't need the axis here as this is handled by the program + # decorator + assert isinstance(node.type, ts_ffront.ScanOperatorType) + + # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. + # In iterator IR we didn't properly specify if this is legal, + # however after lift-inlining the expressions are transformed back to literals. + forward = self.visit(node.forward, **kwargs) + init = self.visit(node.init, **kwargs) + + # lower definition function + func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) + new_body = func_definition.expr + + stencil_args: list[itir.Expr] = [] + assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args + for param in func_definition.params[1:]: + new_body = im.let(param.id, im.deref(param.id))(new_body) + stencil_args.append(im.ref(param.id)) + + definition = itir.Lambda(params=func_definition.params, expr=new_body) + + body = im.as_fieldop(im.scan(definition, forward, init))(*stencil_args) + + return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: raise AssertionError("Statements must always be visited in the context of a function.") @@ -212,19 +236,20 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._map("not_", node.operand) - - return self._map( - node.op.value, - foast.Constant(value="0", type=dtype, location=node.location), - node.operand, - ) + return self._lower_and_map("not_", node.operand) + if node.op in [dialect_ast_enums.UnaryOperator.USUB]: + return self._lower_and_map("neg", node.operand) + if node.op in [dialect_ast_enums.UnaryOperator.UADD]: + return self.visit(node.operand) + else: + raise NotImplementedError(f"Unary operator '{node.op}' is not supported.") def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: assert ( @@ -236,7 +261,7 @@ def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunC ) def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) + return self._lower_and_map(node.op.value, node.left, node.right) def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: current_expr = self.visit(node.func, **kwargs) @@ -324,10 +349,6 @@ def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: *lowered_args, *lowered_kwargs.values() ) - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - raise NotImplementedError("TODO") - return result raise AssertionError( @@ -338,34 +359,37 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id - def create_cast(expr: itir.Expr, t: ts.TypeSpec) -> itir.FunCall: - if isinstance(t, ts.FieldType): - return im.as_fieldop( - im.lambda_("__val")(im.call("cast_")(im.deref("__val"), str(new_type))) - )(expr) - else: - assert isinstance(t, ts.ScalarType) - return im.call("cast_")(expr, str(new_type)) + def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: + return _map(im.lambda_("val")(im.cast_("val", str(new_type))), (expr,), t) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler - return create_cast(obj, node.type) + return create_cast(obj, (node.args[0].type,)) - return lowering_utils.process_elements(create_cast, obj, node.type, with_type=True) + return lowering_utils.process_elements( + create_cast, obj, node.type, arg_types=(node.args[0].type,) + ) def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: if not isinstance(node.type, ts.TupleType): # to keep the IR simpler - return im.op_as_fieldop("if_")(*self.visit(node.args)) + return self._lower_and_map("if_", *node.args) cond_ = self.visit(node.args[0]) cond_symref_name = f"__cond_{eve_utils.content_hash(cond_)}" - def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: - return im.op_as_fieldop("if_")(im.ref(cond_symref_name), true_, false_) + def create_if( + true_: itir.Expr, false_: itir.Expr, arg_types: tuple[ts.TypeSpec, ts.TypeSpec] + ) -> itir.FunCall: + return _map( + "if_", + (im.ref(cond_symref_name), true_, false_), + (node.args[0].type, *arg_types), + ) result = lowering_utils.process_elements( create_if, (self.visit(node.args[1]), self.visit(node.args[2])), node.type, + arg_types=(node.args[1].type, node.args[2].type), ) return im.let(cond_symref_name, cond_)(result) @@ -373,10 +397,11 @@ def create_if(true_: itir.Expr, false_: itir.Expr) -> itir.FunCall: _visit_concat_where = _visit_where # TODO(havogt): upgrade concat_where def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) + expr = self.visit(node.args[0], **kwargs) + return im.as_fieldop(im.ref("deref"))(expr) def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._map(self.visit(node.func, **kwargs), *node.args) + return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) def _make_reduction_expr( self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any @@ -384,7 +409,7 @@ def _make_reduction_expr( # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) it = self.visit(node.args[0], **kwargs) assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - val = im.call(im.call("reduce")(op, init_expr)) + val = im.reduce(op, init_expr) return im.op_as_fieldop(val)(it) def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: @@ -393,12 +418,14 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) min_value, _ = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(min_value), dtype) return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) _, max_value = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(max_value), dtype) return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) @@ -435,19 +462,34 @@ def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: return self._make_literal(node.value, node.type) - def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - lowered_args = [self.visit(arg, **kwargs) for arg in args] - if all( - isinstance(t, ts.ScalarType) - for arg in args - for t in type_info.primitive_constituents(arg.type) - ): - return im.call(op)(*lowered_args) # scalar operation - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - op = im.call("map_")(op) + def _lower_and_map(self, op: itir.Lambda | str, *args: Any, **kwargs: Any) -> itir.FunCall: + return _map( + op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) + ) + + +def _map( + op: itir.Lambda | str, + lowered_args: tuple, + original_arg_types: tuple[ts.TypeSpec, ...], +) -> itir.FunCall: + """ + Mapping includes making the operation an `as_fieldop` (first kind of mapping), but also `itir.map_`ing lists. + """ + if all( + isinstance(t, ts.ScalarType) + for arg_type in original_arg_types + for t in type_info.primitive_constituents(arg_type) + ): + return im.call(op)(*lowered_args) # scalar operation + if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): + lowered_args = tuple( + promote_to_list(arg_type)(larg) + for arg_type, larg in zip(original_arg_types, lowered_args) + ) + op = im.map_(op) - return im.op_as_fieldop(im.call(op))(*lowered_args) + return im.op_as_fieldop(op)(*lowered_args) class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py deleted file mode 100644 index 7936eda1cf..0000000000 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ /dev/null @@ -1,492 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# FIXME[#1582](havogt): remove after refactoring to GTIR - -import dataclasses -from typing import Any, Callable, Optional - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.extended_typing import Never -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.ffront import ( - dialect_ast_enums, - fbuiltins, - field_operator_ast as foast, - lowering_utils, - stages as ffront_stages, - type_specifications as ts_ffront, -) -from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES -from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES -from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.ffront.stages import AOT_FOP, FOP -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts - - -def foast_to_itir(inp: FOP) -> itir.Expr: - """ - Lower a FOAST field operator node to Iterator IR. - - See the docstring of `FieldOperatorLowering` for details. - """ - return FieldOperatorLowering.apply(inp.foast_node) - - -def foast_to_itir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.Expr]: - """Wrap `foast_to_itir` into a chainable and, optionally, cached workflow step.""" - wf = foast_to_itir - if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) - return wf - - -def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, itir.Expr]: - """Wrap the `foast_to_itir` workflow step into an adapter to fit into backend transform workflows.""" - return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs)) - - -def promote_to_list(node: foast.Symbol | foast.Expr) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node.type): - return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) - return lambda x: x - - -@dataclasses.dataclass -class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): - """ - Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). - - The strategy is to lower every expression to lifted stencils, - i.e. taking iterators and returning iterator. - - Examples - -------- - >>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser - >>> from gt4py.next import Field, Dimension, float64 - >>> - >>> IDim = Dimension("IDim") - >>> def fieldop(inp: Field[[IDim], "float64"]): - ... return inp - >>> - >>> parsed = FieldOperatorParser.apply_to_function(fieldop) - >>> lowered = FieldOperatorLowering.apply(parsed) - >>> type(lowered) - - >>> lowered.id - SymbolName('fieldop') - >>> lowered.params # doctest: +ELLIPSIS - [Sym(id=SymbolName('inp'))] - """ - - uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) - - @classmethod - def apply(cls, node: foast.LocatedNode) -> itir.Expr: - return cls().visit(node) - - def visit_FunctionDefinition( - self, node: foast.FunctionDefinition, **kwargs: Any - ) -> itir.FunctionDefinition: - params = self.visit(node.params) - return itir.FunctionDefinition( - id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) - ) # `expr` is a lifted stencil - - def visit_FieldOperator( - self, node: foast.FieldOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - - new_body = func_definition.expr - - return itir.FunctionDefinition( - id=func_definition.id, params=func_definition.params, expr=new_body - ) - - def visit_ScanOperator( - self, node: foast.ScanOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - # note: we don't need the axis here as this is handled by the program - # decorator - assert isinstance(node.type, ts_ffront.ScanOperatorType) - - # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. - # In iterator IR we didn't properly specify if this is legal, - # however after lift-inlining the expressions are transformed back to literals. - forward = im.deref(self.visit(node.forward, **kwargs)) - init = lowering_utils.process_elements( - im.deref, self.visit(node.init, **kwargs), node.init.type - ) - - # lower definition function - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - new_body = im.let( - func_definition.params[0].id, - # promote carry to iterator of tuples - # (this is the only place in the lowering were a variable is captured in a lifted lambda) - lowering_utils.to_tuples_of_iterator( - im.promote_to_const_iterator(func_definition.params[0].id), - [*node.type.definition.pos_or_kw_args.values()][0], # noqa: RUF015 [unnecessary-iterable-allocation-for-first-element] - ), - )( - # the function itself returns a tuple of iterators, deref element-wise - lowering_utils.process_elements( - im.deref, func_definition.expr, node.type.definition.returns - ) - ) - - stencil_args: list[itir.Expr] = [] - assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args - for param, arg_type in zip( - func_definition.params[1:], - [*node.type.definition.pos_or_kw_args.values()][1:], - strict=True, - ): - if isinstance(arg_type, ts.TupleType): - # convert into iterator of tuples - stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) - - new_body = im.let( - param.id, lowering_utils.to_tuples_of_iterator(param.id, arg_type) - )(new_body) - else: - stencil_args.append(im.ref(param.id)) - - definition = itir.Lambda(params=func_definition.params, expr=new_body) - - body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) - - return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) - - def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: - raise AssertionError("Statements must always be visited in the context of a function.") - - def visit_Return( - self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return self.visit(node.value, **kwargs) - - def visit_BlockStmt( - self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - for stmt in reversed(node.stmts): - inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) - assert inner_expr - return inner_expr - - def visit_IfStmt( - self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) - assert ( - isinstance(node.condition.type, ts.ScalarType) - and node.condition.type.kind == ts.ScalarKind.BOOL - ) - - cond = self.visit(node.condition, **kwargs) - - return_kind: StmtReturnKind = deduce_stmt_return_kind(node) - - common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols - - if return_kind is StmtReturnKind.NO_RETURN: - # pack the common symbols into a tuple - common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) - - # apply both branches and extract the common symbols through the prepared tuple - true_branch = self.visit(node.true_branch, inner_expr=common_symrefs, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=common_symrefs, **kwargs) - - # unpack the common symbols' tuple for `inner_expr` - for i, sym in enumerate(common_symbols.keys()): - inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) - - # here we assume neither branch returns - return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( - inner_expr - ) - elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: - common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) - common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) - - # wrap the inner expression in a lambda function. note that this increases the - # operation count if both branches are evaluated. - inner_expr_name = self.uid_generator.sequential_id(prefix="__inner_expr") - inner_expr_evaluator = im.lambda_(*common_syms)(inner_expr) - inner_expr = im.call(inner_expr_name)(*common_symrefs) - - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.let(inner_expr_name, inner_expr_evaluator)( - im.if_(im.deref(cond), true_branch, false_branch) - ) - - assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN - - # note that we do not duplicate `inner_expr` here since if both branches - # return, `inner_expr` is ignored. - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.if_(im.deref(cond), true_branch, false_branch) - - def visit_Assign( - self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( - inner_expr - ) - - def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: - return im.sym(node.id) - - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: - return im.ref(node.id) - - def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: - return im.tuple_get(node.index, self.visit(node.value, **kwargs)) - - def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: - return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) - - def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: - # TODO(tehrengruber): extend iterator ir to support unary operators - dtype = type_info.extract_dtype(node.type) - if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: - if dtype.kind != ts.ScalarKind.BOOL: - raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._map("not_", node.operand) - - return self._map( - node.op.value, - foast.Constant(value="0", type=dtype, location=node.location), - node.operand, - ) - - def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) - - def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: - op = "if_" - args = (node.condition, node.true_expr, node.false_expr) - lowered_args: list[itir.Expr] = [ - lowering_utils.to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) - for arg in args - ] - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - op = im.call("map_")(op) - - return lowering_utils.to_tuples_of_iterator( - im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type - ) - - def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._map(node.op.value, node.left, node.right) - - def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - current_expr = self.visit(node.func, **kwargs) - - for arg in node.args: - match arg: - # `field(Off[idx])` - case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): - current_expr = im.lift( - im.lambda_("it")(im.deref(im.shift(offset_name, offset_index)("it"))) - )(current_expr) - # `field(Dim + idx)` - case foast.BinOp( - op=dialect_ast_enums.BinaryOperator.ADD - | dialect_ast_enums.BinaryOperator.SUB, - left=foast.Name(id=dimension), - right=foast.Constant(value=offset_index), - ): - if arg.op == dialect_ast_enums.BinaryOperator.SUB: - offset_index *= -1 - current_expr = im.lift( - # TODO(SF-N): we rely on the naming-convention that the cartesian dimensions - # are passed suffixed with `off`, e.g. the `K` is passed as `Koff` in the - # offset provider. This is a rather unclean solution and should be - # improved. - im.lambda_("it")( - im.deref( - im.shift( - common.dimension_to_implicit_offset(dimension), offset_index - )("it") - ) - ) - )(current_expr) - # `field(Off)` - case foast.Name(id=offset_name): - # only a single unstructured shift is supported so returning here is fine even though we - # are in a loop. - assert len(node.args) == 1 and len(arg.type.target) > 1 # type: ignore[attr-defined] # ensured by pattern - return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - # `field(as_offset(Off, offset_field))` - case foast.Call(func=foast.Name(id="as_offset")): - func_args = arg - # TODO(tehrengruber): Use type system to deduce the offset dimension instead of - # (e.g. to allow aliasing) - offset_dim = func_args.args[0] - assert isinstance(offset_dim, foast.Name) - offset_it = self.visit(func_args.args[1], **kwargs) - current_expr = im.lift( - im.lambda_("it", "offset")( - im.deref(im.shift(offset_dim.id, im.deref("offset"))("it")) - ) - )(current_expr, offset_it) - case _: - raise FieldOperatorLoweringError("Unexpected shift arguments!") - - return current_expr - - def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - if type_info.type_class(node.func.type) is ts.FieldType: - return self._visit_shift(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: - return self._visit_math_built_in(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in ( - FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES - ): - visitor = getattr(self, f"_visit_{node.func.id}") - return visitor(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: - return self._visit_type_constr(node, **kwargs) - elif isinstance( - node.func.type, - (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), - ): - # ITIR has no support for keyword arguments. Instead, we concatenate both positional - # and keyword arguments and use the unique order as given in the function signature. - lowered_args, lowered_kwargs = type_info.canonicalize_arguments( - node.func.type, - self.visit(node.args, **kwargs), - self.visit(node.kwargs, **kwargs), - use_signature_ordering=True, - ) - result = im.call(self.visit(node.func, **kwargs))( - *lowered_args, *lowered_kwargs.values() - ) - - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - result = lowering_utils.to_tuples_of_iterator( - result, node.func.type.definition.returns - ) - - return result - - raise AssertionError( - f"Call to object of type '{type(node.func.type).__name__}' not understood." - ) - - def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, new_type = node.args[0], node.args[1].id - return lowering_utils.process_elements( - lambda x: im.promote_to_lifted_stencil( - im.lambda_("it")(im.call("cast_")("it", str(new_type))) - )(x), - self.visit(obj, **kwargs), - obj.type, - ) - - def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - condition, true_value, false_value = node.args - - lowered_condition = self.visit(condition, **kwargs) - return lowering_utils.process_elements( - lambda tv, fv: im.promote_to_lifted_stencil("if_")(lowered_condition, tv, fv), - [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], - node.type, - ) - - _visit_concat_where = _visit_where - - def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) - - def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._map(self.visit(node.func, **kwargs), *node.args) - - def _make_reduction_expr( - self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any - ) -> itir.Expr: - # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) - it = self.visit(node.args[0], **kwargs) - assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - val = im.call(im.call("reduce")(op, im.deref(init_expr))) - return im.promote_to_lifted_stencil(val)(it) - - def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) - - def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - min_value, _ = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(min_value), dtype) - return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) - - def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - _, max_value = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(max_value), dtype) - return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) - - def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - el = node.args[0] - node_kind = self.visit(node.type).kind.name.lower() - source_type = {**fbuiltins.BUILTINS, "string": str}[el.type.__str__().lower()] - target_type = fbuiltins.BUILTINS[node_kind] - - if isinstance(el, foast.Constant): - val = source_type(el.value) - elif isinstance(el, foast.UnaryOp) and isinstance(el.operand, foast.Constant): - operand = source_type(el.operand.value) - val = eval(f"lambda arg: {el.op}arg")(operand) - else: - raise FieldOperatorLoweringError( - f"Type cast only supports literal arguments, {node.type} not supported." - ) - val = target_type(val) - - return im.promote_to_const_iterator(im.literal(str(val), node_kind)) - - def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: - # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; - # the following constructs work if they are removed by inlining. - if isinstance(type_, ts.TupleType): - return im.make_tuple( - *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) - ) - elif isinstance(type_, ts.ScalarType): - typename = type_.kind.name.lower() - return im.promote_to_const_iterator(im.literal(str(val), typename)) - raise ValueError(f"Unsupported literal type '{type_}'.") - - def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: - return self._make_literal(node.value, node.type) - - def _map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - lowered_args = [self.visit(arg, **kwargs) for arg in args] - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [promote_to_list(arg)(larg) for arg, larg in zip(args, lowered_args)] - op = im.call("map_")(op) - - return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - - -class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 0844f63286..330bc79809 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -12,7 +12,7 @@ from gt4py.eve import utils as eve_utils from gt4py.next.ffront import ( dialect_ast_enums, - foast_to_itir, + foast_to_gtir, program_ast as past, stages as ffront_stages, type_specifications as ts_ffront, @@ -45,6 +45,11 @@ def __gt_type__(self) -> ts.CallableType: def __gt_itir__(self) -> itir.Expr: return self.foast_to_itir(self.definition) + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + def __gt_gtir__(self) -> itir.Expr: + # backend should have self.foast_to_itir set to foast_to_gtir + return self.foast_to_itir(self.definition) + @dataclasses.dataclass(frozen=True) class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): @@ -63,7 +68,7 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... def copy(a: gtx.Field[[IDim], gtx.float32]) -> gtx.Field[[IDim], gtx.float32]: ... return a - >>> op_to_prog = OperatorToProgram(foast_to_itir.adapted_foast_to_itir_factory()) + >>> op_to_prog = OperatorToProgram(foast_to_gtir.adapted_foast_to_gtir_factory()) >>> compile_time_args = arguments.CompileTimeArgs( ... args=tuple(param.type for param in copy.foast_stage.foast_node.definition.params), @@ -164,7 +169,7 @@ def operator_to_program_factory( ) -> workflow.Workflow[AOT_FOP, AOT_PRG]: """Optionally wrap `OperatorToProgram` in a `CachedStep`.""" wf: workflow.Workflow[AOT_FOP, AOT_PRG] = OperatorToProgram( - foast_to_itir_step or foast_to_itir.adapted_foast_to_itir_factory() + foast_to_itir_step or foast_to_gtir.adapted_foast_to_gtir_factory() ) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ebe12d3a8b..ef20b99d91 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -60,8 +60,9 @@ def func_to_foast(inp: DSL_FOP) -> FOP: >>> print(foast_definition.foast_node.id) dsl_operator - >>> print(foast_definition.closure_vars) - {'const': 2.0} + >>> foast_closure_vars = {k: str(v) for k, v in foast_definition.closure_vars.items()} + >>> print(foast_closure_vars) + {'const': '2.0'} """ source_def = source_utils.SourceDefinition.from_function(inp.definition) closure_vars = source_utils.get_closure_vars_from_function(inp.definition) diff --git a/src/gt4py/next/ffront/func_to_past.py b/src/gt4py/next/ffront/func_to_past.py index f415c95b63..09f53be600 100644 --- a/src/gt4py/next/ffront/func_to_past.py +++ b/src/gt4py/next/ffront/func_to_past.py @@ -64,7 +64,7 @@ def func_to_past(inp: DSL_PRG) -> PRG: ) -def func_to_past_factory(cached: bool = False) -> workflow.Workflow[DSL_PRG, PRG]: +def func_to_past_factory(cached: bool = True) -> workflow.Workflow[DSL_PRG, PRG]: """ Wrap `func_to_past` in a chainable and optionally cached workflow step. diff --git a/src/gt4py/next/ffront/gtcallable.py b/src/gt4py/next/ffront/gtcallable.py index beaebb3a5a..cdfb23910e 100644 --- a/src/gt4py/next/ffront/gtcallable.py +++ b/src/gt4py/next/ffront/gtcallable.py @@ -52,6 +52,16 @@ def __gt_itir__(self) -> itir.FunctionDefinition: """ ... + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR + @abc.abstractmethod + def __gt_gtir__(self) -> itir.FunctionDefinition: + """ + Return iterator IR function definition representing the callable. + Used internally by the Program decorator to populate the function + definitions of the iterator IR. + """ + ... + # TODO(tehrengruber): For embedded execution a `__call__` method and for # "truly" embedded execution arguably also a `from_function` method is # required. Since field operators currently have a `__gt_type__` with a diff --git a/src/gt4py/next/ffront/lowering_utils.py b/src/gt4py/next/ffront/lowering_utils.py index a52581edb0..7049f70021 100644 --- a/src/gt4py/next/ffront/lowering_utils.py +++ b/src/gt4py/next/ffront/lowering_utils.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import Any, Callable, TypeVar +from typing import Any, Callable, Optional, TypeVar from gt4py.eve import utils as eve_utils from gt4py.next.ffront import type_info as ti_ffront @@ -102,7 +102,7 @@ def process_elements( process_func: Callable[..., itir.Expr], objs: itir.Expr | Iterable[itir.Expr], current_el_type: ts.TypeSpec, - with_type: bool = False, + arg_types: Optional[Iterable[ts.TypeSpec]] = None, ) -> itir.FunCall: """ Recursively applies a processing function to all primitive constituents of a tuple. @@ -113,9 +113,9 @@ def process_elements( objs: The object whose elements are to be transformed. current_el_type: A type with the same structure as the elements of `objs`. The leaf-types are not used and thus not relevant. - current_el_type: A type with the same structure as the elements of `objs`. Unless `with_type=True` - the leaf-types are not used and thus not relevant. - with_type: If True, the last argument passed to `process_func` will be its type. + arg_types: If provided, a tuple of the type of each argument is passed to `process_func` as last argument. + Note, that `arg_types` might coincide with `(current_el_type,)*len(objs)`, but not necessarily, + in case of implicit broadcasts. """ if isinstance(objs, itir.Expr): objs = (objs,) @@ -125,7 +125,7 @@ def process_elements( process_func, tuple(im.ref(let_id) for let_id in let_ids), current_el_type, - with_type=with_type, + arg_types=arg_types, ) return im.let(*(zip(let_ids, objs, strict=True)))(body) @@ -138,7 +138,7 @@ def _process_elements_impl( process_func: Callable[..., itir.Expr], _current_el_exprs: Iterable[T], current_el_type: ts.TypeSpec, - with_type: bool, + arg_types: Optional[Iterable[ts.TypeSpec]], ) -> itir.Expr: if isinstance(current_el_type, ts.TupleType): result = im.make_tuple( @@ -149,16 +149,16 @@ def _process_elements_impl( im.tuple_get(i, current_el_expr) for current_el_expr in _current_el_exprs ), current_el_type.types[i], - with_type=with_type, + arg_types=tuple(arg_t.types[i] for arg_t in arg_types) # type: ignore[attr-defined] # guaranteed by the requirement that `current_el_type` and each element of `arg_types` have the same tuple structure + if arg_types is not None + else None, ) for i in range(len(current_el_type.types)) ) ) - elif type_info.contains_local_field(current_el_type): - raise NotImplementedError("Processing fields with local dimension is not implemented.") else: - if with_type: - result = process_func(*_current_el_exprs, current_el_type) + if arg_types is not None: + result = process_func(*_current_el_exprs, arg_types) else: result = process_func(*_current_el_exprs) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 92f7327218..9355273588 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -104,6 +104,15 @@ def visit_Program(self, node: past.Program, **kwargs: Any) -> past.Program: location=node.location, ) + def visit_Slice(self, node: past.Slice, **kwargs: Any) -> past.Slice: + return past.Slice( + lower=self.visit(node.lower, **kwargs), + upper=self.visit(node.upper, **kwargs), + step=self.visit(node.step, **kwargs), + type=ts.DeferredType(constraint=None), + location=node.location, + ) + def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript: value = self.visit(node.value, **kwargs) return past.Subscript( diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 7958b7a8d3..ea4a2995e0 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -83,40 +83,47 @@ def _process_args( # TODO(tehrengruber): Previously this function was called with the actual arguments # not their type. The check using the shape here is not functional anymore and # should instead be placed in a proper location. - shapes_and_dims = [*_field_constituents_shape_and_dims(args[param_idx], param.type)] + ranges_and_dims = [*_field_constituents_range_and_dims(args[param_idx], param.type)] # check that all non-scalar like constituents have the same shape and dimension, e.g. # for `(scalar, (field1, field2))` the two fields need to have the same shape and # dimension - if shapes_and_dims: - shape, dims = shapes_and_dims[0] + if ranges_and_dims: + range_, dims = ranges_and_dims[0] if not all( - el_shape == shape and el_dims == dims for (el_shape, el_dims) in shapes_and_dims + el_range == range_ and el_dims == dims + for (el_range, el_dims) in ranges_and_dims ): raise ValueError( "Constituents of composite arguments (e.g. the elements of a" " tuple) need to have the same shape and dimensions." ) + index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) size_args.extend( - shape if shape else [ts.ScalarType(kind=ts.ScalarKind.INT32)] * len(dims) # type: ignore[arg-type] # shape is always empty + range_ if range_ else [ts.TupleType(types=[index_type, index_type])] * len(dims) # type: ignore[arg-type] # shape is always empty ) return tuple(rewritten_args), tuple(size_args), kwargs -def _field_constituents_shape_and_dims( +def _field_constituents_range_and_dims( arg: Any, # TODO(havogt): improve typing arg_type: ts.DataType, -) -> Iterator[tuple[tuple[int, ...], list[common.Dimension]]]: +) -> Iterator[tuple[tuple[tuple[int, int], ...], list[common.Dimension]]]: match arg_type: case ts.TupleType(): for el, el_type in zip(arg, arg_type.types): - yield from _field_constituents_shape_and_dims(el, el_type) + assert isinstance(el_type, ts.DataType) + yield from _field_constituents_range_and_dims(el, el_type) case ts.FieldType(): dims = type_info.extract_dims(arg_type) if isinstance(arg, ts.TypeSpec): # TODO yield (tuple(), dims) elif dims: - assert hasattr(arg, "shape") and len(arg.shape) == len(dims) - yield (arg.shape, dims) + assert ( + hasattr(arg, "domain") + and isinstance(arg.domain, common.Domain) + and len(arg.domain.dims) == len(dims) + ) + yield (tuple((r.start, r.stop) for r in arg.domain.ranges), dims) else: yield from [] # ignore 0-dim fields case ts.ScalarType(): diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a20c517cce..4bc1dfb2f8 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -9,7 +9,6 @@ from __future__ import annotations import dataclasses -import functools from typing import Any, Optional, cast import devtools @@ -19,23 +18,21 @@ from gt4py.next.ffront import ( fbuiltins, gtcallable, - lowering_utils, program_ast as past, stages as ffront_stages, transform_utils, type_specifications as ts_ffront, ) from gt4py.next.ffront.stages import AOT_PRG -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_info, type_specifications as ts -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR # FIXME[#1582](tehrengruber): This should only depend on the program not the arguments. Remove # dependency as soon as column axis can be deduced from ITIR in consumers of the CompilableProgram. -def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgram: +def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: """ Lower a PAST program definition to Iterator IR. @@ -59,7 +56,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ... column_axis=None, ... ) - >>> itir_copy = past_to_itir( + >>> itir_copy = past_to_gtir( ... toolchain.CompilableProgram(copy_program.past_stage, compile_time_args) ... ) @@ -67,7 +64,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra copy_program >>> print(type(itir_copy.data)) - + """ all_closure_vars = transform_utils._get_closure_vars_recursively(inp.data.closure_vars) offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( @@ -80,14 +77,18 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra gt_callables = transform_utils._filter_closure_vars_by_type( all_closure_vars, gtcallable.GTCallable ).values() + + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR # TODO(ricoh): The following calls to .__gt_itir__, which will use whatever - # backend is set for each of these field operators (GTCallables). Instead - # we should use the current toolchain to lower these to ITIR. This will require - # making this step aware of the toolchain it is called by (it can be part of multiple). - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + # backend is set for each of these field operators (GTCallables). Instead + # we should use the current toolchain to lower these to ITIR. This will require + # making this step aware of the toolchain it is called by (it can be part of multiple). + lowered_funcs = [] + for gt_callable in gt_callables: + lowered_funcs.append(gt_callable.__gt_gtir__()) itir_program = ProgramLowering.apply( - inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir + inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) if config.DEBUG or inp.data.debug: @@ -99,11 +100,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ) -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR -def past_to_itir_factory( - cached: bool = True, to_gtir: bool = False +def past_to_gtir_factory( + cached: bool = True, ) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: - wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) + wf = workflow.make_step(past_to_gtir) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) return wf @@ -138,8 +138,8 @@ def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension] return iter(scanops_per_axis.keys()).__next__() -def _size_arg_from_field(field_name: str, dim: int) -> str: - return f"__{field_name}_size_{dim}" +def _range_arg_from_field(field_name: str, dim: int) -> str: + return f"__{field_name}_{dim}_range" def _flatten_tuple_expr(node: past.Expr) -> list[past.Name | past.Subscript]: @@ -183,7 +183,7 @@ class ProgramLowering( ... parsed, [fieldop_def], grid_type=common.GridType.CARTESIAN ... ) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP - + >>> lowered.id # doctest: +SKIP SymbolName('program') >>> lowered.params # doctest: +SKIP @@ -191,7 +191,6 @@ class ProgramLowering( """ grid_type: common.GridType - to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons # the above doctest fails when executed using `pytest --doctest-modules`. @@ -202,11 +201,8 @@ def apply( node: past.Program, function_definitions: list[itir.FunctionDefinition], grid_type: common.GridType, - to_gtir: bool = False, # FIXME[#1582](havogt): remove after refactoring to GTIR - ) -> itir.FencilDefinition: - return cls(grid_type=grid_type, to_gtir=to_gtir).visit( - node, function_definitions=function_definitions - ) + ) -> itir.Program: + return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: """Generate symbols for each field param and dimension.""" @@ -221,13 +217,14 @@ def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: ) if len(fields_dims) > 0: # otherwise `param` has no constituent which is of `FieldType` assert all(field_dims == fields_dims[0] for field_dims in fields_dims) + index_type = ts.ScalarType( + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + ) for dim_idx in range(len(fields_dims[0])): size_params.append( itir.Sym( - id=_size_arg_from_field(param.id, dim_idx), - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), + id=_range_arg_from_field(param.id, dim_idx), + type=ts.TupleType(types=[index_type, index_type]), ) ) @@ -239,7 +236,7 @@ def visit_Program( *, function_definitions: list[itir.FunctionDefinition], **kwargs: Any, - ) -> itir.FencilDefinition | itir.Program: + ) -> itir.Program: # The ITIR does not support dynamically getting the size of a field. As # a workaround we add additional arguments to the fencil definition # containing the size of all fields. The caller of a program is (e.g. @@ -252,27 +249,17 @@ def visit_Program( params = params + self._gen_size_params_from_program(node) implicit_domain = True - if self.to_gtir: - set_ats = [self._visit_stencil_call_as_set_at(stmt, **kwargs) for stmt in node.body] - return itir.Program( - id=node.id, - function_definitions=function_definitions, - params=params, - declarations=[], - body=set_ats, - implicit_domain=implicit_domain, - ) - else: - closures = [self._visit_stencil_call_as_closure(stmt, **kwargs) for stmt in node.body] - return itir.FencilDefinition( - id=node.id, - function_definitions=function_definitions, - params=params, - closures=closures, - implicit_domain=implicit_domain, - ) + set_ats = [self._visit_field_operator_call(stmt, **kwargs) for stmt in node.body] + return itir.Program( + id=node.id, + function_definitions=function_definitions, + params=params, + declarations=[], + body=set_ats, + implicit_domain=implicit_domain, + ) - def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir.SetAt: + def _visit_field_operator_call(self, node: past.Call, **kwargs: Any) -> itir.SetAt: assert isinstance(node.kwargs["out"].type, ts.TypeSpec) assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) @@ -296,61 +283,12 @@ def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir. target=output, ) - # FIXME[#1582](havogt): remove after refactoring to GTIR - def _visit_stencil_call_as_closure(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: - assert isinstance(node.kwargs["out"].type, ts.TypeSpec) - assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) - - node_kwargs = {**node.kwargs} - domain = node_kwargs.pop("domain", None) - output, lowered_domain = self._visit_stencil_call_out_arg( - node_kwargs.pop("out"), domain, **kwargs - ) - - assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)) - - args, node_kwargs = type_info.canonicalize_arguments( - node.func.type, node.args, node_kwargs, use_signature_ordering=True - ) - - lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) - - stencil_params = [] - stencil_args: list[itir.Expr] = [] - for i, arg in enumerate([*args, *node_kwargs]): - stencil_params.append(f"__stencil_arg{i}") - if isinstance(arg.type, ts.TupleType): - # convert into tuple of iterators - stencil_args.append( - lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) - ) - else: - stencil_args.append(im.ref(f"__stencil_arg{i}")) - - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - # scan operators return an iterator of tuples, just deref directly - stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) - else: - # field operators return a tuple of iterators, deref element-wise - stencil_body = lowering_utils.process_elements( - im.deref, - im.call(node.func.id)(*stencil_args), - node.func.type.definition.returns, - ) - - return itir.StencilClosure( - domain=lowered_domain, - stencil=im.lambda_(*stencil_params)(stencil_body), - inputs=[*lowered_args, *lowered_kwargs.values()], - output=output, - location=node.location, - ) - def _visit_slice_bound( self, slice_bound: Optional[past.Constant], default_value: itir.Expr, - dim_size: itir.Expr, + start_idx: itir.Expr, + stop_idx: itir.Expr, **kwargs: Any, ) -> itir.Expr: if slice_bound is None: @@ -360,11 +298,9 @@ def _visit_slice_bound( slice_bound.type ) if slice_bound.value < 0: - lowered_bound = itir.FunCall( - fun=itir.SymRef(id="plus"), args=[dim_size, self.visit(slice_bound, **kwargs)] - ) + lowered_bound = im.plus(stop_idx, self.visit(slice_bound, **kwargs)) else: - lowered_bound = self.visit(slice_bound, **kwargs) + lowered_bound = im.plus(start_idx, self.visit(slice_bound, **kwargs)) else: raise AssertionError("Expected 'None' or 'past.Constant'.") if slice_bound: @@ -412,8 +348,9 @@ def _construct_itir_domain_arg( domain_args = [] domain_args_kind = [] for dim_i, dim in enumerate(out_dims): - # an expression for the size of a dimension - dim_size = itir.SymRef(id=_size_arg_from_field(out_field.id, dim_i)) + # an expression for the range of a dimension + dim_range = itir.SymRef(id=_range_arg_from_field(out_field.id, dim_i)) + dim_start, dim_stop = im.tuple_get(0, dim_range), im.tuple_get(1, dim_range) # bounds lower: itir.Expr upper: itir.Expr @@ -423,11 +360,15 @@ def _construct_itir_domain_arg( else: lower = self._visit_slice_bound( slices[dim_i].lower if slices else None, - im.literal("0", itir.INTEGER_INDEX_BUILTIN), - dim_size, + dim_start, + dim_start, + dim_stop, ) upper = self._visit_slice_bound( - slices[dim_i].upper if slices else None, dim_size, dim_size + slices[dim_i].upper if slices else None, + dim_stop, + dim_start, + dim_stop, ) if dim.kind == common.DimensionKind.LOCAL: diff --git a/src/gt4py/next/ffront/signature.py b/src/gt4py/next/ffront/signature.py index 9752ceaf32..4a58d56f57 100644 --- a/src/gt4py/next/ffront/signature.py +++ b/src/gt4py/next/ffront/signature.py @@ -6,20 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - # TODO(ricoh): This overlaps with `canonicalize_arguments`, solutions: # - merge the two # - extract the signature gathering functionality from canonicalize_arguments diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index bf3bee4b56..834536ff59 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -100,6 +100,7 @@ def add_content_to_fingerprint(obj: Any, hasher: xtyping.HashlibAlgorithm) -> No @add_content_to_fingerprint.register(FieldOperatorDefinition) @add_content_to_fingerprint.register(FoastOperatorDefinition) +@add_content_to_fingerprint.register(ProgramDefinition) @add_content_to_fingerprint.register(PastProgramDefinition) @add_content_to_fingerprint.register(toolchain.CompilableProgram) @add_content_to_fingerprint.register(arguments.CompileTimeArgs) @@ -121,10 +122,14 @@ def add_func_to_fingerprint(obj: types.FunctionType, hasher: xtyping.HashlibAlgo for item in sourcedef: add_content_to_fingerprint(item, hasher) + closure_vars = source_utils.get_closure_vars_from_function(obj) + for item in sorted(closure_vars.items(), key=lambda x: x[0]): + add_content_to_fingerprint(item, hasher) + @add_content_to_fingerprint.register def add_dict_to_fingerprint(obj: dict, hasher: xtyping.HashlibAlgorithm) -> None: - for key, value in obj.items(): + for key, value in sorted(obj.items()): add_content_to_fingerprint(key, hasher) add_content_to_fingerprint(value, hasher) @@ -148,4 +153,3 @@ def add_foast_located_node_to_fingerprint( ) -> None: add_content_to_fingerprint(obj.location, hasher) add_content_to_fingerprint(str(obj), hasher) - add_content_to_fingerprint(str(obj), hasher) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 8160a2c42d..80ba93e187 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -169,9 +169,11 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType -------- >>> _scan_param_promotion( ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64), + ... ts.FieldType( + ... dims=[common.Dimension("I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ), ... ) - FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) + FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) """ def _as_field(dtype: ts.TypeSpec, path: tuple[int, ...]) -> ts.FieldType: @@ -252,8 +254,8 @@ def function_signature_incompatibilities_scanop( # build a function type to leverage the already existing signature checking capabilities function_type = ts.FunctionType( pos_only_args=[], - pos_or_kw_args=promoted_params, # type: ignore[arg-type] # dict is invariant, but we don't care here. - kw_only_args=promoted_kwparams, # type: ignore[arg-type] # same as above + pos_or_kw_args=promoted_params, + kw_only_args=promoted_kwparams, returns=ts.DeferredType(constraint=None), ) diff --git a/src/gt4py/next/ffront/type_specifications.py b/src/gt4py/next/ffront/type_specifications.py index e4f6c826fe..b76a116297 100644 --- a/src/gt4py/next/ffront/type_specifications.py +++ b/src/gt4py/next/ffront/type_specifications.py @@ -6,23 +6,19 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass import gt4py.next.type_system.type_specifications as ts -from gt4py.next import common as func_common +from gt4py.next import common -@dataclass(frozen=True) class ProgramType(ts.TypeSpec, ts.CallableType): definition: ts.FunctionType -@dataclass(frozen=True) class FieldOperatorType(ts.TypeSpec, ts.CallableType): definition: ts.FunctionType -@dataclass(frozen=True) class ScanOperatorType(ts.TypeSpec, ts.CallableType): - axis: func_common.Dimension + axis: common.Dimension definition: ts.FunctionType diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 264ac2685c..8e5f7addca 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -22,6 +22,11 @@ def as_fieldop(*args): raise BackendNotSelectedError() +@builtin_dispatch +def index(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def deref(*args): raise BackendNotSelectedError() @@ -287,6 +292,11 @@ def trunc(*args): raise BackendNotSelectedError() +@builtin_dispatch +def neg(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def isfinite(*args): raise BackendNotSelectedError() @@ -332,16 +342,46 @@ def int(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() +@builtin_dispatch +def int8(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def uint8(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def int16(*args): + raise BackendNotSelectedError() + + +@builtin_dispatch +def uint16(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def int32(*args): raise BackendNotSelectedError() +@builtin_dispatch +def uint32(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def int64(*args): raise BackendNotSelectedError() +@builtin_dispatch +def uint64(*args): + raise BackendNotSelectedError() + + @builtin_dispatch def float(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() @@ -362,7 +402,8 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] raise BackendNotSelectedError() -UNARY_MATH_NUMBER_BUILTINS = {"abs"} +UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"} +UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { "sin", "cos", @@ -386,51 +427,69 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing] "trunc", } UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} -BINARY_MATH_NUMBER_BUILTINS = {"minimum", "maximum", "fmod", "power"} -TYPEBUILTINS = {"int32", "int64", "float32", "float64", "bool"} -MATH_BUILTINS = ( - UNARY_MATH_NUMBER_BUILTINS - | UNARY_MATH_FP_BUILTINS - | UNARY_MATH_FP_PREDICATE_BUILTINS - | BINARY_MATH_NUMBER_BUILTINS - | TYPEBUILTINS -) +BINARY_MATH_NUMBER_BUILTINS = { + "plus", + "minus", + "multiplies", + "divides", + "mod", + "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 + "minimum", + "maximum", + "fmod", +} +BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"} +BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} + + +#: builtin / dtype used to construct integer indices, like domain bounds +INTEGER_INDEX_BUILTIN = "int32" +INTEGER_TYPE_BUILTINS = { + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", +} +FLOATING_POINT_TYPE_BUILTINS = {"float32", "float64"} +TYPE_BUILTINS = {*INTEGER_TYPE_BUILTINS, *FLOATING_POINT_TYPE_BUILTINS, "bool"} + +ARITHMETIC_BUILTINS = { + *UNARY_MATH_NUMBER_BUILTINS, + *UNARY_LOGICAL_BUILTINS, + *UNARY_MATH_FP_BUILTINS, + *UNARY_MATH_FP_PREDICATE_BUILTINS, + *BINARY_MATH_NUMBER_BUILTINS, + "power", + *BINARY_MATH_COMPARISON_BUILTINS, + *BINARY_LOGICAL_BUILTINS, +} + BUILTINS = { - "deref", + "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) "can_deref", + "cartesian_domain", + "cast_", + "deref", + "if_", + "index", # `index(dim)` creates a dim-field that has the current index at each point "shift", - "neighbors", "list_get", + "lift", "make_const_list", + "make_tuple", "map_", - "lift", + "named_range", + "neighbors", "reduce", - "plus", - "minus", - "multiplies", - "divides", - "floordiv", - "mod", - "make_tuple", - "tuple_get", - "if_", - "cast_", - "greater", - "less", - "less_equal", - "greater_equal", - "eq", - "not_eq", - "not_", - "and_", - "or_", - "xor_", "scan", - "cartesian_domain", + "tuple_get", "unstructured_domain", - "named_range", - "as_fieldop", - *MATH_BUILTINS, + *ARITHMETIC_BUILTINS, + *TYPE_BUILTINS, } __all__ = [*BUILTINS] diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index afe0cec402..da0516d26b 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -15,7 +15,7 @@ import dataclasses import itertools import math -import sys +import operator import warnings import numpy as np @@ -85,84 +85,132 @@ TupleAxis: TypeAlias = type[None] Axis: TypeAlias = Union[FieldAxis, TupleAxis] Scalar: TypeAlias = ( - SupportsInt | SupportsFloat | np.int32 | np.int64 | np.float32 | np.float64 | np.bool_ + SupportsInt + | SupportsFloat + | np.int8 + | np.uint8 + | np.int16 + | np.uint16 + | np.int32 + | np.uint32 + | np.int64 + | np.uint64 + | np.float32 + | np.float64 + | np.bool_ ) class SparseTag(Tag): ... -class NeighborTableOffsetProvider: +@xtyping.deprecated("Use a 'Connectivity' instead.") +def NeighborTableOffsetProvider( + table: core_defs.NDArrayObject, + origin_axis: common.Dimension, + neighbor_axis: common.Dimension, + max_neighbors: int, + has_skip_values=True, +) -> common.Connectivity: + return common._connectivity( + table, + codomain=neighbor_axis, + domain={ + origin_axis: table.shape[0], + common.Dimension( + value="_DummyLocalDim", kind=common.DimensionKind.LOCAL + ): max_neighbors, + }, + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + ) + + +# TODO(havogt): complete implementation and make available for fieldview embedded +@dataclasses.dataclass(frozen=True) +class StridedConnectivityField(common.Connectivity): + domain_dims: tuple[common.Dimension, common.Dimension] + codomain_dim: common.Dimension + _max_neighbors: int + def __init__( self, - table: core_defs.NDArrayObject, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, + domain_dims: Sequence[common.Dimension], + codomain_dim: common.Dimension, max_neighbors: int, - has_skip_values=True, - ) -> None: - self.table = table - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - assert not hasattr(table, "shape") or table.shape[1] == max_neighbors - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = table.dtype - - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - res = self.table[(primary, neighbor_idx)] - assert common.is_int_index(res) - return res + ): + object.__setattr__(self, "domain_dims", tuple(domain_dims)) + object.__setattr__(self, "codomain_dim", codomain_dim) + object.__setattr__(self, "_max_neighbors", max_neighbors) - if dace: - # Extension of NeighborTableOffsetProvider adding SDFGConvertible support in GT4Py Programs - def _dace_data_ptr(self) -> int: - obj = self.table - if dace.dtypes.is_array(obj): - if hasattr(obj, "__array_interface__"): - return obj.__array_interface__["data"][0] - if hasattr(obj, "__cuda_array_interface__"): - return obj.__cuda_array_interface__["data"][0] - raise ValueError("Unsupported data container.") - - def _dace_descriptor(self) -> dace.data.Data: - return dace.data.create_datadescriptor(self.table) - else: + @property + def __gt_origin__(self) -> xtyping.Never: + raise NotImplementedError + + def __gt_type__(self) -> common.NeighborConnectivityType: + return common.NeighborConnectivityType( + domain=self.domain_dims, + codomain=self.codomain_dim, + max_neighbors=self._max_neighbors, + skip_value=self.skip_value, + dtype=self.dtype, + ) - def _dace_data_ptr(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "data_ptr is only supported when the 'dace' module is available." - ) + @property + def domain(self) -> common.Domain: + return common.Domain( + dims=self.domain_dims, + ranges=(common.UnitRange.infinite(), common.unit_range(self._max_neighbors)), + ) - def _dace_descriptor(self) -> NoReturn: # type: ignore[misc] - raise NotImplementedError( - "__descriptor__ is only supported when the 'dace' module is available." - ) + @property + def codomain(self) -> common.Dimension: + return self.codomain_dim - data_ptr = _dace_data_ptr - __descriptor__ = _dace_descriptor + @property + def dtype(self) -> core_defs.DType[core_defs.IntegralScalar]: + return core_defs.Int32DType() # type: ignore[return-value] + @property + def ndarray(self) -> core_defs.NDArrayObject: + raise NotImplementedError -class StridedNeighborOffsetProvider: - def __init__( + def asnumpy(self) -> np.ndarray: + raise NotImplementedError + + def premap(self, index_field: common.Connectivity | fbuiltins.FieldOffset) -> common.Field: + raise NotImplementedError + + def restrict( # type: ignore[override] self, - origin_axis: common.Dimension, - neighbor_axis: common.Dimension, - max_neighbors: int, - has_skip_values=True, - ) -> None: - self.origin_axis = origin_axis - self.neighbor_axis = neighbor_axis - self.max_neighbors = max_neighbors - self.has_skip_values = has_skip_values - self.index_type = int + item: common.AnyIndexSpec, + ) -> common.Field: + if not isinstance(item, tuple) or (isinstance(item, tuple) and not len(item) == 2): + raise NotImplementedError() # TODO(havogt): add proper slicing + index = item[0] * self._max_neighbors + item[1] # type: ignore[operator, call-overload] + return ConstantField(index) + + def as_scalar(self) -> xtyping.Never: + raise NotImplementedError() + + def __call__( + self, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, + ) -> common.Field: + raise NotImplementedError() - def mapped_index( - self, primary: common.IntIndex, neighbor_idx: common.IntIndex - ) -> common.IntIndex: - return primary * self.max_neighbors + neighbor_idx + __getitem__ = restrict # type: ignore[assignment] + + def inverse_image( + self, image_range: common.UnitRange | common.NamedRange + ) -> Sequence[common.NamedRange]: + raise NotImplementedError + + @property + def skip_value( + self, + ) -> None: + return None # Offsets @@ -186,6 +234,12 @@ def mapped_index( NamedFieldIndices: TypeAlias = Mapping[Tag, FieldIndex | SparsePositionEntry] +# Magic local dimension for the result of a `make_const_list`. +# A clean implementation will probably involve to tag the `make_const_list` +# with the neighborhood it is meant to be used with. +_CONST_DIM = common.Dimension(value="_CONST_DIM", kind=common.DimensionKind.LOCAL) + + @runtime_checkable class ItIterator(Protocol): """ @@ -227,6 +281,12 @@ class MutableLocatedField(LocatedField, Protocol): def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ... +def _numpy_structured_value_to_tuples(value: Any) -> Any: + if _elem_dtype(value).names is not None: + return tuple(_numpy_structured_value_to_tuples(v) for v in value) + return value + + class Column(np.lib.mixins.NDArrayOperatorsMixin): """Represents a column when executed in column mode (`column_axis != None`). @@ -247,6 +307,10 @@ def dtype(self) -> np.dtype: # not directly dtype of `self.data` as that might be a structured type containing `None` return _elem_dtype(self.data[self.kstart]) + def __gt_type__(self) -> ts.TypeSpec: + elem = self.data[self.kstart] + return type_translation.from_value(_numpy_structured_value_to_tuples(elem)) + def __getitem__(self, i: int) -> Any: result = self.data[i - self.kstart] # numpy type @@ -327,6 +391,13 @@ def not_(a): return not a +@builtins.neg.register(EMBEDDED) +def neg(a): + if isinstance(a, Column): + return np.negative(a) + return np.negative(a) + + @builtins.gamma.register(EMBEDDED) def gamma(a): gamma_ = np.vectorize(math.gamma) @@ -337,27 +408,6 @@ def gamma(a): return res.item() -@builtins.and_.register(EMBEDDED) -def and_(a, b): - if isinstance(a, Column): - return np.logical_and(a, b) - return a and b - - -@builtins.or_.register(EMBEDDED) -def or_(a, b): - if isinstance(a, Column): - return np.logical_or(a, b) - return a or b - - -@builtins.xor_.register(EMBEDDED) -def xor_(a, b): - if isinstance(a, Column): - return np.logical_xor(a, b) - return a ^ b - - @builtins.tuple_get.register(EMBEDDED) def tuple_get(i, tup): if isinstance(tup, Column): @@ -445,66 +495,6 @@ def named_range(tag: Tag | common.Dimension, start: int, end: int) -> NamedRange return (tag, range(start, end)) -@builtins.minus.register(EMBEDDED) -def minus(first, second): - return first - second - - -@builtins.plus.register(EMBEDDED) -def plus(first, second): - return first + second - - -@builtins.multiplies.register(EMBEDDED) -def multiplies(first, second): - return first * second - - -@builtins.divides.register(EMBEDDED) -def divides(first, second): - return first / second - - -@builtins.floordiv.register(EMBEDDED) -def floordiv(first, second): - return first // second - - -@builtins.mod.register(EMBEDDED) -def mod(first, second): - return first % second - - -@builtins.eq.register(EMBEDDED) -def eq(first, second): - return first == second - - -@builtins.greater.register(EMBEDDED) -def greater(first, second): - return first > second - - -@builtins.less.register(EMBEDDED) -def less(first, second): - return first < second - - -@builtins.less_equal.register(EMBEDDED) -def less_equal(first, second): - return first <= second - - -@builtins.greater_equal.register(EMBEDDED) -def greater_equal(first, second): - return first >= second - - -@builtins.not_eq.register(EMBEDDED) -def not_eq(first, second): - return first != second - - CompositeOfScalarOrField: TypeAlias = Scalar | common.Field | tuple["CompositeOfScalarOrField", ...] @@ -533,11 +523,32 @@ def promote_scalars(val: CompositeOfScalarOrField): ) -for math_builtin_name in builtins.MATH_BUILTINS: - python_builtins = {"int": int, "float": float, "bool": bool, "str": str} +for math_builtin_name in builtins.ARITHMETIC_BUILTINS | builtins.TYPE_BUILTINS: + python_builtins: dict[str, Callable] = { + "int": int, + "float": float, + "bool": bool, + "str": str, + "plus": operator.add, + "minus": operator.sub, + "multiplies": operator.mul, + "divides": operator.truediv, + "mod": operator.mod, + "floordiv": operator.floordiv, + "eq": operator.eq, + "less": operator.lt, + "greater": operator.gt, + "greater_equal": operator.ge, + "less_equal": operator.le, + "not_eq": operator.ne, + "and_": operator.and_, + "or_": operator.or_, + "xor_": operator.xor, + "neg": operator.neg, + } decorator = getattr(builtins, math_builtin_name).register(EMBEDDED) impl: Callable - if math_builtin_name == "gamma": + if math_builtin_name in ["gamma", "not_"]: continue # treated explicitly elif math_builtin_name in python_builtins: # TODO: Should potentially use numpy fixed size types to be consistent @@ -576,17 +587,21 @@ def execute_shift( for i, p in reversed(list(enumerate(new_entry))): # first shift applies to the last sparse dimensions of that axis type if p is None: - offset_implementation = offset_provider[tag] - assert isinstance(offset_implementation, common.Connectivity) - cur_index = pos[offset_implementation.origin_axis.value] - assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ - None, - common._DEFAULT_SKIP_VALUE, - ]: - return None - - new_entry[i] = index + if tag == _CONST_DIM.value: + new_entry[i] = 0 + else: + offset_implementation = offset_provider[tag] + assert common.is_neighbor_connectivity(offset_implementation) + source_dim = offset_implementation.__gt_type__().source_dim + cur_index = pos[source_dim.value] + assert common.is_int_index(cur_index) + if offset_implementation[cur_index, index].as_scalar() in [ + None, + common._DEFAULT_SKIP_VALUE, + ]: + return None + + new_entry[i] = index break # the assertions above confirm pos is incomplete casting here to avoid duplicating work in a type guard return cast(IncompletePosition, pos) | {tag: new_entry} @@ -600,22 +615,22 @@ def execute_shift( else: raise AssertionError() return new_pos - else: - assert isinstance(offset_implementation, common.Connectivity) - assert offset_implementation.origin_axis.value in pos + elif common.is_neighbor_connectivity(offset_implementation): + source_dim = offset_implementation.__gt_type__().source_dim + assert source_dim.value in pos new_pos = pos.copy() - new_pos.pop(offset_implementation.origin_axis.value) - cur_index = pos[offset_implementation.origin_axis.value] + new_pos.pop(source_dim.value) + cur_index = pos[source_dim.value] assert common.is_int_index(cur_index) - if offset_implementation.mapped_index(cur_index, index) in [ + if offset_implementation[cur_index, index].as_scalar() in [ None, common._DEFAULT_SKIP_VALUE, ]: return None else: - new_index = offset_implementation.mapped_index(cur_index, index) + new_index = offset_implementation[cur_index, index].as_scalar() assert new_index is not None - new_pos[offset_implementation.neighbor_axis.value] = int(new_index) + new_pos[offset_implementation.codomain.value] = int(new_index) return new_pos @@ -661,7 +676,7 @@ def __float__(self): return np.nan def __int__(self): - return sys.maxsize + return np.iinfo(np.int32).max def __repr__(self): return "_UNDEFINED" @@ -920,9 +935,9 @@ def deref(self) -> Any: return _make_tuple(self.field, position, column_axis=self.column_axis) -def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[Tag]: +def _get_sparse_dimensions(axes: Sequence[common.Dimension]) -> list[common.Dimension]: return [ - axis.value + axis for axis in axes if isinstance(axis, common.Dimension) and axis.kind == common.DimensionKind.LOCAL ] @@ -945,7 +960,7 @@ def make_in_iterator( new_pos: Position = pos.copy() for sparse_dim in set(sparse_dimensions): init = [None] * sparse_dimensions.count(sparse_dim) - new_pos[sparse_dim] = init # type: ignore[assignment] # looks like mypy is confused + new_pos[sparse_dim.value] = init # type: ignore[assignment] # looks like mypy is confused if column_dimension is not None: column_range = embedded_context.closure_column_range.get().unit_range # if we deal with column stencil the column position is just an offset by which the whole column needs to be shifted @@ -956,7 +971,7 @@ def make_in_iterator( ) if len(sparse_dimensions) >= 1: if len(sparse_dimensions) == 1: - return SparseListIterator(it, sparse_dimensions[0]) + return SparseListIterator(it, sparse_dimensions[0].value) else: raise NotImplementedError( f"More than one local dimension is currently not supported, got {sparse_dimensions}." @@ -1004,7 +1019,17 @@ def field_getitem(self, named_indices: NamedFieldIndices) -> Any: def field_setitem(self, named_indices: NamedFieldIndices, value: Any): if isinstance(self._ndarrayfield, common.MutableField): - self._ndarrayfield[self._translate_named_indices(named_indices)] = value + if isinstance(value, _List): + for i, v in enumerate(value): # type:ignore[var-annotated, arg-type] + self._ndarrayfield[ + self._translate_named_indices({**named_indices, value.offset.value: i}) # type: ignore[dict-item] + ] = v + elif isinstance(value, _ConstList): + self._ndarrayfield[ + self._translate_named_indices({**named_indices, _CONST_DIM.value: 0}) + ] = value.value + else: + self._ndarrayfield[self._translate_named_indices(named_indices)] = value else: raise RuntimeError("Assigment into a non-mutable Field is not allowed.") @@ -1166,20 +1191,22 @@ def as_scalar(self) -> core_defs.IntegralScalar: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() def restrict(self, item: common.AnyIndexSpec) -> Self: if isinstance(item, Sequence) and all(isinstance(e, common.NamedIndex) for e in item): + assert len(item) == 1 assert isinstance(item[0], common.NamedIndex) # for mypy errors on multiple lines below d, r = item[0] assert d == self._dimension assert isinstance(r, core_defs.INTEGRAL_TYPES) + # TODO(tehrengruber): Use a regular zero dimensional field instead. return self.__class__(self._dimension, r) - # TODO set a domain... + # TODO: set a domain... raise NotImplementedError() __call__ = premap @@ -1290,8 +1317,8 @@ def asnumpy(self) -> np.ndarray: def premap( self, - index_field: common.ConnectivityField | fbuiltins.FieldOffset, - *args: common.ConnectivityField | fbuiltins.FieldOffset, + index_field: common.Connectivity | fbuiltins.FieldOffset, + *args: common.Connectivity | fbuiltins.FieldOffset, ) -> common.Field: # TODO can be implemented by constructing and ndarray (but do we know of which kind?) raise NotImplementedError() @@ -1383,7 +1410,25 @@ def impl(it: ItIterator) -> ItIterator: DT = TypeVar("DT") -class _List(tuple, Generic[DT]): ... +@dataclasses.dataclass(frozen=True) +class _List(Generic[DT]): + values: tuple[DT, ...] + offset: runtime.Offset + + def __getitem__(self, i: int): + return self.values[i] + + def __gt_type__(self) -> ts.ListType: + offset_tag = self.offset.value + assert isinstance(offset_tag, str) + element_type = type_translation.from_value(self.values[0]) + assert isinstance(element_type, ts.DataType) + offset_provider = embedded_context.offset_provider.get() + assert offset_provider is not None + connectivity = offset_provider[offset_tag] + assert common.is_neighbor_connectivity(connectivity) + local_dim = connectivity.__gt_type__().neighbor_dim + return ts.ListType(element_type=element_type, offset_type=local_dim) @dataclasses.dataclass(frozen=True) @@ -1393,6 +1438,14 @@ class _ConstList(Generic[DT]): def __getitem__(self, _): return self.value + def __gt_type__(self) -> ts.ListType: + element_type = type_translation.from_value(self.value) + assert isinstance(element_type, ts.DataType) + return ts.ListType( + element_type=element_type, + offset_type=_CONST_DIM, + ) + @builtins.neighbors.register(EMBEDDED) def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: @@ -1401,11 +1454,14 @@ def neighbors(offset: runtime.Offset, it: ItIterator) -> _List: offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[offset_str] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( - shifted.deref() - for i in range(connectivity.max_neighbors) - if (shifted := it.shift(offset_str, i)).can_deref() + values=tuple( + shifted.deref() + for i in range(connectivity.__gt_type__().max_neighbors) + if (shifted := it.shift(offset_str, i)).can_deref() + ), + offset=offset, ) @@ -1414,10 +1470,23 @@ def list_get(i, lst: _List[Optional[DT]]) -> Optional[DT]: return lst[i] +def _get_offset(*lists: _List | _ConstList) -> Optional[runtime.Offset]: + offsets = set((lst.offset for lst in lists if hasattr(lst, "offset"))) + if len(offsets) == 0: + return None + if len(offsets) == 1: + return offsets.pop() + raise AssertionError("All lists must have the same offset.") + + @builtins.map_.register(EMBEDDED) def map_(op): def impl_(*lists): - return _List(map(lambda x: op(*x), zip(*lists))) + offset = _get_offset(*lists) + if offset is None: + return _ConstList(value=op(*[lst.value for lst in lists])) + else: + return _List(values=tuple(map(lambda x: op(*x), zip(*lists))), offset=offset) return impl_ @@ -1438,7 +1507,7 @@ def sten(*lists): break # we can check a single argument for length, # because all arguments share the same pattern - n = len(lst) + n = len(lst.values) res = init for i in range(n): res = fun(res, *(lst[i] for lst in lists)) @@ -1454,14 +1523,23 @@ class SparseListIterator: offsets: Sequence[OffsetPart] = dataclasses.field(default_factory=list, kw_only=True) def deref(self) -> Any: + if self.list_offset == _CONST_DIM.value: + return _ConstList( + value=self.it.shift(*self.offsets, SparseTag(self.list_offset), 0).deref() + ) offset_provider = embedded_context.offset_provider.get() assert offset_provider is not None connectivity = offset_provider[self.list_offset] - assert isinstance(connectivity, common.Connectivity) + assert common.is_neighbor_connectivity(connectivity) return _List( - shifted.deref() - for i in range(connectivity.max_neighbors) - if (shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i)).can_deref() + values=tuple( + shifted.deref() + for i in range(connectivity.__gt_type__().max_neighbors) + if ( + shifted := self.it.shift(*self.offsets, SparseTag(self.list_offset), i) + ).can_deref() + ), + offset=runtime.Offset(value=self.list_offset), ) def can_deref(self) -> bool: @@ -1586,13 +1664,15 @@ def impl(*iters: ItIterator): return impl -def _dimension_to_tag(domain: Domain) -> dict[Tag, range]: - return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()} +def _dimension_to_tag( + domain: runtime.CartesianDomain | runtime.UnstructuredDomain, +) -> dict[Tag, range]: + return {k.value: v for k, v in domain.items()} -def _validate_domain(domain: Domain, offset_provider: OffsetProvider) -> None: +def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None: if isinstance(domain, runtime.CartesianDomain): - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()): + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()): raise RuntimeError( "Got a 'CartesianDomain', but found a 'Connectivity' in 'offset_provider', expected 'UnstructuredDomain'." ) @@ -1654,16 +1734,6 @@ def _extract_column_range(domain) -> common.NamedRange | eve.NothingType: return eve.NOTHING -def _structured_dtype_to_typespec(structured_dtype: np.dtype) -> ts.ScalarType | ts.TupleType: - if structured_dtype.names is None: - return type_translation.from_dtype(core_defs.dtype(structured_dtype)) - return ts.TupleType( - types=[ - _structured_dtype_to_typespec(structured_dtype[name]) for name in structured_dtype.names - ] - ) - - def _get_output_type( fun: Callable, domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, @@ -1682,8 +1752,29 @@ def _get_output_type( with embedded_context.new_context(closure_column_range=col_range) as ctx: single_pos_result = ctx.run(_compute_at_position, fun, args, pos_in_domain, col_dim) assert single_pos_result is not _UNDEFINED, "Stencil contains an Out-Of-Bound access." - dtype = _elem_dtype(single_pos_result) - return _structured_dtype_to_typespec(dtype) + return type_translation.from_value(single_pos_result) + + +def _fieldspec_list_to_value( + domain: common.Domain, type_: ts.TypeSpec +) -> tuple[common.Domain, ts.TypeSpec]: + """Translate the list element type into the domain.""" + if isinstance(type_, ts.ListType): + if type_.offset_type == _CONST_DIM: + return domain.insert( + len(domain), common.named_range((_CONST_DIM, 1)) + ), type_.element_type + else: + offset_provider = embedded_context.offset_provider.get() + offset_type = type_.offset_type + assert isinstance(offset_type, common.Dimension) + connectivity = offset_provider[offset_type.value] + assert common.is_neighbor_connectivity(connectivity) + return domain.insert( + len(domain), + common.named_range((offset_type, connectivity.__gt_type__().max_neighbors)), + ), type_.element_type + return domain, type_ @builtins.as_fieldop.register(EMBEDDED) @@ -1691,26 +1782,32 @@ def as_fieldop(fun: Callable, domain: runtime.CartesianDomain | runtime.Unstruct def impl(*args): xp = field_utils.get_array_ns(*args) type_ = _get_output_type(fun, domain, [promote_scalars(arg) for arg in args]) - out = field_utils.field_from_typespec(type_, common.domain(domain), xp) + + new_domain, type_ = _fieldspec_list_to_value(common.domain(domain), type_) + out = field_utils.field_from_typespec(type_, new_domain, xp) # TODO(havogt): after updating all tests to use the new program, # we should get rid of closure and move the implementation to this function - closure(_dimension_to_tag(domain), fun, out, list(args)) + closure(domain, fun, out, list(args)) return out return impl -@runtime.closure.register(EMBEDDED) +@builtins.index.register(EMBEDDED) +def index(axis: common.Dimension) -> common.Field: + return IndexField(axis) + + def closure( - domain_: Domain, + domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, sten: Callable[..., Any], out, #: MutableLocatedField, ins: list[common.Field | Scalar | tuple[common.Field | Scalar | tuple, ...]], ) -> None: assert embedded_context.within_valid_context() offset_provider = embedded_context.offset_provider.get() - _validate_domain(domain_, offset_provider) + _validate_domain(domain_, common.offset_provider_to_type(offset_provider)) domain: dict[Tag, range] = _dimension_to_tag(domain_) if not (isinstance(out, common.Field) or is_tuple_of_field(out)): raise TypeError("'Out' needs to be a located field.") diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index b2a549501f..ea5cf84d86 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -9,20 +9,17 @@ from typing import ClassVar, List, Optional, Union import gt4py.eve as eve -from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve import Coerced, SymbolName, SymbolRef from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable from gt4py.next import common +from gt4py.next.iterator.builtins import BUILTINS from gt4py.next.type_system import type_specifications as ts DimensionKind = common.DimensionKind -# TODO(havogt): -# After completion of refactoring to GTIR, FencilDefinition and StencilClosure should be removed everywhere. -# During transition, we lower to FencilDefinitions and apply a transformation to GTIR-style afterwards. - @noninstantiable class Node(eve.Node): @@ -37,10 +34,14 @@ def __str__(self) -> str: return pformat(self) def __hash__(self) -> int: - return hash(type(self)) ^ hash( - tuple( - hash(tuple(v)) if isinstance(v, list) else hash(v) - for v in self.iter_children_values() + return hash( + ( + type(self), + *( + tuple(v) if isinstance(v, list) else v + for (k, v) in self.iter_children_items() + if k not in ["location", "type"] + ), ) ) @@ -93,114 +94,6 @@ class FunctionDefinition(Node, SymbolTableTrait): expr: Expr -class StencilClosure(Node): - domain: FunCall - stencil: Expr - output: Union[SymRef, FunCall] - inputs: List[SymRef] - - @datamodels.validator("output") - def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): - raise ValueError("Only FunCall to 'make_tuple' allowed.") - - -UNARY_MATH_NUMBER_BUILTINS = {"abs"} -UNARY_LOGICAL_BUILTINS = {"not_"} -UNARY_MATH_FP_BUILTINS = { - "sin", - "cos", - "tan", - "arcsin", - "arccos", - "arctan", - "sinh", - "cosh", - "tanh", - "arcsinh", - "arccosh", - "arctanh", - "sqrt", - "exp", - "log", - "gamma", - "cbrt", - "floor", - "ceil", - "trunc", -} -UNARY_MATH_FP_PREDICATE_BUILTINS = {"isfinite", "isinf", "isnan"} -BINARY_MATH_NUMBER_BUILTINS = { - "minimum", - "maximum", - "fmod", - "plus", - "minus", - "multiplies", - "divides", - "mod", - "floordiv", # TODO see https://github.com/GridTools/gt4py/issues/1136 -} -BINARY_MATH_COMPARISON_BUILTINS = {"eq", "less", "greater", "greater_equal", "less_equal", "not_eq"} -BINARY_LOGICAL_BUILTINS = {"and_", "or_", "xor_"} - -ARITHMETIC_BUILTINS = { - *UNARY_MATH_NUMBER_BUILTINS, - *UNARY_LOGICAL_BUILTINS, - *UNARY_MATH_FP_BUILTINS, - *UNARY_MATH_FP_PREDICATE_BUILTINS, - *BINARY_MATH_NUMBER_BUILTINS, - "power", - *BINARY_MATH_COMPARISON_BUILTINS, - *BINARY_LOGICAL_BUILTINS, -} - -#: builtin / dtype used to construct integer indices, like domain bounds -INTEGER_INDEX_BUILTIN = "int32" -INTEGER_BUILTINS = {"int32", "int64"} -FLOATING_POINT_BUILTINS = {"float32", "float64"} -TYPEBUILTINS = {*INTEGER_BUILTINS, *FLOATING_POINT_BUILTINS, "bool"} - -BUILTINS = { - "tuple_get", - "cast_", - "cartesian_domain", - "unstructured_domain", - "make_tuple", - "shift", - "neighbors", - "named_range", - "list_get", - "map_", - "make_const_list", - "lift", - "reduce", - "deref", - "can_deref", - "scan", - "if_", - *ARITHMETIC_BUILTINS, - *TYPEBUILTINS, -} - -# only used in `Program`` not `FencilDefinition` -# TODO(havogt): restructure after refactoring to GTIR -GTIR_BUILTINS = { - *BUILTINS, - "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) -} - - -class FencilDefinition(Node, ValidatedSymbolTableTrait): - id: Coerced[SymbolName] - function_definitions: List[FunctionDefinition] - params: List[Sym] - closures: List[StencilClosure] - implicit_domain: bool = False - - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] - - class Stmt(Node): ... @@ -230,7 +123,9 @@ class Program(Node, ValidatedSymbolTableTrait): body: List[Stmt] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in GTIR_BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + Sym(id=name) for name in sorted(BUILTINS) + ] # sorted for serialization stability # TODO(fthaler): just use hashable types in nodes (tuples instead of lists) @@ -244,8 +139,6 @@ class Program(Node, ValidatedSymbolTableTrait): Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign] FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign] FunctionDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] -StencilClosure.__hash__ = Node.__hash__ # type: ignore[method-assign] -FencilDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py index 4aea7ef149..c16b9f2b48 100644 --- a/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py +++ b/src/gt4py/next/iterator/ir_utils/common_pattern_matcher.py @@ -7,9 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause from collections.abc import Iterable -from typing import TypeGuard +from typing import Any, TypeGuard from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: @@ -22,6 +23,16 @@ def is_applied_lift(arg: itir.Node) -> TypeGuard[itir.FunCall]: ) +def is_applied_map(arg: itir.Node) -> TypeGuard[itir.FunCall]: + """Match expressions of the form `map(λ(...) → ...)(...)`.""" + return ( + isinstance(arg, itir.FunCall) + and isinstance(arg.fun, itir.FunCall) + and isinstance(arg.fun.fun, itir.SymRef) + and arg.fun.fun.id == "map_" + ) + + def is_applied_reduce(arg: itir.Node) -> TypeGuard[itir.FunCall]: """Match expressions of the form `reduce(λ(...) → ...)(...)`.""" return ( @@ -52,12 +63,16 @@ def is_let(node: itir.Node) -> TypeGuard[itir.FunCall]: return isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda) -def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: +def is_call_to(node: Any, fun: str | Iterable[str]) -> TypeGuard[itir.FunCall]: """ Match call expression to a given function. + If the `node` argument is not an `itir.Node` the function does not error, but just returns + `False`. This is useful in visitors, where sometimes we pass a list of nodes or a leaf + attribute which can be anything. + >>> from gt4py.next.iterator.ir_utils import ir_makers as im - >>> node = im.call("plus")(1, 2) + >>> node = im.plus(1, 2) >>> is_call_to(node, "plus") True >>> is_call_to(node, "minus") @@ -65,6 +80,7 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC >>> is_call_to(node, ("plus", "minus")) True """ + assert not isinstance(fun, itir.Node) # to avoid accidentally passing the fun as first argument if isinstance(fun, (list, tuple, set, Iterable)) and not isinstance(fun, str): return any((is_call_to(node, f) for f in fun)) return ( @@ -74,3 +90,28 @@ def is_call_to(node: itir.Node, fun: str | Iterable[str]) -> TypeGuard[itir.FunC def is_ref_to(node, ref: str): return isinstance(node, itir.SymRef) and node.id == ref + + +def is_identity_as_fieldop(node: itir.Expr): + """ + Match field operators implementing element-wise copy of a field argument, + that is expressions of the form `as_fieldop(stencil)(*args)` + + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> node = im.as_fieldop(im.lambda_("__arg0")(im.deref("__arg0")))("a") + >>> is_identity_as_fieldop(node) + True + >>> node = im.as_fieldop("deref")("a") + >>> is_identity_as_fieldop(node) + False + """ + if not is_applied_as_fieldop(node): + return False + stencil = node.fun.args[0] # type: ignore[attr-defined] + if ( + isinstance(stencil, itir.Lambda) + and len(stencil.params) == 1 + and stencil == im.lambda_(stencil.params[0])(im.deref(stencil.params[0].id)) + ): + return True + return False diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index 8eec405136..17df4f2ec5 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -10,33 +10,32 @@ import dataclasses import functools -from typing import Any, Literal, Mapping +from typing import Any, Literal, Mapping, Optional -import gt4py.next as gtx from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: """ Extract horizontal domain sizes from an `offset_provider`. - Considers the shape of the neighbor table to get the size of each `origin_axis` and the maximum - value inside the neighbor table to get the size of each `neighbor_axis`. + Considers the shape of the neighbor table to get the size of each `source_dim` and the maximum + value inside the neighbor table to get the size of each `codomain`. """ sizes = dict[str, int]() for provider in offset_provider.values(): - if isinstance(provider, gtx.NeighborTableOffsetProvider): - assert provider.origin_axis.kind == gtx.DimensionKind.HORIZONTAL - assert provider.neighbor_axis.kind == gtx.DimensionKind.HORIZONTAL - sizes[provider.origin_axis.value] = max( - sizes.get(provider.origin_axis.value, 0), provider.table.shape[0] + if common.is_neighbor_connectivity(provider): + conn_type = provider.__gt_type__() + sizes[conn_type.source_dim.value] = max( + sizes.get(conn_type.source_dim.value, 0), provider.ndarray.shape[0] ) - sizes[provider.neighbor_axis.value] = max( - sizes.get(provider.neighbor_axis.value, 0), - provider.table.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject + sizes[conn_type.codomain.value] = max( + sizes.get(conn_type.codomain.value, 0), + provider.ndarray.max() + 1, # type: ignore[attr-defined] # TODO(havogt): improve typing for NDArrayObject ) return sizes @@ -80,7 +79,7 @@ def from_expr(cls, node: itir.Node) -> SymbolicDomain: return cls(node.fun.id, ranges) # type: ignore[attr-defined] # ensure by assert above def as_expr(self) -> itir.FunCall: - converted_ranges: dict[common.Dimension | str, tuple[itir.Expr, itir.Expr]] = { + converted_ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]] = { key: (value.start, value.stop) for key, value in self.ranges.items() } return im.domain(self.grid_type, converted_ranges) @@ -93,6 +92,9 @@ def translate( ..., ], offset_provider: common.OffsetProvider, + #: A dictionary mapping axes names to their length. See + #: func:`gt4py.next.iterator.transforms.infer_domain.infer_expr` for more details. + symbolic_domain_sizes: Optional[dict[str, str]] = None, ) -> SymbolicDomain: dims = list(self.ranges.keys()) new_ranges = {dim: self.ranges[dim] for dim in dims} @@ -111,7 +113,7 @@ def translate( new_ranges[current_dim] = SymbolicRange.translate( self.ranges[current_dim], val.value ) - elif isinstance(nbt_provider, common.Connectivity): + elif common.is_neighbor_connectivity(nbt_provider): # unstructured shift assert ( isinstance(val, itir.OffsetLiteral) and isinstance(val.value, int) @@ -119,18 +121,24 @@ def translate( trace_shifts.Sentinel.ALL_NEIGHBORS, trace_shifts.Sentinel.VALUE, ] - # note: ugly but cheap re-computation, but should disappear - horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - - old_dim = nbt_provider.origin_axis - new_dim = nbt_provider.neighbor_axis + horizontal_sizes: dict[str, itir.Expr] + if symbolic_domain_sizes is not None: + horizontal_sizes = {k: im.ref(v) for k, v in symbolic_domain_sizes.items()} + else: + # note: ugly but cheap re-computation, but should disappear + horizontal_sizes = { + k: im.literal(str(v), builtins.INTEGER_INDEX_BUILTIN) + for k, v in _max_domain_sizes_by_location_type(offset_provider).items() + } + + old_dim = nbt_provider.__gt_type__().source_dim + new_dim = nbt_provider.__gt_type__().codomain assert new_dim not in new_ranges or old_dim == new_dim - # TODO(tehrengruber): Do we need symbolic sizes, e.g., for ICON? new_range = SymbolicRange( - im.literal("0", itir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_dim.value]), itir.INTEGER_INDEX_BUILTIN), + im.literal("0", builtins.INTEGER_INDEX_BUILTIN), + horizontal_sizes[new_dim.value], ) new_ranges = dict( (dim, range_) if dim != old_dim else (new_dim, new_range) @@ -140,7 +148,9 @@ def translate( raise AssertionError() return SymbolicDomain(self.grid_type, new_ranges) elif len(shift) > 2: - return self.translate(shift[0:2], offset_provider).translate(shift[2:], offset_provider) + return self.translate(shift[0:2], offset_provider, symbolic_domain_sizes).translate( + shift[2:], offset_provider, symbolic_domain_sizes + ) else: raise AssertionError("Number of shifts must be a multiple of 2.") @@ -152,13 +162,15 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: assert all(domain.ranges.keys() == domains[0].ranges.keys() for domain in domains) for dim in domains[0].ranges.keys(): start = functools.reduce( - lambda current_expr, el_expr: im.call("minimum")(current_expr, el_expr), + lambda current_expr, el_expr: im.minimum(current_expr, el_expr), [domain.ranges[dim].start for domain in domains], ) stop = functools.reduce( - lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), + lambda current_expr, el_expr: im.maximum(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) diff --git a/src/gt4py/next/iterator/ir_utils/ir_makers.py b/src/gt4py/next/iterator/ir_utils/ir_makers.py index b2662fa278..9d77ca4686 100644 --- a/src/gt4py/next/iterator/ir_utils/ir_makers.py +++ b/src/gt4py/next/iterator/ir_utils/ir_makers.py @@ -10,9 +10,8 @@ from typing import Callable, Optional, Union from gt4py._core import definitions as core_defs -from gt4py.eve.extended_typing import Dict, Tuple from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.type_system import type_specifications as ts, type_translation @@ -30,7 +29,7 @@ def sym(sym_or_name: Union[str, itir.Sym], type_: str | ts.TypeSpec | None = Non >>> a = sym("a", "float32") >>> a.id, a.type - (SymbolName('a'), ScalarType(kind=, shape=None)) + (SymbolName('a'), ScalarType(kind=, shape=None)) """ if isinstance(sym_or_name, itir.Sym): assert not type_ @@ -54,7 +53,7 @@ def ref( >>> a = ref("a", "float32") >>> a.id, a.type - (SymbolRef('a'), ScalarType(kind=, shape=None)) + (SymbolRef('a'), ScalarType(kind=, shape=None)) """ if isinstance(ref_or_name, itir.SymRef): assert not type_ @@ -72,7 +71,7 @@ def ensure_expr(literal_or_expr: Union[str, core_defs.Scalar, itir.Expr]) -> iti SymRef(id=SymbolRef('a')) >>> ensure_expr(3) - Literal(value='3', type=ScalarType(kind=, shape=None)) + Literal(value='3', type=ScalarType(kind=, shape=None)) >>> ensure_expr(itir.OffsetLiteral(value="i")) OffsetLiteral(value='i') @@ -135,7 +134,7 @@ class call: Examples -------- >>> call("plus")(1, 1) - FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), Literal(value='1', type=ScalarType(kind=, shape=None))]) + FunCall(fun=SymRef(id=SymbolRef('plus')), args=[Literal(value='1', type=ScalarType(kind=, shape=None)), Literal(value='1', type=ScalarType(kind=, shape=None))]) """ def __init__(self, expr): @@ -170,18 +169,6 @@ def divides_(left, right): return call("divides")(left, right) -def floordiv_(left, right): - """Create a floor division FunCall, shorthand for ``call("floordiv")(left, right)``.""" - # TODO(tehrengruber): Use int(floor(left/right)) as soon as we support integer casting - # and remove the `floordiv` builtin again. - return call("floordiv")(left, right) - - -def mod(left, right): - """Create a modulo FunCall, shorthand for ``call("mod")(left, right)``.""" - return call("mod")(left, right) - - def and_(left, right): """Create an and_ FunCall, shorthand for ``call("and_")(left, right)``.""" return call("and_")(left, right) @@ -239,7 +226,7 @@ def make_tuple(*args): def tuple_get(index: str | int, tuple_expr): """Create a tuple_get FunCall, shorthand for ``call("tuple_get")(index, tuple_expr)``.""" - return call("tuple_get")(literal(str(index), itir.INTEGER_INDEX_BUILTIN), tuple_expr) + return call("tuple_get")(literal(str(index), builtins.INTEGER_INDEX_BUILTIN), tuple_expr) def if_(cond, true_val, false_val): @@ -303,7 +290,10 @@ def shift(offset, value=None): offset = ensure_offset(offset) args = [offset] if value is not None: - value = ensure_offset(value) + if isinstance(value, int): + value = ensure_offset(value) + elif isinstance(value, str): + value = ref(value) args.append(value) return call(call("shift")(*args)) @@ -317,11 +307,11 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: Make a literal node from a value. >>> literal_from_value(1.0) - Literal(value='1.0', type=ScalarType(kind=, shape=None)) + Literal(value='1.0', type=ScalarType(kind=, shape=None)) >>> literal_from_value(1) - Literal(value='1', type=ScalarType(kind=, shape=None)) + Literal(value='1', type=ScalarType(kind=, shape=None)) >>> literal_from_value(2147483648) - Literal(value='2147483648', type=ScalarType(kind=, shape=None)) + Literal(value='2147483648', type=ScalarType(kind=, shape=None)) >>> literal_from_value(True) Literal(value='True', type=ScalarType(kind=, shape=None)) """ @@ -336,7 +326,7 @@ def literal_from_value(val: core_defs.Scalar) -> itir.Literal: assert isinstance(type_spec, ts.ScalarType) typename = type_spec.kind.name.lower() - assert typename in itir.TYPEBUILTINS + assert typename in builtins.TYPE_BUILTINS return literal(str(val), typename) @@ -412,23 +402,15 @@ def _impl(*its: itir.Expr) -> itir.FunCall: def domain( grid_type: Union[common.GridType, str], - ranges: Dict[Union[common.Dimension, str], Tuple[itir.Expr, itir.Expr]], + ranges: dict[common.Dimension, tuple[itir.Expr, itir.Expr]], ) -> itir.FunCall: """ - >>> str( - ... domain( - ... common.GridType.CARTESIAN, - ... { - ... common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL): (0, 10), - ... common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL): (0, 20), - ... }, - ... ) - ... ) - 'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' - >>> str(domain(common.GridType.CARTESIAN, {"IDim": (0, 10), "JDim": (0, 20)})) - 'c⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' - >>> str(domain(common.GridType.UNSTRUCTURED, {"IDim": (0, 10), "JDim": (0, 20)})) - 'u⟨ IDimₕ: [0, 10), JDimₕ: [0, 20) ⟩' + >>> IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) + >>> JDim = common.Dimension(value="JDim", kind=common.DimensionKind.HORIZONTAL) + >>> str(domain(common.GridType.CARTESIAN, {IDim: (0, 10), JDim: (0, 20)})) + 'c⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' + >>> str(domain(common.GridType.UNSTRUCTURED, {IDim: (0, 10), JDim: (0, 20)})) + 'u⟨ IDimₕ: [0, 10[, JDimₕ: [0, 20[ ⟩' """ if isinstance(grid_type, common.GridType): grid_type = f"{grid_type!s}_domain" @@ -446,7 +428,7 @@ def domain( ) -def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: +def as_fieldop(expr: itir.Expr | str, domain: Optional[itir.Expr] = None) -> Callable: """ Create an `as_fieldop` call. @@ -455,7 +437,9 @@ def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: >>> str(as_fieldop(lambda_("it1", "it2")(plus(deref("it1"), deref("it2"))))("field1", "field2")) '(⇑(λ(it1, it2) → ·it1 + ·it2))(field1, field2)' """ - return call( + from gt4py.next.iterator.ir_utils import domain_utils + + result = call( call("as_fieldop")( *( ( @@ -468,9 +452,17 @@ def as_fieldop(expr: itir.Expr, domain: Optional[itir.FunCall] = None) -> call: ) ) + def _populate_domain_annex_wrapper(*args, **kwargs): + node = result(*args, **kwargs) + if domain: + node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + return node + + return _populate_domain_annex_wrapper + def op_as_fieldop( - op: str | itir.SymRef | Callable, domain: Optional[itir.FunCall] = None + op: str | itir.SymRef | itir.Lambda | Callable, domain: Optional[itir.FunCall] = None ) -> Callable[..., itir.FunCall]: """ Promotes a function `op` to a field_operator. @@ -498,6 +490,79 @@ def _impl(*its: itir.Expr) -> itir.FunCall: return _impl +def cast_as_fieldop(type_: str, domain: Optional[itir.FunCall] = None): + """ + Promotes the function `cast_` to a field_operator. + + Args: + type_: the target type to be passed as argument to `cast_` function. + domain: the domain of the returned field. + + Returns: + A function from Fields to Field. + + Examples: + >>> str(cast_as_fieldop("float32")("a")) + '(⇑(λ(__arg0) → cast_(·__arg0, float32)))(a)' + """ + + def _impl(it: itir.Expr) -> itir.FunCall: + return op_as_fieldop(lambda v: call("cast_")(v, type_), domain)(it) + + return _impl + + +def index(dim: common.Dimension) -> itir.FunCall: + """ + Create a call to the `index` builtin, shorthand for `call("index")(axis)`, + after converting the given dimension to `itir.AxisLiteral`. + + Args: + dim: the dimension corresponding to the index axis. + + Returns: + A function that constructs a Field of indices in the given dimension. + """ + return call("index")(itir.AxisLiteral(value=dim.value, kind=dim.kind)) + + def map_(op): """Create a `map_` call.""" return call(call("map_")(op)) + + +def reduce(op, expr): + """Create a `reduce` call.""" + return call(call("reduce")(op, expr)) + + +def scan(expr, forward, init): + """Create a `scan` call.""" + return call("scan")(expr, forward, init) + + +def list_get(list_idx, list_): + """Create a `list_get` call.""" + return call("list_get")(list_idx, list_) + + +def maximum(expr1, expr2): + """Create a `maximum` call.""" + return call("maximum")(expr1, expr2) + + +def minimum(expr1, expr2): + """Create a `minimum` call.""" + return call("minimum")(expr1, expr2) + + +def cast_(expr, dtype: ts.ScalarType | str): + """Create a `cast_` call.""" + if isinstance(dtype, ts.ScalarType): + dtype = dtype.kind.name.lower() + return call("cast_")(expr, dtype) + + +def can_deref(expr): + """Create a `can_deref` call.""" + return call("can_deref")(expr) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index 08459a9423..a077b39911 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -31,9 +31,9 @@ INT_LITERAL: SIGNED_INT FLOAT_LITERAL: SIGNED_FLOAT OFFSET_LITERAL: ( INT_LITERAL | CNAME ) "ₒ" - _literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL + AXIS_LITERAL: CNAME ("ᵥ" | "ₕ") + _literal: INT_LITERAL | FLOAT_LITERAL | OFFSET_LITERAL | AXIS_LITERAL ID_NAME: CNAME - AXIS_NAME: CNAME ("ᵥ" | "ₕ") ?prec0: prec1 | "λ(" ( SYM "," )* SYM? ")" "→" prec0 -> lam @@ -84,7 +84,7 @@ else_branch_seperator: "else" if_stmt: "if" "(" prec0 ")" "{" ( stmt )* "}" else_branch_seperator "{" ( stmt )* "}" - named_range: AXIS_NAME ":" "[" prec0 "," prec0 ")" + named_range: AXIS_LITERAL ":" "[" prec0 "," prec0 "[" function_definition: ID_NAME "=" "λ(" ( SYM "," )* SYM? ")" "→" prec0 ";" declaration: ID_NAME "=" "temporary(" "domain=" prec0 "," "dtype=" TYPE_LITERAL ")" ";" stencil_closure: prec0 "←" "(" prec0 ")" "(" ( SYM_REF ", " )* SYM_REF ")" "@" prec0 ";" @@ -128,7 +128,7 @@ def OFFSET_LITERAL(self, value: lark_lexer.Token) -> ir.OffsetLiteral: def ID_NAME(self, value: lark_lexer.Token) -> str: return value.value - def AXIS_NAME(self, value: lark_lexer.Token) -> ir.AxisLiteral: + def AXIS_LITERAL(self, value: lark_lexer.Token) -> ir.AxisLiteral: name = value.value[:-1] kind = ir.DimensionKind.HORIZONTAL if value.value[-1] == "ₕ" else ir.DimensionKind.VERTICAL return ir.AxisLiteral(value=name, kind=kind) @@ -216,10 +216,6 @@ def function_definition(self, *args: ir.Node) -> ir.FunctionDefinition: fid, *params, expr = args return ir.FunctionDefinition(id=fid, params=params, expr=expr) - def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure: - output, stencil, *inputs, domain = args - return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - def if_stmt(self, cond: ir.Expr, *args): found_else_seperator = False true_branch = [] @@ -249,23 +245,6 @@ def set_at(self, *args: ir.Expr) -> ir.SetAt: target, domain, expr = args return ir.SetAt(expr=expr, domain=domain, target=target) - # TODO(havogt): remove after refactoring. - def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition: - params = [] - function_definitions = [] - closures = [] - for arg in args: - if isinstance(arg, ir.Sym): - params.append(arg) - elif isinstance(arg, ir.FunctionDefinition): - function_definitions.append(arg) - else: - assert isinstance(arg, ir.StencilClosure) - closures.append(arg) - return ir.FencilDefinition( - id=fid, function_definitions=function_definitions, params=params, closures=closures - ) - def program(self, fid: str, *args: ir.Node) -> ir.Program: params = [] function_definitions = [] diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 99287f8a11..7acbf5d23d 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -190,7 +190,9 @@ def visit_FunCall(self, node: ir.FunCall, *, prec: int) -> list[str]: if fun_name == "named_range" and len(node.args) == 3: # named_range(dim, start, stop) → dim: [star, stop) dim, start, end = self.visit(node.args, prec=0) - res = self._hmerge(dim, [": ["], start, [", "], end, [")"]) + res = self._hmerge( + dim, [": ["], start, [", "], end, ["["] + ) # to get matching parenthesis of functions return self._prec_parens(res, prec, PRECEDENCE["__call__"]) if fun_name == "cartesian_domain" and len(node.args) >= 1: # cartesian_domain(x, y, ...) → c{ x × y × ... } # noqa: RUF003 [ambiguous-unicode-character-comment] @@ -248,28 +250,6 @@ def visit_FunctionDefinition(self, node: ir.FunctionDefinition, prec: int) -> li vbody = self._vmerge(params, self._indent(expr)) return self._optimum(hbody, vbody) - def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[str]: - assert prec == 0 - domain = self.visit(node.domain, prec=0) - stencil = self.visit(node.stencil, prec=0) - output = self.visit(node.output, prec=0) - inputs = self.visit(node.inputs, prec=0) - - hinputs = self._hmerge(["("], *self._hinterleave(inputs, ", "), [")"]) - vinputs = self._vmerge(["("], *self._hinterleave(inputs, ",", indent=True), [")"]) - inputs = self._optimum(hinputs, vinputs) - - head = self._hmerge(output, [" ← "]) - foot = self._hmerge(inputs, [" @ "], domain, [";"]) - - h = self._hmerge(head, ["("], stencil, [")"], foot) - v = self._vmerge( - self._hmerge(head, ["("]), - self._indent(self._indent(stencil)), - self._indent(self._hmerge([")"], foot)), - ) - return self._optimum(h, v) - def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: start, end = [node.id + " = temporary("], [");"] args = [] @@ -312,25 +292,6 @@ def visit_IfStmt(self, node: ir.IfStmt, *, prec: int) -> list[str]: head, self._indent(true_branch), ["} else {"], self._indent(false_branch), ["}"] ) - def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]: - assert prec == 0 - function_definitions = self.visit(node.function_definitions, prec=0) - closures = self.visit(node.closures, prec=0) - params = self.visit(node.params, prec=0) - - hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"]) - vparams = self._vmerge( - [node.id + "("], *self._hinterleave(params, ",", indent=True), [") {"] - ) - params = self._optimum(hparams, vparams) - - function_definitions = self._vmerge(*function_definitions) - closures = self._vmerge(*closures) - - return self._vmerge( - params, self._indent(function_definitions), self._indent(closures), ["}"] - ) - def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index ad85d154cb..c9a5b15de7 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -12,7 +12,7 @@ import functools import types from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union import devtools @@ -26,7 +26,7 @@ # TODO(tehrengruber): remove cirular dependency and import unconditionally from gt4py.next import backend as next_backend -__all__ = ["offset", "fundef", "fendef", "closure", "set_at", "if_stmt"] +__all__ = ["fendef", "fundef", "if_stmt", "offset", "set_at"] @dataclass(frozen=True) @@ -127,7 +127,9 @@ def fendef( ) -def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[str, Any]): +def _deduce_domain( + domain: dict[common.Dimension, range], offset_provider_type: common.OffsetProviderType +): if isinstance(domain, UnstructuredDomain): domain_builtin = builtins.unstructured_domain elif isinstance(domain, CartesianDomain): @@ -135,7 +137,7 @@ def _deduce_domain(domain: dict[common.Dimension, range], offset_provider: dict[ else: domain_builtin = ( builtins.unstructured_domain - if any(isinstance(o, common.Connectivity) for o in offset_provider.values()) + if any(isinstance(o, common.ConnectivityType) for o in offset_provider_type.values()) else builtins.cartesian_domain ) @@ -160,8 +162,8 @@ def impl(out, *inps): elif isinstance(dom, dict): # if passed as a dict, we need to convert back to builtins for interpretation by the backends assert offset_provider is not None - dom = _deduce_domain(dom, offset_provider) - closure(dom, self.fundef_dispatcher, out, [*inps]) + dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider)) + set_at(builtins.as_fieldop(self.fundef_dispatcher, dom)(*inps), dom, out) return impl @@ -206,11 +208,6 @@ def fundef(fun): return FundefDispatcher(fun) -@builtin_dispatch -def closure(*args): # TODO remove - return BackendNotSelectedError() - - @builtin_dispatch def set_at(*args): return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 6772d4b507..12c86680b5 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -23,7 +23,6 @@ Lambda, NoneLiteral, OffsetLiteral, - StencilClosure, Sym, SymRef, ) @@ -202,9 +201,6 @@ def __bool__(self): class TracerContext: fundefs: ClassVar[List[FunctionDefinition]] = [] - closures: ClassVar[ - List[StencilClosure] - ] = [] # TODO(havogt): remove after refactoring to `Program` is complete, currently handles both programs and fencils body: ClassVar[List[itir.Stmt]] = [] @classmethod @@ -212,10 +208,6 @@ def add_fundef(cls, fun): if fun not in cls.fundefs: cls.fundefs.append(fun) - @classmethod - def add_closure(cls, closure): - cls.closures.append(closure) - @classmethod def add_stmt(cls, stmt): cls.body.append(stmt) @@ -225,23 +217,10 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): type(self).fundefs = [] - type(self).closures = [] type(self).body = [] iterator.builtins.builtin_dispatch.pop_key() -@iterator.runtime.closure.register(TRACING) -def closure(domain, stencil, output, inputs): - if hasattr(stencil, "__name__") and stencil.__name__ in iterator.builtins.__all__: - stencil = _s(stencil.__name__) - else: - stencil(*(_s(param) for param in inspect.signature(stencil).parameters)) - stencil = make_node(stencil) - TracerContext.add_closure( - StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - ) - - @iterator.runtime.set_at.register(TRACING) def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None: TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target)) @@ -279,7 +258,7 @@ def _contains_tuple_dtype_field(arg): return isinstance(arg, common.Field) and any(dim is None for dim in arg.domain.dims) -def _make_fencil_params(fun, args) -> list[Sym]: +def _make_program_params(fun, args) -> list[Sym]: params: list[Sym] = [] param_infos = list(inspect.signature(fun).parameters.values()) @@ -314,33 +293,22 @@ def _make_fencil_params(fun, args) -> list[Sym]: return params -def trace_fencil_definition( - fun: typing.Callable, args: typing.Iterable -) -> itir.FencilDefinition | itir.Program: +def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> itir.Program: """ - Transform fencil given as a callable into `itir.FencilDefinition` using tracing. + Transform fencil given as a callable into `itir.Program` using tracing. Arguments: - fun: The fencil / callable to trace. + fun: The program / callable to trace. args: A list of arguments, e.g. fields, scalars, composites thereof, or directly a type. """ with TracerContext() as _: - params = _make_fencil_params(fun, args) + params = _make_program_params(fun, args) trace_function_call(fun, args=(_s(param.id) for param in params)) - if TracerContext.closures: - return itir.FencilDefinition( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - closures=TracerContext.closures, - ) - else: - assert TracerContext.body - return itir.Program( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - declarations=[], # TODO - body=TracerContext.body, - ) + return itir.Program( + id=fun.__name__, + function_definitions=TracerContext.fundefs, + params=params, + declarations=[], # TODO + body=TracerContext.body, + ) diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index 58678cfc9c..1d91254ee8 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -6,7 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator.transforms.pass_manager import LiftMode, apply_common_transforms +from gt4py.next.iterator.transforms.pass_manager import ( + GTIRTransform, + apply_common_transforms, + apply_fieldview_transforms, +) -__all__ = ["apply_common_transforms", "LiftMode"] +__all__ = ["GTIRTransform", "apply_common_transforms", "apply_fieldview_transforms"] diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index f8a3c08e8f..b0a0c1e1dc 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -7,7 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py import eve -from gt4py.next.iterator import ir +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): @@ -18,32 +19,29 @@ class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): - `list_get(i, make_const_list(e))` -> `e` """ - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + def visit_FunCall(self, node: itir.FunCall, **kwargs) -> itir.Node: node = self.generic_visit(node) - if node.fun == ir.SymRef(id="list_get"): - if isinstance(node.args[1], ir.FunCall): - if node.args[1].fun == ir.SymRef(id="neighbors"): - offset_tag = node.args[1].args[0] - offset_index = ( - ir.OffsetLiteral(value=int(node.args[0].value)) - if isinstance(node.args[0], ir.Literal) - else node.args[ - 0 - ] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn - ) - it = node.args[1].args[1] - return ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), args=[offset_tag, offset_index] - ), - args=[it], - ) - ], - ) - if node.args[1].fun == ir.SymRef(id="make_const_list"): - return node.args[1].args[0] + if cpm.is_call_to(node, "list_get"): + if cpm.is_call_to(node.args[1], "if_"): + list_idx = node.args[0] + cond, true_val, false_val = node.args[1].args + return im.if_( + cond, + self.visit(im.list_get(list_idx, true_val)), + self.visit(im.list_get(list_idx, false_val)), + ) + if cpm.is_call_to(node.args[1], "neighbors"): + offset_tag = node.args[1].args[0] + offset_index = ( + itir.OffsetLiteral(value=int(node.args[0].value)) + if isinstance(node.args[0], itir.Literal) + else node.args[ + 0 + ] # else-branch: e.g. SymRef from unroll_reduce, TODO(havogt): remove when we replace unroll_reduce by list support in gtfn + ) + it = node.args[1].args[1] + return im.deref(im.shift(offset_tag, offset_index)(it)) + if cpm.is_call_to(node.args[1], "make_const_list"): + return node.args[1].args[0] return node diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 40d98208dd..462f87b600 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -16,21 +16,24 @@ from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, ir_makers as im, misc as ir_misc, ) +from gt4py.next.iterator.transforms import fixed_point_transformation from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda from gt4py.next.iterator.type_system import inference as itir_type_inference from gt4py.next.type_system import type_info, type_specifications as ts -def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr): +def _with_altered_arg(node: ir.FunCall, arg_idx: int, new_arg: ir.Expr | str): """Given a itir.FunCall return a new call with one of its argument replaced.""" return ir.FunCall( - fun=node.fun, args=[arg if i != arg_idx else new_arg for i, arg in enumerate(node.args)] + fun=node.fun, + args=[arg if i != arg_idx else im.ensure_expr(new_arg) for i, arg in enumerate(node.args)], ) @@ -46,13 +49,48 @@ def _is_trivial_make_tuple_call(node: ir.Expr): return True +def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool: + """ + Return `true` if the expr is a trivial expression (`SymRef` or `Literal`) or tuple thereof. + + Let forms with trivial body and args as well as `if` calls with trivial branches are also + considered trivial. + + >>> _is_trivial_or_tuple_thereof_expr(im.make_tuple("a", "b")) + True + >>> _is_trivial_or_tuple_thereof_expr(im.tuple_get(1, "a")) + True + >>> _is_trivial_or_tuple_thereof_expr( + ... im.let("t", im.make_tuple("a", "b"))(im.tuple_get(1, "t")) + ... ) + True + """ + if isinstance(node, (ir.SymRef, ir.Literal)): + return True + if cpm.is_call_to(node, "make_tuple"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args) + if cpm.is_call_to(node, "tuple_get"): + return _is_trivial_or_tuple_thereof_expr(node.args[1]) + # This will duplicate the condition and increase the size of the tree, but this is probably + # acceptable. + if cpm.is_call_to(node, "if_"): + return all(_is_trivial_or_tuple_thereof_expr(arg) for arg in node.args[1:]) + if cpm.is_let(node): + return _is_trivial_or_tuple_thereof_expr(node.fun.expr) and all( # type: ignore[attr-defined] # ensured by is_let + _is_trivial_or_tuple_thereof_expr(arg) for arg in node.args + ) + return False + + # TODO(tehrengruber): Conceptually the structure of this pass makes sense: Visit depth first, # transform each node until no transformations apply anymore, whenever a node is to be transformed # go through all available transformation and apply them. However the final result here still # reads a little convoluted and is also different to how we write other transformations. We # should revisit the pattern here and try to find a more general mechanism. -@dataclasses.dataclass(frozen=True) -class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): +@dataclasses.dataclass(frozen=True, kw_only=True) +class CollapseTuple( + fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor +): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -63,7 +101,7 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): # TODO(tehrengruber): This Flag mechanism is a little low level. What we actually want # is something like a pass manager, where for each pattern we have a corresponding # transformation, etc. - class Flag(enum.Flag): + class Transformation(enum.Flag): #: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` COLLAPSE_MAKE_TUPLE_TUPLE_GET = enum.auto() #: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` @@ -75,27 +113,41 @@ class Flag(enum.Flag): #: `let(tup, {trivial_expr1, trivial_expr2})(foo(tup))` #: -> `foo({trivial_expr1, trivial_expr2})` INLINE_TRIVIAL_MAKE_TUPLE = enum.auto() + #: Similar as `PROPAGATE_TO_IF_ON_TUPLES`, but propagates in the opposite direction, i.e. + #: into the tree, allowing removal of tuple expressions across `if_` calls without + #: increasing the size of the tree. This is particularly important for `if` statements + #: in the frontend, where outwards propagation can have devastating effects on the tree + #: size, without any gained optimization potential. For example + #: ``` + #: complex_lambda(if cond1 + #: if cond2 + #: {...} + #: else: + #: {...} + #: else + #: {...}) + #: ``` + #: is problematic, since `PROPAGATE_TO_IF_ON_TUPLES` would propagate and hence duplicate + #: `complex_lambda` three times, while we only want to get rid of the tuple expressions + #: inside of the `if_`s. + #: Note that this transformation is not mutually exclusive to `PROPAGATE_TO_IF_ON_TUPLES`. + PROPAGATE_TO_IF_ON_TUPLES_CPS = enum.auto() #: `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` PROPAGATE_TO_IF_ON_TUPLES = enum.auto() #: `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` PROPAGATE_NESTED_LET = enum.auto() - #: `let(a, 1)(a)` -> `1` + #: `let(a, 1)(a)` -> `1` or `let(a, b)(f(a))` -> `f(a)` INLINE_TRIVIAL_LET = enum.auto() @classmethod - def all(self) -> CollapseTuple.Flag: + def all(self) -> CollapseTuple.Transformation: return functools.reduce(operator.or_, self.__members__.values()) + uids: eve_utils.UIDGenerator ignore_tuple_size: bool - flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] + enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] - PRESERVED_ANNEX_ATTRS = ("type",) - - # we use one UID generator per instance such that the generated ids are - # stable across multiple runs (required for caching to properly work) - _letify_make_tuple_uids: eve_utils.UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: eve_utils.UIDGenerator(prefix="_tuple_el") - ) + PRESERVED_ANNEX_ATTRS = ("type", "domain") @classmethod def apply( @@ -104,11 +156,13 @@ def apply( *, ignore_tuple_size: bool = False, remove_letified_make_tuple_elements: bool = True, - offset_provider=None, - # manually passing flags is mostly for allowing separate testing of the modes - flags=None, + offset_provider_type: Optional[common.OffsetProviderType] = None, + within_stencil: Optional[bool] = None, + # manually passing enabled transformations is mostly for allowing separate testing of the modes + enabled_transformations: Optional[Transformation] = None, # allow sym references without a symbol declaration, mostly for testing allow_undeclared_symbols: bool = False, + uids: Optional[eve_utils.UIDGenerator] = None, ) -> ir.Node: """ Simplifies `make_tuple`, `tuple_get` calls. @@ -123,20 +177,29 @@ def apply( to remove left-overs from `LETIFY_MAKE_TUPLE_ELEMENTS` transformation. `(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}` """ - flags = flags or cls.flags - offset_provider = offset_provider or {} + enabled_transformations = enabled_transformations or cls.enabled_transformations + offset_provider_type = offset_provider_type or {} + uids = uids or eve_utils.UIDGenerator() + + if isinstance(node, ir.Program): + within_stencil = False + assert within_stencil in [ + True, + False, + ], "Parameter 'within_stencil' mandatory if node is not a 'Program'." if not ignore_tuple_size: node = itir_type_inference.infer( node, - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, allow_undeclared_symbols=allow_undeclared_symbols, ) new_node = cls( ignore_tuple_size=ignore_tuple_size, - flags=flags, - ).visit(node) + enabled_transformations=enabled_transformations, + uids=uids, + ).visit(node, within_stencil=within_stencil) # inline to remove left-overs from LETIFY_MAKE_TUPLE_ELEMENTS. this is important # as otherwise two equal expressions containing a tuple will not be equal anymore @@ -150,36 +213,17 @@ def apply( return new_node - def visit_FunCall(self, node: ir.FunCall) -> ir.Node: - node = self.generic_visit(node) - return self.fp_transform(node) - - def fp_transform(self, node: ir.Node) -> ir.Node: - while True: - new_node = self.transform(node) - if new_node is None: - break - assert new_node != node - node = new_node - return node - - def transform(self, node: ir.Node) -> Optional[ir.Node]: - if not isinstance(node, ir.FunCall): - return None + def visit(self, node, **kwargs): + if cpm.is_call_to(node, "as_fieldop"): + kwargs = {**kwargs, "within_stencil": True} - for transformation in self.Flag: - if self.flags & transformation: - assert isinstance(transformation.name, str) - method = getattr(self, f"transform_{transformation.name.lower()}") - result = method(node) - if result is not None: - return result - return None + return super().visit(node, **kwargs) - def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="make_tuple") and all( - isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get") - for arg in node.args + def transform_collapse_make_tuple_tuple_get( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: + if cpm.is_call_to(node, "make_tuple") and all( + cpm.is_call_to(arg, "tuple_get") for arg in node.args ): # `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t` assert isinstance(node.args[0], ir.FunCall) @@ -192,20 +236,27 @@ def transform_collapse_make_tuple_tuple_get(self, node: ir.FunCall) -> Optional[ # tuple argument differs, just continue with the rest of the tree return None - assert self.ignore_tuple_size or isinstance(first_expr.type, ts.TupleType) - if self.ignore_tuple_size or len(first_expr.type.types) == len(node.args): # type: ignore[union-attr] # ensured by assert above + itir_type_inference.reinfer(first_expr) # type is needed so reinfer on-demand + assert self.ignore_tuple_size or isinstance( + first_expr.type, (ts.TupleType, ts.DeferredType) + ) + if self.ignore_tuple_size or ( + isinstance(first_expr.type, ts.TupleType) + and len(first_expr.type.types) == len(node.args) + ): return first_expr return None - def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_collapse_tuple_get_make_tuple( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: if ( - node.fun == ir.SymRef(id="tuple_get") - and isinstance(node.args[1], ir.FunCall) - and node.args[1].fun == ir.SymRef(id="make_tuple") + cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], ir.Literal) + and cpm.is_call_to(node.args[1], "make_tuple") ): # `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i` - assert type_info.is_integer(node.args[0].type) + assert not node.args[0].type or type_info.is_integer(node.args[0].type) make_tuple_call = node.args[1] idx = int(node.args[0].value) assert idx < len( @@ -214,8 +265,8 @@ def transform_collapse_tuple_get_make_tuple(self, node: ir.FunCall) -> Optional[ return node.args[1].args[idx] return None - def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal): + def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], ir.Literal): # TODO(tehrengruber): extend to general symbols as long as the tail call in the let # does not capture # `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))` @@ -223,7 +274,7 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: idx, let_expr = node.args return im.call( im.lambda_(*let_expr.fun.params)( # type: ignore[attr-defined] # ensured by is_let - self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr)) # type: ignore[attr-defined] # ensured by is_let + self.fp_transform(im.tuple_get(idx.value, let_expr.fun.expr), **kwargs) # type: ignore[attr-defined] # ensured by is_let ) )( *let_expr.args # type: ignore[attr-defined] # ensured by is_let @@ -233,41 +284,47 @@ def transform_propagate_tuple_get(self, node: ir.FunCall) -> Optional[ir.Node]: cond, true_branch, false_branch = node.args[1].args return im.if_( cond, - self.fp_transform(im.tuple_get(idx.value, true_branch)), - self.fp_transform(im.tuple_get(idx.value, false_branch)), + self.fp_transform(im.tuple_get(idx.value, true_branch), **kwargs), + self.fp_transform(im.tuple_get(idx.value, false_branch), **kwargs), ) return None - def transform_letify_make_tuple_elements(self, node: ir.FunCall) -> Optional[ir.Node]: - if node.fun == ir.SymRef(id="make_tuple"): + def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + if cpm.is_call_to(node, "make_tuple"): # `make_tuple(expr1, expr1)` # -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))` - bound_vars: dict[str, ir.Expr] = {} + bound_vars: dict[ir.Sym, ir.Expr] = {} new_args: list[ir.Expr] = [] for arg in node.args: if cpm.is_call_to(node, "make_tuple") and not _is_trivial_make_tuple_call(node): - el_name = self._letify_make_tuple_uids.sequential_id() - new_args.append(im.ref(el_name)) - bound_vars[el_name] = arg + el_name = self.uids.sequential_id(prefix="__ct_el") + new_args.append(im.ref(el_name, arg.type)) + bound_vars[im.sym(el_name, arg.type)] = arg else: new_args.append(arg) if bound_vars: - return self.fp_transform(im.let(*bound_vars.items())(im.call(node.fun)(*new_args))) + return self.fp_transform( + im.let(*bound_vars.items())(im.call(node.fun)(*new_args)), **kwargs + ) return None - def transform_inline_trivial_make_tuple(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))` # -> `foo(make_tuple(trivial_expr1, trivial_expr2))` eligible_params = [_is_trivial_make_tuple_call(arg) for arg in node.args] if any(eligible_params): - return self.visit(inline_lambda(node, eligible_params=eligible_params)) + return self.visit(inline_lambda(node, eligible_params=eligible_params), **kwargs) return None - def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.Node]: - if not cpm.is_call_to(node, "if_"): - # TODO(tehrengruber): This significantly increases the size of the tree. Revisit. + def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + if kwargs["within_stencil"]: + # TODO(tehrengruber): This significantly increases the size of the tree. Skip transformation + # in local-view for now. Revisit. + return None + + if isinstance(node, ir.FunCall) and not cpm.is_call_to(node, "if_"): # TODO(tehrengruber): Only inline if type of branch value is a tuple. # Examples: # `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]` @@ -276,12 +333,112 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall) -> Optional[ir.N for i, arg in enumerate(node.args): if cpm.is_call_to(arg, "if_"): cond, true_branch, false_branch = arg.args - new_true_branch = self.fp_transform(_with_altered_arg(node, i, true_branch)) - new_false_branch = self.fp_transform(_with_altered_arg(node, i, false_branch)) + new_true_branch = self.fp_transform( + _with_altered_arg(node, i, true_branch), **kwargs + ) + new_false_branch = self.fp_transform( + _with_altered_arg(node, i, false_branch), **kwargs + ) return im.if_(cond, new_true_branch, new_false_branch) return None - def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: + def transform_propagate_to_if_on_tuples_cps( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: + # The basic idea of this transformation is to remove tuples across if-stmts by rewriting + # the expression in continuation passing style, e.g. something like a tuple reordering + # ``` + # let t = if True then {1, 2} else {3, 4} in + # {t[1], t[0]}) + # end + # ``` + # is rewritten into: + # ``` + # let cont = λ(el0, el1) → {el1, el0} in + # if True then cont(1, 2) else cont(3, 4) + # end + # ``` + # Note how the `make_tuple` call argument of the `if` disappears. Since lambda functions + # are currently inlined (due to limitations of the domain inference) we will only + # gain something compared `PROPAGATE_TO_IF_ON_TUPLES` if the continuation `cont` is trivial, + # e.g. a `make_tuple` call like in the example. In that case we can inline the trivial + # continuation and end up with an only moderately larger tree, e.g. + # `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this + # tuple reordering example here. + + if not isinstance(node, ir.FunCall) or cpm.is_call_to(node, "if_"): + return None + + # The first argument that is eligible also transforms all remaining args (They will be + # part of the continuation which is recursively transformed). + for i, arg in enumerate(node.args): + if cpm.is_call_to(arg, "if_"): + itir_type_inference.reinfer(arg) + + cond, true_branch, false_branch = arg.args # e.g. `True`, `{1, 2}`, `{3, 4}` + if not any( + isinstance(branch.type, ts.TupleType) for branch in [true_branch, false_branch] + ): + continue + tuple_type: ts.TupleType = true_branch.type # type: ignore[assignment] # type ensured above + tuple_len = len(tuple_type.types) + + # build and simplify continuation, e.g. λ(el0, el1) → {el1, el0} + itir_type_inference.reinfer(node) + assert node.type + f_type = ts.FunctionType( # type of continuation in order to keep full type info + pos_only_args=tuple_type.types, + pos_or_kw_args={}, + kw_only_args={}, + returns=node.type, + ) + f_params = [ + im.sym(self.uids.sequential_id(prefix="__ct_el_cps"), type_) + for type_ in tuple_type.types + ] + f_args = [im.ref(param.id, param.type) for param in f_params] + f_body = _with_altered_arg(node, i, im.make_tuple(*f_args)) + # simplify, e.g., inline trivial make_tuple args + new_f_body = self.fp_transform(f_body, **kwargs) + # if the continuation did not simplify there is nothing to gain. Skip + # transformation of this argument. + if new_f_body is f_body: + continue + # if the function is not trivial the transformation we would create a larger tree + # after inlining so we skip transformation this argument. + if not _is_trivial_or_tuple_thereof_expr(new_f_body): + continue + f = im.lambda_(*f_params)(new_f_body) + + # this is the symbol refering to the tuple value inside the two branches of the + # if, e.g. a symbol refering to `{1, 2}` and `{3, 4}` respectively + tuple_var = self.uids.sequential_id(prefix="__ct_tuple_cps") + # this is the symbol refering to our continuation, e.g. `cont` in our example. + f_var = self.uids.sequential_id(prefix="__ct_cont") + new_branches = [] + for branch in [true_branch, false_branch]: + new_branch = im.let(tuple_var, branch)( + im.call(im.ref(f_var, f_type))( # call to the continuation + *( + im.tuple_get(i, im.ref(tuple_var, branch.type)) + for i in range(tuple_len) + ) + ) + ) + new_branches.append(self.fp_transform(new_branch, **kwargs)) + + # assemble everything together + new_node = im.let(f_var, f)(im.if_(cond, *new_branches)) + new_node = inline_lambda(new_node, eligible_params=[True]) + assert cpm.is_call_to(new_node, "if_") + new_node = im.if_( + cond, *(self.fp_transform(branch, **kwargs) for branch in new_node.args[1:]) + ) + return new_node + + return None + + def transform_propagate_nested_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if cpm.is_let(node): # `let((a, let(b, 1)(a_val)))(a)`-> `let(b, 1)(let(a, a_val)(a))` outer_vars = {} @@ -299,15 +456,22 @@ def transform_propagate_nested_let(self, node: ir.FunCall) -> Optional[ir.Node]: if outer_vars: return self.fp_transform( im.let(*outer_vars.items())( - self.fp_transform(im.let(*inner_vars.items())(original_inner_expr)) - ) + self.fp_transform( + im.let(*inner_vars.items())(original_inner_expr), **kwargs + ) + ), + **kwargs, ) return None - def transform_inline_trivial_let(self, node: ir.FunCall) -> Optional[ir.Node]: - if cpm.is_let(node) and isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let - # `let(a, 1)(a)` -> `1` - for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let - if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let - return arg + def transform_inline_trivial_let(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + if cpm.is_let(node): + if isinstance(node.fun.expr, ir.SymRef): # type: ignore[attr-defined] # ensured by is_let + # `let(a, 1)(a)` -> `1` + for arg_sym, arg in zip(node.fun.params, node.args): # type: ignore[attr-defined] # ensured by is_let + if isinstance(node.fun.expr, ir.SymRef) and node.fun.expr.id == arg_sym.id: # type: ignore[attr-defined] # ensured by is_let + return arg + if any(trivial_args := [isinstance(arg, (ir.SymRef, ir.Literal)) for arg in node.args]): + return inline_lambda(node, eligible_params=trivial_args) + return None diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index 2084ab2518..fdbfec99ca 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -6,52 +6,226 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import embedded, ir -from gt4py.next.iterator.ir_utils import ir_makers as im +from __future__ import annotations +import dataclasses +import enum +import functools +import operator +from typing import Optional + +from gt4py import eve +from gt4py.next.iterator import builtins, embedded, ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import fixed_point_transformation + + +def _value_from_literal(literal: ir.Literal): + return getattr(embedded, str(literal.type))(literal.value) + + +class UndoCanonicalizeMinus(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: + node = super().generic_visit(node, **kwargs) + # `a + (-b)` -> `a - b` , `-a + b` -> `b - a`, `-a + (-b)` -> `-a - b` + if cpm.is_call_to(node, "plus"): + a, b = node.args + if cpm.is_call_to(b, "neg"): + return im.minus(a, b.args[0]) + if isinstance(b, ir.Literal) and _value_from_literal(b) < 0: + return im.minus(a, -_value_from_literal(b)) + if cpm.is_call_to(a, "neg"): + return im.minus(b, a.args[0]) + if isinstance(a, ir.Literal) and _value_from_literal(a) < 0: + return im.minus(b, -_value_from_literal(a)) + return node + + +_COMMUTATIVE_OPS = ("plus", "multiplies", "minimum", "maximum") + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class ConstantFolding( + fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor +): + PRESERVED_ANNEX_ATTRS = ( + "type", + "domain", + ) + + class Transformation(enum.Flag): + # `1 + a` -> `a + 1`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `1 + f(...)` -> `f(...) + 1`, prerequisite for FOLD_FUNCALL_LITERAL, FOLD_NEUTRAL_OP + # `f(...) + (expr1 + expr2)` -> `(expr1 + expr2) + f(...)`, for `s[0] + (s[0] + 1)`, prerequisite for FOLD_MIN_MAX_PLUS + CANONICALIZE_OP_FUNCALL_SYMREF_LITERAL = enum.auto() + + # `a - b` -> `a + (-b)`, prerequisite for FOLD_MIN_MAX_PLUS + CANONICALIZE_MINUS = enum.auto() + + # `maximum(a, maximum(...))` -> `maximum(maximum(...), a)`, prerequisite for FOLD_MIN_MAX + CANONICALIZE_MIN_MAX = enum.auto() + + # `(a + 1) + 1` -> `a + (1 + 1)` + FOLD_FUNCALL_LITERAL = enum.auto() + + # `maximum(maximum(a, 1), a)` -> `maximum(a, 1)` + # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` + FOLD_MIN_MAX = enum.auto() + + # `maximum(a + 1), a)` -> `a + 1` + # `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)` + FOLD_MIN_MAX_PLUS = enum.auto() + + # `a + 0` -> `a`, `a * 1` -> `a` + FOLD_NEUTRAL_OP = enum.auto() + + # `1 + 1` -> `2` + FOLD_ARITHMETIC_BUILTINS = enum.auto() + + # `minimum(a, a)` -> `a` + FOLD_MIN_MAX_LITERALS = enum.auto() + + # `if_(True, true_branch, false_branch)` -> `true_branch` + FOLD_IF = enum.auto() + + @classmethod + def all(self) -> ConstantFolding.Transformation: + return functools.reduce(operator.or_, self.__members__.values()) + + enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument] -class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node) -> ir.Node: - return cls().visit(node) + node = cls().visit(node) + return UndoCanonicalizeMinus().visit(node) + + def transform_canonicalize_op_funcall_symref_literal( + self, node: ir.FunCall, **kwargs + ) -> Optional[ir.Node]: + # `op(literal, symref|funcall)` -> `op(symref|funcall, literal)` + # `op1(funcall, op2(...))` -> `op1(op2(...), funcall)` for `s[0] + (s[0] + 1)` + if cpm.is_call_to(node, _COMMUTATIVE_OPS): + a, b = node.args + if (isinstance(a, ir.Literal) and not isinstance(b, ir.Literal)) or ( + not cpm.is_call_to(a, _COMMUTATIVE_OPS) and cpm.is_call_to(b, _COMMUTATIVE_OPS) + ): + return im.call(node.fun)(b, a) + return None + + def transform_canonicalize_minus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `a - b` -> `a + (-b)` + if cpm.is_call_to(node, "minus"): + return im.plus(node.args[0], self.fp_transform(im.call("neg")(node.args[1]))) + return None - def visit_FunCall(self, node: ir.FunCall): - # visit depth-first such that nested constant expressions (e.g. `(1+2)+3`) are properly folded - new_node = self.generic_visit(node) + def transform_canonicalize_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `maximum(a, maximum(...))` -> `maximum(maximum(...), a)` + if cpm.is_call_to(node, ("maximum", "minimum")): + op = node.fun.id # type: ignore[attr-defined] # assured by if above + if cpm.is_call_to(node.args[1], op) and not cpm.is_call_to(node.args[0], op): + return im.call(op)(node.args[1], node.args[0]) + return None + def transform_fold_funcall_literal(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `(a + 1) + 1` -> `a + (1 + 1)` + if cpm.is_call_to(node, "plus"): + if cpm.is_call_to(node.args[0], "plus") and isinstance(node.args[1], ir.Literal): + (expr, lit1), lit2 = node.args[0].args, node.args[1] + if isinstance(expr, (ir.SymRef, ir.FunCall)) and isinstance(lit1, ir.Literal): + return im.plus( + expr, + self.fp_transform(im.plus(lit1, lit2)), + ) + return None + + def transform_fold_min_max(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `maximum(maximum(a, 1), a)` -> `maximum(a, 1)` + # `maximum(maximum(a, 1), 1)` -> `maximum(a, 1)` + if cpm.is_call_to(node, ("minimum", "maximum")): + op = node.fun.id # type: ignore[attr-defined] # assured by if above + if cpm.is_call_to(node.args[0], op): + fun_call, arg1 = node.args + if arg1 in fun_call.args: # type: ignore[attr-defined] # assured by if above + return fun_call + return None + + def transform_fold_min_max_plus(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id in ["minimum", "maximum"] - and new_node.args[0] == new_node.args[1] - ): # `minimum(a, a)` -> `a` - return new_node.args[0] + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and cpm.is_call_to(node, ("minimum", "maximum")) + ): + arg0, arg1 = node.args + # `maximum(a + 1, a)` -> `a + 1` + if cpm.is_call_to(arg0, "plus"): + if arg0.args[0] == arg1: + return im.plus( + arg0.args[0], self.fp_transform(im.call(node.fun.id)(arg0.args[1], 0)) + ) + # `maximum(a + 1, a + (-1))` -> `a + maximum(1, -1)` + if cpm.is_call_to(arg0, "plus") and cpm.is_call_to(arg1, "plus"): + if arg0.args[0] == arg1.args[0]: + return im.plus( + arg0.args[0], + self.fp_transform(im.call(node.fun.id)(arg0.args[1], arg1.args[1])), + ) + return None + + def transform_fold_neutral_op(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `a + 0` -> `a`, `a * 1` -> `a` if ( - isinstance(new_node.fun, ir.SymRef) - and new_node.fun.id == "if_" - and isinstance(new_node.args[0], ir.Literal) - ): # `if_(True, true_branch, false_branch)` -> `true_branch` - if new_node.args[0].value == "True": - new_node = new_node.args[1] - else: - new_node = new_node.args[2] + cpm.is_call_to(node, "plus") + and isinstance(node.args[1], ir.Literal) + and node.args[1].value.isdigit() + and int(node.args[1].value) == 0 + ) or ( + cpm.is_call_to(node, "multiplies") + and isinstance(node.args[1], ir.Literal) + and node.args[1].value.isdigit() + and int(node.args[1].value) == 1 + ): + return node.args[0] + return None + @classmethod + def transform_fold_arithmetic_builtins(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `1 + 1` -> `2` if ( - isinstance(new_node, ir.FunCall) - and isinstance(new_node.fun, ir.SymRef) - and len(new_node.args) > 0 - and all(isinstance(arg, ir.Literal) for arg in new_node.args) - ): # `1 + 1` -> `2` + isinstance(node, ir.FunCall) + and isinstance(node.fun, ir.SymRef) + and len(node.args) > 0 + and all(isinstance(arg, ir.Literal) for arg in node.args) + ): try: - if new_node.fun.id in ir.ARITHMETIC_BUILTINS: - fun = getattr(embedded, str(new_node.fun.id)) + if node.fun.id in builtins.ARITHMETIC_BUILTINS: + fun = getattr(embedded, str(node.fun.id)) arg_values = [ - getattr(embedded, str(arg.type))(arg.value) # type: ignore[attr-defined] # arg type already established in if condition - for arg in new_node.args + _value_from_literal(arg) # type: ignore[arg-type] # arg type already established in if condition + for arg in node.args ] - new_node = im.literal_from_value(fun(*arg_values)) + return im.literal_from_value(fun(*arg_values)) except ValueError: pass # happens for inf and neginf + return None + + def transform_fold_min_max_literals(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `minimum(a, a)` -> `a` + if cpm.is_call_to(node, ("minimum", "maximum")): + if node.args[0] == node.args[1]: + return node.args[0] + return None - return new_node + def transform_fold_if(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]: + # `if_(True, true_branch, false_branch)` -> `true_branch` + if cpm.is_call_to(node, "if_") and isinstance(node.args[0], ir.Literal): + if node.args[0].value == "True": + return node.args[1] + else: + return node.args[2] + return None diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 1a89adbb20..cc1ffc3c50 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -14,6 +14,7 @@ import operator from typing import Callable, Iterable, TypeVar, Union, cast +import gt4py.next.iterator.ir_utils.ir_makers as im from gt4py.eve import ( NodeTranslator, NodeVisitor, @@ -30,9 +31,24 @@ from gt4py.next.type_system import type_info, type_specifications as ts +def _is_trivial_tuple_expr(node: itir.Expr): + """Return if node is a `make_tuple` call with all elements `SymRef`s, `Literal`s or tuples thereof.""" + if cpm.is_call_to(node, "make_tuple") and all( + isinstance(arg, (itir.SymRef, itir.Literal)) or _is_trivial_tuple_expr(arg) + for arg in node.args + ): + return True + if cpm.is_call_to(node, "tuple_get") and ( + isinstance(node.args[1], (itir.SymRef, itir.Literal)) + or _is_trivial_tuple_expr(node.args[1]) + ): + return True + return False + + @dataclasses.dataclass class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): - PRESERVED_ANNEX_ATTRS = ("type",) + PRESERVED_ANNEX_ATTRS = ("type", "domain") expr_map: dict[int, itir.SymRef] @@ -43,15 +59,16 @@ def visit_Expr(self, node: itir.Node) -> itir.Node: def visit_FunCall(self, node: itir.FunCall) -> itir.Node: node = cast(itir.FunCall, self.visit_Expr(node)) + # TODO(tehrengruber): Use symbol name from the inner let, to increase readability of IR # If we encounter an expression like: # (λ(_cs_1) → (λ(a) → a+a)(_cs_1))(outer_expr) # (non-recursively) inline the lambda to obtain: # (λ(_cs_1) → _cs_1+_cs_1)(outer_expr) - # This allows identifying more common subexpressions later on + # In the CSE this allows identifying more common subexpressions later on. Other users + # of `extract_subexpression` (e.g. temporary extraction) can also rely on this to avoid + # the need to handle this artificial let-statements. if isinstance(node, itir.FunCall) and isinstance(node.fun, itir.Lambda): - eligible_params = [] - for arg in node.args: - eligible_params.append(isinstance(arg, itir.SymRef) and arg.id.startswith("_cs")) + eligible_params = [isinstance(arg, itir.SymRef) for arg in node.args] if any(eligible_params): # note: the inline is opcount preserving anyway so avoid the additional # effort in the inliner by disabling opcount preservation. @@ -65,11 +82,13 @@ def _is_collectable_expr(node: itir.Node) -> bool: if isinstance(node, itir.FunCall): # do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be # visited, to ensure symbol dependencies are recognized correctly. - # do also not collect reduce nodes if they are left in the it at this point, this may lead to + # do also not collect reduce nodes if they are left in the IR at this point, this may lead to # conceptual problems (other parts of the tool chain rely on the arguments being present directly # on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend # backend (single pass eager depth first visit approach) - if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce"]: + # do also not collect lifts or applied lifts as they become invisible to the lift inliner + # otherwise + if cpm.is_call_to(node, ("lift", "shift", "reduce", "map_")) or cpm.is_applied_lift(node): return False return True elif isinstance(node, itir.Lambda): @@ -240,7 +259,6 @@ def extract_subexpression( Examples: Default case for `(x+y) + ((x+y)+z)`: - >>> import gt4py.next.iterator.ir_utils.ir_makers as im >>> from gt4py.eve.utils import UIDGenerator >>> expr = im.plus(im.plus("x", "y"), im.plus(im.plus("x", "y"), "z")) >>> predicate = lambda subexpr, num_occurences: num_occurences > 1 @@ -319,7 +337,7 @@ def extract_subexpression( subexprs = CollectSubexpressions.apply(node) # collect multiple occurrences and map them to fresh symbols - expr_map = dict[int, itir.SymRef]() + expr_map: dict[int, itir.SymRef] = {} ignored_ids = set() for expr, subexpr_entry in ( subexprs.items() if not deepest_expr_first else reversed(subexprs.items()) @@ -360,7 +378,7 @@ def extract_subexpression( return _NodeReplacer(expr_map).visit(node), extracted, ignored_children -ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr) +ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.Expr) @dataclasses.dataclass(frozen=True) @@ -372,7 +390,7 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): >>> x = itir.SymRef(id="x") >>> plus = lambda a, b: itir.FunCall(fun=itir.SymRef(id=("plus")), args=[a, b]) >>> expr = plus(plus(x, x), plus(x, x)) - >>> print(CommonSubexpressionElimination.apply(expr, is_local_view=True)) + >>> print(CommonSubexpressionElimination.apply(expr, within_stencil=True)) (λ(_cs_1) → _cs_1 + _cs_1)(x + x) The pass visits the tree top-down starting from the root node, e.g. an itir.Program. @@ -394,53 +412,61 @@ class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): def apply( cls, node: ProgramOrExpr, - is_local_view: bool | None = None, - offset_provider: common.OffsetProvider | None = None, + within_stencil: bool | None = None, + offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: - is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) + is_program = isinstance(node, itir.Program) if is_program: - assert is_local_view is None - is_local_view = False + assert within_stencil is None + within_stencil = False else: assert ( - is_local_view is not None - ), "The expression's context must be specified using `is_local_view`." + within_stencil is not None + ), "The expression's context must be specified using `within_stencil`." - offset_provider = offset_provider or {} + offset_provider_type = offset_provider_type or {} node = itir_type_inference.infer( - node, offset_provider=offset_provider, allow_undeclared_symbols=not is_program + node, offset_provider_type=offset_provider_type, allow_undeclared_symbols=not is_program ) - return cls().visit(node, is_local_view=is_local_view) + return cls().visit(node, within_stencil=within_stencil) def generic_visit(self, node, **kwargs): - if cpm.is_call_to("as_fieldop", node): - assert not kwargs.get("is_local_view") - is_local_view = cpm.is_call_to("as_fieldop", node) or kwargs.get("is_local_view") + if cpm.is_call_to(node, "as_fieldop"): + assert not kwargs.get("within_stencil") + within_stencil = cpm.is_call_to(node, "as_fieldop") or kwargs.get("within_stencil") - return super().generic_visit(node, **(kwargs | {"is_local_view": is_local_view})) + return super().generic_visit(node, **(kwargs | {"within_stencil": within_stencil})) def visit_FunCall(self, node: itir.FunCall, **kwargs): - is_local_view = kwargs["is_local_view"] + within_stencil = kwargs["within_stencil"] if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): return node def predicate(subexpr: itir.Expr, num_occurences: int): # note: be careful here with the syntatic context: the expression might be in local - # view, even though the syntactic context `node` is in field view. + # view, even though the syntactic context of `node` is in field view. # note: what is extracted is sketched in the docstring above. keep it updated. if num_occurences > 1: - if is_local_view: + if within_stencil: + # TODO(tehrengruber): Lists must not be extracted to avoid errors in partial + # shift detection of UnrollReduce pass. Solve there. See #1795. + if isinstance(subexpr.type, ts.ListType): + return False return True - else: + # condition is only necessary since typing on lambdas is not preserved during + # the transformation + elif not isinstance(subexpr, itir.Lambda): # only extract fields outside of `as_fieldop` # `as_fieldop(...)(field_expr, field_expr)` # -> `(λ(_cs_1) → as_fieldop(...)(_cs_1, _cs_1))(field_expr)` + # only extract if subexpression is not a trivial tuple expressions, e.g., + # `make_tuple(a, b)`, as this would result in a more costly temporary. assert isinstance(subexpr.type, ts.TypeSpec) if all( isinstance(stype, ts.FieldType) for stype in type_info.primitive_constituents(subexpr.type) - ): + ) and not _is_trivial_tuple_expr(subexpr): return True return False @@ -450,10 +476,8 @@ def predicate(subexpr: itir.Expr, num_occurences: int): return self.generic_visit(node, **kwargs) # apply remapping - result = itir.FunCall( - fun=itir.Lambda(params=list(extracted.keys()), expr=new_expr), - args=list(extracted.values()), - ) + result = im.let(*extracted.items())(new_expr) + itir_type_inference.copy_type(from_=node, to=result, allow_untyped=True) # if the node id is ignored (because its parent is eliminated), but it occurs # multiple times then we want to visit the final result once more. diff --git a/src/gt4py/next/iterator/transforms/extractors.py b/src/gt4py/next/iterator/transforms/extractors.py new file mode 100644 index 0000000000..04c2b09139 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/extractors.py @@ -0,0 +1,72 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts + + +class SymbolNameSetExtractor(eve.NodeVisitor): + """Extract a set of symbol names""" + + def visit_Literal(self, node: itir.Literal) -> set[str]: + return set() + + def generic_visitor(self, node: itir.Node) -> set[str]: + input_fields: set[str] = set() + for child in eve.trees.iter_children_values(node): + input_fields |= self.visit(child) + return input_fields + + def visit_Node(self, node: itir.Node) -> set[str]: + return set() + + def visit_Program(self, node: itir.Program) -> set[str]: + names = set() + for stmt in node.body: + names |= self.visit(stmt) + return names + + def visit_IfStmt(self, node: itir.IfStmt) -> set[str]: + names = set() + for stmt in node.true_branch + node.false_branch: + names |= self.visit(stmt) + return names + + def visit_Temporary(self, node: itir.Temporary) -> set[str]: + return set() + + def visit_SymRef(self, node: itir.SymRef) -> set[str]: + return {str(node.id)} + + @classmethod + def only_fields(cls, program: itir.Program) -> set[str]: + field_param_names = [ + str(param.id) for param in program.params if isinstance(param.type, ts.FieldType) + ] + return {name for name in cls().visit(program) if name in field_param_names} + + +class InputNamesExtractor(SymbolNameSetExtractor): + """Extract the set of symbol names passed into field operators within a program.""" + + def visit_SetAt(self, node: itir.SetAt) -> set[str]: + return self.visit(node.expr) + + def visit_FunCall(self, node: itir.FunCall) -> set[str]: + input_fields = set() + for arg in node.args: + input_fields |= self.visit(arg) + return input_fields + + +class OutputNamesExtractor(SymbolNameSetExtractor): + """Extract the set of symbol names written to within a program""" + + def visit_SetAt(self, node: itir.SetAt) -> set[str]: + return self.visit(node.target) diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py deleted file mode 100644 index db0b81a837..0000000000 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ /dev/null @@ -1,44 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py import eve -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms import global_tmps - - -class FencilToProgram(eve.NodeTranslator): - @classmethod - def apply( - cls, node: itir.FencilDefinition | global_tmps.FencilWithTemporaries | itir.Program - ) -> itir.Program: - return cls().visit(node) - - def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: - as_fieldop = im.call(im.call("as_fieldop")(node.stencil, node.domain))(*node.inputs) - return itir.SetAt(expr=as_fieldop, domain=node.domain, target=node.output) - - def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: - return itir.Program( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - declarations=[], - body=self.visit(node.closures), - implicit_domain=node.implicit_domain, - ) - - def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -> itir.Program: - return itir.Program( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=node.params, - declarations=node.tmps, - body=self.visit(node.fencil.closures), - implicit_domain=node.fencil.implicit_domain, - ) diff --git a/src/gt4py/next/iterator/transforms/fixed_point_transformation.py b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py new file mode 100644 index 0000000000..f1176b4bef --- /dev/null +++ b/src/gt4py/next/iterator/transforms/fixed_point_transformation.py @@ -0,0 +1,67 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +import enum +from typing import ClassVar, Optional, Type + +from gt4py import eve +from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as itir_type_inference + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class FixedPointTransformation(eve.NodeTranslator): + """ + Transformation pass that transforms until no transformation is applicable anymore. + """ + + #: Enum of all transformation (names). The transformations need to be defined as methods + #: named `transform_`. + Transformation: ClassVar[Type[enum.Flag]] + + #: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`. + #: Usually the default value is chosen to be all transformations. + enabled_transformations: enum.Flag + + def visit(self, node, **kwargs): + node = super().visit(node, **kwargs) + return self.fp_transform(node, **kwargs) if isinstance(node, ir.Node) else node + + def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: + """ + Transform node until a fixed point is reached, e.g. no transformation is applicable anymore. + """ + while True: + new_node = self.transform(node, **kwargs) + if new_node is None: + break + assert new_node != node + node = new_node + return node + + def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: + """ + Transform node once. + + Execute transformations until one is applicable. As soon as a transformation occured + the function will return the transformed node. Note that the transformation itself + may call other transformations on child nodes again. + """ + for transformation in self.Transformation: + if self.enabled_transformations & transformation: + assert isinstance(transformation.name, str) + method = getattr(self, f"transform_{transformation.name.lower()}") + result = method(node, **kwargs) + if result is not None: + assert ( + result is not node + ), f"Transformation {transformation.name.lower()} should have returned None, since nothing changed." + itir_type_inference.reinfer(result) + return result + return None diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py new file mode 100644 index 0000000000..81633dfb87 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -0,0 +1,488 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + +import dataclasses +import enum +import functools +import operator +from typing import Optional + +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.iterator.transforms import ( + fixed_point_transformation, + inline_center_deref_lift_vars, + inline_lambdas, + inline_lifts, + merge_let, + trace_shifts, +) +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_info, type_specifications as ts + + +def _merge_arguments( + args1: dict[str, itir.Expr], arg2: dict[str, itir.Expr] +) -> dict[str, itir.Expr]: + new_args = {**args1} + for stencil_param, stencil_arg in arg2.items(): + if stencil_param not in new_args: + new_args[stencil_param] = stencil_arg + else: + assert new_args[stencil_param] == stencil_arg + return new_args + + +def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: + """ + Canonicalize applied `as_fieldop`s. + + In case the stencil argument is a `deref` wrap it into a lambda such that we have a unified + format to work with (e.g. each parameter has a name without the need to special case). + """ + assert cpm.is_applied_as_fieldop(expr) + + stencil = expr.fun.args[0] # type: ignore[attr-defined] + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] + if cpm.is_ref_to(stencil, "deref"): + stencil = im.lambda_("arg")(im.deref("arg")) + new_expr = im.as_fieldop(stencil, domain)(*expr.args) + + return new_expr + + return expr + + +def _is_tuple_expr_of_literals(expr: itir.Expr): + if cpm.is_call_to(expr, "make_tuple"): + return all(_is_tuple_expr_of_literals(arg) for arg in expr.args) + if cpm.is_call_to(expr, "tuple_get"): + return _is_tuple_expr_of_literals(expr.args[1]) + return isinstance(expr, itir.Literal) + + +def _inline_as_fieldop_arg( + arg: itir.Expr, *, uids: eve_utils.UIDGenerator +) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + if inner_arg.id in extracted_args: + assert extracted_args[inner_arg.id] == inner_arg + alias = stencil_params[list(extracted_args.keys()).index(inner_arg.id)] + stencil_body = im.let(inner_param, im.ref(alias.id))(stencil_body) + else: + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + +def _unwrap_scan(stencil: itir.Lambda | itir.FunCall): + """ + If given a scan, extract stencil part of its scan pass and a back-transformation into a scan. + + If a regular stencil is given the stencil is left as-is and the back-transformation is the + identity function. This function allows treating a scan stencil like a regular stencil during + a transformation avoiding the complexity introduced by the different IR format. + + >>> scan = im.call("scan")( + ... im.lambda_("state", "arg")(im.plus("state", im.deref("arg"))), True, 0.0 + ... ) + >>> stencil, back_trafo = _unwrap_scan(scan) + >>> str(stencil) + 'λ(arg) → state + ·arg' + >>> str(back_trafo(stencil)) + 'scan(λ(state, arg) → (λ(arg) → state + ·arg)(arg), True, 0.0)' + + In case a regular stencil is given it is returned as-is: + + >>> deref_stencil = im.lambda_("it")(im.deref("it")) + >>> stencil, back_trafo = _unwrap_scan(deref_stencil) + >>> assert stencil == deref_stencil + """ + if cpm.is_call_to(stencil, "scan"): + scan_pass, direction, init = stencil.args + assert isinstance(scan_pass, itir.Lambda) + # remove scan pass state to be used by caller + state_param = scan_pass.params[0] + stencil_like = im.lambda_(*scan_pass.params[1:])(scan_pass.expr) + + def restore_scan(transformed_stencil_like: itir.Lambda): + new_scan_pass = im.lambda_(state_param, *transformed_stencil_like.params)( + im.call(transformed_stencil_like)( + *(param.id for param in transformed_stencil_like.params) + ) + ) + return im.call("scan")(new_scan_pass, direction, init) + + return stencil_like, restore_scan + + assert isinstance(stencil, itir.Lambda) + return stencil, lambda s: s + + +def fuse_as_fieldop( + expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator +) -> itir.Expr: + assert cpm.is_applied_as_fieldop(expr) + + stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + assert isinstance(expr.fun.args[0], itir.Lambda) or cpm.is_call_to(stencil, "scan") # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + stencil, restore_scan = _unwrap_scan(stencil) + + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + args: list[itir.Expr] = expr.args + + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr + + for eligible, stencil_param, arg in zip(eligible_args, stencil.params, args, strict=True): + if eligible: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + # transform scalar `if` into per-grid-point `if` + # TODO(tehrengruber): revisit if we want to inline if_ + arg = im.op_as_fieldop("if_")(*arg.args) + elif _is_tuple_expr_of_literals(arg): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = _inline_as_fieldop_arg(arg, uids=uids) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = _merge_arguments(new_args, extracted_args) + else: + # just a safety check if typing information is available + type_inference.reinfer(arg) + if arg.type and not isinstance(arg.type, ts.DeferredType): + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + assert not isinstance(dtype, ts.ListType) + new_param: str + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = _merge_arguments(new_args, {new_param: arg}) + + stencil = im.lambda_(*new_args.keys())(new_stencil_body) + stencil = restore_scan(stencil) + + # simplify stencil directly to keep the tree small + new_stencil = inline_lambdas.InlineLambdas.apply( + stencil, opcount_preserving=True, force_inline_lift_args=False + ) + new_stencil = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_stencil, is_stencil=True, uids=uids + ) # to keep the tree small + new_stencil = merge_let.MergeLet().visit(new_stencil) + new_stencil = inline_lambdas.InlineLambdas.apply( + new_stencil, opcount_preserving=True, force_inline_lift_args=True + ) + new_stencil = inline_lifts.InlineLifts().visit(new_stencil) + + new_node = im.as_fieldop(new_stencil, domain)(*new_args.values()) + + return new_node + + +def _arg_inline_predicate(node: itir.Expr, shifts: set[tuple[itir.OffsetLiteral, ...]]) -> bool: + if _is_tuple_expr_of_literals(node): + return True + + if ( + is_applied_fieldop := cpm.is_applied_as_fieldop(node) + and not cpm.is_call_to(node.fun.args[0], "scan") # type: ignore[attr-defined] # ensured by cpm.is_applied_as_fieldop + ) or cpm.is_call_to(node, "if_"): + # always inline arg if it is an applied fieldop with only a single arg + if is_applied_fieldop and len(node.args) == 1: + return True + # argument is never used, will be removed when inlined + if len(shifts) == 0: + return True + # applied fieldop with list return type must always be inlined as no backend supports this + type_inference.reinfer(node) + assert isinstance(node.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, node.type) + if isinstance(dtype, ts.ListType): + return True + # only accessed at the center location + if shifts in [set(), {()}]: + return True + # TODO(tehrengruber): Disabled as the InlineCenterDerefLiftVars does not support this yet + # and it would increase the size of the tree otherwise. + # if len(shifts) == 1 and not any( + # trace_shifts.Sentinel.ALL_NEIGHBORS in access for access in shifts + # ): + # return True # noqa: ERA001 [commented-out-code] + + return False + + +def _make_tuple_element_inline_predicate(node: itir.Expr): + if cpm.is_applied_as_fieldop(node): # field, or tuple of fields + return True + if isinstance(node.type, ts.FieldType) and isinstance(node, itir.SymRef): + return True + return False + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class FuseAsFieldOp( + fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor +): + """ + Merge multiple `as_fieldop` calls into one. + + >>> from gt4py import next as gtx + >>> from gt4py.next.iterator.ir_utils import ir_makers as im + >>> IDim = gtx.Dimension("IDim") + >>> field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + >>> d = im.domain("cartesian_domain", {IDim: (0, 1)}) + >>> nested_as_fieldop = im.op_as_fieldop("plus", d)( + ... im.op_as_fieldop("multiplies", d)( + ... im.ref("inp1", field_type), im.ref("inp2", field_type) + ... ), + ... im.ref("inp3", field_type), + ... ) + >>> print(nested_as_fieldop) + as_fieldop(λ(__arg0, __arg1) → ·__arg0 + ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)( + as_fieldop(λ(__arg0, __arg1) → ·__arg0 × ·__arg1, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2), inp3 + ) + >>> print( + ... FuseAsFieldOp.apply( + ... nested_as_fieldop, offset_provider_type={}, allow_undeclared_symbols=True + ... ) + ... ) + as_fieldop(λ(inp1, inp2, inp3) → ·inp1 × ·inp2 + ·inp3, c⟨ IDimₕ: [0, 1[ ⟩)(inp1, inp2, inp3) + """ # noqa: RUF002 # ignore ambiguous multiplication character + + class Transformation(enum.Flag): + #: Let `f_expr` be an expression with list dtype then + #: `let(f, f_expr) -> as_fieldop(...)(f)` -> `as_fieldop(...)(f_expr)` + FUSE_MAKE_TUPLE = enum.auto() + #: `as_fieldop(...)(as_fieldop(...)(a, b), c)` + #: -> as_fieldop(fused_stencil)(a, b, c) + FUSE_AS_FIELDOP = enum.auto() + INLINE_LET_VARS_OPCOUNT_PRESERVING = enum.auto() + + @classmethod + def all(self) -> FuseAsFieldOp.Transformation: + return functools.reduce(operator.or_, self.__members__.values()) + + PRESERVED_ANNEX_ATTRS = ("domain",) + + enabled_transformations = Transformation.all() + + uids: eve_utils.UIDGenerator + + @classmethod + def apply( + cls, + node: itir.Program, + *, + offset_provider_type: common.OffsetProviderType, + uids: Optional[eve_utils.UIDGenerator] = None, + allow_undeclared_symbols=False, + within_set_at_expr: Optional[bool] = None, + enabled_transformations: Optional[Transformation] = None, + ): + enabled_transformations = enabled_transformations or cls.enabled_transformations + + node = type_inference.infer( + node, + offset_provider_type=offset_provider_type, + allow_undeclared_symbols=allow_undeclared_symbols, + ) + + if within_set_at_expr is None: + within_set_at_expr = not isinstance(node, itir.Program) + + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids, enabled_transformations=enabled_transformations).visit( + node, within_set_at_expr=within_set_at_expr + ) + + def transform_fuse_make_tuple(self, node: itir.Node, **kwargs): + if not cpm.is_call_to(node, "make_tuple"): + return None + + for arg in node.args: + type_inference.reinfer(arg) + assert not isinstance(arg.type, ts.FieldType) or ( + hasattr(arg.annex, "domain") + and isinstance(arg.annex.domain, domain_utils.SymbolicDomain) + ) + + eligible_els = [_make_tuple_element_inline_predicate(arg) for arg in node.args] + field_args = [arg for i, arg in enumerate(node.args) if eligible_els[i]] + distinct_domains = set(arg.annex.domain.as_expr() for arg in field_args) + if len(distinct_domains) != len(field_args): + new_els: list[itir.Expr | None] = [None for _ in node.args] + field_args_by_domain: dict[itir.FunCall, list[tuple[int, itir.Expr]]] = {} + for i, arg in enumerate(node.args): + if eligible_els[i]: + assert isinstance(arg.annex.domain, domain_utils.SymbolicDomain) + domain = arg.annex.domain.as_expr() + field_args_by_domain.setdefault(domain, []) + field_args_by_domain[domain].append((i, arg)) + else: + new_els[i] = arg # keep as is + + if len(field_args_by_domain) == 1 and all(eligible_els): + # if we only have a single domain covering all args we don't need to create an + # unnecessary let + ((domain, inner_field_args),) = field_args_by_domain.items() + new_node = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( + *(arg for _, arg in inner_field_args) + ) + new_node = self.visit(new_node, **{**kwargs, "recurse": False}) + else: + let_vars = {} + for domain, inner_field_args in field_args_by_domain.items(): + if len(inner_field_args) > 1: + var = self.uids.sequential_id(prefix="__fasfop") + fused_args = im.op_as_fieldop(lambda *args: im.make_tuple(*args), domain)( + *(arg for _, arg in inner_field_args) + ) + type_inference.reinfer(arg) + # don't recurse into nested args, but only consider newly created `as_fieldop` + # note: this will always inline (as we inline center accessed) + let_vars[var] = self.visit(fused_args, **{**kwargs, "recurse": False}) + for outer_tuple_idx, (inner_tuple_idx, _) in enumerate(inner_field_args): + new_el = im.tuple_get(outer_tuple_idx, var) + new_el.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + new_els[inner_tuple_idx] = new_el + else: + i, arg = inner_field_args[0] + new_els[i] = arg + assert not any(el is None for el in new_els) + assert let_vars + new_node = im.let(*let_vars.items())(im.make_tuple(*new_els)) + new_node = inline_lambdas.inline_lambda(new_node, opcount_preserving=True) + return new_node + return None + + def transform_fuse_as_fieldop(self, node: itir.Node, **kwargs): + if cpm.is_applied_as_fieldop(node): + node = _canonicalize_as_fieldop(node) + stencil = node.fun.args[0] # type: ignore[attr-defined] # ensure cpm.is_applied_as_fieldop + assert isinstance(stencil, itir.Lambda) or cpm.is_call_to(stencil, "scan") + args: list[itir.Expr] = node.args + shifts = trace_shifts.trace_stencil(stencil, num_args=len(args)) + + eligible_els = [ + _arg_inline_predicate(arg, arg_shifts) + for arg, arg_shifts in zip(args, shifts, strict=True) + ] + if any(eligible_els): + return self.visit( + fuse_as_fieldop(node, eligible_els, uids=self.uids), + **{**kwargs, "recurse": False}, + ) + return None + + def transform_inline_let_vars_opcount_preserving(self, node: itir.Node, **kwargs): + # when multiple `as_fieldop` calls are fused that use the same argument, this argument + # might become referenced once only. In order to be able to continue fusing such arguments + # try inlining here. + if cpm.is_let(node): + new_node = inline_lambdas.inline_lambda(node, opcount_preserving=True) + if new_node is not node: # nothing has been inlined + return self.visit(new_node, **kwargs) + + return None + + def generic_visit(self, node, **kwargs): + if cpm.is_applied_as_fieldop(node): # don't descend in stencil + return im.as_fieldop(*node.fun.args)(*self.visit(node.args, **kwargs)) + + # TODO(tehrengruber): This is a common pattern that should be absorbed in + # `FixedPointTransformation`. + if kwargs.get("recurse", True): + return super().generic_visit(node, **kwargs) + else: + return node + + def visit(self, node, **kwargs): + if isinstance(node, itir.SetAt): + return itir.SetAt( + expr=self.visit(node.expr, **kwargs | {"within_set_at_expr": True}), + # rest doesn't need to be visited + domain=node.domain, + target=node.target, + ) + + # don't execute transformations unless inside `SetAt` node + if not kwargs.get("within_set_at_expr"): + return self.generic_visit(node, **kwargs) + + # inline all fields with list dtype. This needs to happen before the children are visited + # such that the `as_fieldop` can be fused. + # TODO(tehrengruber): what should we do in case the field with list dtype is a let itself? + # This could duplicate other expressions which we did not intend to duplicate. + # TODO(tehrengruber): This should be moved into a `transform_` method, but + # `FixedPointTransformation` does not support pre-order transformations yet. + if cpm.is_let(node): + for arg in node.args: + type_inference.reinfer(arg) + eligible_els = [ + isinstance(arg.type, ts.FieldType) and isinstance(arg.type.dtype, ts.ListType) + for arg in node.args + ] + if any(eligible_els): + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_els) + return self.visit(node, **kwargs) + + node = super().visit(node, **kwargs) + + if isinstance(node, itir.Expr) and hasattr(node.annex, "domain"): + node.annex.domain = node.annex.domain + + return node diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index 430d794880..8d27178682 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import TypeGuard from gt4py.eve import NodeTranslator, traits from gt4py.eve.utils import UIDGenerator @@ -16,14 +15,6 @@ from gt4py.next.iterator.transforms import inline_lambdas -def _is_map(node: ir.Node) -> TypeGuard[ir.FunCall]: - return ( - isinstance(node, ir.FunCall) - and isinstance(node.fun, ir.FunCall) - and node.fun.fun == ir.SymRef(id="map_") - ) - - @dataclasses.dataclass(frozen=True) class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ @@ -58,10 +49,10 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: def visit_FunCall(self, node: ir.FunCall, **kwargs): node = self.generic_visit(node) - if _is_map(node) or cpm.is_applied_reduce(node): - if any(_is_map(arg) for arg in node.args): + if cpm.is_applied_map(node) or cpm.is_applied_reduce(node): + if any(cpm.is_applied_map(arg) for arg in node.args): first_param = ( - 0 if _is_map(node) else 1 + 0 if cpm.is_applied_map(node) else 1 ) # index of the first param of op that maps to args (0 for map, 1 for reduce) assert isinstance(node.fun, ir.FunCall) assert isinstance(node.fun.args[0], (ir.Lambda, ir.SymRef)) @@ -76,7 +67,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_params.append(outer_op.params[0]) for i in range(len(node.args)): - if _is_map(node.args[i]): + if cpm.is_applied_map(node.args[i]): map_call = node.args[i] assert isinstance(map_call, ir.FunCall) assert isinstance(map_call.fun, ir.FunCall) @@ -102,7 +93,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_body ) # removes one level of nesting (the recursive inliner could simplify more, however this can also be done on the full tree later) new_op = ir.Lambda(params=new_params, expr=new_body) - if _is_map(node): + if cpm.is_applied_map(node): return ir.FunCall( fun=ir.FunCall(fun=ir.SymRef(id="map_"), args=[new_op]), args=new_args ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 5a6873f916..ac7fcb8f1c 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -8,582 +8,202 @@ from __future__ import annotations -import copy -import dataclasses -from collections.abc import Mapping -from typing import Any, Callable, Final, Iterable, Literal, Optional, Sequence - -import gt4py.next as gtx -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.traits import SymbolTableTrait -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.ir_utils.domain_utils import ( - SymbolicDomain, - SymbolicRange, - _max_domain_sizes_by_location_type, - domain_union, -) -from gt4py.next.iterator.pretty_printer import PrettyPrinter -from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.iterator.transforms.cse import extract_subexpression -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs -from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs -from gt4py.next.iterator.type_system import ( - inference as itir_type_inference, - type_specifications as it_ts, +import functools +from typing import Callable, Optional + +from gt4py.eve import utils as eve_utils +from gt4py.next import common, utils as next_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, ) -from gt4py.next.type_system import type_specifications as ts - - -"""Iterator IR extension for global temporaries. - -Replaces lifted function calls by temporaries using the following steps: -1. Split closures by popping up lifted function calls to the top of the expression tree, (that is, - to stencil arguments) and then extracting them as new closures. -2. Introduces a new fencil-scope variable (the temporary) for each output of newly created closures. - The domain size is set to a new symbol `_gtmp_auto_domain`. -3. Infer the domain sizes for the new closures by analysing the accesses/shifts within all closures - and replace all occurrences of `_gtmp_auto_domain` by concrete domain sizes. -4. Infer the data type and size of the temporary buffers. -""" - - -AUTO_DOMAIN: Final = ir.FunCall(fun=ir.SymRef(id="_gtmp_auto_domain"), args=[]) - - -# Iterator IR extension nodes - - -class FencilWithTemporaries( - ir.Node, SymbolTableTrait -): # TODO(havogt): remove and use new `itir.Program` instead. - """Iterator IR extension: declaration of a fencil with temporary buffers.""" - - fencil: ir.FencilDefinition - params: list[ir.Sym] - tmps: list[ir.Temporary] - - -# Extensions for `PrettyPrinter` for easier debugging - - -def pformat_FencilWithTemporaries( - printer: PrettyPrinter, node: FencilWithTemporaries, *, prec: int -) -> list[str]: - assert prec == 0 - params = printer.visit(node.params, prec=0) - fencil = printer.visit(node.fencil, prec=0) - tmps = printer.visit(node.tmps, prec=0) - args = params + [[tmp.id] for tmp in node.tmps] - - hparams = printer._hmerge([node.fencil.id + "("], *printer._hinterleave(params, ", "), [") {"]) - vparams = printer._vmerge( - [node.fencil.id + "("], *printer._hinterleave(params, ",", indent=True), [") {"] +from gt4py.next.iterator.transforms import cse, infer_domain, inline_lambdas +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_info, type_specifications as ts + + +def _transform_if( + stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator +) -> Optional[list[itir.Stmt]]: + if isinstance(stmt, itir.SetAt) and cpm.is_call_to(stmt.expr, "if_"): + cond, true_val, false_val = stmt.expr.args + return [ + itir.IfStmt( + cond=cond, + true_branch=_transform_stmt( + itir.SetAt(target=stmt.target, expr=true_val, domain=stmt.domain), + declarations, + uids, + ), + false_branch=_transform_stmt( + itir.SetAt(target=stmt.target, expr=false_val, domain=stmt.domain), + declarations, + uids, + ), + ) + ] + return None + + +def _transform_by_pattern( + stmt: itir.Stmt, + predicate: Callable[[itir.Expr, int], bool], + declarations: list[itir.Temporary], + uids: eve_utils.UIDGenerator, +) -> Optional[list[itir.Stmt]]: + if not isinstance(stmt, itir.SetAt): + return None + + new_expr, extracted_fields, _ = cse.extract_subexpression( + stmt.expr, + predicate=predicate, + uid_generator=eve_utils.UIDGenerator(prefix="__tmp_subexpr"), + # TODO(tehrengruber): extracting the deepest expression first would allow us to fuse + # the extracted expressions resulting in fewer kernel calls & better data-locality. + # Extracting multiple expressions deepest-first is however not supported right now. + # deepest_expr_first=True # noqa: ERA001 ) - params = printer._optimum(hparams, vparams) - - hargs = printer._hmerge(*printer._hinterleave(args, ", ")) - vargs = printer._vmerge(*printer._hinterleave(args, ",")) - args = printer._optimum(hargs, vargs) - - fencil = printer._hmerge(fencil, [";"]) - - hcall = printer._hmerge([node.fencil.id + "("], args, [");"]) - vcall = printer._vmerge(printer._hmerge([node.fencil.id + "("]), printer._indent(args), [");"]) - call = printer._optimum(hcall, vcall) - - body = printer._vmerge(*tmps, fencil, call) - return printer._vmerge(params, printer._indent(body), ["}"]) - - -PrettyPrinter.visit_FencilWithTemporaries = pformat_FencilWithTemporaries # type: ignore - - -# Main implementation -def canonicalize_applied_lift(closure_params: list[str], node: ir.FunCall) -> ir.FunCall: - """ - Canonicalize applied lift expressions. - - Transform lift such that the arguments to the applied lift are only symbols. - >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) - >>> it_type = it_ts.IteratorType(position_dims=[], defined_dims=[], element_type=bool_type) - >>> expr = im.lift(im.lambda_("a")(im.deref("a")))(im.lift("deref")(im.ref("inp", it_type))) - >>> print(expr) - (↑(λ(a) → ·a))((↑deref)(inp)) - >>> print(canonicalize_applied_lift(["inp"], expr)) - (↑(λ(inp) → (λ(a) → ·a)((↑deref)(inp))))(inp) - """ - assert cpm.is_applied_lift(node) - stencil = node.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied lift - it_args = node.args - if any(not isinstance(it_arg, ir.SymRef) for it_arg in it_args): - closure_param_refs = collect_symbol_refs(node, as_ref=True) - assert not ({str(ref.id) for ref in closure_param_refs} - set(closure_params)) - new_node = im.lift( - im.lambda_(*[im.sym(param.id) for param in closure_param_refs])( - im.call(stencil)(*it_args) + if extracted_fields: + tmp_stmts: list[itir.Stmt] = [] + + # for each extracted expression generate: + # - one or more `Temporary` declarations (depending on whether the expression is a field + # or a tuple thereof) + # - one `SetAt` statement that materializes the expression into the temporary + for tmp_sym, tmp_expr in extracted_fields.items(): + domain: infer_domain.DomainAccess = tmp_expr.annex.domain + + # TODO(tehrengruber): Implement. This happens when the expression is a combination + # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are + # able to eliminate all tuples, e.g., by propagating the scalar ifs to the top-level + # of a SetAt, the CollapseTuple pass will eliminate most of this cases. + if isinstance(domain, tuple): + flattened_domains: tuple[domain_utils.SymbolicDomain] = ( + next_utils.flatten_nested_tuple(domain) # type: ignore[assignment] # mypy not smart enough + ) + if not all(d == flattened_domains[0] for d in flattened_domains): + raise NotImplementedError( + "Tuple expressions with different domains is not supported yet." + ) + domain = flattened_domains[0] + assert isinstance(domain, domain_utils.SymbolicDomain) + domain_expr = domain.as_expr() + + assert isinstance(tmp_expr.type, ts.TypeSpec) + tmp_names: str | tuple[str | tuple, ...] = type_info.apply_to_primitive_constituents( + lambda x: uids.sequential_id(), + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), + ) + tmp_dtypes: ( + ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...] + ) = type_info.apply_to_primitive_constituents( + type_info.extract_dtype, + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), ) - )(*closure_param_refs) - # ensure all types are inferred - return itir_type_inference.infer( - new_node, inplace=True, allow_undeclared_symbols=True, offset_provider={} - ) - return node - - -@dataclasses.dataclass(frozen=True) -class TemporaryExtractionPredicate: - """ - Construct a callable that determines if a lift expr can and should be extracted to a temporary. - - The class optionally takes a heuristic that can restrict the extraction. - """ - - heuristics: Optional[Callable[[ir.Expr], bool]] = None - - def __call__(self, expr: ir.Expr, num_occurences: int) -> bool: - """Determine if `expr` is an applied lift that should be extracted as a temporary.""" - if not cpm.is_applied_lift(expr): - return False - # do not extract when the result is a list (i.e. a lift expression used in a `reduce` call) - # as we can not create temporaries for these stencils - assert isinstance(expr.type, it_ts.IteratorType) - if isinstance(expr.type.element_type, it_ts.ListType): - return False - if self.heuristics and not self.heuristics(expr): - return False - stencil = expr.fun.args[0] # type: ignore[attr-defined] # ensured by `is_applied_lift` - # do not extract when the stencil is capturing - used_symbols = collect_symbol_refs(stencil) - if used_symbols: - return False - return True - - -@dataclasses.dataclass(frozen=True) -class SimpleTemporaryExtractionHeuristics: - """ - Heuristic that extracts only if a lift expr is derefed in more than one position. - - Note that such expression result in redundant computations if inlined instead of being - placed into a temporary. - """ - - closure: ir.StencilClosure - - def __post_init__(self) -> None: - trace_shifts.trace_stencil( - self.closure.stencil, num_args=len(self.closure.inputs), save_to_annex=True - ) - - def __call__(self, expr: ir.Expr) -> bool: - shifts = expr.annex.recorded_shifts - if len(shifts) > 1: - return True - return False - - -def _closure_parameter_argument_mapping(closure: ir.StencilClosure) -> dict[str, ir.Expr]: - """ - Create a mapping from the closures parameters to the closure arguments. - - E.g. for the closure `out ← (λ(param) → ...)(arg) @ u⟨ ... ⟩;` we get a mapping from `param` - to `arg`. In case the stencil is a scan, a mapping from closure inputs to scan pass (i.e. first - arg is ignored) is returned. - """ - is_scan = cpm.is_call_to(closure.stencil, "scan") - - if is_scan: - stencil = closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan - return { - param.id: arg for param, arg in zip(stencil.params[1:], closure.inputs, strict=True) - } - else: - assert isinstance(closure.stencil, ir.Lambda) - return { - param.id: arg for param, arg in zip(closure.stencil.params, closure.inputs, strict=True) - } - - -def _ensure_expr_does_not_capture(expr: ir.Expr, whitelist: list[ir.Sym]) -> None: - used_symbol_refs = collect_symbol_refs(expr) - assert not (set(used_symbol_refs) - {param.id for param in whitelist}) - - -def split_closures( - node: ir.FencilDefinition, - offset_provider: common.OffsetProvider, - *, - extraction_heuristics: Optional[ - Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] - ] = None, -) -> FencilWithTemporaries: - """Split closures on lifted function calls and introduce new temporary buffers for return values. - - Newly introduced temporaries will have the symbolic size of `AUTO_DOMAIN`. A symbol with the - same name is also added as a fencil argument (to be replaced at a later stage). - - For each closure, follows these steps: - 1. Pops up lifted function calls to the top of the expression tree. - 2. Introduce new temporary for the output. - 3. Extract lifted function class as new closures with the previously created temporary as output. - The closures are processed in reverse order to properly respect the dependencies. - """ - if not extraction_heuristics: - # extract all (eligible) lifts - def always_extract_heuristics(_: ir.StencilClosure) -> Callable[[ir.Expr], bool]: - return lambda _: True - - extraction_heuristics = always_extract_heuristics - - uid_gen_tmps = UIDGenerator(prefix="_tmp") - - node = itir_type_inference.infer(node, offset_provider=offset_provider) - - tmps: list[tuple[str, ts.DataType]] = [] - closures: list[ir.StencilClosure] = [] - for closure in reversed(node.closures): - closure_stack: list[ir.StencilClosure] = [closure] - while closure_stack: - current_closure: ir.StencilClosure = closure_stack.pop() + # allocate temporary for all tuple elements + def allocate_temporary(tmp_name: str, dtype: ts.ScalarType): + declarations.append(itir.Temporary(id=tmp_name, domain=domain_expr, dtype=dtype)) # noqa: B023 # function only used inside loop - if ( - isinstance(current_closure.stencil, ir.SymRef) - and current_closure.stencil.id == "deref" - ): - closures.append(current_closure) - continue + next_utils.tree_map(allocate_temporary)(tmp_names, tmp_dtypes) - is_scan: bool = cpm.is_call_to(current_closure.stencil, "scan") - current_closure_stencil = ( - current_closure.stencil if not is_scan else current_closure.stencil.args[0] # type: ignore[attr-defined] # ensured by is_scan - ) + # if the expr is a field this just gives a simple `itir.SymRef`, otherwise we generate a + # `make_tuple` expression. + target_expr: itir.Expr = next_utils.tree_map( + lambda x: im.ref(x), result_collection_constructor=lambda els: im.make_tuple(*els) + )(tmp_names) # type: ignore[assignment] # typing of tree_map does not reflect action of `result_collection_constructor` yet - extraction_predicate = TemporaryExtractionPredicate( - extraction_heuristics(current_closure) + # note: the let would be removed automatically by the `cse.extract_subexpression`, but + # we remove it here for readability & debuggability. + new_expr = inline_lambdas.inline_lambda( + im.let(tmp_sym, target_expr)(new_expr), opcount_preserving=False ) - stencil_body, extracted_lifts, _ = extract_subexpression( - current_closure_stencil.expr, - extraction_predicate, - uid_gen_tmps, - once_only=True, - deepest_expr_first=True, + # TODO(tehrengruber): _transform_stmt not needed if deepest_expr_first=True + tmp_stmts.extend( + _transform_stmt( + itir.SetAt(target=target_expr, domain=domain_expr, expr=tmp_expr), + declarations, + uids, + ) ) - if extracted_lifts: - for tmp_sym, lift_expr in extracted_lifts.items(): - # make sure the applied lift is not capturing anything except of closure params - _ensure_expr_does_not_capture(lift_expr, current_closure_stencil.params) - - assert isinstance(lift_expr, ir.FunCall) and isinstance( - lift_expr.fun, ir.FunCall - ) - - # make sure the arguments to the applied lift are only symbols - if not all(isinstance(arg, ir.SymRef) for arg in lift_expr.args): - lift_expr = canonicalize_applied_lift( - [str(param.id) for param in current_closure_stencil.params], lift_expr - ) - assert all(isinstance(arg, ir.SymRef) for arg in lift_expr.args) - - # create a mapping from the closures parameters to the closure arguments - closure_param_arg_mapping = _closure_parameter_argument_mapping(current_closure) + return [*tmp_stmts, itir.SetAt(target=stmt.target, domain=stmt.domain, expr=new_expr)] + return None - # usually an ir.Lambda or scan - stencil: ir.Node = lift_expr.fun.args[0] # type: ignore[attr-defined] # ensured by canonicalize_applied_lift - # allocate a new temporary - assert isinstance(stencil.type, ts.FunctionType) - assert isinstance(stencil.type.returns, ts.DataType) - tmps.append((tmp_sym.id, stencil.type.returns)) - - # create a new closure that executes the stencil of the applied lift and - # writes the result to the newly created temporary - closure_stack.append( - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=stencil, - output=im.ref(tmp_sym.id), - inputs=[ - closure_param_arg_mapping[param.id] # type: ignore[attr-defined] - for param in lift_expr.args - ], - location=current_closure.location, - ) - ) - - new_stencil: ir.Lambda | ir.FunCall - # create a new stencil where all applied lifts that have been extracted are - # replaced by references to the respective temporary - new_stencil = ir.Lambda( - params=current_closure_stencil.params + list(extracted_lifts.keys()), - expr=stencil_body, - ) - # if we are extracting from an applied scan we have to wrap the scan pass again, - # i.e. transform `λ(state, ...) → ...` into `scan(λ(state, ...) → ..., ...)` - if is_scan: - new_stencil = im.call("scan")(new_stencil, current_closure.stencil.args[1:]) # type: ignore[attr-defined] # ensure by is_scan - # inline such that let statements which are just rebinding temporaries disappear - new_stencil = InlineLambdas.apply( - new_stencil, opcount_preserving=True, force_inline_lift_args=False - ) - # we're done with the current closure, add it back to the stack for further - # extraction. - closure_stack.append( - ir.StencilClosure( - domain=current_closure.domain, - stencil=new_stencil, - output=current_closure.output, - inputs=current_closure.inputs - + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], - location=current_closure.location, - ) - ) - else: - closures.append(current_closure) +def _transform_stmt( + stmt: itir.Stmt, declarations: list[itir.Temporary], uids: eve_utils.UIDGenerator +) -> list[itir.Stmt]: + unprocessed_stmts: list[itir.Stmt] = [stmt] + stmts: list[itir.Stmt] = [] - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.id, - function_definitions=node.function_definitions, - params=node.params + [im.sym(name) for name, _ in tmps] + [im.sym(AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant - closures=list(reversed(closures)), - location=node.location, - implicit_domain=node.implicit_domain, + transforms: list[Callable] = [ + # transform `if_` call into `IfStmt` + _transform_if, + # extract applied `as_fieldop` to top-level + functools.partial( + _transform_by_pattern, predicate=lambda expr, _: cpm.is_applied_as_fieldop(expr) ), - params=node.params, - tmps=[ir.Temporary(id=name, dtype=type_) for name, type_ in tmps], - ) - - -def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporaries: - """Remove temporaries that are never read.""" - unused_tmps = {tmp.id for tmp in node.tmps} - for closure in node.fencil.closures: - unused_tmps -= {inp.id for inp in closure.inputs} - - if not unused_tmps: - return node - - closures = [ - closure - for closure in node.fencil.closures - if not (isinstance(closure.output, ir.SymRef) and closure.output.id in unused_tmps) - ] - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=[p for p in node.fencil.params if p.id not in unused_tmps], - closures=closures, - location=node.fencil.location, + # extract if_ call to the top-level + functools.partial( + _transform_by_pattern, predicate=lambda expr, _: cpm.is_call_to(expr, "if_") ), - params=node.params, - tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], - ) - - -def _group_offsets( - offset_literals: Sequence[ir.OffsetLiteral], -) -> Sequence[tuple[str, int | Literal[trace_shifts.Sentinel.ALL_NEIGHBORS]]]: - tags = [tag.value for tag in offset_literals[::2]] - offsets = [ - offset.value if isinstance(offset, ir.OffsetLiteral) else offset - for offset in offset_literals[1::2] ] - assert all(isinstance(tag, str) for tag in tags) - assert all( - isinstance(offset, int) or offset == trace_shifts.Sentinel.ALL_NEIGHBORS - for offset in offsets - ) - return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly - - -def update_domains( - node: FencilWithTemporaries, - offset_provider: Mapping[str, Any], - symbolic_sizes: Optional[dict[str, str]], -) -> FencilWithTemporaries: - horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - closures: list[ir.StencilClosure] = [] - domains = dict[str, ir.FunCall]() - for closure in reversed(node.fencil.closures): - if closure.domain == AUTO_DOMAIN: - # every closure with auto domain should have a single out field - assert isinstance(closure.output, ir.SymRef) - if closure.output.id not in domains: - raise NotImplementedError(f"Closure output '{closure.output.id}' is never used.") + while unprocessed_stmts: + stmt = unprocessed_stmts.pop(0) - domain = domains[closure.output.id] + did_transform = False + for transform in transforms: + transformed_stmts = transform(stmt=stmt, declarations=declarations, uids=uids) + if transformed_stmts: + unprocessed_stmts = [*transformed_stmts, *unprocessed_stmts] + did_transform = True + break - closure = ir.StencilClosure( - domain=copy.deepcopy(domain), - stencil=closure.stencil, - output=closure.output, - inputs=closure.inputs, - location=closure.location, - ) - else: - domain = closure.domain - - closures.append(closure) - - local_shifts = trace_shifts.trace_stencil(closure.stencil, num_args=len(closure.inputs)) - for param_sym, shift_chains in zip(closure.inputs, local_shifts): - param = param_sym.id - assert isinstance(param, str) - consumed_domains: list[SymbolicDomain] = ( - [SymbolicDomain.from_expr(domains[param])] if param in domains else [] - ) - for shift_chain in shift_chains: - consumed_domain = SymbolicDomain.from_expr(domain) - for offset_name, offset in _group_offsets(shift_chain): - if isinstance(offset_provider[offset_name], gtx.Dimension): - # cartesian shift - dim = offset_provider[offset_name] - assert offset is not trace_shifts.Sentinel.ALL_NEIGHBORS - consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) - elif isinstance(offset_provider[offset_name], common.Connectivity): - # unstructured shift - nbt_provider = offset_provider[offset_name] - old_axis = nbt_provider.origin_axis - new_axis = nbt_provider.neighbor_axis + # no transformation occurred + if not did_transform: + stmts.append(stmt) - assert new_axis not in consumed_domain.ranges or old_axis == new_axis + return stmts - if symbolic_sizes is None: - new_range = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal( - str(horizontal_sizes[new_axis.value]), ir.INTEGER_INDEX_BUILTIN - ), - ) - else: - new_range = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(symbolic_sizes[new_axis.value]), - ) - consumed_domain.ranges = dict( - (axis, range_) if axis != old_axis else (new_axis, new_range) - for axis, range_ in consumed_domain.ranges.items() - ) - # TODO(tehrengruber): Revisit. Somehow the order matters so preserve it. - consumed_domain.ranges = dict( - (axis, range_) if axis != old_axis else (new_axis, new_range) - for axis, range_ in consumed_domain.ranges.items() - ) - else: - raise NotImplementedError() - consumed_domains.append(consumed_domain) - # compute the bounds of all consumed domains - if consumed_domains: - if all( - consumed_domain.ranges.keys() == consumed_domains[0].ranges.keys() - for consumed_domain in consumed_domains - ): # scalar otherwise - domains[param] = domain_union(*consumed_domains).as_expr() +def create_global_tmps( + program: itir.Program, + offset_provider: common.OffsetProvider, + *, + uids: Optional[eve_utils.UIDGenerator] = None, +) -> itir.Program: + """ + Given an `itir.Program` create temporaries for intermediate values. - return FencilWithTemporaries( - fencil=ir.FencilDefinition( - id=node.fencil.id, - function_definitions=node.fencil.function_definitions, - params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again - closures=list(reversed(closures)), - location=node.fencil.location, - implicit_domain=node.fencil.implicit_domain, - ), - params=node.params, - tmps=node.tmps, + This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its + arguments into temporaries. + """ + program = infer_domain.infer_program(program, offset_provider=offset_provider) + program = type_inference.infer( + program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) - -def _tuple_constituents(node: ir.Expr) -> Iterable[ir.Expr]: - if cpm.is_call_to(node, "make_tuple"): - for arg in node.args: - yield from _tuple_constituents(arg) - else: - yield node - - -def collect_tmps_info( - node: FencilWithTemporaries, *, offset_provider: common.OffsetProvider -) -> FencilWithTemporaries: - """Perform type inference for finding the types of temporaries and sets the temporary size.""" - tmps = {tmp.id for tmp in node.tmps} - domains: dict[str, ir.Expr] = {} - for closure in node.fencil.closures: - for output_field in _tuple_constituents(closure.output): - assert isinstance(output_field, ir.SymRef) - if output_field.id not in tmps: - continue - - assert output_field.id not in domains or domains[output_field.id] == closure.domain - domains[output_field.id] = closure.domain - - new_node = FencilWithTemporaries( - fencil=node.fencil, - params=node.params, - tmps=[ - ir.Temporary(id=tmp.id, domain=domains[tmp.id], dtype=tmp.dtype) for tmp in node.tmps - ], + if not uids: + uids = eve_utils.UIDGenerator(prefix="__tmp") + declarations = program.declarations.copy() + new_body = [] + + for stmt in program.body: + assert isinstance(stmt, itir.SetAt) + new_body.extend(_transform_stmt(stmt, uids=uids, declarations=declarations)) + + return itir.Program( + id=program.id, + function_definitions=program.function_definitions, + params=program.params, + declarations=declarations, + body=new_body, ) - # TODO(tehrengruber): type inference is only really needed to infer the types of the temporaries - # and write them to the params of the inner fencil. This should be cleaned up after we - # refactored the IR. - return itir_type_inference.infer(new_node, offset_provider=offset_provider) - - -def validate_no_dynamic_offsets(node: ir.Node) -> None: - """Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)`""" - for call_node in node.walk_values().if_isinstance(ir.FunCall): - assert isinstance(call_node, ir.FunCall) - if cpm.is_call_to(call_node, "shift"): - if any(not isinstance(arg, ir.OffsetLiteral) for arg in call_node.args): - raise NotImplementedError("Dynamic offsets not supported in temporary pass.") - - -# TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be -# tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore -# and hence also not extract as a temporary. -class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator): - """Main entry point for introducing global temporaries. - - Transforms an existing iterator IR fencil into a fencil with global temporaries. - """ - - def visit_FencilDefinition( - self, - node: ir.FencilDefinition, - *, - offset_provider: Mapping[str, Any], - extraction_heuristics: Optional[ - Callable[[ir.StencilClosure], Callable[[ir.Expr], bool]] - ] = None, - symbolic_sizes: Optional[dict[str, str]], - ) -> FencilWithTemporaries: - # Vaidate we have no dynamic offsets, e.g. `shift(Ioff, deref(...))(...)` - validate_no_dynamic_offsets(node) - # Split closures on lifted function calls and introduce temporaries - res = split_closures( - node, offset_provider=offset_provider, extraction_heuristics=extraction_heuristics - ) - # Prune unreferences closure inputs introduced in the previous step - res = PruneClosureInputs().visit(res) - # Prune unused temporaries possibly introduced in the previous step - res = prune_unused_temporaries(res) - # Perform an eta-reduction which should put all calls at the highest level of a closure - res = EtaReduction().visit(res) - # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider, symbolic_sizes) - # Use type inference to determine the data type of the temporaries - return collect_tmps_info(res, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index c1a743af1c..f3c3185225 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,22 +10,62 @@ import itertools import typing -from typing import Callable, TypeAlias +from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ( common_pattern_matcher as cpm, domain_utils, ir_makers as im, ) from gt4py.next.iterator.transforms import trace_shifts -from gt4py.next.utils import tree_map +from gt4py.next.utils import flatten_nested_tuple, tree_map -DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] -ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class DomainAccessDescriptor(eve.StrEnum): + """ + Descriptor for domains that could not be inferred. + """ + + # TODO(tehrengruber): Revisit this concept. It is strange that we don't have a descriptor + # `KNOWN`, but since we don't need it, it wasn't added. + + #: The access is unknown because of a dynamic shift.whose extent is not known. + #: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)` + UNKNOWN = "unknown" + #: The domain is never accessed. + #: E.g.: `{in_field1, in_field2}[0]` + NEVER = "never" + + +NonTupleDomainAccess: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor +#: The domain can also be a tuple of domains, usually this only occurs for scan operators returning +#: a tuple since other occurrences for tuples are removed before domain inference. This is +#: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just +#: fine to a tuple of a vertex and an edge domain. +DomainAccess: TypeAlias = NonTupleDomainAccess | tuple["DomainAccess", ...] +AccessedDomains: TypeAlias = dict[str, DomainAccess] + + +class InferenceOptions(typing.TypedDict): + offset_provider: common.OffsetProvider + symbolic_domain_sizes: Optional[dict[str, str]] + allow_uninferred: bool + + +class DomainAnnexDebugger(eve.NodeVisitor): + """ + Small utility class to debug missing domain attribute in annex. + """ + + def visit_Node(self, node: itir.Node): + if cpm.is_applied_as_fieldop(node): + if not hasattr(node.annex, "domain"): + breakpoint() # noqa: T100 + return self.generic_visit(node) def _split_dict_by_key(pred: Callable, d: dict): @@ -44,43 +84,58 @@ def _split_dict_by_key(pred: Callable, d: dict): # TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_utils.domain_union`. -def _domain_union_with_none( - *domains: domain_utils.SymbolicDomain | None, -) -> domain_utils.SymbolicDomain | None: - filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] +def _domain_union( + *domains: domain_utils.SymbolicDomain | DomainAccessDescriptor, +) -> domain_utils.SymbolicDomain | DomainAccessDescriptor: + if any(d == DomainAccessDescriptor.UNKNOWN for d in domains): + return DomainAccessDescriptor.UNKNOWN + + filtered_domains: list[domain_utils.SymbolicDomain] = [ + d # type: ignore[misc] # domain can never be unknown as these cases are filtered above + for d in domains + if d != DomainAccessDescriptor.NEVER + ] if len(filtered_domains) == 0: - return None + return DomainAccessDescriptor.NEVER return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: +def _canonicalize_domain_structure( + d1: DomainAccess, d2: DomainAccess +) -> tuple[DomainAccess, DomainAccess]: """ Given two domains or composites thereof, canonicalize their structure. If one of the arguments is a tuple the other one will be promoted to a tuple of same structure - unless it already is a tuple. Missing values are replaced by None, meaning no domain is - specified. + unless it already is a tuple. Missing values are filled by :ref:`DomainAccessDescriptor.NEVER`. >>> domain = im.domain(common.GridType.CARTESIAN, {}) >>> _canonicalize_domain_structure((domain,), (domain, domain)) == ( - ... (domain, None), + ... (domain, DomainAccessDescriptor.NEVER), ... (domain, domain), ... ) True - >>> _canonicalize_domain_structure((domain, None), None) == ((domain, None), (None, None)) + >>> _canonicalize_domain_structure( + ... (domain, DomainAccessDescriptor.NEVER), DomainAccessDescriptor.NEVER + ... ) == ( + ... (domain, DomainAccessDescriptor.NEVER), + ... (DomainAccessDescriptor.NEVER, DomainAccessDescriptor.NEVER), + ... ) True """ - if d1 is None and isinstance(d2, tuple): - return _canonicalize_domain_structure((None,) * len(d2), d2) - if d2 is None and isinstance(d1, tuple): - return _canonicalize_domain_structure(d1, (None,) * len(d1)) + if d1 is DomainAccessDescriptor.NEVER and isinstance(d2, tuple): + return _canonicalize_domain_structure((DomainAccessDescriptor.NEVER,) * len(d2), d2) + if d2 is DomainAccessDescriptor.NEVER and isinstance(d1, tuple): + return _canonicalize_domain_structure(d1, (DomainAccessDescriptor.NEVER,) * len(d1)) if isinstance(d1, tuple) and isinstance(d2, tuple): return tuple( zip( *( _canonicalize_domain_structure(el1, el2) - for el1, el2 in itertools.zip_longest(d1, d2, fillvalue=None) + for el1, el2 in itertools.zip_longest( + d1, d2, fillvalue=DomainAccessDescriptor.NEVER + ) ) ) ) # type: ignore[return-value] # mypy not smart enough @@ -88,16 +143,16 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA def _merge_domains( - original_domains: ACCESSED_DOMAINS, - additional_domains: ACCESSED_DOMAINS, -) -> ACCESSED_DOMAINS: + original_domains: AccessedDomains, + additional_domains: AccessedDomains, +) -> AccessedDomains: new_domains = {**original_domains} for key, domain in additional_domains.items(): original_domain, domain = _canonicalize_domain_structure( - original_domains.get(key, None), domain + original_domains.get(key, DomainAccessDescriptor.NEVER), domain ) - new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) + new_domains[key] = tree_map(_domain_union)(original_domain, domain) return new_domains @@ -105,37 +160,52 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: domain_utils.SymbolicDomain, + target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider, -) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} + symbolic_domain_sizes: Optional[dict[str, str]], +) -> dict[str, NonTupleDomainAccess]: + accessed_domains: dict[str, NonTupleDomainAccess] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): + # TODO(tehrengruber): Dynamic shifts are not supported by `SymbolicDomain.translate`. Use + # special `UNKNOWN` marker for them until we have implemented a proper solution. + if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): + accessed_domains[in_field_id] = DomainAccessDescriptor.UNKNOWN + continue + new_domains = [ - domain_utils.SymbolicDomain.translate(target_domain, shift, offset_provider) + domain_utils.SymbolicDomain.translate( + target_domain, shift, offset_provider, symbolic_domain_sizes + ) + if not isinstance(target_domain, DomainAccessDescriptor) + else target_domain for shift in shifts_list ] - # `None` means field is never accessed - accessed_domains[in_field_id] = _domain_union_with_none( - accessed_domains.get(in_field_id, None), *new_domains + accessed_domains[in_field_id] = _domain_union( + accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains ) - return typing.cast(ACCESSED_DOMAINS, accessed_domains) + return accessed_domains -def infer_as_fieldop( +def _infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: DOMAIN, + target_domain: DomainAccess, + *, offset_provider: common.OffsetProvider, -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + symbolic_domain_sizes: Optional[dict[str, str]], + allow_uninferred: bool, +) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") - if target_domain is None: - raise ValueError("'target_domain' cannot be 'None'.") - if not isinstance(target_domain, domain_utils.SymbolicDomain): - raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") + if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: + raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") + # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. + if isinstance(target_domain, tuple): + target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough + assert isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)) # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -157,23 +227,30 @@ def infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( - stencil, input_ids, target_domain, offset_provider + inputs_accessed_domains: dict[str, NonTupleDomainAccess] = _extract_accessed_domains( + stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s + accessed_domains: AccessedDomains = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, accessed_domains[in_field_id], offset_provider + in_field, + inputs_accessed_domains[in_field_id], + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) transformed_inputs.append(transformed_input) accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - transformed_call = im.as_fieldop(stencil, domain_utils.SymbolicDomain.as_expr(target_domain))( - *transformed_inputs - ) + if not isinstance(target_domain, DomainAccessDescriptor): + target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + else: + target_domain_expr = None + transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) accessed_domains_without_tmp = { k: v @@ -184,18 +261,17 @@ def infer_as_fieldop( return transformed_call, accessed_domains_without_tmp -def infer_let( +def _infer_let( let_expr: itir.FunCall, - input_domain: DOMAIN, - offset_provider: common.OffsetProvider, -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + input_domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.FunCall, AccessedDomains]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy - transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider - ) - let_params = {param_sym.id for param_sym in let_expr.fun.params} + + transformed_calls_expr, accessed_domains = infer_expr(let_expr.fun.expr, input_domain, **kwargs) + accessed_domains_let_args, accessed_domains_outer = _split_dict_by_key( lambda k: k in let_params, accessed_domains ) @@ -206,9 +282,9 @@ def infer_let( arg, accessed_domains_let_args.get( param.id, - None, + DomainAccessDescriptor.NEVER, ), - offset_provider, + **kwargs, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -223,14 +299,14 @@ def infer_let( return transformed_call, accessed_domains_outer -def infer_make_tuple( +def _infer_make_tuple( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} if not isinstance(domain, tuple): # promote domain to a tuple of domains such that it has the same structure as # the expression @@ -238,105 +314,174 @@ def infer_make_tuple( # out @ c⟨ IDimₕ: [0, __out_size_0) ⟩ ← {__sym_1, __sym_2}; domain = (domain,) * len(expr.args) assert len(expr.args) >= len(domain) - # There may be less domains than tuple args, pad the domain with `None` in that case. - # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` - domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) + # There may be fewer domains than tuple args, pad the domain with `NEVER` + # in that case. + # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` + domain = (*domain, *(DomainAccessDescriptor.NEVER for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) - return im.call(expr.fun)(*infered_args_expr), actual_domains + result_expr = im.call(expr.fun)(*infered_args_expr) + return result_expr, actual_domains -def infer_tuple_get( +def _infer_tuple_get( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "tuple_get") - actual_domains: ACCESSED_DOMAINS = {} - idx, tuple_arg = expr.args - assert isinstance(idx, itir.Literal) - child_domain = tuple(None if i != int(idx.value) else domain for i in range(int(idx.value) + 1)) - infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, child_domain, offset_provider) + actual_domains: AccessedDomains = {} + idx_expr, tuple_arg = expr.args + assert isinstance(idx_expr, itir.Literal) + idx = int(idx_expr.value) + tuple_domain = tuple( + DomainAccessDescriptor.NEVER if i != idx else domain for i in range(idx + 1) + ) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, **kwargs) - infered_args_expr = im.tuple_get(idx.value, infered_arg_expr) + infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) return infered_args_expr, actual_domains -def infer_if( +def _infer_if( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, offset_provider) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) - return im.call(expr.fun)(cond, *infered_args_expr), actual_domains + result_expr = im.call(expr.fun)(cond, *infered_args_expr) + return result_expr, actual_domains -def infer_expr( +def _infer_expr( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return infer_as_fieldop(expr, domain, offset_provider) + return _infer_as_fieldop(expr, domain, **kwargs) elif cpm.is_let(expr): - return infer_let(expr, domain, offset_provider) + return _infer_let(expr, domain, **kwargs) elif cpm.is_call_to(expr, "make_tuple"): - return infer_make_tuple(expr, domain, offset_provider) + return _infer_make_tuple(expr, domain, **kwargs) elif cpm.is_call_to(expr, "tuple_get"): - return infer_tuple_get(expr, domain, offset_provider) + return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): - return infer_if(expr, domain, offset_provider) + return _infer_if(expr, domain, **kwargs) elif ( - cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) - or cpm.is_call_to(expr, itir.TYPEBUILTINS) - or cpm.is_call_to(expr, "cast_") + cpm.is_call_to(expr, builtins.ARITHMETIC_BUILTINS) + or cpm.is_call_to(expr, builtins.TYPE_BUILTINS) + or cpm.is_call_to(expr, ("cast_", "index", "unstructured_domain", "cartesian_domain")) ): return expr, {} else: raise ValueError(f"Unsupported expression: {expr}") +def infer_expr( + expr: itir.Expr, + domain: DomainAccess, + *, + offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, +) -> tuple[itir.Expr, AccessedDomains]: + """ + Infer the domain of all field subexpressions of `expr`. + + Given an expression `expr` and the domain it is accessed at, back-propagate the domain of all + (field-typed) subexpression. + + Arguments: + - expr: The expression to be inferred. + - domain: The domain `expr` is read at. + - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol + name that evaluates to the length of that axis. + - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. + because of a dynamic shift) or never accessed. + + Returns: + A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) + having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to + domain they are accessed at. + """ + expr, accessed_domains = _infer_expr( + expr, + domain, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) + expr.annex.domain = domain + + return expr, accessed_domains + + +def _infer_stmt( + stmt: itir.Stmt, + **kwargs: Unpack[InferenceOptions], +): + if isinstance(stmt, itir.SetAt): + transformed_call, _ = infer_expr( + stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), **kwargs + ) + + return itir.SetAt( + expr=transformed_call, + domain=stmt.domain, + target=stmt.target, + ) + elif isinstance(stmt, itir.IfStmt): + return itir.IfStmt( + cond=stmt.cond, + true_branch=[_infer_stmt(c, **kwargs) for c in stmt.true_branch], + false_branch=[_infer_stmt(c, **kwargs) for c in stmt.false_branch], + ) + raise ValueError(f"Unsupported stmt: {stmt}") + + def infer_program( program: itir.Program, + *, offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ) -> itir.Program: - transformed_set_ats: list[itir.SetAt] = [] + """ + Infer the domain of all field subexpressions inside a program. + + See :func:`infer_expr` for more details. + """ assert ( not program.function_definitions ), "Domain propagation does not support function definitions." - for set_at in program.body: - assert isinstance(set_at, itir.SetAt) - - transformed_call, _unused_domain = infer_expr( - set_at.expr, domain_utils.SymbolicDomain.from_expr(set_at.domain), offset_provider - ) - transformed_set_ats.append( - itir.SetAt( - expr=transformed_call, - domain=set_at.domain, - target=set_at.target, - ), - ) - return itir.Program( id=program.id, function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=transformed_set_ats, + body=[ + _infer_stmt( + stmt, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) + for stmt in program.body + ], ) diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 95c761d7ba..c0a8c9f1b7 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses -from typing import ClassVar, Optional +from typing import ClassVar, Optional, TypeVar import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm from gt4py import eve @@ -23,6 +23,9 @@ def is_center_derefed_only(node: itir.Node) -> bool: return hasattr(node.annex, "recorded_shifts") and node.annex.recorded_shifts in [set(), {()}] +T = TypeVar("T", bound=itir.Program | itir.Lambda) + + @dataclasses.dataclass class InlineCenterDerefLiftVars(eve.NodeTranslator): """ @@ -33,26 +36,40 @@ class InlineCenterDerefLiftVars(eve.NodeTranslator): `let(var, (↑stencil)(it))(·var + ·var)` Directly inlining `var` would increase the size of the tree and duplicate the calculation. - Instead, this pass computes the value at the current location once and replaces all previous - references to `var` by an applied lift which captures this value. + Instead this pass, first takes the iterator `(↑stencil)(it)` and transforms it into a + 0-ary function that evaluates to the value at the current location. + + `λ() → ·(↑stencil)(it)` - `let(_icdlv_1, stencil(it))(·(↑(λ() → _icdlv_1) + ·(↑(λ() → _icdlv_1))` + Then all previous occurences of `var` are replaced by this function. + + `let(_icdlv_1, λ() → ·(↑stencil)(it))(·(↑(λ() → _icdlv_1()) + ·(↑(λ() → _icdlv_1()))` The lift inliner can then later easily transform this into a nice expression: - `let(_icdlv_1, stencil(it))(_icdlv_1 + _icdlv_1)` + `let(_icdlv_1, λ() → stencil(it))(_icdlv_1() + _icdlv_1())` + + Finally, recomputation is avoided by using the common subexpression elimination and lamba + inlining (can be configured opcount preserving). Both is up to the caller to do later. + + `λ(_cs_1) → _cs_1 + _cs_1)(stencil(it))` - Note: This pass uses and preserves the `recorded_shifts` annex. + Note: This pass uses and preserves the `domain` and `recorded_shifts` annex. """ - PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("recorded_shifts",) + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain", "recorded_shifts") uids: eve_utils.UIDGenerator @classmethod - def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + def apply( + cls, node: T, *, is_stencil=False, uids: Optional[eve_utils.UIDGenerator] = None + ) -> T: if not uids: uids = eve_utils.UIDGenerator() + if is_stencil: + assert isinstance(node, itir.Expr) + trace_shifts.trace_stencil(node, save_to_annex=True) return cls(uids=uids).visit(node) def visit_FunCall(self, node: itir.FunCall, **kwargs): @@ -70,20 +87,25 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): assert isinstance(node.fun, itir.Lambda) # to make mypy happy eligible_params = [False] * len(node.fun.params) new_args = [] - bound_scalars: dict[str, itir.Expr] = {} + # values are 0-ary lambda functions that evaluate to the derefed argument. We don't put + # the values themselves here as they might be inside of an if to protected from an oob + # access + evaluators: dict[str, itir.Expr] = {} for i, (param, arg) in enumerate(zip(node.fun.params, node.args)): if cpm.is_applied_lift(arg) and is_center_derefed_only(param): eligible_params[i] = True - bound_arg_name = self.uids.sequential_id(prefix="_icdlv") - capture_lift = im.promote_to_const_iterator(bound_arg_name) + bound_arg_evaluator = self.uids.sequential_id(prefix="_icdlv") + capture_lift = im.promote_to_const_iterator(im.call(bound_arg_evaluator)()) trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift) new_args.append(capture_lift) # since we deref an applied lift here we can (but don't need to) immediately # inline - bound_scalars[bound_arg_name] = InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - ).visit(im.deref(arg), recurse=False) + evaluators[bound_arg_evaluator] = im.lambda_()( + InlineLifts(flags=InlineLifts.Flag.INLINE_DEREF_LIFT).visit( + im.deref(arg), recurse=False + ) + ) else: new_args.append(arg) @@ -92,6 +114,6 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): im.call(node.fun)(*new_args), eligible_params=eligible_params ) # TODO(tehrengruber): propagate let outwards - return im.let(*bound_scalars.items())(new_node) + return im.let(*evaluators.items())(new_node) return node diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py new file mode 100644 index 0000000000..0af9d9dab9 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -0,0 +1,73 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts +from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs + + +def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]: + if not cpm.is_applied_as_fieldop(node): + return None + params_shifts = trace_shifts.trace_stencil( + node.fun.args[0], # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + num_args=len(node.args), + save_to_annex=True, + ) + dynamic_shifts = [ + any(trace_shifts.Sentinel.VALUE in shifts for shifts in param_shifts) + for param_shifts in params_shifts + ] + return dynamic_shifts + + +@dataclasses.dataclass +class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + uids: eve_utils.UIDGenerator + + @classmethod + def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall, **kwargs): + node = self.generic_visit(node, **kwargs) + + if cpm.is_let(node) and ( + dynamic_shift_args := _dynamic_shift_args(let_body := node.fun.expr) # type: ignore[attr-defined] # ensured by is_let + ): + inline_let_params = {p.id: False for p in node.fun.params} # type: ignore[attr-defined] # ensured by is_let + + for inp, is_dynamic_shift_arg in zip(let_body.args, dynamic_shift_args, strict=True): + for ref in collect_symbol_refs(inp): + if ref in inline_let_params and is_dynamic_shift_arg: + inline_let_params[ref] = True + + if any(inline_let_params): + node = inline_lambdas.inline_lambda( + node, eligible_params=list(inline_let_params.values()) + ) + + if dynamic_shift_args := _dynamic_shift_args(node): + assert len(node.fun.args) in [1, 2] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop in _dynamic_shift_args + fuse_args = [ + not isinstance(inp, itir.SymRef) and dynamic_shift_arg + for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True) + ] + if any(fuse_args): + return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids) + + return node diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index a2188030a1..03b20d14fe 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -36,12 +36,12 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: >>> fun1 = itir.FunctionDefinition( ... id="fun1", ... params=[im.sym("a")], - ... expr=im.call("deref")("a"), + ... expr=im.deref("a"), ... ) >>> fun2 = itir.FunctionDefinition( ... id="fun2", ... params=[im.sym("a")], - ... expr=im.call("deref")("a"), + ... expr=im.deref("a"), ... ) >>> program = itir.Program( ... id="testee", @@ -59,7 +59,7 @@ def prune_unreferenced_fundefs(program: itir.Program) -> itir.Program: >>> print(prune_unreferenced_fundefs(program)) testee(inp, out) { fun1 = λ(a) → ·a; - out @ c⟨ IDimₕ: [0, 10) ⟩ ← fun1(inp); + out @ c⟨ IDimₕ: [0, 10[ ⟩ ← fun1(inp); } """ fun_names = [fun.id for fun in program.function_definitions] diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index f899da73b1..33e36bfa4b 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -5,7 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause - +# FIXME[#1582](tehrengruber): This transformation is not used anymore. Decide on its fate. from typing import Sequence, TypeGuard from gt4py import eve diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 920d628166..9053214b39 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -14,6 +14,7 @@ from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs +from gt4py.next.iterator.type_system import inference as itir_inference # TODO(tehrengruber): Reduce complexity of the function by removing the different options here @@ -96,9 +97,8 @@ def new_name(name): if all(eligible_params): new_expr.location = node.location - return new_expr else: - return ir.FunCall( + new_expr = ir.FunCall( fun=ir.Lambda( params=[ param @@ -110,6 +110,11 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass @@ -117,10 +122,10 @@ class InlineLambdas(PreserveLocationVisitor, NodeTranslator): """ Inline lambda calls by substituting every argument by its value. - Note: This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. + Note: This pass preserves, but doesn't use the `type` `recorded_shifts`, `domain` annex. """ - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") opcount_preserving: bool diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index f27dbbb74c..166324486a 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -8,7 +8,7 @@ import dataclasses import enum -from typing import Callable, Optional +from typing import Callable, ClassVar, Optional import gt4py.eve as eve from gt4py.eve import NodeTranslator, traits @@ -80,6 +80,7 @@ def _transform_and_extract_lift_args( new_args = [] for i, arg in enumerate(node.args): if isinstance(arg, ir.SymRef): + # TODO(tehrengruber): Is it possible to reinfer the type if it is not inherited here? sym = ir.Sym(id=arg.id) assert sym not in extracted_args or extracted_args[sym] == arg extracted_args[sym] = arg @@ -92,6 +93,7 @@ def _transform_and_extract_lift_args( ) assert new_symbol not in extracted_args extracted_args[new_symbol] = arg + # TODO(tehrengruber): Is it possible to reinfer the type if it is not inherited here? new_args.append(ir.SymRef(id=new_symbol.id)) itir_node = im.lift(inner_stencil)(*new_args) @@ -112,6 +114,8 @@ class InlineLifts( function nodes. """ + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain",) + class Flag(enum.IntEnum): #: `shift(...)(lift(f)(args...))` -> `lift(f)(shift(...)(args)...)` PROPAGATE_SHIFT = 1 @@ -157,6 +161,9 @@ def visit_FunCall( if self.flags & self.Flag.PROPAGATE_SHIFT and _is_shift_lift(node): shift = node.fun + # This transformation does not preserve the type (the position dims of the iterator + # change). Delete type to avoid errors. + shift.type = None assert len(node.args) == 1 lift_call = node.args[0] new_args = [ @@ -197,11 +204,11 @@ def visit_FunCall( if len(args) == 0: return im.literal_from_value(True) - res = ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[args[0]]) + res = im.can_deref(args[0]) for arg in args[1:]: res = ir.FunCall( fun=ir.SymRef(id="and_"), - args=[res, ir.FunCall(fun=ir.SymRef(id="can_deref"), args=[arg])], + args=[res, im.can_deref(arg)], ) return res elif ( diff --git a/src/gt4py/next/iterator/transforms/inline_scalar.py b/src/gt4py/next/iterator/transforms/inline_scalar.py new file mode 100644 index 0000000000..b424074b5c --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_scalar.py @@ -0,0 +1,39 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import inline_lambdas +from gt4py.next.iterator.type_system import inference as itir_inference +from gt4py.next.type_system import type_specifications as ts + + +class InlineScalar(eve.NodeTranslator): + PRESERVED_ANNEX_ATTRS = ("domain",) + + @classmethod + def apply(cls, program: itir.Program, offset_provider_type: common.OffsetProviderType): + program = itir_inference.infer(program, offset_provider_type=offset_provider_type) + return cls().visit(program) + + def generic_visit(self, node, **kwargs): + if cpm.is_call_to(node, "as_fieldop"): + return node + + return super().generic_visit(node, **kwargs) + + def visit_Expr(self, node: itir.Expr): + node = self.generic_visit(node) + + if cpm.is_let(node): + eligible_params = [isinstance(arg.type, ts.ScalarType) for arg in node.args] + node = inline_lambdas.inline_lambda(node, eligible_params=eligible_params) + return node + return node diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 0e7d74e594..9c0c25bd49 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from typing import ClassVar import gt4py.eve as eve from gt4py.next.iterator import ir as itir @@ -26,6 +27,8 @@ class MergeLet(eve.PreserveLocationVisitor, eve.NodeTranslator): This can significantly reduce the depth of the tree and its readability. """ + PRESERVED_ANNEX_ATTRS: ClassVar[tuple[str, ...]] = ("domain",) + def visit_FunCall(self, node: itir.FunCall): node = self.generic_visit(node) if ( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8dd76b289b..4023950dfb 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,127 +6,115 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import enum -from typing import Callable, Optional +from typing import Optional, Protocol from gt4py.eve import utils as eve_utils +from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs +from gt4py.next.iterator.transforms import ( + fuse_as_fieldop, + global_tmps, + infer_domain, + inline_dynamic_shifts, + inline_fundefs, + inline_lifts, +) from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.global_tmps import CreateGlobalTmps, FencilWithTemporaries -from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.inline_lifts import InlineLifts +from gt4py.next.iterator.transforms.inline_scalar import InlineScalar from gt4py.next.iterator.transforms.merge_let import MergeLet from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts -from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce +from gt4py.next.iterator.type_system.inference import infer -@enum.unique -class LiftMode(enum.Enum): - FORCE_INLINE = enum.auto() - USE_TEMPORARIES = enum.auto() - - -def _inline_lifts(ir, lift_mode): - if lift_mode == LiftMode.FORCE_INLINE: - return InlineLifts().visit(ir) - elif lift_mode == LiftMode.USE_TEMPORARIES: - return InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - ).visit(ir) - else: - raise ValueError() - - return ir - - -def _inline_into_scan(ir, *, max_iter=10): - for _ in range(10): - # in case there are multiple levels of lambdas around the scan we have to do multiple iterations - inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") - return ir +class GTIRTransform(Protocol): + def __call__( + self, _: itir.Program, *, offset_provider: common.OffsetProvider + ) -> itir.Program: ... # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward -# `lift_mode` and `temporary_extraction_heuristics` which is inconvenient. +# `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( - ir: itir.Node, + ir: itir.Program, *, - lift_mode=None, - offset_provider=None, + offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info + extract_temporaries=False, unroll_reduce=False, common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, + #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for + #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, + offset_provider_type: Optional[common.OffsetProviderType] = None, ) -> itir.Program: - if isinstance(ir, (itir.FencilDefinition, FencilWithTemporaries)): - ir = fencil_to_program.FencilToProgram().apply( - ir - ) # FIXME[#1582](havogt): should be removed after refactoring to combined IR - else: - assert isinstance(ir, itir.Program) - # FIXME[#1582](havogt): note: currently the case when using the roundtrip backend - pass + # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps + if offset_provider_type is None: + offset_provider_type = common.offset_provider_to_type(offset_provider) + + assert isinstance(ir, itir.Program) - icdlv_uids = eve_utils.UIDGenerator() + tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") + mergeasfop_uids = eve_utils.UIDGenerator() + collapse_tuple_uids = eve_utils.UIDGenerator() - if lift_mode is None: - lift_mode = LiftMode.FORCE_INLINE - assert isinstance(lift_mode, LiftMode) ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program - ir = PropagateDeref.apply(ir) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = NormalizeShifts().visit(ir) + # TODO(tehrengruber): Many iterator test contain lifts that need to be inlined, e.g. + # test_can_deref. We didn't notice previously as FieldOpFusion did this implicitly everywhere. + ir = inline_lifts.InlineLifts().visit(ir) + + # note: this increases the size of the tree + # Inline. The domain inference can not handle "user" functions, e.g. `let f = λ(...) → ... in f(...)` + ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) + # required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed) + ir = CollapseTuple.apply( + ir, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet + ir = infer_domain.infer_program( + ir, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + ) + for _ in range(10): inlined = ir - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil - inlined = _inline_lifts(inlined, lift_mode) - - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, - ) - inlined = ConstantFolding.apply(inlined) + inlined = InlineLambdas.apply(inlined, opcount_preserving=True) + inlined = ConstantFolding.apply(inlined) # type: ignore[assignment] # always an itir.Program # This pass is required to be in the loop such that when an `if_` call with tuple arguments # is constant-folded the surrounding tuple_get calls can be removed. inlined = CollapseTuple.apply( inlined, - offset_provider=offset_provider, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + uids=collapse_tuple_uids, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program + inlined = InlineScalar.apply(inlined, offset_provider_type=offset_provider_type) + + # This pass is required to run after CollapseTuple as otherwise we can not inline + # expressions like `tuple_get(make_tuple(as_fieldop(stencil)(...)))` where stencil returns + # a list. Such expressions must be inlined however because no backend supports such + # field operators right now. + inlined = fuse_as_fieldop.FuseAsFieldOp.apply( + inlined, uids=mergeasfop_uids, offset_provider_type=offset_provider_type ) - # This pass is required such that a deref outside of a - # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the - # `tuple_get` is removed by the `CollapseTuple` pass. - inlined = PropagateDeref.apply(inlined) if inlined == ir: break @@ -134,32 +122,15 @@ def apply_common_transforms( else: raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - if lift_mode != LiftMode.FORCE_INLINE: - # FIXME[#1582](tehrengruber): implement new temporary pass here - raise NotImplementedError() - assert offset_provider is not None - ir = CreateGlobalTmps().visit( - ir, - offset_provider=offset_provider, - extraction_heuristics=temporary_extraction_heuristics, - symbolic_sizes=symbolic_domain_sizes, - ) - - for _ in range(10): - inlined = InlineLifts().visit(ir) - inlined = InlineLambdas.apply( - inlined, opcount_preserving=True, force_inline_lift_args=True - ) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") + # breaks in test_zero_dim_tuple_arg as trivial tuple_get is not inlined + if common_subexpression_elimination: + ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) + ir = MergeLet().visit(ir) + ir = InlineLambdas.apply(ir, opcount_preserving=True) - # If after creating temporaries, the scan is not at the top, we inline. - # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. - # λ(inp) → scan(λ(state, k, kp) → state + ·k + ·kp, True, 0.0)(inp, ⟪Koffₒ, 1ₒ⟫(inp))` - ir = _inline_into_scan(ir) + if extract_temporaries: + ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can @@ -168,13 +139,10 @@ def apply_common_transforms( ir = CollapseTuple.apply( ir, ignore_tuple_size=True, - offset_provider=offset_provider, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - - if lift_mode == LiftMode.FORCE_INLINE: - ir = _inline_into_scan(ir) + uids=collapse_tuple_uids, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + offset_provider_type=offset_provider_type, + ) # type: ignore[assignment] # always an itir.Program ir = NormalizeShifts().visit(ir) @@ -183,27 +151,40 @@ def apply_common_transforms( if unroll_reduce: for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider=offset_provider) + unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) if unrolled == ir: break - ir = unrolled + ir = unrolled # type: ignore[assignment] # still a `itir.Program` ir = CollapseListGet().visit(ir) ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) + # this is required as nested neighbor reductions can contain lifts, e.g., + # `neighbors(V2Eₒ, ↑f(...))` + ir = inline_lifts.InlineLifts().visit(ir) ir = NormalizeShifts().visit(ir) else: raise RuntimeError("Reduction unrolling failed.") - ir = EtaReduction().visit(ir) - ir = ScanEtaReduction().visit(ir) - - if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider=offset_provider) # type: ignore[type-var] # always an itir.Program - ir = MergeLet().visit(ir) - ir = InlineLambdas.apply( ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args ) assert isinstance(ir, itir.Program) return ir + + +def apply_fieldview_transforms( + ir: itir.Program, *, offset_provider: common.OffsetProvider +) -> itir.Program: + ir = inline_fundefs.InlineFundefs().visit(ir) + ir = inline_fundefs.prune_unreferenced_fundefs(ir) + ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True) + ir = CollapseTuple.apply( + ir, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + offset_provider_type=common.offset_provider_to_type(offset_provider), + ) # type: ignore[assignment] # type is still `itir.Program` + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet + ir = infer_domain.infer_program(ir, offset_provider=offset_provider) + return ir diff --git a/src/gt4py/next/iterator/transforms/program_to_fencil.py b/src/gt4py/next/iterator/transforms/program_to_fencil.py deleted file mode 100644 index 4411dda74f..0000000000 --- a/src/gt4py/next/iterator/transforms/program_to_fencil.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm - - -def program_to_fencil(node: itir.Program) -> itir.FencilDefinition: - assert not node.declarations - closures = [] - for stmt in node.body: - assert isinstance(stmt, itir.SetAt) - assert isinstance(stmt.expr, itir.FunCall) and cpm.is_call_to(stmt.expr.fun, "as_fieldop") - stencil, domain = stmt.expr.fun.args - inputs = stmt.expr.args - assert all(isinstance(inp, itir.SymRef) for inp in inputs) - closures.append( - itir.StencilClosure(domain=domain, stencil=stencil, output=stmt.target, inputs=inputs) - ) - - return itir.FencilDefinition( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - closures=closures, - ) diff --git a/src/gt4py/next/iterator/transforms/prune_casts.py b/src/gt4py/next/iterator/transforms/prune_casts.py new file mode 100644 index 0000000000..3276f47042 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/prune_casts.py @@ -0,0 +1,50 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import eve +from gt4py.next.iterator import builtins, ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.type_system import type_specifications as ts + + +class PruneCasts(eve.NodeTranslator): + """ + Removes cast expressions where the argument is already in the target type. + + This transformation requires the IR to be fully type-annotated, + therefore it should be applied after type-inference. + """ + + PRESERVED_ANNEX_ATTRS = ("domain",) + + def visit_FunCall(self, node: ir.FunCall) -> ir.Node: + node = self.generic_visit(node) + + if cpm.is_call_to(node, "cast_"): + value, type_constructor = node.args + + assert ( + value.type + and isinstance(type_constructor, ir.SymRef) + and (type_constructor.id in builtins.TYPE_BUILTINS) + ) + dtype = ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) + + if value.type == dtype: + return value + + elif cpm.is_identity_as_fieldop(node): + # pruning of cast expressions may leave some trivial `as_fieldop` expressions + # with form '(⇑(λ(__arg) → ·__arg))(a)' + return node.args[0] + + return node + + @classmethod + def apply(cls, node: ir.Node) -> ir.Node: + return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py deleted file mode 100644 index 5058a91216..0000000000 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ /dev/null @@ -1,44 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir - - -class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): - """Removes all unused input arguments from a stencil closure.""" - - def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: - if not isinstance(node.stencil, ir.Lambda): - return node - - unused: set[str] = {p.id for p in node.stencil.params} - expr = self.visit(node.stencil.expr, unused=unused, shadowed=set[str]()) - params = [] - inputs = [] - for param, inp in zip(node.stencil.params, node.inputs): - if param.id not in unused: - params.append(param) - inputs.append(inp) - - return ir.StencilClosure( - domain=node.domain, - stencil=ir.Lambda(params=params, expr=expr), - output=node.output, - inputs=inputs, - ) - - def visit_SymRef(self, node: ir.SymRef, *, unused: set[str], shadowed: set[str]) -> ir.SymRef: - if node.id not in shadowed: - unused.discard(node.id) - return node - - def visit_Lambda(self, node: ir.Lambda, *, unused: set[str], shadowed: set[str]) -> ir.Lambda: - return self.generic_visit( - node, unused=unused, shadowed=shadowed | {p.id for p in node.params} - ) diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 02180a3699..fb909dc5d0 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,11 +10,12 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as type_inference class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): - # This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): return symbol_map.get(str(node.id), node) @@ -32,8 +33,8 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] class RenameSymbols(PreserveLocationVisitor, NodeTranslator): - # This pass preserves, but doesn't use the `type` and `recorded_shifts` annex. - PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts") + # This pass preserves, but doesn't use the `type`, `recorded_shifts`, `domain` annex. + PRESERVED_ANNEX_ATTRS = ("type", "recorded_shifts", "domain") def visit_Sym( self, node: ir.Sym, *, name_map: Dict[str, str], active: Optional[Set[str]] = None @@ -46,7 +47,9 @@ def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id)) + new_ref = ir.SymRef(id=name_map.get(node.id, node.id)) + type_inference.copy_type(from_=node, to=new_ref, allow_untyped=True) + return new_ref return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 05163a3630..2903201083 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -69,7 +69,7 @@ def apply( Counter({SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1}) """ if ignore_builtins: - inactive_refs = {str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_} + inactive_refs = {str(n.id) for n in itir.Program._NODE_SYMBOLS_} else: inactive_refs = set() @@ -140,6 +140,4 @@ def collect_symbol_refs( def get_user_defined_symbols(symtable: dict[eve.SymbolName, itir.Sym]) -> set[str]: - return {str(sym) for sym in symtable.keys()} - { - str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_ - } + return {str(sym) for sym in symtable.keys()} - {str(n.id) for n in itir.Program._NODE_SYMBOLS_} diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 68346b6622..0648df8363 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -13,16 +13,15 @@ from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift +from gt4py.next.iterator import builtins, ir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im class ValidateRecordedShiftsAnnex(eve.NodeVisitor): """Ensure every applied lift and its arguments have the `recorded_shifts` annex populated.""" def visit_FunCall(self, node: ir.FunCall): - if is_applied_lift(node): + if cpm.is_applied_lift(node): assert hasattr(node.annex, "recorded_shifts") if len(node.annex.recorded_shifts) == 0: @@ -278,9 +277,9 @@ def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: def visit_SymRef(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: if node.id in ctx: return ctx[node.id] - elif node.id in ir.TYPEBUILTINS: + elif node.id in builtins.TYPE_BUILTINS: return Sentinel.TYPE - elif node.id in (ir.ARITHMETIC_BUILTINS | {"list_get", "make_const_list", "cast_"}): + elif node.id in (builtins.ARITHMETIC_BUILTINS | {"list_get", "make_const_list", "cast_"}): return _combine raise ValueError(f"Undefined symbol {node.id}") @@ -329,13 +328,16 @@ def fun(*args): @classmethod def trace_stencil( cls, stencil: ir.Expr, *, num_args: Optional[int] = None, save_to_annex: bool = False - ): + ) -> list[set[tuple[ir.OffsetLiteral, ...]]]: # If we get a lambda we can deduce the number of arguments. if isinstance(stencil, ir.Lambda): assert num_args is None or num_args == len(stencil.params) num_args = len(stencil.params) + elif cpm.is_call_to(stencil, "scan"): + assert isinstance(stencil.args[0], ir.Lambda) + num_args = len(stencil.args[0].params) - 1 if not isinstance(num_args, int): - raise ValueError("Stencil must be an 'itir.Lambda' or `num_args` is given.") + raise ValueError("Stencil must be an 'itir.Lambda', scan, or `num_args` is given.") assert isinstance(num_args, int) args = [im.ref(f"__arg{i}") for i in range(num_args)] diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 700b8571a5..6e993a2ed7 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -14,7 +14,7 @@ from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -30,7 +30,14 @@ def _is_neighbors_or_lifted_and_neighbors(arg: itir.Expr) -> TypeGuard[itir.FunC def _get_neighbors_args(reduce_args: Iterable[itir.Expr]) -> Iterator[itir.FunCall]: - return filter(_is_neighbors_or_lifted_and_neighbors, reduce_args) + flat_reduce_args: list[itir.Expr] = [] + for arg in reduce_args: + if cpm.is_call_to(arg, "if_"): + flat_reduce_args.extend(_get_neighbors_args(arg.args[1:3])) + else: + flat_reduce_args.append(arg) + + return filter(_is_neighbors_or_lifted_and_neighbors, flat_reduce_args) def _is_list_of_funcalls(lst: list) -> TypeGuard[list[itir.FunCall]]: @@ -57,16 +64,16 @@ def _get_partial_offset_tags(reduce_args: Iterable[itir.Expr]) -> Iterable[str]: def _get_connectivity( applied_reduce_node: itir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: + offset_provider_type: common.OffsetProviderType, +) -> common.NeighborConnectivityType: """Return single connectivity that is compatible with the arguments of the reduce.""" if not cpm.is_applied_reduce(applied_reduce_node): raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - connectivities: list[common.Connectivity] = [] + connectivities: list[common.NeighborConnectivityType] = [] for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) + conn = offset_provider_type[o] + assert isinstance(conn, common.NeighborConnectivityType) connectivities.append(conn) if not connectivities: @@ -78,34 +85,6 @@ def _get_connectivity( return connectivities[0] -def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), - args=[iterator], - location=iterator.location, - ) - - -def _make_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator], location=iterator.location) - - -def _make_can_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.SymRef(id="can_deref"), args=[iterator], location=iterator.location - ) - - -def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall: - return itir.FunCall( - fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], location=cond.location - ) - - -def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr], location=expr.location) - - @dataclasses.dataclass(frozen=True) class UnrollReduce(PreserveLocationVisitor, NodeTranslator): # we use one UID generator per instance such that the generated ids are @@ -113,37 +92,35 @@ class UnrollReduce(PreserveLocationVisitor, NodeTranslator): uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) @classmethod - def apply(cls, node: itir.Node, **kwargs) -> itir.Node: - return cls().visit(node, **kwargs) + def apply(cls, node: itir.Node, offset_provider_type: common.OffsetProviderType) -> itir.Node: + return cls().visit(node, offset_provider_type=offset_provider_type) - def _visit_reduce(self, node: itir.FunCall, **kwargs) -> itir.Expr: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - connectivity = _get_connectivity(node, offset_provider) - max_neighbors = connectivity.max_neighbors - has_skip_values = connectivity.has_skip_values + def _visit_reduce( + self, node: itir.FunCall, offset_provider_type: common.OffsetProviderType + ) -> itir.Expr: + connectivity_type = _get_connectivity(node, offset_provider_type) + max_neighbors = connectivity_type.max_neighbors + has_skip_values = connectivity_type.has_skip_values - acc = itir.SymRef(id=self.uids.sequential_id(prefix="_acc")) - offset = itir.SymRef(id=self.uids.sequential_id(prefix="_i")) - step = itir.SymRef(id=self.uids.sequential_id(prefix="_step")) + acc: str = self.uids.sequential_id(prefix="_acc") + offset: str = self.uids.sequential_id(prefix="_i") + step: str = self.uids.sequential_id(prefix="_step") assert isinstance(node.fun, itir.FunCall) fun, init = node.fun.args - elems = [_make_list_get(offset, arg) for arg in node.args] - step_fun: itir.Expr = itir.FunCall(fun=fun, args=[acc, *elems]) + elems = [im.list_get(offset, arg) for arg in node.args] + step_fun: itir.Expr = im.call(fun)(acc, *elems) if has_skip_values: check_arg = next(_get_neighbors_args(node.args)) offset_tag, it = check_arg.args - can_deref = _make_can_deref(_make_shift([offset_tag, offset], it)) - step_fun = _make_if(can_deref, step_fun, acc) - step_fun = itir.Lambda(params=[itir.Sym(id=acc.id), itir.Sym(id=offset.id)], expr=step_fun) + can_deref = im.can_deref(im.shift(offset_tag, offset)(it)) + step_fun = im.if_(can_deref, step_fun, acc) + step_fun = im.lambda_(acc, offset)(step_fun) expr = init for i in range(max_neighbors): - expr = itir.FunCall(fun=step, args=[expr, itir.OffsetLiteral(value=i)]) - expr = itir.FunCall( - fun=itir.Lambda(params=[itir.Sym(id=step.id)], expr=expr), args=[step_fun] - ) + expr = im.call(step)(expr, itir.OffsetLiteral(value=i)) + expr = im.let(step, step_fun)(expr) return expr diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index bc1095dfb8..d6faefc372 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -17,9 +17,8 @@ from gt4py.eve import concepts from gt4py.eve.extended_typing import Any, Callable, Optional, TypeVar, Union from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_call_to -from gt4py.next.iterator.transforms import global_tmps from gt4py.next.iterator.type_system import type_specifications as it_ts, type_synthesizer from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.type_system.type_info import primitive_constituents @@ -33,69 +32,35 @@ def _is_representable_as_int(s: int | str) -> bool: return False -def _is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec): - """ - Predicate to determine if two types are compatible. - - This function gracefully handles: - - iterators with unknown positions which are considered compatible to any other positions - of another iterator. - - iterators which are defined everywhere, i.e. empty defined dimensions - Beside that this function simply checks for equality of types. - - >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) - >>> IDim = common.Dimension(value="IDim") - >>> type_on_i_of_i_it = it_ts.IteratorType( - ... position_dims=[IDim], defined_dims=[IDim], element_type=bool_type - ... ) - >>> type_on_undefined_of_i_it = it_ts.IteratorType( - ... position_dims="unknown", defined_dims=[IDim], element_type=bool_type - ... ) - >>> _is_compatible_type(type_on_i_of_i_it, type_on_undefined_of_i_it) - True - - >>> JDim = common.Dimension(value="JDim") - >>> type_on_j_of_j_it = it_ts.IteratorType( - ... position_dims=[JDim], defined_dims=[JDim], element_type=bool_type - ... ) - >>> _is_compatible_type(type_on_i_of_i_it, type_on_j_of_j_it) - False - """ - is_compatible = True - - if isinstance(type_a, it_ts.IteratorType) and isinstance(type_b, it_ts.IteratorType): - if not any(el_type.position_dims == "unknown" for el_type in [type_a, type_b]): - is_compatible &= type_a.position_dims == type_b.position_dims - if type_a.defined_dims and type_b.defined_dims: - is_compatible &= type_a.defined_dims == type_b.defined_dims - is_compatible &= type_a.element_type == type_b.element_type - elif isinstance(type_a, ts.TupleType) and isinstance(type_b, ts.TupleType): - for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True): - is_compatible &= _is_compatible_type(el_type_a, el_type_b) - elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType): - for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): - is_compatible &= _is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip( - type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True - ): - is_compatible &= _is_compatible_type(arg_a, arg_b) - for arg_a, arg_b in zip( - type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True - ): - is_compatible &= _is_compatible_type(arg_a, arg_b) - is_compatible &= _is_compatible_type(type_a.returns, type_b.returns) - else: - is_compatible &= type_a == type_b - - return is_compatible - - def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: if node.type: - assert _is_compatible_type(node.type, type_), "Node already has a type which differs." + assert type_info.is_compatible_type( + node.type, type_ + ), "Node already has a type which differs." + # Also populate the type of the parameters of a lambda. That way the one can access the type + # of a parameter by a lookup in the symbol table. As long as `_set_node_type` is used + # exclusively, the information stays consistent with the types stored in the `FunctionType` + # of the lambda itself. + if isinstance(node, itir.Lambda): + assert isinstance(type_, ts.FunctionType) + for param, param_type in zip(node.params, type_.pos_only_args): + _set_node_type(param, param_type) node.type = type_ +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> None: + """ + Copy type from one node to another. + + This function mainly exists for readability reasons. + """ + assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) + if from_.type is None: + assert allow_untyped + return + _set_node_type(to, from_.type) + + def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: """ Execute `callback` as soon as all `args` have a type. @@ -135,7 +100,7 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) >>> int_type = ts.ScalarType(kind=ts.ScalarKind.INT64) >>> power(float_type, int_type) - ScalarType(kind=, shape=None) + ScalarType(kind=, shape=None) Now, consider a simple lambda function that squares its argument using the power builtin. A type synthesizer for this function is simple to formulate, but merely gives us the return @@ -146,8 +111,8 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): >>> square_func_type_synthesizer = type_synthesizer.TypeSynthesizer( ... type_synthesizer=lambda base: power(base, int_type) ... ) - >>> square_func_type_synthesizer(float_type, offset_provider={}) - ScalarType(kind=, shape=None) + >>> square_func_type_synthesizer(float_type, offset_provider_type={}) + ScalarType(kind=, shape=None) Note that without a corresponding call the function itself can not be fully typed and as such the type inference algorithm has to defer typing until then. This task is handled transparently @@ -160,8 +125,8 @@ class ObservableTypeSynthesizer(type_synthesizer.TypeSynthesizer): ... node=square_func, ... store_inferred_type_in_node=True, ... ) - >>> o_type_synthesizer(float_type, offset_provider={}) - ScalarType(kind=, shape=None) + >>> o_type_synthesizer(float_type, offset_provider_type={}) + ScalarType(kind=, shape=None) >>> square_func.type == ts.FunctionType( ... pos_only_args=[float_type], pos_or_kw_args={}, kw_only_args={}, returns=float_type ... ) @@ -216,13 +181,15 @@ def on_type_ready(self, cb: Callable[[ts.TypeSpec], None]) -> None: def __call__( self, *args: type_synthesizer.TypeOrTypeSynthesizer, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]: assert all( isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args ), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" - return_type_or_synthesizer = self.type_synthesizer(*args, offset_provider=offset_provider) + return_type_or_synthesizer = self.type_synthesizer( + *args, offset_provider_type=offset_provider_type + ) # return type is a typing rule by itself if isinstance(return_type_or_synthesizer, type_synthesizer.TypeSynthesizer): @@ -241,18 +208,18 @@ def __call__( def _get_dimensions_from_offset_provider( - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, ) -> dict[str, common.Dimension]: dimensions: dict[str, common.Dimension] = {} - for offset_name, provider in offset_provider.items(): + for offset_name, provider in offset_provider_type.items(): dimensions[offset_name] = common.Dimension( value=offset_name, kind=common.DimensionKind.LOCAL ) if isinstance(provider, common.Dimension): dimensions[provider.value] = provider - elif isinstance(provider, common.Connectivity): - dimensions[provider.origin_axis.value] = provider.origin_axis - dimensions[provider.neighbor_axis.value] = provider.neighbor_axis + elif isinstance(provider, common.NeighborConnectivityType): + dimensions[provider.source_dim.value] = provider.source_dim + dimensions[provider.codomain.value] = provider.codomain return dimensions @@ -261,8 +228,8 @@ def _get_dimensions(obj: Any): if isinstance(obj, common.Dimension): yield obj elif isinstance(obj, ts.TypeSpec): - for field in dataclasses.fields(obj.__class__): - yield from _get_dimensions(getattr(obj, field.name)) + for field in obj.__datamodel_fields__.keys(): + yield from _get_dimensions(getattr(obj, field)) elif isinstance(obj, collections.abc.Mapping): for el in obj.values(): yield from _get_dimensions(el) @@ -278,10 +245,14 @@ def type_synthesizer(*args, **kwargs): assert type_info.accepts_args(fun_type, with_args=list(args), with_kwargs=kwargs) return fun_type.returns - return type_synthesizer + return ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer, store_inferred_type_in_node=False + ) class SanitizeTypes(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + PRESERVED_ANNEX_ATTRS = ("domain",) + def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir.Node: node = self.generic_visit(node) # We only want to sanitize types that have been inferred previously such that we don't run @@ -296,6 +267,15 @@ def visit_Node(self, node: itir.Node, *, symtable: dict[str, itir.Node]) -> itir T = TypeVar("T", bound=itir.Node) +_INITIAL_CONTEXT = { + name: ObservableTypeSynthesizer( + type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], + # builtin functions are polymorphic + store_inferred_type_in_node=False, + ) + for name in type_synthesizer.builtin_type_synthesizers.keys() +} + @dataclasses.dataclass class ITIRTypeInference(eve.NodeTranslator): @@ -305,18 +285,22 @@ class ITIRTypeInference(eve.NodeTranslator): See :method:ITIRTypeInference.apply for more details. """ - offset_provider: common.OffsetProvider + PRESERVED_ANNEX_ATTRS = ("domain",) + + offset_provider_type: Optional[common.OffsetProviderType] #: Mapping from a dimension name to the actual dimension instance. - dimensions: dict[str, common.Dimension] + dimensions: Optional[dict[str, common.Dimension]] #: Allow sym refs to symbols that have not been declared. Mostly used in testing. allow_undeclared_symbols: bool + #: Reinference-mode skipping already typed nodes. + reinfer: bool @classmethod def apply( cls, node: T, *, - offset_provider: common.OffsetProvider, + offset_provider_type: common.OffsetProviderType, inplace: bool = False, allow_undeclared_symbols: bool = False, ) -> T: @@ -327,14 +311,14 @@ def apply( node: The :class:`itir.Node` to infer the types of. Keyword Arguments: - offset_provider: Offset provider dictionary. + offset_provider_type: Offset provider dictionary. inplace: Write types directly to the given ``node`` instead of returning a copy. allow_undeclared_symbols: Allow references to symbols that don't have a corresponding declaration. This is useful for testing or inference on partially inferred sub-nodes. Preconditions: - All parameters in :class:`itir.Program` and :class:`itir.FencilDefinition` must have a type + All parameters in :class:`itir.Program` must have a type defined, as they are the starting point for type propagation. Design decisions: @@ -383,16 +367,16 @@ def apply( # parts of a program. node = SanitizeTypes().visit(node) - if isinstance(node, (itir.FencilDefinition, itir.Program)): + if isinstance(node, itir.Program): assert all(isinstance(param.type, ts.DataType) for param in node.params), ( - "All parameters in 'itir.Program' and 'itir.FencilDefinition' must have a type " + "All parameters in 'itir.Program' must have a type " "defined, as they are the starting point for type propagation.", ) instance = cls( - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, dimensions=( - _get_dimensions_from_offset_provider(offset_provider) + _get_dimensions_from_offset_provider(offset_provider_type) | _get_dimensions_from_types( node.pre_walk_values() .if_isinstance(itir.Node) @@ -402,28 +386,49 @@ def apply( ) ), allow_undeclared_symbols=allow_undeclared_symbols, + reinfer=False, ) if not inplace: node = copy.deepcopy(node) - instance.visit( - node, - ctx={ - name: ObservableTypeSynthesizer( - type_synthesizer=type_synthesizer.builtin_type_synthesizers[name], - # builtin functions are polymorphic - store_inferred_type_in_node=False, - ) - for name in type_synthesizer.builtin_type_synthesizers.keys() - }, + instance.visit(node, ctx=_INITIAL_CONTEXT) + return node + + @classmethod + def apply_reinfer(cls, node: T) -> T: + """ + Given a partially typed node infer the type of ``node`` and its sub-nodes. + + Contrary to the regular inference, this method does not descend into already typed sub-nodes + and can be used as a lightweight way to restore type information during a pass. + + Note that this function alters the input node, which is usually desired, and more + performant. + + Arguments: + node: The :class:`itir.Node` to infer the types of. + """ + if node.type: # already inferred + return node + + instance = cls( + offset_provider_type=None, dimensions=None, allow_undeclared_symbols=True, reinfer=True ) + instance.visit(node, ctx=_INITIAL_CONTEXT) return node def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + # we found a node that is typed, do not descend into children + if self.reinfer and isinstance(node, itir.Node) and node.type: + if isinstance(node.type, ts.FunctionType): + return _type_synthesizer_from_function_type(node.type) + return node.type + result = super().visit(node, **kwargs) + if isinstance(node, itir.Node): if isinstance(result, ts.TypeSpec): - if node.type: - assert _is_compatible_type(node.type, result) + if node.type and not isinstance(node.type, ts.DeferredType): + assert type_info.is_compatible_type(node.type, result) node.type = result elif isinstance(result, ObservableTypeSynthesizer) or result is None: pass @@ -442,42 +447,6 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return result - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.FencilType: - params: dict[str, ts.DataType] = {} - for param in node.params: - assert isinstance(param.type, ts.DataType) - params[param.id] = param.type - - function_definitions: dict[str, type_synthesizer.TypeSynthesizer] = {} - for fun_def in node.function_definitions: - function_definitions[fun_def.id] = self.visit(fun_def, ctx=ctx | function_definitions) - - closures = self.visit(node.closures, ctx=ctx | params | function_definitions) - return it_ts.FencilType(params=params, closures=closures) - - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_FencilWithTemporaries( - self, node: global_tmps.FencilWithTemporaries, *, ctx - ) -> it_ts.FencilType: - # TODO(tehrengruber): This implementation is not very appealing. Since we are about to - # refactor the IR anyway this is fine for now. - params: dict[str, ts.DataType] = {} - for param in node.params: - assert isinstance(param.type, ts.DataType) - params[param.id] = param.type - # infer types of temporary declarations - tmps: dict[str, ts.FieldType] = {} - for tmp_node in node.tmps: - tmps[tmp_node.id] = self.visit(tmp_node, ctx=ctx | params) - # and store them in the inner fencil - for fencil_param in node.fencil.params: - if fencil_param.id in tmps: - fencil_param.type = tmps[fencil_param.id] - self.visit(node.fencil, ctx=ctx) - assert isinstance(node.fencil.type, it_ts.FencilType) - return node.fencil.type - def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: params: dict[str, ts.DataType] = {} for param in node.params: @@ -494,9 +463,11 @@ def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.TupleType: domain = self.visit(node.domain, ctx=ctx) assert isinstance(domain, it_ts.DomainType) + assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype + lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), + node.dtype, ) def visit_IfStmt(self, node: itir.IfStmt, *, ctx) -> None: @@ -514,67 +485,46 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: # the target can have fewer elements than the expr in which case the output from the # expression is simply discarded. expr_type = functools.reduce( - lambda tuple_type, i: tuple_type.types[i], # type: ignore[attr-defined] # format ensured by primitive_constituents + lambda tuple_type, i: tuple_type.types[i] # type: ignore[attr-defined] # format ensured by primitive_constituents + # `ts.DeferredType` only occurs for scans returning a tuple + if not isinstance(tuple_type, ts.DeferredType) + else ts.DeferredType(constraint=None), path, node.expr.type, ) - assert isinstance(target_type, ts.FieldType) - assert isinstance(expr_type, ts.FieldType) + assert isinstance(target_type, (ts.FieldType, ts.DeferredType)) + assert isinstance(expr_type, (ts.FieldType, ts.DeferredType)) # TODO(tehrengruber): The lowering emits domains that always have the horizontal domain # first. Since the expr inherits the ordering from the domain this can lead to a mismatch # between the target and expr (e.g. when the target has dimension K, Vertex). We should # probably just change the behaviour of the lowering. Until then we do this more # complicated comparison. - assert ( - set(expr_type.dims) == set(target_type.dims) - and target_type.dtype == expr_type.dtype - ) - - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: - domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) - inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) - output: ts.FieldType = self.visit(node.output, ctx=ctx) - - assert isinstance(domain, it_ts.DomainType) - for output_el in type_info.primitive_constituents(output): - assert isinstance(output_el, ts.FieldType) - - stencil_type_synthesizer = self.visit(node.stencil, ctx=ctx) - stencil_args = [ - type_synthesizer._convert_as_fieldop_input_to_iterator(domain, input_) - for input_ in inputs - ] - stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider=self.offset_provider - ) - - return it_ts.StencilClosureType( - domain=domain, - stencil=ts.FunctionType( - pos_only_args=stencil_args, - pos_or_kw_args={}, - kw_only_args={}, - returns=stencil_returns, - ), - output=output, - inputs=inputs, - ) + if isinstance(target_type, ts.FieldType) and isinstance(expr_type, ts.FieldType): + assert ( + set(expr_type.dims).issubset(set(target_type.dims)) + and target_type.dtype == expr_type.dtype + ) def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: - assert ( - node.value in self.dimensions - ), f"Dimension {node.value} not present in offset provider." - return ts.DimensionType(dim=self.dimensions[node.value]) + return ts.DimensionType(dim=common.Dimension(value=node.value, kind=node.kind)) # TODO: revisit what we want to do with OffsetLiterals as we already have an Offset type in # the frontend. - def visit_OffsetLiteral(self, node: itir.OffsetLiteral, **kwargs) -> it_ts.OffsetLiteralType: + def visit_OffsetLiteral( + self, node: itir.OffsetLiteral, **kwargs + ) -> it_ts.OffsetLiteralType | ts.DeferredType: + # `self.dimensions` not available in re-inference mode. Skip since we don't care anyway. + if self.reinfer: + return ts.DeferredType(constraint=it_ts.OffsetLiteralType) + if _is_representable_as_int(node.value): return it_ts.OffsetLiteralType( - value=ts.ScalarType(kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper())) + value=ts.ScalarType( + kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper()) + ) ) else: + assert isinstance(self.dimensions, dict) assert isinstance(node.value, str) and node.value in self.dimensions return it_ts.OffsetLiteralType(value=self.dimensions[node.value]) @@ -621,7 +571,7 @@ def visit_FunCall( self.visit(value, ctx=ctx) # ensure types in value are also inferred assert ( isinstance(type_constructor, itir.SymRef) - and type_constructor.id in itir.TYPEBUILTINS + and type_constructor.id in builtins.TYPE_BUILTINS ) return ts.ScalarType(kind=getattr(ts.ScalarKind, type_constructor.id.upper())) @@ -630,13 +580,15 @@ def visit_FunCall( self.visit(tuple_, ctx=ctx) # ensure tuple is typed assert isinstance(index_literal, itir.Literal) index = int(index_literal.value) + if isinstance(tuple_.type, ts.DeferredType): + return ts.DeferredType(constraint=None) assert isinstance(tuple_.type, ts.TupleType) return tuple_.type.types[index] fun = self.visit(node.fun, ctx=ctx) args = self.visit(node.args, ctx=ctx) - result = fun(*args, offset_provider=self.offset_provider) + result = fun(*args, offset_provider_type=self.offset_provider_type) if isinstance(result, ObservableTypeSynthesizer): assert not result.node @@ -649,3 +601,5 @@ def visit_Node(self, node: itir.Node, **kwargs): infer = ITIRTypeInference.apply + +reinfer = ITIRTypeInference.apply_reinfer diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index cfe3987b8c..7825bf1c98 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -6,64 +6,29 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import dataclasses from typing import Literal from gt4py.next import common from gt4py.next.type_system import type_specifications as ts -@dataclasses.dataclass(frozen=True) class NamedRangeType(ts.TypeSpec): dim: common.Dimension -@dataclasses.dataclass(frozen=True) class DomainType(ts.DataType): - dims: list[common.Dimension] + dims: list[common.Dimension] | Literal["unknown"] -@dataclasses.dataclass(frozen=True) class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | common.Dimension -@dataclasses.dataclass(frozen=True) -class ListType(ts.DataType): - element_type: ts.DataType - - -@dataclasses.dataclass(frozen=True) class IteratorType(ts.DataType, ts.CallableType): position_dims: list[common.Dimension] | Literal["unknown"] defined_dims: list[common.Dimension] element_type: ts.DataType -@dataclasses.dataclass(frozen=True) -class StencilClosureType(ts.TypeSpec): - domain: DomainType - stencil: ts.FunctionType - output: ts.FieldType | ts.TupleType - inputs: list[ts.FieldType] - - def __post_init__(self): - # local import to avoid importing type_info from a type_specification module - from gt4py.next.type_system import type_info - - for i, el_type in enumerate(type_info.primitive_constituents(self.output)): - assert isinstance( - el_type, ts.FieldType - ), f"All constituent types must be field types, but the {i}-th element is of type '{el_type}'." - - -# TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere -@dataclasses.dataclass(frozen=True) -class FencilType(ts.TypeSpec): - params: dict[str, ts.DataType] - closures: list[StencilClosureType] - - -@dataclasses.dataclass(frozen=True) class ProgramType(ts.TypeSpec): params: dict[str, ts.DataType] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 77cd39389a..131b773dd2 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -14,7 +14,7 @@ from gt4py.eve.extended_typing import Callable, Iterable, Optional, Union from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -35,20 +35,20 @@ class TypeSynthesizer: - isinstance checks to determine if an object is actually (meant to be) a type synthesizer and not just any callable. - writing simple type synthesizers without cluttering the signature with the additional - offset_provider argument that is only needed by some. + offset_provider_type argument that is only needed by some. """ type_synthesizer: Callable[..., TypeOrTypeSynthesizer] def __post_init__(self): - if "offset_provider" not in inspect.signature(self.type_synthesizer).parameters: + if "offset_provider_type" not in inspect.signature(self.type_synthesizer).parameters: synthesizer = self.type_synthesizer - self.type_synthesizer = lambda *args, offset_provider: synthesizer(*args) + self.type_synthesizer = lambda *args, offset_provider_type: synthesizer(*args) def __call__( - self, *args: TypeOrTypeSynthesizer, offset_provider: common.OffsetProvider + self, *args: TypeOrTypeSynthesizer, offset_provider_type: common.OffsetProviderType ) -> TypeOrTypeSynthesizer: - return self.type_synthesizer(*args, offset_provider=offset_provider) + return self.type_synthesizer(*args, offset_provider_type=offset_provider_type) TypeOrTypeSynthesizer = Union[ts.TypeSpec, TypeSynthesizer] @@ -81,7 +81,7 @@ def _register_builtin_type_synthesizer( @_register_builtin_type_synthesizer( - fun_names=itir.UNARY_MATH_NUMBER_BUILTINS | itir.UNARY_MATH_FP_BUILTINS + fun_names=builtins.UNARY_MATH_NUMBER_BUILTINS | builtins.UNARY_MATH_FP_BUILTINS ) def _(val: ts.ScalarType) -> ts.ScalarType: return val @@ -92,21 +92,25 @@ def power(base: ts.ScalarType, exponent: ts.ScalarType) -> ts.ScalarType: return base -@_register_builtin_type_synthesizer(fun_names=itir.BINARY_MATH_NUMBER_BUILTINS) +@_register_builtin_type_synthesizer(fun_names=builtins.BINARY_MATH_NUMBER_BUILTINS) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType: + if isinstance(lhs, ts.DeferredType): + return rhs + if isinstance(rhs, ts.DeferredType): + return lhs assert lhs == rhs return lhs @_register_builtin_type_synthesizer( - fun_names=itir.UNARY_MATH_FP_PREDICATE_BUILTINS | itir.UNARY_LOGICAL_BUILTINS + fun_names=builtins.UNARY_MATH_FP_PREDICATE_BUILTINS | builtins.UNARY_LOGICAL_BUILTINS ) def _(arg: ts.ScalarType) -> ts.ScalarType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) @_register_builtin_type_synthesizer( - fun_names=itir.BINARY_MATH_COMPARISON_BUILTINS | itir.BINARY_LOGICAL_BUILTINS + fun_names=builtins.BINARY_MATH_COMPARISON_BUILTINS | builtins.BINARY_LOGICAL_BUILTINS ) def _(lhs: ts.ScalarType, rhs: ts.ScalarType) -> ts.ScalarType | ts.TupleType: return ts.ScalarType(kind=ts.ScalarKind.BOOL) @@ -137,7 +141,9 @@ def can_deref(it: it_ts.IteratorType | ts.DeferredType) -> ts.ScalarType: @_register_builtin_type_synthesizer -def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType) -> ts.DataType: +def if_( + pred: ts.ScalarType | ts.DeferredType, true_branch: ts.DataType, false_branch: ts.DataType +) -> ts.DataType: if isinstance(true_branch, ts.TupleType) and isinstance(false_branch, ts.TupleType): return tree_map( collection_type=ts.TupleType, @@ -145,7 +151,9 @@ def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType )(functools.partial(if_, pred))(true_branch, false_branch) assert not isinstance(true_branch, ts.TupleType) and not isinstance(false_branch, ts.TupleType) - assert isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL + assert isinstance(pred, ts.DeferredType) or ( + isinstance(pred, ts.ScalarType) and pred.kind == ts.ScalarKind.BOOL + ) # TODO(tehrengruber): Enable this or a similar check. In case the true- and false-branch are # iterators defined on different positions this fails. For the GTFN backend we also don't # want this, but for roundtrip it is totally fine. @@ -155,18 +163,18 @@ def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType @_register_builtin_type_synthesizer -def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType: +def make_const_list(scalar: ts.ScalarType) -> ts.ListType: assert isinstance(scalar, ts.ScalarType) - return it_ts.ListType(element_type=scalar) + return ts.ListType(element_type=scalar) @_register_builtin_type_synthesizer -def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: it_ts.ListType) -> ts.DataType: +def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: ts.ListType) -> ts.DataType: if isinstance(index, it_ts.OffsetLiteralType): assert isinstance(index.value, ts.ScalarType) index = index.value assert isinstance(index, ts.ScalarType) and type_info.is_integral(index) - assert isinstance(list_, it_ts.ListType) + assert isinstance(list_, ts.ListType) return list_.element_type @@ -190,21 +198,29 @@ def make_tuple(*args: ts.DataType) -> ts.TupleType: @_register_builtin_type_synthesizer -def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: +def index(arg: ts.DimensionType) -> ts.FieldType: + return ts.FieldType( + dims=[arg.dim], + dtype=ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())), + ) + + +@_register_builtin_type_synthesizer +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) and isinstance(offset_literal.value, common.Dimension) and offset_literal.value.kind == common.DimensionKind.LOCAL ) assert isinstance(it, it_ts.IteratorType) - return it_ts.ListType(element_type=it.element_type) + return ts.ListType(element_type=it.element_type) @_register_builtin_type_synthesizer def lift(stencil: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def apply_lift( - *its: it_ts.IteratorType, offset_provider: common.OffsetProvider + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType ) -> it_ts.IteratorType: assert all(isinstance(it, it_ts.IteratorType) for it in its) stencil_args = [ @@ -216,7 +232,7 @@ def apply_lift( ) for it in its ] - stencil_return_type = stencil(*stencil_args, offset_provider=offset_provider) + stencil_return_type = stencil(*stencil_args, offset_provider_type=offset_provider_type) assert isinstance(stencil_return_type, ts.DataType) position_dims = its[0].position_dims if its else [] @@ -262,7 +278,7 @@ def _convert_as_fieldop_input_to_iterator( else: defined_dims.append(dim) if is_nb_field: - element_type = it_ts.ListType(element_type=element_type) + element_type = ts.ListType(element_type=element_type) return it_ts.IteratorType( position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type @@ -271,17 +287,43 @@ def _convert_as_fieldop_input_to_iterator( @_register_builtin_type_synthesizer def as_fieldop( - stencil: TypeSynthesizer, domain: it_ts.DomainType, offset_provider: common.OffsetProvider + stencil: TypeSynthesizer, + domain: Optional[it_ts.DomainType] = None, + *, + offset_provider_type: common.OffsetProviderType, ) -> TypeSynthesizer: + # In case we don't have a domain argument to `as_fieldop` we can not infer the exact result + # type. In order to still allow some passes which don't need this information to run before the + # domain inference, we continue with a dummy domain. One example is the CollapseTuple pass + # which only needs information about the structure, e.g. how many tuple elements does this node + # have, but not the dimensions of a field. + # Note that it might appear as if using the TraceShift pass would allow us to deduce the return + # type of `as_fieldop` without a domain, but this is not the case, since we don't have + # information on the ordering of dimensions. In this example + # `as_fieldop(it1, it2 -> deref(it1) + deref(it2))(i_field, j_field)` + # it is unclear if the result has dimension I, J or J, I. + if domain is None: + domain = it_ts.DomainType(dims="unknown") + @TypeSynthesizer - def applied_as_fieldop(*fields) -> ts.FieldType: + def applied_as_fieldop(*fields) -> ts.FieldType | ts.DeferredType: + if any( + isinstance(el, ts.DeferredType) + for f in fields + for el in type_info.primitive_constituents(f) + ): + return ts.DeferredType(constraint=None) + stencil_return = stencil( *(_convert_as_fieldop_input_to_iterator(domain, field) for field in fields), - offset_provider=offset_provider, + offset_provider_type=offset_provider_type, ) assert isinstance(stencil_return, ts.DataType) return type_info.apply_to_primitive_constituents( - lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type), stencil_return + lambda el_type: ts.FieldType(dims=domain.dims, dtype=el_type) + if domain.dims != "unknown" + else ts.DeferredType(constraint=ts.FieldType), + stencil_return, ) return applied_as_fieldop @@ -294,8 +336,10 @@ def scan( assert isinstance(direction, ts.ScalarType) and direction.kind == ts.ScalarKind.BOOL @TypeSynthesizer - def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) -> ts.DataType: - result = scan_pass(init, *its, offset_provider=offset_provider) + def apply_scan( + *its: it_ts.IteratorType, offset_provider_type: common.OffsetProviderType + ) -> ts.DataType: + result = scan_pass(init, *its, offset_provider_type=offset_provider_type) assert isinstance(result, ts.DataType) return result @@ -306,14 +350,14 @@ def apply_scan(*its: it_ts.IteratorType, offset_provider: common.OffsetProvider) def map_(op: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def applied_map( - *args: it_ts.ListType, offset_provider: common.OffsetProvider - ) -> it_ts.ListType: + *args: ts.ListType, offset_provider_type: common.OffsetProviderType + ) -> ts.ListType: assert len(args) > 0 - assert all(isinstance(arg, it_ts.ListType) for arg in args) + assert all(isinstance(arg, ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] - el_type = op(*arg_el_types, offset_provider=offset_provider) + el_type = op(*arg_el_types, offset_provider_type=offset_provider_type) assert isinstance(el_type, ts.DataType) - return it_ts.ListType(element_type=el_type) + return ts.ListType(element_type=el_type) return applied_map @@ -321,15 +365,17 @@ def applied_map( @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @TypeSynthesizer - def applied_reduce(*args: it_ts.ListType, offset_provider: common.OffsetProvider): - assert all(isinstance(arg, it_ts.ListType) for arg in args) - return op(init, *(arg.element_type for arg in args), offset_provider=offset_provider) + def applied_reduce(*args: ts.ListType, offset_provider_type: common.OffsetProviderType): + assert all(isinstance(arg, ts.ListType) for arg in args) + return op( + init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type + ) return applied_reduce @_register_builtin_type_synthesizer -def shift(*offset_literals, offset_provider) -> TypeSynthesizer: +def shift(*offset_literals, offset_provider_type: common.OffsetProviderType) -> TypeSynthesizer: @TypeSynthesizer def apply_shift( it: it_ts.IteratorType | ts.DeferredType, @@ -339,25 +385,30 @@ def apply_shift( assert isinstance(it, it_ts.IteratorType) if it.position_dims == "unknown": # nothing to do here return it - new_position_dims = [*it.position_dims] - assert len(offset_literals) % 2 == 0 - for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): - assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( - offset_axis.value, common.Dimension - ) - provider = offset_provider[offset_axis.value.value] # TODO: naming - if isinstance(provider, common.Dimension): - pass - elif isinstance(provider, common.Connectivity): - found = False - for i, dim in enumerate(new_position_dims): - if dim.value == provider.origin_axis.value: - assert not found - new_position_dims[i] = provider.neighbor_axis - found = True - assert found - else: - raise NotImplementedError() + new_position_dims: list[common.Dimension] | str + if offset_provider_type: + new_position_dims = [*it.position_dims] + assert len(offset_literals) % 2 == 0 + for offset_axis, _ in zip(offset_literals[:-1:2], offset_literals[1::2], strict=True): + assert isinstance(offset_axis, it_ts.OffsetLiteralType) and isinstance( + offset_axis.value, common.Dimension + ) + type_ = offset_provider_type[offset_axis.value.value] + if isinstance(type_, common.Dimension): + pass + elif isinstance(type_, common.NeighborConnectivityType): + found = False + for i, dim in enumerate(new_position_dims): + if dim.value == type_.source_dim.value: + assert not found + new_position_dims[i] = type_.codomain + found = True + assert found + else: + raise NotImplementedError(f"{type_} is not a supported Connectivity type.") + else: + # during re-inference we don't have an offset provider type + new_position_dims = "unknown" return it_ts.IteratorType( position_dims=new_position_dims, defined_dims=it.defined_dims, diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index 802ad2155f..a9b52a49d0 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -6,27 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - from __future__ import annotations import dataclasses import typing from typing import Any, Iterable, Iterator, Optional -import numpy as np from typing_extensions import Self from gt4py.next import common @@ -49,47 +34,19 @@ def from_signature(cls, *args: Any, **kwargs: Any) -> Self: return cls(args=args, kwargs=kwargs) -@dataclasses.dataclass(frozen=True) -class CompileTimeConnectivity(common.Connectivity): - """Compile-time standin for a GTX connectivity, retaining everything except the connectivity tables.""" - - max_neighbors: int - has_skip_values: bool - origin_axis: common.Dimension - neighbor_axis: common.Dimension - index_type: type[int] | type[np.int32] | type[np.int64] - - def mapped_index( - self, cur_index: int | np.integer, neigh_index: int | np.integer - ) -> Optional[int | np.integer]: - raise NotImplementedError( - "A CompileTimeConnectivity instance should not call `mapped_index`." - ) - - @classmethod - def from_connectivity(cls, connectivity: common.Connectivity) -> Self: - return cls( - max_neighbors=connectivity.max_neighbors, - has_skip_values=connectivity.has_skip_values, - origin_axis=connectivity.origin_axis, - neighbor_axis=connectivity.neighbor_axis, - index_type=connectivity.index_type, - ) - - @property - def table(self) -> None: - return None - - @dataclasses.dataclass(frozen=True) class CompileTimeArgs: """Compile-time standins for arguments to a GTX program to be used in ahead-of-time compilation.""" args: tuple[ts.TypeSpec, ...] kwargs: dict[str, ts.TypeSpec] - offset_provider: dict[str, common.Connectivity | common.Dimension] + offset_provider: common.OffsetProvider # TODO(havogt): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information column_axis: Optional[common.Dimension] + @property + def offset_provider_type(self) -> common.OffsetProviderType: + return common.offset_provider_to_type(self.offset_provider) + @classmethod def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: """Convert concrete GTX program arguments into their compile-time counterparts.""" @@ -98,8 +55,7 @@ def from_concrete_no_size(cls, *args: Any, **kwargs: Any) -> Self: offset_provider = kwargs_copy.pop("offset_provider", {}) return cls( args=compile_args, - offset_provider=offset_provider, # TODO(ricoh): replace with the line below once the temporaries pass is AOT-ready. If unsure, just try it and run the tests. - # offset_provider={k: connectivity_or_dimension(v) for k, v in offset_provider.items()}, # noqa: ERA001 [commented-out-code] + offset_provider=offset_provider, column_axis=kwargs_copy.pop("column_axis", None), kwargs={ k: type_translation.from_value(v) for k, v in kwargs_copy.items() if v is not None @@ -138,18 +94,6 @@ def adapted_jit_to_aot_args_factory() -> ( return toolchain.ArgsOnlyAdapter(jit_to_aot_args) -def connectivity_or_dimension( - some_offset_provider: common.Connectivity | common.Dimension, -) -> CompileTimeConnectivity | common.Dimension: - match some_offset_provider: - case common.Dimension(): - return some_offset_provider - case common.Connectivity(): - return CompileTimeConnectivity.from_connectivity(some_offset_provider) - case _: - raise ValueError - - def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: for element in tuple_arg: match element: @@ -164,7 +108,7 @@ def find_first_field(tuple_arg: tuple[Any, ...]) -> Optional[common.Field]: return None -def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]: +def iter_size_args(args: tuple[Any, ...]) -> Iterator[tuple[int, int]]: """ Yield the size of each field argument in each dimension. @@ -178,7 +122,9 @@ def iter_size_args(args: tuple[Any, ...]) -> Iterator[int]: if first_field: yield from iter_size_args((first_field,)) case common.Field(): - yield from arg.ndarray.shape + for range_ in arg.domain.ranges: + assert isinstance(range_, common.UnitRange) + yield (range_.start, range_.stop) case _: pass @@ -198,6 +144,7 @@ def iter_size_compile_args( ) if field_constituents: # we only need the first field, because all fields in a tuple must have the same dims and sizes + index_type = ts.ScalarType(kind=ts.ScalarKind.INT32) yield from [ - ts.ScalarType(kind=ts.ScalarKind.INT32) for dim in field_constituents[0].dims + ts.TupleType(types=[index_type, index_type]) for dim in field_constituents[0].dims ] diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index d112a9c256..17eee4d5c6 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -8,7 +8,7 @@ from typing import Final, Sequence -from gt4py.next.otf import languages +from gt4py.next.otf import cpp_utils, languages from gt4py.next.otf.binding import interface from gt4py.next.type_system import type_info as ti, type_specifications as ts @@ -18,32 +18,12 @@ ) -def render_scalar_type(scalar_type: ts.ScalarType) -> str: - match scalar_type.kind: - case ts.ScalarKind.BOOL: - return "bool" - case ts.ScalarKind.INT32: - return "std::int32_t" - case ts.ScalarKind.INT64: - return "std::int64_t" - case ts.ScalarKind.FLOAT32: - return "float" - case ts.ScalarKind.FLOAT64: - return "double" - case ts.ScalarKind.STRING: - return "std::string" - case _: - raise AssertionError( - f"Scalar kind '{scalar_type}' is not implemented when it should be." - ) - - def render_function_declaration(function: interface.Function, body: str) -> str: template_params: list[str] = [] rendered_params: list[str] = [] for index, param in enumerate(function.parameters): if isinstance(param.type_, ts.ScalarType): - rendered_params.append(f"{render_scalar_type(param.type_)} {param.name}") + rendered_params.append(f"{cpp_utils.pytype_to_cpptype(param.type_)} {param.name}") elif ti.is_type_or_tuple_of_type(param.type_, (ts.FieldType, ts.ScalarType)): template_param = f"ArgT{index}" template_params.append(f"class {template_param}") diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 24913a1365..a2cf480d7f 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -14,7 +14,7 @@ import gt4py.eve as eve from gt4py.eve.codegen import JinjaTemplate as as_jinja, TemplatedGenerator -from gt4py.next.otf import languages, stages, workflow +from gt4py.next.otf import cpp_utils, languages, stages, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.type_system import type_specifications as ts @@ -86,13 +86,15 @@ def _type_string(type_: ts.TypeSpec) -> str: return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>" elif isinstance(type_, ts.FieldType): ndims = len(type_.dims) - dtype = cpp_interface.render_scalar_type(type_.dtype) + # cannot be ListType: the concept is represented as Field with local Dimension in this interface + assert isinstance(type_.dtype, ts.ScalarType) + dtype = cpp_utils.pytype_to_cpptype(type_.dtype) shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>" buffer_t = f"nanobind::ndarray<{dtype}, {shape}>" origin_t = f"std::tuple<{', '.join(['ptrdiff_t'] * ndims)}>" return f"std::pair<{buffer_t}, {origin_t}>" elif isinstance(type_, ts.ScalarType): - return cpp_interface.render_scalar_type(type_) + return cpp_utils.pytype_to_cpptype(type_) else: raise ValueError(f"Type '{type_}' is not supported in nanobind interfaces.") diff --git a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py index 0533adac81..23c80793c7 100644 --- a/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py +++ b/src/gt4py/next/otf/compilation/build_systems/cmake_lists.py @@ -88,9 +88,15 @@ def visit_FindDependency(self, dep: FindDependency) -> str: # Instead, design this to be extensible (refer to ADR-0016). match dep.name: case "nanobind": + import sys + import nanobind - py = "find_package(Python COMPONENTS Interpreter Development REQUIRED)" + py = f""" + set(Python_EXECUTABLE {sys.executable}) + + find_package(Python COMPONENTS Interpreter Development REQUIRED) + """ nb = f"find_package(nanobind CONFIG REQUIRED PATHS {nanobind.cmake_dir()} NO_DEFAULT_PATHS)" return py + "\n" + nb case "gridtools_cpu" | "gridtools_gpu": diff --git a/src/gt4py/next/otf/cpp_utils.py b/src/gt4py/next/otf/cpp_utils.py new file mode 100644 index 0000000000..8b2af40eb5 --- /dev/null +++ b/src/gt4py/next/otf/cpp_utils.py @@ -0,0 +1,32 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.next.type_system import type_specifications as ts + + +def pytype_to_cpptype(t: ts.ScalarType | str) -> str: + if isinstance(t, ts.ScalarType): + t = t.kind.name.lower() + try: + return { + "float32": "float", + "float64": "double", + "int8": "std::int8_t", + "uint8": "std::uint8_t", + "int16": "std::int16_t", + "uint16": "std::uint16_t", + "int32": "std::int32_t", + "uint32": "std::uint32_t", + "int64": "std::int64_t", + "uint64": "std::uint64_t", + "bool": "bool", + "string": "string", + }[t] + except KeyError: + raise TypeError(f"Unsupported type '{t}'.") from None diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 85838d9c76..ff4285d72d 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -11,6 +11,8 @@ import dataclasses from typing import Any, Generic, Optional, Protocol, TypeAlias, TypeVar +from gt4py.eve import utils +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.otf import arguments, languages, toolchain from gt4py.next.otf.binding import interface @@ -26,9 +28,45 @@ SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) -CompilableProgram: TypeAlias = toolchain.CompilableProgram[ - itir.FencilDefinition | itir.Program, arguments.CompileTimeArgs -] +CompilableProgram: TypeAlias = toolchain.CompilableProgram[itir.Program, arguments.CompileTimeArgs] + + +def compilation_hash(otf_closure: CompilableProgram) -> int: + """Given closure compute a hash uniquely determining if we need to recompile.""" + offset_provider = otf_closure.args.offset_provider + return hash( + ( + otf_closure.data, + # As the frontend types contain lists they are not hashable. As a workaround we just + # use content_hash here. + utils.content_hash(tuple(arg for arg in otf_closure.args.args)), + # Directly using the `id` of the offset provider is not possible as the decorator adds + # the implicitly defined ones (i.e. to allow the `TDim + 1` syntax) resulting in a + # different `id` every time. Instead use the `id` of each individual offset provider. + tuple((k, id(v)) for (k, v) in offset_provider.items()) if offset_provider else None, + otf_closure.args.column_axis, + ) + ) + + +def fingerprint_compilable_program(inp: CompilableProgram) -> str: + """ + Generates a unique hash string for a stencil source program representing + the program, sorted offset_provider, and column_axis. + """ + program: itir.Program = inp.data + offset_provider: common.OffsetProvider = inp.args.offset_provider + column_axis: Optional[common.Dimension] = inp.args.column_axis + + program_hash = utils.content_hash( + ( + program, + sorted(offset_provider.items(), key=lambda el: el[0]), + column_axis, + ) + ) + + return program_hash @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index a63801c97e..ef3a4083b9 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -12,6 +12,7 @@ import dataclasses import functools import typing +from collections.abc import MutableMapping from typing import Any, Callable, Generic, Protocol, TypeVar from typing_extensions import Self @@ -253,16 +254,15 @@ class CachedStep( step: Workflow[StartT, EndT] hash_function: Callable[[StartT], HashT] = dataclasses.field(default=hash) # type: ignore[assignment] - - _cache: dict[HashT, EndT] = dataclasses.field(repr=False, init=False, default_factory=dict) + cache: MutableMapping[HashT, EndT] = dataclasses.field(repr=False, default_factory=dict) def __call__(self, inp: StartT) -> EndT: """Run the step only if the input is not cached, else return from cache.""" hash_ = self.hash_function(inp) try: - result = self._cache[hash_] + result = self.cache[hash_] except KeyError: - result = self._cache[hash_] = self.step(inp) + result = self.cache[hash_] = self.step(inp) return result diff --git a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py index 92dbcedeaa..969e203689 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/codegen.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/codegen.py @@ -11,8 +11,8 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import common +from gt4py.next.otf import cpp_utils from gt4py.next.program_processors.codegens.gtfn import gtfn_im_ir, gtfn_ir, gtfn_ir_common -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import pytype_to_cpptype class GTFNCodegen(codegen.TemplatedGenerator): @@ -23,6 +23,7 @@ class GTFNCodegen(codegen.TemplatedGenerator): _builtins_mapping: Final = { "abs": "std::abs", + "neg": "std::negate<>{}", "sin": "std::sin", "cos": "std::cos", "tan": "std::tan", @@ -52,24 +53,30 @@ class GTFNCodegen(codegen.TemplatedGenerator): "power": "std::pow", "float32": "float", "float64": "double", + "int8": "std::int8_t", + "uint8": "std::uint8_t", + "int16": "std::int16_t", + "uint16": "std::uint16_t", "int32": "std::int32_t", + "uint32": "std::uint32_t", "int64": "std::int64_t", + "uint64": "std::uint64_t", "bool": "bool", - "plus": "std::plus{}", - "minus": "std::minus{}", - "multiplies": "std::multiplies{}", - "divides": "std::divides{}", - "eq": "std::equal_to{}", - "not_eq": "std::not_equal_to{}", - "less": "std::less{}", - "less_equal": "std::less_equal{}", - "greater": "std::greater{}", - "greater_equal": "std::greater_equal{}", - "and_": "std::logical_and{}", - "or_": "std::logical_or{}", - "xor_": "std::bit_xor{}", - "mod": "std::modulus{}", - "not_": "std::logical_not{}", + "plus": "std::plus<>{}", + "minus": "std::minus<>{}", + "multiplies": "std::multiplies<>{}", + "divides": "std::divides<>{}", + "eq": "std::equal_to<>{}", + "not_eq": "std::not_equal_to<>{}", + "less": "std::less<>{}", + "less_equal": "std::less_equal<>{}", + "greater": "std::greater<>{}", + "greater_equal": "std::greater_equal<>{}", + "and_": "std::logical_and<>{}", + "or_": "std::logical_or<>{}", + "xor_": "std::bit_xor<>{}", + "mod": "std::modulus<>{}", + "not_": "std::logical_not<>{}", } Sym = as_fmt("{id}") @@ -92,8 +99,11 @@ def asfloat(value: str) -> str: return value def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: + if node.type == "axis_literal": + return node.value + # TODO(tehrengruber): isn't this wrong and int32 should be casted to an actual int32? - match pytype_to_cpptype(node.type): + match cpp_utils.pytype_to_cpptype(node.type): case "float": return self.asfloat(node.value) + "f" case "double": @@ -101,6 +111,7 @@ def visit_Literal(self, node: gtfn_ir.Literal, **kwargs: Any) -> str: case "bool": return node.value.lower() case _: + # TODO(tehrengruber): we should probably shouldn't just allow anything here. Revisit. return node.value IntegralConstant = as_fmt("{value}_c") @@ -260,7 +271,7 @@ def visit_Program(self, node: gtfn_ir.Program, **kwargs: Any) -> Union[str, Coll #include #include #include - + namespace generated{ namespace gtfn = ::gridtools::fn; diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py index 1995e4de0b..831694791a 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir.py @@ -8,12 +8,12 @@ from __future__ import annotations -from typing import ClassVar, Optional, Union +from typing import Callable, ClassVar, Optional, Union from gt4py.eve import Coerced, SymbolName, datamodels from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.next import common -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ImperativeFunctionDefinition from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Node, Sym, SymRef @@ -96,25 +96,23 @@ class Backend(Node): domain: Union[SymRef, CartesianDomain, UnstructuredDomain] -def _is_ref_literal_or_tuple_expr_of_ref(expr: Expr) -> bool: +def _is_tuple_expr_of(pred: Callable[[Expr], bool], expr: Expr) -> bool: if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "tuple_get" and len(expr.args) == 2 - and _is_ref_literal_or_tuple_expr_of_ref(expr.args[1]) + and _is_tuple_expr_of(pred, expr.args[1]) ): return True if ( isinstance(expr, FunCall) and isinstance(expr.fun, SymRef) and expr.fun.id == "make_tuple" - and all(_is_ref_literal_or_tuple_expr_of_ref(arg) for arg in expr.args) + and all(_is_tuple_expr_of(pred, arg) for arg in expr.args) ): return True - if isinstance(expr, (SymRef, Literal)): - return True - return False + return pred(expr) class SidComposite(Expr): @@ -126,14 +124,32 @@ def _values_validator( ) -> None: if not all( isinstance(el, (SidFromScalar, SidComposite)) - or _is_ref_literal_or_tuple_expr_of_ref(el) + or _is_tuple_expr_of(lambda expr: isinstance(expr, (SymRef, Literal)), el) for el in value ): raise ValueError( - "Only 'SymRef', tuple expr of 'SymRef', 'SidFromScalar', or 'SidComposite' allowed." + "Only 'SymRef', 'Literal', tuple expr thereof, 'SidFromScalar', or 'SidComposite' allowed." ) +def _might_be_scalar_expr(expr: Expr) -> bool: + if isinstance(expr, BinaryExpr): + return all(_is_tuple_expr_of(_might_be_scalar_expr, arg) for arg in (expr.lhs, expr.rhs)) + if isinstance(expr, UnaryExpr): + return _is_tuple_expr_of(_might_be_scalar_expr, expr.expr) + if ( + isinstance(expr, FunCall) + and isinstance(expr.fun, SymRef) + and expr.fun.id in ARITHMETIC_BUILTINS + ): + return all(_might_be_scalar_expr(arg) for arg in expr.args) + if isinstance(expr, CastExpr): + return _might_be_scalar_expr(expr.obj_expr) + if _is_tuple_expr_of(lambda e: isinstance(e, (SymRef, Literal)), expr): + return True + return False + + class SidFromScalar(Expr): arg: Expr @@ -141,8 +157,10 @@ class SidFromScalar(Expr): def _arg_validator( self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: Expr ) -> None: - if not _is_ref_literal_or_tuple_expr_of_ref(value): - raise ValueError("Only 'SymRef' or tuple expr of 'SymRef' allowed.") + if not _might_be_scalar_expr(value): + raise ValueError( + "Only 'SymRef', 'Literal', arithmetic op or tuple expr thereof allowed." + ) class Stmt(Node): @@ -153,7 +171,25 @@ class StencilExecution(Stmt): backend: Backend stencil: SymRef output: Union[SymRef, SidComposite] - inputs: list[Union[SymRef, SidComposite, SidFromScalar]] + inputs: list[Union[SymRef, SidComposite, SidFromScalar, FunCall]] + + @datamodels.validator("inputs") + def _arg_validator( + self: datamodels.DataModelTP, attribute: datamodels.Attribute, inputs: list[Expr] + ) -> None: + for inp in inputs: + if not _is_tuple_expr_of( + lambda expr: isinstance(expr, (SymRef, SidComposite, SidFromScalar)) + or ( + isinstance(expr, FunCall) + and isinstance(expr.fun, SymRef) + and expr.fun.id == "index" + ), + inp, + ): + raise ValueError( + "Only 'SymRef', 'SidComposite', 'SidFromScalar', 'index' call or tuple expr thereof allowed." + ) class Scan(Node): @@ -192,9 +228,10 @@ class TemporaryAllocation(Node): "unstructured_domain", "named_range", "reduce", + "index", ] -ARITHMETIC_BUILTINS = itir.ARITHMETIC_BUILTINS -TYPEBUILTINS = itir.TYPEBUILTINS +ARITHMETIC_BUILTINS = builtins.ARITHMETIC_BUILTINS +TYPEBUILTINS = builtins.TYPE_BUILTINS BUILTINS = {*GTFN_BUILTINS, *ARITHMETIC_BUILTINS, *TYPEBUILTINS} diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py index cc57c137bf..b2aea05641 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_to_gtfn_im_ir.py @@ -12,7 +12,6 @@ import gt4py.eve as eve from gt4py.eve import NodeTranslator, concepts from gt4py.eve.utils import UIDGenerator -from gt4py.next import common from gt4py.next.program_processors.codegens.gtfn import gtfn_ir, gtfn_ir_common from gt4py.next.program_processors.codegens.gtfn.gtfn_im_ir import ( AssignStmt, @@ -84,54 +83,9 @@ def _is_reduce(node: gtfn_ir.FunCall) -> TypeGuard[gtfn_ir.FunCall]: ) -def _get_connectivity( - applied_reduce_node: gtfn_ir.FunCall, - offset_provider: dict[str, common.Dimension | common.Connectivity], -) -> common.Connectivity: - """Return single connectivity that is compatible with the arguments of the reduce.""" - if not _is_reduce(applied_reduce_node): - raise ValueError("Expected a call to a 'reduce' object, i.e. 'reduce(...)(...)'.") - - connectivities: list[common.Connectivity] = [] - for o in _get_partial_offset_tags(applied_reduce_node.args): - conn = offset_provider[o] - assert isinstance(conn, common.Connectivity) - connectivities.append(conn) - - if not connectivities: - raise RuntimeError("Couldn't detect partial shift in any arguments of 'reduce'.") - - if len({(c.max_neighbors, c.has_skip_values) for c in connectivities}) != 1: - # The condition for this check is required but not sufficient: the actual neighbor tables could still be incompatible. - raise RuntimeError("Arguments to 'reduce' have incompatible partial shifts.") - return connectivities[0] - - # TODO: end of code clone -def _make_dense_acess( - shift_call: gtfn_ir.FunCall, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="deref"), - args=[ - gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="shift"), args=[*shift_call.args, nbh_iter] - ) - ], - ) - - -def _make_sparse_acess( - field_ref: gtfn_ir_common.SymRef, nbh_iter: gtfn_ir_common.SymRef -) -> gtfn_ir.FunCall: - return gtfn_ir.FunCall( - fun=gtfn_ir_common.SymRef(id="tuple_get"), - args=[nbh_iter, gtfn_ir.FunCall(fun=gtfn_ir_common.SymRef(id="deref"), args=[field_ref])], - ) - - class PlugInCurrentIdx(NodeTranslator): def visit_SymRef( self, node: gtfn_ir_common.SymRef @@ -225,32 +179,6 @@ def _expand_symref( ) self.imp_list_ir.append(AssignStmt(lhs=gtfn_ir_common.SymRef(id=red_idx), rhs=rhs)) - def handle_Reduction(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.SymRef: - offset_provider = kwargs["offset_provider"] - assert offset_provider is not None - - connectivity = _get_connectivity(node, offset_provider) - - args = node.args - # do the following transformations to the node arguments - # dense fields: shift(dense_f, X2Y) -> deref(shift(dense_f, X2Y, nbh_iterator) - # sparse_fields: sparse_f -> tuple_get(nbh_iterator, deref(sparse_f))) - new_args = [] - nbh_iter = gtfn_ir_common.SymRef(id="nbh_iter") - for arg in args: - if isinstance(arg, gtfn_ir.FunCall) and arg.fun.id == "shift": # type: ignore - new_args.append(_make_dense_acess(arg, nbh_iter)) - if isinstance(arg, gtfn_ir_common.SymRef): - new_args.append(_make_sparse_acess(arg, nbh_iter)) - - red_idx = self.uids.sequential_id(prefix="red") - if isinstance(node.fun.args[0], gtfn_ir.Lambda): # type: ignore - self._expand_lambda(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - elif isinstance(node.fun.args[0], gtfn_ir_common.SymRef): # type: ignore - self._expand_symref(node, new_args, red_idx, connectivity.max_neighbors, **kwargs) - - return gtfn_ir_common.SymRef(id=red_idx) - def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common.Expr: if any(isinstance(arg, gtfn_ir.Lambda) for arg in node.args): # do not try to lower constructs that take lambdas as argument to something more readable @@ -278,7 +206,9 @@ def visit_FunCall(self, node: gtfn_ir.FunCall, **kwargs: Any) -> gtfn_ir_common. self.imp_list_ir.append(InitStmt(lhs=gtfn_ir_common.Sym(id=f"{lam_idx}"), rhs=expr)) return gtfn_ir_common.SymRef(id=f"{lam_idx}") if _is_reduce(node): - return self.handle_Reduction(node, **kwargs) + raise AssertionError( + "Not implemented. The code-path was removed as it was not actively used and tested." + ) if isinstance(node.fun, gtfn_ir_common.SymRef) and node.fun.id == "make_tuple": tupl_id = self.uids.sequential_id(prefix="tupl") tuple_fun = self.commit_args(node, tupl_id, "make_tuple", **kwargs) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index d729a5ba2f..48f15acffb 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Callable, Final, Optional +from typing import Any, Final, Optional import factory import numpy as np @@ -18,10 +18,9 @@ from gt4py._core import definitions as core_defs from gt4py.eve import codegen from gt4py.next import common -from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode, fencil_to_program, pass_manager +from gt4py.next.iterator.transforms import pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen @@ -52,12 +51,8 @@ class GTFNTranslationStep( # TODO replace by more general mechanism, see https://github.com/GridTools/gt4py/issues/1135 enable_itir_transforms: bool = True use_imperative_backend: bool = False - lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -82,9 +77,9 @@ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSetting def _process_regular_arguments( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, arg_types: tuple[ts.TypeSpec, ...], - offset_provider: dict[str, Connectivity | Dimension], + offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] @@ -106,22 +101,22 @@ def _process_regular_arguments( ): # translate sparse dimensions to tuple dtype dim_name = dim.value - connectivity = offset_provider[dim_name] - assert isinstance(connectivity, Connectivity) + connectivity = offset_provider_type[dim_name] + assert isinstance(connectivity, common.NeighborConnectivityType) size = connectivity.max_neighbors arg = f"gridtools::sid::dimension_to_tuple_like({arg})" arg_exprs.append(arg) return parameters, arg_exprs def _process_connectivity_args( - self, offset_provider: dict[str, Connectivity | Dimension] + self, offset_provider_type: common.OffsetProviderType ) -> tuple[list[interface.Parameter], list[str]]: parameters: list[interface.Parameter] = [] arg_exprs: list[str] = [] - for name, connectivity in offset_provider.items(): - if isinstance(connectivity, Connectivity): - if connectivity.index_type not in [np.int32, np.int64]: + for name, connectivity_type in offset_provider_type.items(): + if isinstance(connectivity_type, common.NeighborConnectivityType): + if connectivity_type.dtype.scalar_type not in [np.int32, np.int64]: raise ValueError( "Neighbor table indices must be of type 'np.int32' or 'np.int64'." ) @@ -131,10 +126,8 @@ def _process_connectivity_args( interface.Parameter( name=GENERATED_CONNECTIVITY_PARAM_PREFIX + name.lower(), type_=ts.FieldType( - dims=[connectivity.origin_axis, Dimension(name)], - dtype=ts.ScalarType( - type_translation.get_scalar_kind(connectivity.index_type) - ), + dims=list(connectivity_type.domain), + dtype=type_translation.from_dtype(connectivity_type.dtype), ), ) ) @@ -142,41 +135,36 @@ def _process_connectivity_args( # connectivity argument expression nbtbl = ( f"gridtools::fn::sid_neighbor_table::as_neighbor_table<" - f"generated::{connectivity.origin_axis.value}_t, " - f"generated::{name}_t, {connectivity.max_neighbors}" + f"generated::{connectivity_type.domain[0].value}_t, " + f"generated::{connectivity_type.domain[1].value}_t, " + f"{connectivity_type.max_neighbors}" f">(std::forward({GENERATED_CONNECTIVITY_PARAM_PREFIX}{name.lower()}))" ) arg_exprs.append( f"gridtools::hymap::keys::make_values({nbtbl})" ) - elif isinstance(connectivity, Dimension): + elif isinstance(connectivity_type, common.Dimension): pass else: raise AssertionError( - f"Expected offset provider '{name}' to be a 'Connectivity' or 'Dimension', " - f"got '{type(connectivity).__name__}'." + f"Expected offset provider type '{name}' to be a 'NeighborConnectivityType' or 'Dimension', " + f"got '{type(connectivity_type).__name__}'." ) return parameters, arg_exprs def _preprocess_program( self, - program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, Connectivity | Dimension], + program: itir.Program, + offset_provider: common.OffsetProvider, ) -> itir.Program: - if isinstance(program, itir.FencilDefinition) and not self.enable_itir_transforms: - return fencil_to_program.FencilToProgram().apply( - program - ) # FIXME[#1582](tehrengruber): should be removed after refactoring to combined IR - apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, - lift_mode=self.lift_mode, + extract_temporaries=True, offset_provider=offset_provider, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, ) new_program = apply_common_transforms( @@ -195,13 +183,20 @@ def _preprocess_program( def generate_stencil_source( self, - program: itir.FencilDefinition | itir.Program, - offset_provider: dict[str, Connectivity | Dimension], + program: itir.Program, + offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: - new_program = self._preprocess_program(program, offset_provider) + if self.enable_itir_transforms: + new_program = self._preprocess_program(program, offset_provider) + else: + assert isinstance(program, itir.Program) + new_program = program + gtfn_ir = GTFN_lowering.apply( - new_program, offset_provider=offset_provider, column_axis=column_axis + new_program, + offset_provider_type=common.offset_provider_to_type(offset_provider), + column_axis=column_axis, ) if self.use_imperative_backend: @@ -209,24 +204,25 @@ def generate_stencil_source( generated_code = GTFNIMCodegen.apply(gtfn_im_ir) else: generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) regular_parameters, regular_args_expr = self._process_regular_arguments( - program, inp.args.args, inp.args.offset_provider + program, inp.args.args, inp.args.offset_provider_type ) # handle connectivity parameters and arguments (i.e. what the user provided in the offset # provider) connectivity_parameters, connectivity_args_expr = self._process_connectivity_args( - inp.args.offset_provider + inp.args.offset_provider_type ) # combine into a format that is aligned with what the backend expects diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 3bd96d14d7..104e2eccc1 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -15,8 +15,9 @@ from gt4py.eve.concepts import SymbolName from gt4py.next import common from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.type_system import inference as itir_type_inference +from gt4py.next.otf import cpp_utils from gt4py.next.program_processors.codegens.gtfn.gtfn_ir import ( Backend, BinaryExpr, @@ -47,26 +48,31 @@ from gt4py.next.type_system import type_info, type_specifications as ts -def pytype_to_cpptype(t: ts.ScalarType | str) -> Optional[str]: - if isinstance(t, ts.ScalarType): - t = t.kind.name.lower() - try: - return { - "float32": "float", - "float64": "double", - "int32": "std::int32_t", - "int64": "std::int64_t", - "bool": "bool", - "axis_literal": None, # TODO: domain? - }[t] - except KeyError: - raise TypeError(f"Unsupported type '{t}'.") from None - - _vertical_dimension = "gtfn::unstructured::dim::vertical" _horizontal_dimension = "gtfn::unstructured::dim::horizontal" +def _is_tuple_of_ref_or_literal(expr: itir.Expr) -> bool: + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "tuple_get" + and len(expr.args) == 2 + and _is_tuple_of_ref_or_literal(expr.args[1]) + ): + return True + if ( + isinstance(expr, itir.FunCall) + and isinstance(expr.fun, itir.SymRef) + and expr.fun.id == "make_tuple" + and all(_is_tuple_of_ref_or_literal(arg) for arg in expr.args) + ): + return True + if isinstance(expr, (itir.SymRef, itir.Literal)): + return True + return False + + def _get_domains(nodes: Iterable[itir.Stmt]) -> Iterable[itir.FunCall]: result = set() for node in nodes: @@ -87,7 +93,7 @@ def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( - f"Found 'StencilClosures' with more than one 'GridType': '{grid_types}'. This is currently not supported." + f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." ) return grid_types.pop() @@ -138,7 +144,7 @@ def _collect_dimensions_from_domain( def _collect_offset_definitions( node: itir.Node, grid_type: common.GridType, - offset_provider: dict[str, common.Dimension | common.Connectivity], + offset_provider_type: common.OffsetProviderType, ) -> dict[str, TagDefinition]: used_offset_tags: set[itir.OffsetLiteral] = ( node.walk_values() @@ -146,13 +152,13 @@ def _collect_offset_definitions( .filter(lambda offset_literal: isinstance(offset_literal.value, str)) .getattr("value") ).to_set() - if not used_offset_tags.issubset(set(offset_provider.keys())): + if not used_offset_tags.issubset(set(offset_provider_type.keys())): raise AssertionError("ITIR contains an offset tag without a corresponding offset provider.") offset_definitions = {} - for offset_name, dim_or_connectivity in offset_provider.items(): - if isinstance(dim_or_connectivity, common.Dimension): - dim: common.Dimension = dim_or_connectivity + for offset_name, dim_or_connectivity_type in offset_provider_type.items(): + if isinstance(dim_or_connectivity_type, common.Dimension): + dim: common.Dimension = dim_or_connectivity_type if grid_type == common.GridType.CARTESIAN: # create alias from offset to dimension offset_definitions[dim.value] = TagDefinition(name=Sym(id=dim.value)) @@ -177,15 +183,23 @@ def _collect_offset_definitions( "Mapping an offset to a horizontal dimension in unstructured is not allowed." ) # create alias from vertical offset to vertical dimension + offset_definitions[dim.value] = TagDefinition( + name=Sym(id=dim.value), alias=_vertical_dimension + ) offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) - elif isinstance(dim_or_connectivity, common.Connectivity): + elif isinstance( + connectivity_type := dim_or_connectivity_type, common.NeighborConnectivityType + ): assert grid_type == common.GridType.UNSTRUCTURED offset_definitions[offset_name] = TagDefinition(name=Sym(id=offset_name)) + if offset_name != connectivity_type.neighbor_dim.value: + offset_definitions[connectivity_type.neighbor_dim.value] = TagDefinition( + name=Sym(id=connectivity_type.neighbor_dim.value) + ) - connectivity: common.Connectivity = dim_or_connectivity - for dim in [connectivity.origin_axis, connectivity.neighbor_axis]: + for dim in [connectivity_type.source_dim, connectivity_type.codomain]: if dim.kind != common.DimensionKind.HORIZONTAL: raise NotImplementedError() offset_definitions[dim.value] = TagDefinition( @@ -302,7 +316,7 @@ class GTFN_lowering(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): } _unary_op_map: ClassVar[dict[str, str]] = {"not_": "!"} - offset_provider: dict + offset_provider_type: common.OffsetProviderType column_axis: Optional[common.Dimension] grid_type: common.GridType @@ -317,18 +331,18 @@ def apply( cls, node: itir.Program, *, - offset_provider: dict, + offset_provider_type: common.OffsetProviderType, column_axis: Optional[common.Dimension], ) -> Program: if not isinstance(node, itir.Program): raise TypeError(f"Expected a 'Program', got '{type(node).__name__}'.") - node = itir_type_inference.infer(node, offset_provider=offset_provider) + node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) grid_type = _get_gridtype(node.body) if grid_type == common.GridType.UNSTRUCTURED: node = _CannonicalizeUnstructuredDomain.apply(node) return cls( - offset_provider=offset_provider, column_axis=column_axis, grid_type=grid_type + offset_provider_type=offset_provider_type, column_axis=column_axis, grid_type=grid_type ).visit(node) def visit_Sym(self, node: itir.Sym, **kwargs: Any) -> Sym: @@ -463,8 +477,8 @@ def _visit_unstructured_domain(self, node: itir.FunCall, **kwargs: Any) -> Node: if "stencil" in kwargs: shift_offsets = self._collect_offset_or_axis_node(itir.OffsetLiteral, kwargs["stencil"]) for o in shift_offsets: - if o in self.offset_provider and isinstance( - self.offset_provider[o], common.Connectivity + if o in self.offset_provider_type and isinstance( + self.offset_provider_type[o], common.NeighborConnectivityType ): connectivities.append(SymRef(id=o)) return UnstructuredDomain( @@ -587,6 +601,9 @@ def visit_IfStmt(self, node: itir.IfStmt, **kwargs: Any) -> IfStmt: def visit_SetAt( self, node: itir.SetAt, *, extracted_functions: list, **kwargs: Any ) -> Union[StencilExecution, ScanExecution]: + if _is_tuple_of_ref_or_literal(node.expr): + node.expr = im.as_fieldop("deref", node.domain)(node.expr) + assert cpm.is_applied_as_fieldop(node.expr) stencil = node.expr.fun.args[0] # type: ignore[attr-defined] # checked in assert domain = node.domain @@ -611,7 +628,6 @@ def convert_el_to_sid(el_expr: Expr, el_type: ts.ScalarType | ts.FieldType) -> E tuple_constructor=lambda *elements: SidComposite(values=list(elements)), ) - assert isinstance(lowered_input_as_sid, (SidComposite, SidFromScalar, SymRef)) lowered_inputs.append(lowered_input_as_sid) backend = Backend(domain=self.visit(domain, stencil=stencil, **kwargs)) @@ -656,7 +672,7 @@ def visit_Program(self, node: itir.Program, **kwargs: Any) -> Program: function_definitions = self.visit(node.function_definitions) + extracted_functions offset_definitions = { **_collect_dimensions_from_domain(node.body), - **_collect_offset_definitions(node, self.grid_type, self.offset_provider), + **_collect_offset_definitions(node, self.grid_type, self.offset_provider_type), } return Program( id=SymbolName(node.id), @@ -674,9 +690,9 @@ def visit_Temporary( def dtype_to_cpp(x: ts.DataType) -> str: if isinstance(x, ts.TupleType): assert all(isinstance(i, ts.ScalarType) for i in x.types) - return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" + return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" # type: ignore[arg-type] # ensured by assert assert isinstance(x, ts.ScalarType) - res = pytype_to_cpptype(x) + res = cpp_utils.pytype_to_cpptype(x) assert isinstance(res, str) return res diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index db1242e2a4..5f32eaa2bb 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. gtfn_translation = gtfn.GTFNBackendFactory().executor.translation assert isinstance(gtfn_translation, GTFNTranslationStep) diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py deleted file mode 100644 index c477795c34..0000000000 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ /dev/null @@ -1,69 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Any - -from gt4py.eve.codegen import FormatTemplate as as_fmt, TemplatedGenerator -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import apply_common_transforms -from gt4py.next.program_processors import program_formatter - - -class ToLispLike(TemplatedGenerator): - Sym = as_fmt("{id}") - FunCall = as_fmt("({fun} {' '.join(args)})") - Literal = as_fmt("{value}") - OffsetLiteral = as_fmt("{value}") - SymRef = as_fmt("{id}") - StencilClosure = as_fmt( - """( - :domain {domain} - :stencil {stencil} - :output {output} - :inputs {' '.join(inputs)} - ) - """ - ) - FencilDefinition = as_fmt( - """ - ({' '.join(function_definitions)}) - (defen {id}({' '.join(params)}) - {''.join(closures)}) - """ - ) - FunctionDefinition = as_fmt( - """(defun {id}({' '.join(params)}) - {expr} - ) - -""" - ) - Lambda = as_fmt( - """(lambda ({' '.join(params)}) - {expr} - )""" - ) - - @classmethod - def apply(cls, root: itir.Node, **kwargs: Any) -> str: # type: ignore[override] - transformed = apply_common_transforms( - root, lift_mode=kwargs.get("lift_mode"), offset_provider=kwargs["offset_provider"] - ) - generated_code = super().apply(transformed, **kwargs) - try: - from yasi import indent_code - - indented = indent_code(generated_code, "--dialect lisp") - return "".join(indented["indented_code"]) - except ImportError: - return generated_code - - -@program_formatter.program_formatter -def format_lisp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return ToLispLike.apply(program, **kwargs) diff --git a/src/gt4py/next/program_processors/formatters/pretty_print.py b/src/gt4py/next/program_processors/formatters/pretty_print.py index f14ac5653f..cbf9fd1978 100644 --- a/src/gt4py/next/program_processors/formatters/pretty_print.py +++ b/src/gt4py/next/program_processors/formatters/pretty_print.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_itir_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_itir_and_check(program: itir.Program, *args: Any, **kwargs: Any) -> str: pretty = pretty_printer.pformat(program) parsed = pretty_parser.pparse(pretty) assert parsed == program diff --git a/src/gt4py/next/program_processors/program_formatter.py b/src/gt4py/next/program_processors/program_formatter.py index f77e7f32ee..321c09668c 100644 --- a/src/gt4py/next/program_processors/program_formatter.py +++ b/src/gt4py/next/program_processors/program_formatter.py @@ -10,7 +10,7 @@ Interface for program processors. Program processors are functions which operate on a program paired with the input -arguments for the program. Programs are represented by an ``iterator.ir.itir.FencilDefinition`` +arguments for the program. Programs are represented by an ``iterator.ir.Program`` node. Program processors that execute the program with the given arguments (possibly by generating code along the way) are program executors. Those that generate any kind of string based on the program and (optionally) input values are program formatters. @@ -30,14 +30,14 @@ class ProgramFormatter: @abc.abstractmethod - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: ... + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: ... @dataclasses.dataclass(frozen=True) class WrappedProgramFormatter(ProgramFormatter): formatter: Callable[..., str] - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: return self.formatter(program, *args, **kwargs) @@ -47,7 +47,7 @@ def program_formatter(func: Callable[..., str]) -> ProgramFormatter: Examples: >>> @program_formatter - ... def format_foo(fencil: itir.FencilDefinition, *args, **kwargs) -> str: + ... def format_foo(fencil: itir.Program, *args, **kwargs) -> str: ... '''A very useless fencil formatter.''' ... return "foo" diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py deleted file mode 100644 index 2db8e98804..0000000000 --- a/src/gt4py/next/program_processors/runners/dace.py +++ /dev/null @@ -1,56 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import factory - -from gt4py.next import allocators as next_allocators, backend -from gt4py.next.ffront import foast_to_gtir, past_to_itir -from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow -from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow -from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory - - -class DaCeIteratorBackendFactory(GTFNBackendFactory): - class Params: - otf_workflow = factory.SubFactory( - dace_iterator_workflow.DaCeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), - ) - auto_optimize = factory.Trait( - otf_workflow__translation__auto_optimize=True, name_temps="_opt" - ) - use_field_canonical_representation: bool = False - - name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" - ) - - transforms = backend.DEFAULT_TRANSFORMS - - -run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) -run_dace_cpu_noopt = DaCeIteratorBackendFactory(cached=True, auto_optimize=False) - -run_dace_gpu = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=True) -run_dace_gpu_noopt = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=False) - -itir_cpu = run_dace_cpu -itir_gpu = run_dace_gpu - -gtir_cpu = backend.Backend( - name="dace.gtir.cpu", - executor=dace_fieldview_workflow.DaCeWorkflowFactory(), - allocator=next_allocators.StandardCPUFieldBufferAllocator(), - transforms=backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), - foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), - ), -) diff --git a/src/gt4py/next/program_processors/runners/dace/__init__.py b/src/gt4py/next/program_processors/runners/dace/__init__.py new file mode 100644 index 0000000000..8540585494 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/__init__.py @@ -0,0 +1,27 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +from gt4py.next.program_processors.runners.dace.gtir_sdfg import build_sdfg_from_gtir +from gt4py.next.program_processors.runners.dace.sdfg_callable import get_sdfg_args +from gt4py.next.program_processors.runners.dace.workflow.backend import ( + run_dace_cpu, + run_dace_cpu_noopt, + run_dace_gpu, + run_dace_gpu_noopt, +) + + +__all__ = [ + "build_sdfg_from_gtir", + "get_sdfg_args", + "run_dace_cpu", + "run_dace_cpu_noopt", + "run_dace_gpu", + "run_dace_gpu_noopt", +] diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py new file mode 100644 index 0000000000..6b2a32c063 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/gtir_builtin_translators.py @@ -0,0 +1,954 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import abc +import dataclasses +from typing import TYPE_CHECKING, Any, Final, Iterable, Optional, Protocol, Sequence, TypeAlias + +import dace +from dace import subsets as dace_subsets + +from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.ffront import fbuiltins as gtx_fbuiltins +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) +from gt4py.next.program_processors.runners.dace import ( + gtir_dataflow, + gtir_python_codegen, + gtir_sdfg, + gtir_sdfg_utils, + utils as gtx_dace_utils, +) +from gt4py.next.program_processors.runners.dace.gtir_scan_translator import translate_scan +from gt4py.next.type_system import type_info as ti, type_specifications as ts + + +if TYPE_CHECKING: + from gt4py.next.program_processors.runners.dace import gtir_sdfg + + +def get_domain_indices( + dims: Sequence[gtx_common.Dimension], origin: Optional[Sequence[dace.symbolic.SymExpr]] +) -> dace_subsets.Indices: + """ + Helper function to construct the list of indices for a field domain, applying + an optional origin in each dimension as start index. + + Args: + dims: The field dimensions. + origin: The domain start index in each dimension. If set to `None`, assume all zeros. + + Returns: + A list of indices for field access in dace arrays. As this list is returned + as `dace.subsets.Indices`, it should be converted to `dace.subsets.Range` before + being used in memlet subset because ranges are better supported throughout DaCe. + """ + index_variables = [ + dace.symbolic.pystr_to_symbolic(gtir_sdfg_utils.get_map_variable(dim)) for dim in dims + ] + origin = [0] * len(index_variables) if origin is None else origin + return dace_subsets.Indices( + [index - start_index for index, start_index in zip(index_variables, origin, strict=True)] + ) + + +@dataclasses.dataclass(frozen=True) +class FieldopData: + """ + Abstraction to represent data (scalars, arrays) during the lowering to SDFG. + + Attribute 'local_offset' must always be set for `FieldType` data with a local + dimension generated from neighbors access in unstructured domain, and indicates + the name of the offset provider used to generate the list of neighbor values. + + Args: + dc_node: DaCe access node to the data storage. + gt_type: GT4Py type definition, which includes the field domain information. + origin: Tuple of start indices, in each dimension, for `FieldType` data. + Pass an empty tuple for `ScalarType` data or zero-dimensional fields. + """ + + dc_node: dace.nodes.AccessNode + gt_type: ts.FieldType | ts.ScalarType + origin: tuple[dace.symbolic.SymbolicType, ...] + + def __post_init__(self) -> None: + """Implements a sanity check on the constructed data type.""" + assert ( + len(self.origin) == 0 + if isinstance(self.gt_type, ts.ScalarType) + else len(self.origin) == len(self.gt_type.dims) + ) + + def map_to_parent_sdfg( + self, + sdfg_builder: gtir_sdfg.SDFGBuilder, + inner_sdfg: dace.SDFG, + outer_sdfg: dace.SDFG, + outer_sdfg_state: dace.SDFGState, + symbol_mapping: dict[str, dace.symbolic.SymbolicType], + ) -> FieldopData: + """ + Make the data descriptor which 'self' refers to, and which is located inside + a NestedSDFG, available in its parent SDFG. + + Thus, it turns 'self' into a non-transient array and creates a new data + descriptor inside the parent SDFG, with same shape and strides. + """ + inner_desc = self.dc_node.desc(inner_sdfg) + assert inner_desc.transient + inner_desc.transient = False + + if isinstance(self.gt_type, ts.ScalarType): + outer, outer_desc = sdfg_builder.add_temp_scalar(outer_sdfg, inner_desc.dtype) + outer_origin = [] + else: + outer, outer_desc = sdfg_builder.add_temp_array_like(outer_sdfg, inner_desc) + # We cannot use a copy of the inner data descriptor directly, we have to apply the symbol mapping. + dace.symbolic.safe_replace( + symbol_mapping, + lambda m: dace.sdfg.replace_properties_dict(outer_desc, m), + ) + # Same applies to the symbols used as field origin (the domain range start) + outer_origin = [ + gtx_dace_utils.safe_replace_symbolic(val, symbol_mapping) for val in self.origin + ] + + outer_node = outer_sdfg_state.add_access(outer) + return FieldopData(outer_node, self.gt_type, tuple(outer_origin)) + + def get_local_view( + self, domain: FieldopDomain + ) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: + """Helper method to access a field in local view, given the compute domain of a field operator.""" + if isinstance(self.gt_type, ts.ScalarType): + return gtir_dataflow.MemletExpr( + dc_node=self.dc_node, + gt_dtype=self.gt_type, + subset=dace_subsets.Range.from_string("0"), + ) + + if isinstance(self.gt_type, ts.FieldType): + domain_dims = [dim for dim, _, _ in domain] + domain_indices = get_domain_indices(domain_dims, origin=None) + it_indices: dict[gtx_common.Dimension, gtir_dataflow.DataExpr] = { + dim: gtir_dataflow.SymbolExpr(index, INDEX_DTYPE) + for dim, index in zip(domain_dims, domain_indices) + } + field_origin = [ + (dim, dace.symbolic.SymExpr(0) if self.origin is None else self.origin[i]) + for i, dim in enumerate(self.gt_type.dims) + ] + # The property below is ensured by calling `make_field()` to construct `FieldopData`. + # The `make_field` constructor ensures that any local dimension, if present, is converted + # to `ListType` element type, while the field domain consists of all global dimensions. + assert all(dim != gtx_common.DimensionKind.LOCAL for dim in self.gt_type.dims) + return gtir_dataflow.IteratorExpr( + self.dc_node, self.gt_type.dtype, field_origin, it_indices + ) + + raise NotImplementedError(f"Node type {type(self.gt_type)} not supported.") + + def get_symbol_mapping( + self, dataname: str, sdfg: dace.SDFG + ) -> dict[str, dace.symbolic.SymExpr]: + """ + Helper method to create the symbol mapping for array storage in a nested SDFG. + + Args: + dataname: Name of the data container insiode the nested SDFG. + sdfg: The parent SDFG where the `FieldopData` object lives. + + Returns: + Mapping from symbols in nested SDFG to the corresponding symbolic values + in the parent SDFG. This includes the range start and stop symbols (used + to calculate the array shape as range 'stop - start') and the strides. + """ + if isinstance(self.gt_type, ts.ScalarType): + return {} + ndims = len(self.gt_type.dims) + outer_desc = self.dc_node.desc(sdfg) + assert isinstance(outer_desc, dace.data.Array) + # origin and size of the local dimension, in case of a field with `ListType` data, + # are assumed to be compiled-time values (not symbolic), therefore the start and + # stop range symbols of the inner field only extend over the global dimensions + return ( + {gtx_dace_utils.range_start_symbol(dataname, i): (self.origin[i]) for i in range(ndims)} + | { + gtx_dace_utils.range_stop_symbol(dataname, i): ( + self.origin[i] + outer_desc.shape[i] + ) + for i in range(ndims) + } + | { + gtx_dace_utils.field_stride_symbol_name(dataname, i): stride + for i, stride in enumerate(outer_desc.strides) + } + ) + + +FieldopDomain: TypeAlias = list[ + tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] +] +""" +Domain of a field operator represented as a list of tuples with 3 elements: + - dimension definition + - symbolic expression for lower bound (inclusive) + - symbolic expression for upper bound (exclusive) +""" + + +FieldopResult: TypeAlias = FieldopData | tuple[FieldopData | tuple, ...] +"""Result of a field operator, can be either a field or a tuple fields.""" + + +INDEX_DTYPE: Final[dace.typeclass] = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) +"""Data type used for field indexing.""" + + +def get_arg_symbol_mapping( + dataname: str, arg: FieldopResult, sdfg: dace.SDFG +) -> dict[str, dace.symbolic.SymExpr]: + """ + Helper method to build the mapping from inner to outer SDFG of all symbols + used for storage of a field or a tuple of fields. + + Args: + dataname: The storage name inside the nested SDFG. + arg: The argument field in the parent SDFG. + sdfg: The parent SDFG where the argument field lives. + + Returns: + A mapping from inner symbol names to values or symbolic definitions + in the parent SDFG. + """ + if isinstance(arg, FieldopData): + return arg.get_symbol_mapping(dataname, sdfg) + + symbol_mapping: dict[str, dace.symbolic.SymExpr] = {} + for i, elem in enumerate(arg): + dataname_elem = f"{dataname}_{i}" + symbol_mapping |= get_arg_symbol_mapping(dataname_elem, elem, sdfg) + + return symbol_mapping + + +def get_tuple_type(data: tuple[FieldopResult, ...]) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the tuple structure of `FieldopResult`. + """ + return ts.TupleType( + types=[get_tuple_type(d) if isinstance(d, tuple) else d.gt_type for d in data] + ) + + +def flatten_tuples(name: str, arg: FieldopResult) -> list[tuple[str, FieldopData]]: + """ + Visit a `FieldopResult`, potentially containing nested tuples, and construct a list + of pairs `(str, FieldopData)` containing the symbol name of each tuple field and + the corresponding `FieldopData`. + """ + if isinstance(arg, tuple): + tuple_type = get_tuple_type(arg) + tuple_symbols = gtir_sdfg_utils.flatten_tuple_fields(name, tuple_type) + tuple_data_fields = gtx_utils.flatten_nested_tuple(arg) + return [ + (str(tsym.id), tfield) + for tsym, tfield in zip(tuple_symbols, tuple_data_fields, strict=True) + ] + else: + return [(name, arg)] + + +class PrimitiveTranslator(Protocol): + @abc.abstractmethod + def __call__( + self, + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + ) -> FieldopResult: + """Creates the dataflow subgraph representing a GTIR primitive function. + + This method is used by derived classes to build a specialized subgraph + for a specific GTIR primitive function. + + Args: + node: The GTIR node describing the primitive to be lowered + sdfg: The SDFG where the primitive subgraph should be instantiated + state: The SDFG state where the result of the primitive function should be made available + sdfg_builder: The object responsible for visiting child nodes of the primitive node. + + Returns: + A list of data access nodes and the associated GT4Py data type, which provide + access to the result of the primitive subgraph. The GT4Py data type is useful + in the case the returned data is an array, because the type provdes the domain + information (e.g. order of dimensions, dimension types). + """ + + +def _parse_fieldop_arg( + node: gtir.Expr, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + domain: FieldopDomain, +) -> ( + gtir_dataflow.IteratorExpr + | gtir_dataflow.MemletExpr + | tuple[gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr | tuple[Any, ...], ...] +): + """Helper method to visit an expression passed as argument to a field operator.""" + + arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) + + if isinstance(arg, FieldopData): + return arg.get_local_view(domain) + else: + # handle tuples of fields + return gtx_utils.tree_map(lambda targ: targ.get_local_view(domain))(arg) + + +def get_field_layout( + domain: FieldopDomain, +) -> tuple[list[gtx_common.Dimension], list[dace.symbolic.SymExpr], list[dace.symbolic.SymExpr]]: + """ + Parse the field operator domain and generates the shape of the result field. + + It should be enough to allocate an array with shape (upper_bound - lower_bound) + but this would require to use array offset for compensate for the start index. + Suppose that a field operator executes on domain [2,N-2], the dace array to store + the result only needs size (N-4), but this would require to compensate all array + accesses with offset -2 (which corresponds to -lower_bound). Instead, we choose + to allocate (N-2), leaving positions [0:2] unused. The reason is that array offset + is known to cause issues to SDFG inlining. Besides, map fusion will in any case + eliminate most of transient arrays. + + Args: + domain: The field operator domain. + + Returns: + A tuple of three lists containing: + - the domain dimensions + - the domain origin, that is the start indices in all dimensions + - the domain size in each dimension + """ + domain_dims, domain_lbs, domain_ubs = zip(*domain) + domain_sizes = [(ub - lb) for lb, ub in zip(domain_lbs, domain_ubs)] + return list(domain_dims), list(domain_lbs), domain_sizes + + +def _create_field_operator_impl( + sdfg_builder: gtir_sdfg.SDFGBuilder, + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: FieldopDomain, + output_edge: gtir_dataflow.DataflowOutputEdge, + output_type: ts.FieldType, + map_exit: dace.nodes.MapExit, +) -> FieldopData: + """ + Helper method to allocate a temporary array that stores one field computed + by a field operator. + + This method is called by `_create_field_operator()`. + + Args: + sdfg_builder: The object used to build the map scope in the provided SDFG. + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + output_edge: The dataflow write edge representing the output data. + output_type: The GT4Py field type descriptor. + map_exit: The `MapExit` node of the field operator map scope. + + Returns: + The field data descriptor, which includes the field access node in the + given `state` and the field domain offset. + """ + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) + + # the memory layout of the output field follows the field operator compute domain + field_dims, field_origin, field_shape = get_field_layout(domain) + field_indices = get_domain_indices(field_dims, field_origin) + field_subset = dace_subsets.Range.from_indices(field_indices) + + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + if output_edge.result.gt_dtype != output_type.dtype: + raise TypeError( + f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." + ) + assert isinstance(dataflow_output_desc, dace.data.Scalar) + else: + assert isinstance(output_type.dtype, ts.ListType) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + if output_edge.result.gt_dtype.element_type != output_type.dtype.element_type: + raise TypeError( + f"Type mismatch, expected {output_type.dtype.element_type} got {output_edge.result.gt_dtype.element_type}." + ) + assert isinstance(dataflow_output_desc, dace.data.Array) + assert len(dataflow_output_desc.shape) == 1 + # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) + assert output_edge.result.gt_dtype.offset_type is not None + field_shape = [*field_shape, dataflow_output_desc.shape[0]] + field_subset = field_subset + dace_subsets.Range.from_array(dataflow_output_desc) + + # allocate local temporary storage + field_name, _ = sdfg_builder.add_temp_array(sdfg, field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(map_exit, field_node, field_subset) + + return FieldopData( + field_node, ts.FieldType(field_dims, output_edge.result.gt_dtype), tuple(field_origin) + ) + + +def _create_field_operator( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: FieldopDomain, + node_type: ts.FieldType | ts.TupleType, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Iterable[gtir_dataflow.DataflowInputEdge], + output_tree: tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], +) -> FieldopResult: + """ + Helper method to build the output of a field operator, which can consist of + a single field or a tuple of fields. + + A tuple of fields is returned when one stencil computes a grid point on multiple + fields: for each field, this method will call `_create_field_operator_impl()`. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_tree: A tree representation of the dataflow output data. + + Returns: + The descriptor of the field operator result, which can be either a single + field or a tuple fields. + """ + + # create map range corresponding to the field operator domain + map_entry, map_exit = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + gtir_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(map_entry) + + if isinstance(node_type, ts.FieldType): + assert len(output_tree) == 1 and isinstance( + output_tree[0], gtir_dataflow.DataflowOutputEdge + ) + output_edge = output_tree[0] + return _create_field_operator_impl( + sdfg_builder, sdfg, state, domain, output_edge, node_type, map_exit + ) + else: + # handle tuples of fields + output_symbol_tree = gtir_sdfg_utils.make_symbol_tree("x", node_type) + return gtx_utils.tree_map( + lambda output_edge, output_sym: _create_field_operator_impl( + sdfg_builder, sdfg, state, domain, output_edge, output_sym.type, map_exit + ) + )(output_tree, output_symbol_tree) + + +def extract_domain(node: gtir.Node) -> FieldopDomain: + """ + Visits the domain of a field operator and returns a list of dimensions and + the corresponding lower and upper bounds. The returned lower bound is inclusive, + the upper bound is exclusive: [lower_bound, upper_bound[ + """ + + domain = [] + + def parse_range_boundary(expr: gtir.Expr) -> str: + return dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(expr)) + + if cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")): + for named_range in node.args: + assert cpm.is_call_to(named_range, "named_range") + assert len(named_range.args) == 3 + axis = named_range.args[0] + assert isinstance(axis, gtir.AxisLiteral) + lower_bound, upper_bound = (parse_range_boundary(arg) for arg in named_range.args[1:3]) + dim = gtx_common.Dimension(axis.value, axis.kind) + domain.append((dim, lower_bound, upper_bound)) + + elif isinstance(node, domain_utils.SymbolicDomain): + assert str(node.grid_type) in {"cartesian_domain", "unstructured_domain"} + for dim, drange in node.ranges.items(): + domain.append( + (dim, parse_range_boundary(drange.start), parse_range_boundary(drange.stop)) + ) + + else: + raise ValueError(f"Invalid domain {node}.") + + return domain + + +def translate_as_fieldop( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Generates the dataflow subgraph for the `as_fieldop` builtin function. + + Expects a `FunCall` node with two arguments: + 1. a lambda function representing the stencil, which is lowered to a dataflow subgraph + 2. the domain of the field operator, which is used as map range + + The dataflow can be as simple as a single tasklet, or implement a local computation + as a composition of tasklets and even include a map to range on local dimensions (e.g. + neighbors and map builtins). + The stencil dataflow is instantiated inside a map scope, which applies the stencil + over the field domain. + """ + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, (ts.FieldType, ts.TupleType)) + + fun_node = node.fun + assert len(fun_node.args) == 2 + fieldop_expr, domain_expr = fun_node.args + + if cpm.is_call_to(fieldop_expr, "scan"): + return translate_scan(node, sdfg, state, sdfg_builder) + + if cpm.is_ref_to(fieldop_expr, "deref"): + # Special usage of 'deref' as argument to fieldop expression, to pass a scalar + # value to 'as_fieldop' function. It results in broadcasting the scalar value + # over the field domain. + assert isinstance(node.type, ts.FieldType) + stencil_expr = im.lambda_("a")(im.deref("a")) + stencil_expr.expr.type = node.type.dtype + elif isinstance(fieldop_expr, gtir.Lambda): + # Default case, handled below: the argument expression is a lambda function + # representing the stencil operation to be computed over the field domain. + stencil_expr = fieldop_expr + else: + raise NotImplementedError( + f"Expression type '{type(fieldop_expr)}' not supported as argument to 'as_fieldop' node." + ) + + # parse the domain of the field operator + domain = extract_domain(domain_expr) + + # visit the list of arguments to be passed to the lambda expression + fieldop_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] + + # represent the field operator as a mapped tasklet graph, which will range over the field domain + input_edges, output_edges = gtir_dataflow.translate_lambda_to_dataflow( + sdfg, state, sdfg_builder, stencil_expr, fieldop_args + ) + + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edges + ) + + +def translate_if( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """Generates the dataflow subgraph for the `if_` builtin function.""" + assert cpm.is_call_to(node, "if_") + assert len(node.args) == 3 + cond_expr, true_expr, false_expr = node.args + + # expect condition as first argument + if_stmt = gtir_python_codegen.get_source(cond_expr) + + # use current head state to terminate the dataflow, and add a entry state + # to connect the true/false branch states as follows: + # + # ------------ + # === | cond | === + # || ------------ || + # \/ \/ + # ------------ ------------- + # | true | | false | + # ------------ ------------- + # || || + # || ------------ || + # ==> | head | <== + # ------------ + # + cond_state = sdfg.add_state_before(state, state.label + "_cond") + sdfg.remove_edge(sdfg.out_edges(cond_state)[0]) + + # expect true branch as second argument + true_state = sdfg.add_state(state.label + "_true_branch") + sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"{if_stmt}")) + sdfg.add_edge(true_state, state, dace.InterstateEdge()) + + # and false branch as third argument + false_state = sdfg.add_state(state.label + "_false_branch") + sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not ({if_stmt})"))) + sdfg.add_edge(false_state, state, dace.InterstateEdge()) + + true_br_args = sdfg_builder.visit( + true_expr, + sdfg=sdfg, + head_state=true_state, + ) + false_br_args = sdfg_builder.visit( + false_expr, + sdfg=sdfg, + head_state=false_state, + ) + + def construct_output(inner_data: FieldopData) -> FieldopData: + inner_desc = inner_data.dc_node.desc(sdfg) + outer, _ = sdfg_builder.add_temp_array_like(sdfg, inner_desc) + outer_node = state.add_access(outer) + + return FieldopData(outer_node, inner_data.gt_type, inner_data.origin) + + result_temps = gtx_utils.tree_map(construct_output)(true_br_args) + + fields: Iterable[tuple[FieldopData, FieldopData, FieldopData]] = zip( + gtx_utils.flatten_nested_tuple((true_br_args,)), + gtx_utils.flatten_nested_tuple((false_br_args,)), + gtx_utils.flatten_nested_tuple((result_temps,)), + strict=True, + ) + + for true_br, false_br, temp in fields: + if true_br.gt_type != false_br.gt_type: + raise ValueError( + f"Different type of result fields on if-branches '{true_br.gt_type}' vs '{false_br.gt_type}'." + ) + true_br_node = true_br.dc_node + false_br_node = false_br.dc_node + + temp_name = temp.dc_node.data + true_br_output_node = true_state.add_access(temp_name) + true_state.add_nedge( + true_br_node, + true_br_output_node, + sdfg.make_array_memlet(temp_name), + ) + + false_br_output_node = false_state.add_access(temp_name) + false_state.add_nedge( + false_br_node, + false_br_output_node, + sdfg.make_array_memlet(temp_name), + ) + + return result_temps + + +def translate_index( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """ + Lowers the `index` builtin function to a mapped tasklet that writes the dimension + index values to a transient array. The extent of the index range is taken from + the domain information that should be present in the node annex. + """ + assert cpm.is_call_to(node, "index") + assert isinstance(node.type, ts.FieldType) + + assert "domain" in node.annex + domain = extract_domain(node.annex.domain) + assert len(domain) == 1 + dim, _, _ = domain[0] + dim_index = gtir_sdfg_utils.get_map_variable(dim) + + index_data, _ = sdfg_builder.add_temp_scalar(sdfg, INDEX_DTYPE) + index_node = state.add_access(index_data) + index_value = gtir_dataflow.ValueExpr( + dc_node=index_node, + gt_dtype=gtx_dace_utils.as_itir_type(INDEX_DTYPE), + ) + index_write_tasklet = sdfg_builder.add_tasklet( + "index", + state, + inputs={}, + outputs={"__val"}, + code=f"__val = {dim_index}", + ) + state.add_edge( + index_write_tasklet, + "__val", + index_node, + None, + dace.Memlet(data=index_data, subset="0"), + ) + + input_edges = [ + gtir_dataflow.EmptyInputEdge(state, index_write_tasklet), + ] + output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value) + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, (output_edge,) + ) + + +def _get_data_nodes( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + data_name: str, + data_type: ts.DataType, +) -> FieldopResult: + if isinstance(data_type, ts.FieldType): + data_node = state.add_access(data_name) + return sdfg_builder.make_field(data_node, data_type) + + elif isinstance(data_type, ts.ScalarType): + if data_name in sdfg.symbols: + data_node = _get_symbolic_value( + sdfg, state, sdfg_builder, data_name, data_type, temp_name=f"__{data_name}" + ) + else: + data_node = state.add_access(data_name) + return sdfg_builder.make_field(data_node, data_type) + + elif isinstance(data_type, ts.TupleType): + symbol_tree = gtir_sdfg_utils.make_symbol_tree(data_name, data_type) + return gtx_utils.tree_map( + lambda sym: _get_data_nodes(sdfg, state, sdfg_builder, sym.id, sym.type) + )(symbol_tree) + + else: + raise NotImplementedError(f"Symbol type {type(data_type)} not supported.") + + +def _get_symbolic_value( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + symbolic_expr: dace.symbolic.SymExpr, + scalar_type: ts.ScalarType, + temp_name: Optional[str] = None, +) -> dace.nodes.AccessNode: + tasklet_node = sdfg_builder.add_tasklet( + "get_value", + state, + {}, + {"__out"}, + f"__out = {symbolic_expr}", + ) + temp_name, _ = sdfg.add_scalar( + temp_name or sdfg.temp_data_name(), + gtx_dace_utils.as_dace_type(scalar_type), + find_new_name=True, + transient=True, + ) + data_node = state.add_access(temp_name) + state.add_edge( + tasklet_node, + "__out", + data_node, + None, + dace.Memlet(data=temp_name, subset="0"), + ) + return data_node + + +def translate_literal( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """Generates the dataflow subgraph for a `ir.Literal` node.""" + assert isinstance(node, gtir.Literal) + + data_type = node.type + data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) + + return FieldopData(data_node, data_type, origin=()) + + +def translate_make_tuple( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + assert cpm.is_call_to(node, "make_tuple") + return tuple( + sdfg_builder.visit( + arg, + sdfg=sdfg, + head_state=state, + ) + for arg in node.args + ) + + +def translate_tuple_get( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + assert cpm.is_call_to(node, "tuple_get") + assert len(node.args) == 2 + + if not isinstance(node.args[0], gtir.Literal): + raise ValueError("Tuple can only be subscripted with compile-time constants.") + assert ti.is_integral(node.args[0].type) + index = int(node.args[0].value) + + data_nodes = sdfg_builder.visit( + node.args[1], + sdfg=sdfg, + head_state=state, + ) + if isinstance(data_nodes, FieldopData): + raise ValueError(f"Invalid tuple expression {node}") + unused_arg_nodes: Iterable[FieldopData] = gtx_utils.flatten_nested_tuple( + tuple(arg for i, arg in enumerate(data_nodes) if i != index) + ) + state.remove_nodes_from( + [arg.dc_node for arg in unused_arg_nodes if state.degree(arg.dc_node) == 0] + ) + return data_nodes[index] + + +def translate_scalar_expr( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + assert isinstance(node, gtir.FunCall) + assert isinstance(node.type, ts.ScalarType) + + args = [] + connectors = [] + scalar_expr_args = [] + + for i, arg_expr in enumerate(node.args): + visit_expr = True + if isinstance(arg_expr, gtir.SymRef): + try: + # check if symbol is defined in the GT4Py program, throws `KeyError` exception if undefined + sdfg_builder.get_symbol_type(arg_expr.id) + except KeyError: + # all `SymRef` should refer to symbols defined in the program, except in case of non-variable argument, + # e.g. the type name `float64` used in casting expressions like `cast_(variable, float64)` + visit_expr = False + + if visit_expr: + # we visit the argument expression and obtain the access node to + # a scalar data container, which will be connected to the tasklet + arg = sdfg_builder.visit( + arg_expr, + sdfg=sdfg, + head_state=state, + ) + if not (isinstance(arg, FieldopData) and isinstance(node.type, ts.ScalarType)): + raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") + param = f"__arg{i}" + args.append(arg.dc_node) + connectors.append(param) + scalar_expr_args.append(gtir.SymRef(id=param)) + else: + assert isinstance(arg_expr, gtir.SymRef) + scalar_expr_args.append(arg_expr) + + # we visit the scalar expression replacing the input arguments with the corresponding data connectors + scalar_node = gtir.FunCall(fun=node.fun, args=scalar_expr_args) + python_code = gtir_python_codegen.get_source(scalar_node) + tasklet_node = sdfg_builder.add_tasklet( + name="scalar_expr", + state=state, + inputs=set(connectors), + outputs={"__out"}, + code=f"__out = {python_code}", + ) + # create edges for the input data connectors + for arg_node, conn in zip(args, connectors, strict=True): + state.add_edge( + arg_node, + None, + tasklet_node, + conn, + dace.Memlet(data=arg_node.data, subset="0"), + ) + # finally, create temporary for the result value + temp_name, _ = sdfg_builder.add_temp_scalar(sdfg, gtx_dace_utils.as_dace_type(node.type)) + temp_node = state.add_access(temp_name) + state.add_edge( + tasklet_node, + "__out", + temp_node, + None, + dace.Memlet(data=temp_name, subset="0"), + ) + + return FieldopData(temp_node, node.type, origin=()) + + +def translate_symbol_ref( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> FieldopResult: + """Generates the dataflow subgraph for a `ir.SymRef` node.""" + assert isinstance(node, gtir.SymRef) + + symbol_name = str(node.id) + # we retrieve the type of the symbol in the GT4Py prgram + gt_symbol_type = sdfg_builder.get_symbol_type(symbol_name) + + # Create new access node in current state. It is possible that multiple + # access nodes are created in one state for the same data container. + # We rely on the dace simplify pass to remove duplicated access nodes. + return _get_data_nodes(sdfg, state, sdfg_builder, symbol_name, gt_symbol_type) + + +if TYPE_CHECKING: + # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol + __primitive_translators: list[PrimitiveTranslator] = [ + translate_as_fieldop, + translate_if, + translate_index, + translate_literal, + translate_make_tuple, + translate_tuple_get, + translate_scalar_expr, + translate_scan, + translate_symbol_ref, + ] diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py new file mode 100644 index 0000000000..43e7c6354d --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py @@ -0,0 +1,1902 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import abc +import dataclasses +from typing import ( + Any, + Dict, + Final, + Iterable, + List, + Optional, + Protocol, + Sequence, + Set, + Tuple, + TypeAlias, + Union, +) + +import dace +from dace import subsets as dace_subsets + +from gt4py import eve +from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.iterator import builtins as gtir_builtins, ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.iterator.transforms import symbol_ref_utils +from gt4py.next.program_processors.runners.dace import ( + gtir_python_codegen, + gtir_sdfg, + gtir_sdfg_utils, + utils as gtx_dace_utils, +) +from gt4py.next.type_system import type_info as ti, type_specifications as ts + + +# Magic local dimension for the result of a `make_const_list`. +# A clean implementation will probably involve to tag the `make_const_list` +# with the neighborhood it is meant to be used with. +_CONST_DIM = gtx_common.Dimension(value="_CONST_DIM", kind=gtx_common.DimensionKind.LOCAL) + + +@dataclasses.dataclass(frozen=True) +class ValueExpr: + """ + Local storage for the values returned by dataflow computation. + + This type is used in the context in a dataflow, that is a stencil expression. + Therefore, it contains either a scalar value (single elements in the fields) or + a list of values in a local dimension. + This is different from `gtir_builtin_translators.FieldopData` which represents + the result of a field operator, basically the data storage outside a global map. + + Args: + dc_node: Access node to the data container, can be either a scalar or a local list. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. + """ + + dc_node: dace.nodes.AccessNode + gt_dtype: ts.ListType | ts.ScalarType + + +@dataclasses.dataclass(frozen=True) +class MemletExpr: + """ + Scalar or array data access through a memlet. + + Args: + dc_node: Access node to the data container, can be either a scalar or a local list. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. + subset: Represents the subset to use in memlet to access the above data. + """ + + dc_node: dace.nodes.AccessNode + gt_dtype: ts.ListType | ts.ScalarType + subset: dace_subsets.Range + + +@dataclasses.dataclass(frozen=True) +class SymbolExpr: + """Any symbolic expression that is constant in the context of current SDFG.""" + + value: dace.symbolic.SymExpr + dc_dtype: dace.typeclass + + +DataExpr: TypeAlias = ValueExpr | MemletExpr | SymbolExpr + + +@dataclasses.dataclass(frozen=True) +class IteratorExpr: + """ + Iterator for field access to be consumed by `deref` or `shift` builtin functions. + + Args: + field: Access node to the field this iterator operates on. + gt_dtype: GT4Py data type, which includes the `offset_type` local dimension for lists. + field_domain: Field domain represented as a sorted list of dimensions and offset values, + used to find the position of a map index variable in the memlet subset. The offset + value is either the start index of dimension range or the compile-time value of + a shift expression, or a composition of both, and it must be subtracted to the index + variable when constructing the memlet subset range. + indices: Maps each dimension to an index value, which could be either a symbolic value + or the result of a tasklet computation like neighbors connectivity or dynamic offset. + """ + + field: dace.nodes.AccessNode + gt_dtype: ts.ListType | ts.ScalarType + field_domain: list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType]] + indices: dict[gtx_common.Dimension, DataExpr] + + def get_field_type(self) -> ts.FieldType: + return ts.FieldType([dim for dim, _ in self.field_domain], self.gt_dtype) + + def get_memlet_subset(self, sdfg: dace.SDFG) -> dace_subsets.Range: + if not all(isinstance(self.indices[dim], SymbolExpr) for dim, _ in self.field_domain): + raise ValueError(f"Cannot deref iterator {self}.") + + field_desc = self.field.desc(sdfg) + if isinstance(self.gt_dtype, ts.ListType): + assert len(field_desc.shape) == len(self.field_domain) + 1 + assert self.gt_dtype.offset_type is not None + field_domain = [*self.field_domain, (self.gt_dtype.offset_type, 0)] + else: + assert len(field_desc.shape) == len(self.field_domain) + field_domain = self.field_domain + + return dace_subsets.Range.from_string( + ",".join( + str(self.indices[dim].value - offset) # type: ignore[union-attr] + if dim in self.indices + else f"0:{size}" + for (dim, offset), size in zip(field_domain, field_desc.shape, strict=True) + ) + ) + + +class DataflowInputEdge(Protocol): + """ + This protocol describes how to concretize a data edge to read data from a source node + into the dataflow. + + It provides the `connect` method to setup an input edge from an external data source. + The most common case is that the dataflow represents a stencil, which is instantied + inside a map scope and whose inputs and outputs are connected to external data nodes + by means of memlets that traverse the map entry and exit nodes. + The dataflow can also be instatiated without a map, in which case the `map_entry` + argument is set to `None`. + """ + + @abc.abstractmethod + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: ... + + +@dataclasses.dataclass(frozen=True) +class MemletInputEdge(DataflowInputEdge): + """ + Allows to setup an input memlet through a map entry node. + + The edge source has to be a data access node, while the destination node can either + be a tasklet, in which case the connector name is also required, or an access node. + """ + + state: dace.SDFGState + source: dace.nodes.AccessNode + subset: dace_subsets.Range + dest: dace.nodes.AccessNode | dace.nodes.Tasklet + dest_conn: Optional[str] + + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + memlet = dace.Memlet(data=self.source.data, subset=self.subset) + if map_entry is None: + self.state.add_edge(self.source, None, self.dest, self.dest_conn, memlet) + else: + self.state.add_memlet_path( + self.source, + map_entry, + self.dest, + dst_conn=self.dest_conn, + memlet=memlet, + ) + + +@dataclasses.dataclass(frozen=True) +class EmptyInputEdge(DataflowInputEdge): + """ + Allows to setup an edge from a map entry node to a tasklet with no arguments. + + The reason behind this kind of connection is that all nodes inside a map scope + must have an in/out path that traverses the entry and exit nodes. + """ + + state: dace.SDFGState + node: dace.nodes.Tasklet + + def connect(self, map_entry: Optional[dace.nodes.MapEntry]) -> None: + if map_entry is None: + # outside of a map scope it is possible to instantiate a tasklet node + # without input connectors + return + self.state.add_nedge(map_entry, self.node, dace.Memlet()) + + +@dataclasses.dataclass(frozen=True) +class DataflowOutputEdge: + """ + Allows to setup an output memlet through a map exit node. + + The result of a dataflow subgraph needs to be written to an external data node. + The most common case is that the dataflow represents a stencil and the dataflow + is computed over a field domain, therefore the dataflow is instatiated inside + a map scope. The `connect` method creates a memlet that writes the dataflow + result to the external array passing through the `map_exit` node. + The dataflow can also be instatiated without a map, in which case the `map_exit` + argument is set to `None`. + """ + + state: dace.SDFGState + result: ValueExpr + + def connect( + self, + map_exit: Optional[dace.nodes.MapExit], + dest: dace.nodes.AccessNode, + subset: dace_subsets.Range, + ) -> None: + write_edge = self.state.in_edges(self.result.dc_node)[0] + write_size = write_edge.data.dst_subset.num_elements() + # check the kind of node which writes the result + if isinstance(write_edge.src, dace.nodes.Tasklet): + # The temporary data written by a tasklet can be safely deleted + assert write_size.is_constant() + remove_last_node = True + elif isinstance(write_edge.src, dace.nodes.NestedSDFG): + if write_size.is_constant(): + # Temporary data with compile-time size is allocated on the stack + # and therefore is safe to keep. We decide to keep it as a workaround + # for a dace issue with memlet propagation in combination with + # nested SDFGs containing conditional blocks. The output memlet + # of such blocks will be marked as dynamic because dace is not able + # to detect the exact size of a conditional branch dataflow, even + # in case of if-else expressions with exact same output data. + remove_last_node = False + else: + # In case the output data has runtime size it is necessary to remove + # it in order to avoid dynamic memory allocation inside a parallel + # map scope. Otherwise, the memory allocation will for sure lead + # to performance degradation, and eventually illegal memory issues + # when the gpu runs out of local memory. + remove_last_node = True + else: + remove_last_node = False + + if remove_last_node: + last_node = write_edge.src + last_node_connector = write_edge.src_conn + self.state.remove_node(self.result.dc_node) + else: + last_node = self.result.dc_node + last_node_connector = None + + if map_exit is None: + self.state.add_edge( + last_node, + last_node_connector, + dest, + None, + dace.Memlet(data=dest.data, subset=subset), + ) + else: + self.state.add_memlet_path( + last_node, + map_exit, + dest, + src_conn=last_node_connector, + memlet=dace.Memlet(data=dest.data, subset=subset), + ) + + +DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { + "minimum": dace.dtypes.ReductionType.Min, + "maximum": dace.dtypes.ReductionType.Max, + "plus": dace.dtypes.ReductionType.Sum, + "multiplies": dace.dtypes.ReductionType.Product, + "and_": dace.dtypes.ReductionType.Logical_And, + "or_": dace.dtypes.ReductionType.Logical_Or, + "xor_": dace.dtypes.ReductionType.Logical_Xor, + "minus": dace.dtypes.ReductionType.Sub, + "divides": dace.dtypes.ReductionType.Div, +} + + +def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: + assert isinstance(node.type, ts.ScalarType) + dc_dtype = gtx_dace_utils.as_dace_type(node.type) + + assert isinstance(node.fun, gtir.FunCall) + assert len(node.fun.args) == 2 + assert isinstance(node.fun.args[0], gtir.SymRef) + op_name = str(node.fun.args[0]) + assert isinstance(node.fun.args[1], gtir.Literal) + assert node.fun.args[1].type == node.type + reduce_init = SymbolExpr(node.fun.args[1].value, dc_dtype) + + if op_name not in DACE_REDUCTION_MAPPING: + raise RuntimeError(f"Reduction operation '{op_name}' not supported.") + identity_value = dace.dtypes.reduction_identity(dc_dtype, DACE_REDUCTION_MAPPING[op_name]) + reduce_identity = SymbolExpr(identity_value, dc_dtype) + + return op_name, reduce_init, reduce_identity + + +def get_tuple_type( + data: tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...], +) -> ts.TupleType: + """ + Compute the `ts.TupleType` corresponding to the tuple structure of input data expressions. + """ + data_types: list[ts.DataType] = [] + for dataitem in data: + if isinstance(dataitem, tuple): + data_types.append(get_tuple_type(dataitem)) + elif isinstance(dataitem, IteratorExpr): + data_types.append(dataitem.get_field_type()) + elif isinstance(dataitem, MemletExpr): + data_types.append(dataitem.gt_dtype) + else: + data_types.append(dataitem.gt_dtype) + return ts.TupleType(data_types) + + +@dataclasses.dataclass(frozen=True) +class LambdaToDataflow(eve.NodeVisitor): + """ + Visitor class to translate a `Lambda` expression to a dataflow graph. + + This visitor should be applied by calling `apply()` method on a `Lambda` IR. + The dataflow graph generated here typically represents the stencil function + of a field operator. It only computes single elements or pure local fields, + in case of neighbor values. In case of local fields, the dataflow contains + inner maps with fixed literal size (max number of neighbors). + Once the lambda expression has been lowered to a dataflow, the dataflow graph + needs to be instantiated, that is we have to connect all in/out edges to + external source/destination data nodes. Since the lambda expression is used + in GTIR as argument to a field operator, the dataflow is instatiated inside + a map scope and applied on the field domain. Therefore, all in/out edges + must traverse the entry/exit map nodes. + """ + + sdfg: dace.SDFG + state: dace.SDFGState + subgraph_builder: gtir_sdfg.DataflowBuilder + input_edges: list[DataflowInputEdge] = dataclasses.field(default_factory=lambda: []) + symbol_map: dict[ + str, + IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...], + ] = dataclasses.field(default_factory=dict) + + def _add_input_data_edge( + self, + src: dace.nodes.AccessNode, + src_subset: dace_subsets.Range, + dst_node: dace.nodes.Node, + dst_conn: Optional[str] = None, + src_offset: Optional[list[dace.symbolic.SymExpr]] = None, + ) -> None: + input_subset = ( + src_subset + if src_offset is None + else dace_subsets.Range( + (start - off, stop - off, step) + for (start, stop, step), off in zip(src_subset, src_offset, strict=True) + ) + ) + edge = MemletInputEdge(self.state, src, input_subset, dst_node, dst_conn) + self.input_edges.append(edge) + + def _add_edge( + self, + src_node: dace.Node, + src_node_connector: Optional[str], + dst_node: dace.Node, + dst_node_connector: Optional[str], + memlet: dace.Memlet, + ) -> None: + """Helper method to add an edge in current state.""" + self.state.add_edge(src_node, src_node_connector, dst_node, dst_node_connector, memlet) + + def _add_map( + self, + name: str, + ndrange: Union[ + Dict[str, Union[str, dace.subsets.Subset]], + List[Tuple[str, Union[str, dace.subsets.Subset]]], + ], + **kwargs: Any, + ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + """ + Helper method to add a map in current state. + + The subgraph builder ensures that the map receives a unique name, + by adding a unique suffix to the provided name. + """ + return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs) + + def _add_tasklet( + self, + name: str, + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + **kwargs: Any, + ) -> dace.nodes.Tasklet: + """ + Helper method to add a tasklet in current state. + + The subgraph builder ensures that the tasklet receives a unique name, + by adding a unique suffix to the provided name. + """ + tasklet_node = self.subgraph_builder.add_tasklet( + name, self.state, inputs, outputs, code, **kwargs + ) + if len(inputs) == 0: + # All nodes inside a map scope must have an in/out path that traverses + # the entry and exit nodes. Therefore, a tasklet node with no arguments + # still needs an (empty) input edge from map entry node. + edge = EmptyInputEdge(self.state, tasklet_node) + self.input_edges.append(edge) + return tasklet_node + + def _add_mapped_tasklet( + self, + name: str, + map_ranges: Dict[str, str | dace.subsets.Subset] + | List[Tuple[str, str | dace.subsets.Subset]], + inputs: Dict[str, dace.Memlet], + code: str, + outputs: Dict[str, dace.Memlet], + **kwargs: Any, + ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: + """ + Helper method to add a mapped tasklet in current state. + + The subgraph builder ensures that the tasklet receives a unique name, + by adding a unique suffix to the provided name. + """ + return self.subgraph_builder.add_mapped_tasklet( + name, self.state, map_ranges, inputs, code, outputs, **kwargs + ) + + def unique_nsdfg_name(self, prefix: str) -> str: + """Utility function to generate a unique name for a nested SDFG, starting with the given prefix.""" + return self.subgraph_builder.unique_nsdfg_name(self.sdfg, prefix) + + def _construct_local_view(self, field: MemletExpr | ValueExpr) -> ValueExpr: + if isinstance(field, MemletExpr): + desc = field.dc_node.desc(self.sdfg) + local_dim_indices = [i for i, size in enumerate(field.subset.size()) if size != 1] + if len(local_dim_indices) == 0: + # we are accessing a single-element array with shape (1,) + view_shape = (1,) + view_strides = (1,) + else: + view_shape = tuple(desc.shape[i] for i in local_dim_indices) + view_strides = tuple(desc.strides[i] for i in local_dim_indices) + view, _ = self.sdfg.add_view( + f"view_{field.dc_node.data}", + view_shape, + desc.dtype, + strides=view_strides, + find_new_name=True, + ) + local_view_node = self.state.add_access(view) + self._add_input_data_edge(field.dc_node, field.subset, local_view_node) + + return ValueExpr(local_view_node, desc.dtype) + + else: + return field + + def _construct_tasklet_result( + self, + dc_dtype: dace.typeclass, + src_node: dace.nodes.Tasklet, + src_connector: str, + use_array: bool = False, + ) -> ValueExpr: + data_type = gtx_dace_utils.as_itir_type(dc_dtype) + if use_array: + # In some cases, such as result data with list-type annotation, we want + # that output data is represented as an array (single-element 1D array) + # in order to allow for composition of array shape in external memlets. + temp_name, _ = self.subgraph_builder.add_temp_array(self.sdfg, (1,), dc_dtype) + else: + temp_name, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, dc_dtype) + + temp_node = self.state.add_access(temp_name) + self._add_edge( + src_node, + src_connector, + temp_node, + None, + dace.Memlet(data=temp_name, subset="0"), + ) + return ValueExpr( + dc_node=temp_node, + gt_dtype=( + ts.ListType(element_type=data_type, offset_type=_CONST_DIM) + if use_array + else data_type + ), + ) + + def _visit_deref(self, node: gtir.FunCall) -> DataExpr: + """ + Visit a `deref` node, which represents dereferencing of an iterator. + The iterator is the argument of this node. + + The iterator contains the information for accessing a field, that is the + sorted list of dimensions in the field domain and the index values for + each dimension. The index values can be either symbol values, that is + literal values or scalar arguments which are constant in the SDFG scope; + or they can be the result of some expression, that computes a dynamic + index offset or gets an neighbor index from a connectivity table. + In case all indexes are symbol values, the `deref` node is lowered to a + memlet; otherwise dereferencing is a runtime operation represented in + the SDFG as a tasklet node. + """ + # format used for field index tasklet connector + IndexConnectorFmt: Final = "__index_{dim}" + + if isinstance(node.type, ts.TupleType): + raise NotImplementedError("Tuple deref not supported.") + + assert len(node.args) == 1 + arg_expr = self.visit(node.args[0]) + + if not isinstance(arg_expr, IteratorExpr): + # dereferencing a scalar or a literal node results in the node itself + return arg_expr + + field_desc = arg_expr.field.desc(self.sdfg) + if isinstance(field_desc, dace.data.Scalar): + # deref a zero-dimensional field + assert len(arg_expr.field_domain) == 0 + assert isinstance(node.type, ts.ScalarType) + return MemletExpr(arg_expr.field, arg_expr.gt_dtype, subset="0") + + # handle default case below: deref a field with one or more dimensions + + # when the indices are all dace symbolic expressions, the deref is lowered + # to a memlet, where the index is the memlet subset + if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): + # when all indices are symbolic expressions, we can perform direct field access through a memlet + field_subset = arg_expr.get_memlet_subset(self.sdfg) + return MemletExpr(arg_expr.field, arg_expr.gt_dtype, field_subset) + + # when any of the indices is a runtime value (either a dynamic cartesian + # offset or a connectivity offset), the deref is lowered to a tasklet + assert all(dim in arg_expr.indices for dim, _ in arg_expr.field_domain) + assert len(field_desc.shape) == len(arg_expr.field_domain) + field_indices = [(dim, arg_expr.indices[dim]) for dim, _ in arg_expr.field_domain] + index_connectors = [ + IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + if not isinstance(index, SymbolExpr) + ] + # here `internals` refer to the names used as index in the tasklet code string: + # an index can be either a connector name (for dynamic/indirect indices) + # or a symbol value (for literal values and scalar arguments). + index_internals = ",".join( + str(index.value) + if isinstance(index, SymbolExpr) + else IndexConnectorFmt.format(dim=dim.value) + for dim, index in field_indices + ) + deref_node = self._add_tasklet( + "deref", + {"field"} | set(index_connectors), + {"val"}, + code=f"val = field[{index_internals}]", + ) + # add new termination point for the field parameter + self._add_input_data_edge( + arg_expr.field, + dace_subsets.Range.from_array(field_desc), + deref_node, + "field", + src_offset=[offset for (_, offset) in arg_expr.field_domain], + ) + + for dim, index_expr in field_indices: + # add termination points for the dynamic iterator indices + deref_connector = IndexConnectorFmt.format(dim=dim.value) + if isinstance(index_expr, MemletExpr): + self._add_input_data_edge( + index_expr.dc_node, + index_expr.subset, + deref_node, + deref_connector, + ) + + elif isinstance(index_expr, ValueExpr): + self._add_edge( + index_expr.dc_node, + None, + deref_node, + deref_connector, + dace.Memlet(data=index_expr.dc_node.data, subset="0"), + ) + else: + assert isinstance(index_expr, SymbolExpr) + + return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") + + def _visit_if_branch_arg( + self, + if_sdfg: dace.SDFG, + if_branch_state: dace.SDFGState, + param_name: str, + arg: IteratorExpr | DataExpr, + deref_on_input_memlet: bool, + if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + ) -> IteratorExpr | ValueExpr: + """ + Helper method to be called by `_visit_if_branch()` to visit the input arguments. + + Args: + if_sdfg: The nested SDFG where the if expression is lowered. + if_branch_state: The state inside the nested SDFG where the if branch is lowered. + param_name: The parameter name of the input argument. + arg: The input argument expression. + deref_on_input_memlet: When True, the given iterator argument can be dereferenced on the input memlet. + if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + """ + use_full_shape = False + if isinstance(arg, (MemletExpr, ValueExpr)): + arg_desc = arg.dc_node.desc(self.sdfg) + arg_expr = arg + elif isinstance(arg, IteratorExpr): + arg_desc = arg.field.desc(self.sdfg) + if deref_on_input_memlet: + # If the iterator is just dereferenced inside the branch state, + # we can access the array outside the nested SDFG and pass the + # local data. This approach makes the data dependencies of nested + # structures more explicit and thus makes it easier for MapFusion + # to correctly infer the data dependencies. + memlet_subset = arg.get_memlet_subset(self.sdfg) + arg_expr = MemletExpr(arg.field, arg.gt_dtype, memlet_subset) + else: + # In order to shift the iterator inside the branch dataflow, + # we have to pass the full array shape. + arg_expr = MemletExpr( + arg.field, arg.gt_dtype, dace_subsets.Range.from_array(arg_desc) + ) + use_full_shape = True + else: + raise TypeError(f"Unexpected {arg} as input argument.") + + if use_full_shape: + inner_desc = arg_desc.clone() + inner_desc.transient = False + elif isinstance(arg.gt_dtype, ts.ScalarType): + inner_desc = dace.data.Scalar(arg_desc.dtype) + else: + # for list of values, we retrieve the local size from the corresponding offset + assert arg.gt_dtype.offset_type is not None + offset_provider_type = self.subgraph_builder.get_offset_provider_type( + arg.gt_dtype.offset_type.value + ) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + inner_desc = dace.data.Array(arg_desc.dtype, [offset_provider_type.max_neighbors]) + + if param_name in if_sdfg.arrays: + # the data desciptor was added by the visitor of the other branch expression + assert if_sdfg.data(param_name) == inner_desc + else: + if_sdfg.add_datadesc(param_name, inner_desc) + if_sdfg_input_memlets[param_name] = arg_expr + + inner_node = if_branch_state.add_access(param_name) + if isinstance(arg, IteratorExpr) and use_full_shape: + return IteratorExpr(inner_node, arg.gt_dtype, arg.field_domain, arg.indices) + else: + return ValueExpr(inner_node, arg.gt_dtype) + + def _visit_if_branch( + self, + if_sdfg: dace.SDFG, + if_branch_state: dace.SDFGState, + expr: gtir.Expr, + if_sdfg_input_memlets: dict[str, MemletExpr | ValueExpr], + direct_deref_iterators: Iterable[str], + ) -> tuple[ + list[DataflowInputEdge], + tuple[DataflowOutputEdge | tuple[Any, ...], ...], + ]: + """ + Helper method to visit an if-branch expression and lower it to a dataflow inside the given nested SDFG and state. + + This function is called by `_visit_if()` for each if-branch. + + Args: + if_sdfg: The nested SDFG where the if expression is lowered. + if_branch_state: The state inside the nested SDFG where the if branch is lowered. + expr: The if branch expression to lower. + if_sdfg_input_memlets: The memlets that provide input data to the nested SDFG, will be update inside this function. + direct_deref_iterators: Fields that are accessed with direct iterator deref, without any shift. + + Returns: + A tuple containing: + - the list of input edges for the parent dataflow + - the tree representation of output data, in the form of a tuple of data edges. + """ + assert if_branch_state in if_sdfg.states() + + lambda_args = [] + lambda_params = [] + for pname in symbol_ref_utils.collect_symbol_refs(expr, self.symbol_map.keys()): + arg = self.symbol_map[pname] + if isinstance(arg, tuple): + ptype = get_tuple_type(arg) # type: ignore[arg-type] + psymbol = im.sym(pname, ptype) + psymbol_tree = gtir_sdfg_utils.make_symbol_tree(pname, ptype) + deref_on_input_memlet = pname in direct_deref_iterators + inner_arg = gtx_utils.tree_map( + lambda tsym, + targ, + deref_on_input_memlet=deref_on_input_memlet: self._visit_if_branch_arg( + if_sdfg, + if_branch_state, + tsym.id, + targ, + deref_on_input_memlet, + if_sdfg_input_memlets, + ) + )(psymbol_tree, arg) + else: + psymbol = im.sym(pname, arg.gt_dtype) # type: ignore[union-attr] + deref_on_input_memlet = pname in direct_deref_iterators + inner_arg = self._visit_if_branch_arg( + if_sdfg, + if_branch_state, + pname, + arg, + deref_on_input_memlet, + if_sdfg_input_memlets, + ) + lambda_args.append(inner_arg) + lambda_params.append(psymbol) + + # visit each branch of the if-statement as if it was a Lambda node + lambda_node = gtir.Lambda(params=lambda_params, expr=expr) + input_edges, output_tree = translate_lambda_to_dataflow( + if_sdfg, if_branch_state, self.subgraph_builder, lambda_node, lambda_args + ) + + for data_node in if_branch_state.data_nodes(): + # In case of tuple arguments, isolated access nodes might be left in the state, + # because not all tuple fields are necessarily used inside the lambda scope + if if_branch_state.degree(data_node) == 0: + assert not data_node.desc(if_sdfg).transient + if_branch_state.remove_node(data_node) + + return input_edges, output_tree + + def _visit_if_branch_result( + self, sdfg: dace.SDFG, state: dace.SDFGState, edge: DataflowOutputEdge, sym: gtir.Sym + ) -> ValueExpr: + """ + Helper function to be called by `_visit_if` to create an output connector + on the nested SDFG that will write the result to the parent SDFG. + The result data inside the nested SDFG must have the same name as the connector. + """ + output_data = str(sym.id) + if output_data in sdfg.arrays: + output_desc = sdfg.data(output_data) + assert not output_desc.transient + else: + # If the result is currently written to a transient node, inside the nested SDFG, + # we need to allocate a non-transient data node. + result_desc = edge.result.dc_node.desc(sdfg) + output_desc = result_desc.clone() + output_desc.transient = False + output_data = sdfg.add_datadesc(output_data, output_desc, find_new_name=True) + output_node = state.add_access(output_data) + state.add_nedge( + edge.result.dc_node, + output_node, + dace.Memlet.from_array(output_data, output_desc), + ) + return ValueExpr(output_node, edge.result.gt_dtype) + + def _visit_if(self, node: gtir.FunCall) -> ValueExpr | tuple[ValueExpr | tuple[Any, ...], ...]: + """ + Lowers an if-expression with exclusive branch execution into a nested SDFG, + in which each branch is lowered into a dataflow in a separate state and + the if-condition is represented as the inter-state edge condition. + """ + + def write_output_of_nested_sdfg_to_temporary(inner_value: ValueExpr) -> ValueExpr: + # Each output connector of the nested SDFG writes to a transient node in the parent SDFG + inner_data = inner_value.dc_node.data + inner_desc = inner_value.dc_node.desc(nsdfg) + assert not inner_desc.transient + output, output_desc = self.subgraph_builder.add_temp_array_like(self.sdfg, inner_desc) + output_node = self.state.add_access(output) + self.state.add_edge( + nsdfg_node, + inner_data, + output_node, + None, + dace.Memlet.from_array(output, output_desc), + ) + return ValueExpr(output_node, inner_value.gt_dtype) + + assert len(node.args) == 3 + + # evaluate the if-condition that will write to a boolean scalar node + condition_value = self.visit(node.args[0]) + assert ( + ( + isinstance(condition_value.gt_dtype, ts.ScalarType) + and condition_value.gt_dtype.kind == ts.ScalarKind.BOOL + ) + if isinstance(condition_value, (MemletExpr, ValueExpr)) + else (condition_value.dc_dtype == dace.dtypes.bool_) + ) + + nsdfg = dace.SDFG(self.unique_nsdfg_name(prefix="if_stmt")) + nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=self.sdfg.debuginfo) + + # create states inside the nested SDFG for the if-branches + if_region = dace.sdfg.state.ConditionalBlock("if") + nsdfg.add_node(if_region) + entry_state = nsdfg.add_state("entry", is_start_block=True) + nsdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=nsdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("__cond"), then_body) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=nsdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock("not (__cond)"), else_body) + + input_memlets: dict[str, MemletExpr | ValueExpr] = {} + nsdfg_symbols_mapping: Optional[dict[str, dace.symbol]] = None + + # define scalar or symbol for the condition value inside the nested SDFG + if isinstance(condition_value, SymbolExpr): + nsdfg.add_symbol("__cond", dace.dtypes.bool) + else: + nsdfg.add_scalar("__cond", dace.dtypes.bool) + input_memlets["__cond"] = condition_value + + # Collect all field iterators that are shifted inside any of the then/else + # branch expressions. Iterator shift expressions require the field argument + # as iterator, therefore the corresponding array has to be passed with full + # shape into the nested SDFG where the if_ expression is lowered. When the + # branch expression simply does `deref` on the iterator, without any shifting, + # it corresponds to a direct element access. Such `deref` expressions can + # be lowered outside the nested SDFG, so that just the local value (a scalar + # or a list of values) is passed as input to the nested SDFG. + shifted_iterator_symbols = set() + for branch_expr in node.args[1:3]: + for shift_node in eve.walk_values(branch_expr).filter( + lambda x: cpm.is_applied_shift(x) + ): + shifted_iterator_symbols |= ( + eve.walk_values(shift_node) + .if_isinstance(gtir.SymRef) + .map(lambda x: str(x.id)) + .filter(lambda x: isinstance(self.symbol_map.get(x, None), IteratorExpr)) + .to_set() + ) + iterator_symbols = { + sym_name + for sym_name, sym_type in self.symbol_map.items() + if isinstance(sym_type, IteratorExpr) + } + direct_deref_iterators = ( + set(symbol_ref_utils.collect_symbol_refs(node.args[1:3], iterator_symbols)) + - shifted_iterator_symbols + ) + + for nstate, arg in zip([tstate, fstate], node.args[1:3]): + # visit each if-branch in the corresponding state of the nested SDFG + in_edges, output_tree = self._visit_if_branch( + nsdfg, nstate, arg, input_memlets, direct_deref_iterators + ) + for edge in in_edges: + edge.connect(map_entry=None) + + if isinstance(node.type, ts.TupleType): + out_symbol_tree = gtir_sdfg_utils.make_symbol_tree("__output", node.type) + outer_value = gtx_utils.tree_map( + lambda x, y, nstate=nstate: self._visit_if_branch_result(nsdfg, nstate, x, y) + )(output_tree, out_symbol_tree) + else: + assert isinstance(node.type, ts.FieldType | ts.ScalarType) + assert len(output_tree) == 1 and isinstance(output_tree[0], DataflowOutputEdge) + output_edge = output_tree[0] + outer_value = self._visit_if_branch_result( + nsdfg, nstate, output_edge, im.sym("__output", node.type) + ) + # Isolated access node will make validation fail. + # Isolated access nodes can be found in `make_tuple` expressions that + # construct tuples from input arguments. + for data_node in nstate.data_nodes(): + if nstate.degree(data_node) == 0: + assert not data_node.desc(nsdfg).transient + nsdfg.remove_node(data_node) + else: + result = outer_value + + outputs = {outval.dc_node.data for outval in gtx_utils.flatten_nested_tuple((result,))} + + # all free symbols are mapped to the symbols available in parent SDFG + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + if isinstance(condition_value, SymbolExpr): + nsdfg_symbols_mapping["__cond"] = condition_value.value + nsdfg_node = self.state.add_nested_sdfg( + nsdfg, + self.sdfg, + inputs=set(input_memlets.keys()), + outputs=outputs, + symbol_mapping=nsdfg_symbols_mapping, + ) + + for inner, input_expr in input_memlets.items(): + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, nsdfg_node, inner) + else: + self._add_edge( + input_expr.dc_node, + None, + nsdfg_node, + inner, + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + + return ( + gtx_utils.tree_map(write_output_of_nested_sdfg_to_temporary)(result) + if isinstance(result, tuple) + else write_output_of_nested_sdfg_to_temporary(result) + ) + + def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: + assert isinstance(node.type, ts.ListType) + assert len(node.args) == 2 + + assert isinstance(node.args[0], gtir.OffsetLiteral) + offset = node.args[0].value + assert isinstance(offset, str) + offset_provider = self.subgraph_builder.get_offset_provider_type(offset) + assert isinstance(offset_provider, gtx_common.NeighborConnectivityType) + + it = self.visit(node.args[1]) + assert isinstance(it, IteratorExpr) + assert any(dim == offset_provider.codomain for dim, _ in it.field_domain) + assert offset_provider.source_dim in it.indices + origin_index = it.indices[offset_provider.source_dim] + assert isinstance(origin_index, SymbolExpr) + assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) + + field_desc = it.field.desc(self.sdfg) + connectivity = gtx_dace_utils.connectivity_identifier(offset) + # initially, the storage for the connectivty tables is created as transient; + # when the tables are used, the storage is changed to non-transient, + # as the corresponding arrays are supposed to be allocated by the SDFG caller + connectivity_desc = self.sdfg.arrays[connectivity] + connectivity_desc.transient = False + + # The visitor is constructing a list of input connections that will be handled + # by `translate_as_fieldop` (the primitive translator), that is responsible + # of creating the map for the field domain. For each input connection, it will + # create a memlet that will write to a node specified by the third attribute + # in the `InputConnection` tuple (either a tasklet, or a view node, or a library + # node). For the specific case of `neighbors` we need to nest the neighbors map + # inside the field map and the memlets will traverse the external map and write + # to the view nodes. The simplify pass will remove the redundant access nodes. + field_slice = self._construct_local_view( + MemletExpr( + dc_node=it.field, + gt_dtype=node.type, + subset=dace_subsets.Range.from_string( + ",".join( + str(it.indices[dim].value - offset) # type: ignore[union-attr] + if dim != offset_provider.codomain + else f"0:{size}" + for (dim, offset), size in zip( + it.field_domain, field_desc.shape, strict=True + ) + ) + ), + ) + ) + connectivity_slice = self._construct_local_view( + MemletExpr( + dc_node=self.state.add_access(connectivity), + gt_dtype=node.type, + subset=dace_subsets.Range.from_string( + f"{origin_index.value}, 0:{offset_provider.max_neighbors}" + ), + ) + ) + + neighbors_temp, _ = self.subgraph_builder.add_temp_array( + self.sdfg, (offset_provider.max_neighbors,), field_desc.dtype + ) + neighbors_node = self.state.add_access(neighbors_temp) + offset_type = gtx_common.Dimension(offset, gtx_common.DimensionKind.LOCAL) + neighbor_idx = gtir_sdfg_utils.get_map_variable(offset_type) + + index_connector = "__index" + output_connector = "__val" + tasklet_expression = f"{output_connector} = __field[{index_connector}]" + input_memlets = { + "__field": self.sdfg.make_array_memlet(field_slice.dc_node.data), + index_connector: dace.Memlet(data=connectivity_slice.dc_node.data, subset=neighbor_idx), + } + input_nodes = { + field_slice.dc_node.data: field_slice.dc_node, + connectivity_slice.dc_node.data: connectivity_slice.dc_node, + } + + if offset_provider.has_skip_values: + # in case of skip value we can write any dummy value + skip_value = ( + "math.nan" + if ti.is_floating_point(node.type.element_type) + else str(dace.dtypes.max_value(field_desc.dtype)) + ) + tasklet_expression += ( + f" if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}" + ) + + self._add_mapped_tasklet( + name=f"{offset}_neighbors", + map_ranges={neighbor_idx: f"0:{offset_provider.max_neighbors}"}, + code=tasklet_expression, + inputs=input_memlets, + input_nodes=input_nodes, + outputs={ + output_connector: dace.Memlet(data=neighbors_temp, subset=neighbor_idx), + }, + output_nodes={neighbors_temp: neighbors_node}, + external_edges=True, + ) + + return ValueExpr( + dc_node=neighbors_node, gt_dtype=ts.ListType(node.type.element_type, offset_type) + ) + + def _visit_list_get(self, node: gtir.FunCall) -> ValueExpr: + assert len(node.args) == 2 + index_arg = self.visit(node.args[0]) + list_arg = self.visit(node.args[1]) + assert isinstance(list_arg, ValueExpr) + assert isinstance(list_arg.gt_dtype, ts.ListType) + assert isinstance(list_arg.gt_dtype.element_type, ts.ScalarType) + + list_desc = list_arg.dc_node.desc(self.sdfg) + assert len(list_desc.shape) == 1 + + result_dtype = gtx_dace_utils.as_dace_type(list_arg.gt_dtype.element_type) + result, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, result_dtype) + result_node = self.state.add_access(result) + + if isinstance(index_arg, SymbolExpr): + assert index_arg.dc_dtype in dace.dtypes.INTEGER_TYPES + self._add_edge( + list_arg.dc_node, + None, + result_node, + None, + dace.Memlet(data=list_arg.dc_node.data, subset=index_arg.value), + ) + elif isinstance(index_arg, ValueExpr): + tasklet_node = self._add_tasklet( + "list_get", inputs={"index", "list"}, outputs={"value"}, code="value = list[index]" + ) + self._add_edge( + index_arg.dc_node, + None, + tasklet_node, + "index", + dace.Memlet(data=index_arg.dc_node.data, subset="0"), + ) + self._add_edge( + list_arg.dc_node, + None, + tasklet_node, + "list", + self.sdfg.make_array_memlet(list_arg.dc_node.data), + ) + self._add_edge( + tasklet_node, "value", result_node, None, dace.Memlet(data=result, subset="0") + ) + else: + raise TypeError(f"Unexpected value {index_arg} as index argument.") + + return ValueExpr(dc_node=result_node, gt_dtype=list_arg.gt_dtype.element_type) + + def _visit_map(self, node: gtir.FunCall) -> ValueExpr: + """ + A map node defines an operation to be mapped on all elements of input arguments. + + The map operation is applied on the local dimension of input fields. + In the example below, the local dimension consists of a list of neighbor + values as the first argument, and a list of constant values `1.0`: + `map_(plus)(neighbors(V2E, it), make_const_list(1.0))` + + The `plus` operation is lowered to a tasklet inside a map that computes + the domain of the local dimension (in this example, max neighbors in V2E). + + The result is a 1D local field, with same size as the input local dimension. + In above example, the result would be an array with size V2E.max_neighbors, + containing the V2E neighbor values incremented by 1.0. + """ + assert isinstance(node.type, ts.ListType) + assert isinstance(node.fun, gtir.FunCall) + assert len(node.fun.args) == 1 # the operation to be mapped on the arguments + + assert isinstance(node.type.element_type, ts.ScalarType) + dc_dtype = gtx_dace_utils.as_dace_type(node.type.element_type) + + input_connectors = [f"__arg{i}" for i in range(len(node.args))] + output_connector = "__out" + + # Here we build the body of the tasklet + fun_node = im.call(node.fun.args[0])(*input_connectors) + fun_python_code = gtir_python_codegen.get_source(fun_node) + tasklet_expression = f"{output_connector} = {fun_python_code}" + + input_args = [self.visit(arg) for arg in node.args] + input_connectivity_types: dict[ + gtx_common.Dimension, gtx_common.NeighborConnectivityType + ] = {} + for input_arg in input_args: + assert isinstance(input_arg.gt_dtype, ts.ListType) + assert input_arg.gt_dtype.offset_type is not None + offset_type = input_arg.gt_dtype.offset_type + if offset_type == _CONST_DIM: + # this input argument is the result of `make_const_list` + continue + offset_provider_t = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_t, gtx_common.NeighborConnectivityType) + input_connectivity_types[offset_type] = offset_provider_t + + if len(input_connectivity_types) == 0: + raise ValueError(f"Missing information on local dimension for map node {node}.") + + # GT4Py guarantees that all connectivities used to generate lists of neighbors + # have the same length, that is the same value of 'max_neighbors'. + if ( + len( + set( + (conn.has_skip_values, conn.max_neighbors) + for conn in input_connectivity_types.values() + ) + ) + != 1 + ): + raise ValueError("Unexpected arguments to map expression with different neighborhood.") + offset_type, offset_provider_type = next(iter(input_connectivity_types.items())) + local_size = offset_provider_type.max_neighbors + map_index = gtir_sdfg_utils.get_map_variable(offset_type) + + # The dataflow we build in this class has some loose connections on input edges. + # These edges are described as set of nodes, that will have to be connected to + # external data source nodes passing through the map entry node of the field map. + # Similarly to `neighbors` expressions, the `map_` input edges terminate on view + # nodes (see `_construct_local_view` in the for-loop below), because it is simpler + # than representing map-to-map edges (which require memlets with 2 pass-nodes). + input_memlets = {} + input_nodes = {} + for conn, input_arg in zip(input_connectors, input_args): + input_node = self._construct_local_view(input_arg).dc_node + input_desc = input_node.desc(self.sdfg) + # we assume that there is a single local dimension + if len(input_desc.shape) != 1: + raise ValueError(f"More than one local dimension in map expression {node}.") + input_size = input_desc.shape[0] + if input_size == 1: + assert input_arg.gt_dtype.offset_type == _CONST_DIM + input_memlets[conn] = dace.Memlet(data=input_node.data, subset="0") + elif input_size == local_size: + input_memlets[conn] = dace.Memlet(data=input_node.data, subset=map_index) + else: + raise ValueError( + f"Argument to map node with local size {input_size}, expected {local_size}." + ) + input_nodes[input_node.data] = input_node + + result, _ = self.subgraph_builder.add_temp_array(self.sdfg, (local_size,), dc_dtype) + result_node = self.state.add_access(result) + + if offset_provider_type.has_skip_values: + # In case the `map_` input expressions contain skip values, we use + # the connectivity-based offset provider as mask for map computation. + connectivity = gtx_dace_utils.connectivity_identifier(offset_type.value) + connectivity_desc = self.sdfg.arrays[connectivity] + connectivity_desc.transient = False + + origin_map_index = gtir_sdfg_utils.get_map_variable(offset_provider_type.source_dim) + + connectivity_slice = self._construct_local_view( + MemletExpr( + dc_node=self.state.add_access(connectivity), + gt_dtype=ts.ListType( + element_type=node.type.element_type, offset_type=offset_type + ), + subset=dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), + ) + ) + + input_memlets["__neighbor_idx"] = dace.Memlet( + data=connectivity_slice.dc_node.data, subset=map_index + ) + input_nodes[connectivity_slice.dc_node.data] = connectivity_slice.dc_node + + # in case of skip value we can write any dummy value + skip_value = ( + "math.nan" + if ti.is_floating_point(node.type.element_type) + else str(dace.dtypes.max_value(dc_dtype)) + ) + tasklet_expression += ( + f" if __neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}" + ) + + self._add_mapped_tasklet( + name="map", + map_ranges={map_index: f"0:{local_size}"}, + code=tasklet_expression, + inputs=input_memlets, + input_nodes=input_nodes, + outputs={ + output_connector: dace.Memlet(data=result, subset=map_index), + }, + output_nodes={result: result_node}, + external_edges=True, + ) + + return ValueExpr( + dc_node=result_node, + gt_dtype=ts.ListType(node.type.element_type, offset_type), + ) + + def _make_reduce_with_skip_values( + self, + input_expr: ValueExpr | MemletExpr, + offset_provider_type: gtx_common.NeighborConnectivityType, + reduce_init: SymbolExpr, + reduce_identity: SymbolExpr, + reduce_wcr: str, + result_node: dace.nodes.AccessNode, + ) -> None: + """ + Helper method to lower reduction on a local field containing skip values. + + The reduction is implemented as a nested SDFG containing 2 states. In first + state, the result (a scalar data node passed as argumet) is initialized. + In second state, a mapped tasklet uses a write-conflict resolution (wcr) + memlet to update the result. + We use the offset provider as a mask to identify skip values: the value + that is written to the result node is either the input value, when the + corresponding neighbor index in the connectivity table is valid, or the + identity value if the neighbor index is missing. + """ + origin_map_index = gtir_sdfg_utils.get_map_variable(offset_provider_type.source_dim) + + assert ( + isinstance(input_expr.gt_dtype, ts.ListType) + and input_expr.gt_dtype.offset_type is not None + ) + offset_type = input_expr.gt_dtype.offset_type + connectivity = gtx_dace_utils.connectivity_identifier(offset_type.value) + connectivity_node = self.state.add_access(connectivity) + connectivity_desc = connectivity_node.desc(self.sdfg) + connectivity_desc.transient = False + + desc = input_expr.dc_node.desc(self.sdfg) + if isinstance(input_expr, MemletExpr): + local_dim_indices = [i for i, size in enumerate(input_expr.subset.size()) if size != 1] + else: + local_dim_indices = list(range(len(desc.shape))) + + if len(local_dim_indices) != 1: + raise NotImplementedError( + f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." + ) + local_dim_index = local_dim_indices[0] + assert desc.shape[local_dim_index] == offset_provider_type.max_neighbors + + # we lower the reduction map with WCR out memlet in a nested SDFG + nsdfg = dace.SDFG(name=self.unique_nsdfg_name("reduce_with_skip_values")) + nsdfg.add_array( + "values", + (desc.shape[local_dim_index],), + desc.dtype, + strides=(desc.strides[local_dim_index],), + ) + nsdfg.add_array( + "neighbor_indices", + (connectivity_desc.shape[1],), + connectivity_desc.dtype, + strides=(connectivity_desc.strides[1],), + ) + nsdfg.add_scalar("acc", desc.dtype) + st_init = nsdfg.add_state(f"{nsdfg.label}_init") + st_init.add_edge( + st_init.add_tasklet( + "init_acc", + {}, + {"__val"}, + f"__val = {reduce_init.dc_dtype}({reduce_init.value})", + ), + "__val", + st_init.add_access("acc"), + None, + dace.Memlet(data="acc", subset="0"), + ) + st_reduce = nsdfg.add_state_after(st_init, f"{nsdfg.label}_reduce") + # Fill skip values in local dimension with the reduce identity value + skip_value = f"{reduce_identity.dc_dtype}({reduce_identity.value})" + # Since this map operates on a pure local dimension, we explicitly set sequential + # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. + # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. + st_reduce.add_mapped_tasklet( + name="reduce_with_skip_values", + map_ranges={"i": f"0:{offset_provider_type.max_neighbors}"}, + inputs={ + "__val": dace.Memlet(data="values", subset="i"), + "__neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), + }, + code=f"__out = __val if __neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", + outputs={ + "__out": dace.Memlet(data="acc", subset="0", wcr=reduce_wcr, wcr_nonatomic=True), + }, + external_edges=True, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + + nsdfg_node = self.state.add_nested_sdfg( + nsdfg, self.sdfg, inputs={"values", "neighbor_indices"}, outputs={"acc"} + ) + + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, nsdfg_node, "values") + else: + self.state.add_edge( + input_expr.dc_node, + None, + nsdfg_node, + "values", + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + self._add_input_data_edge( + connectivity_node, + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), + nsdfg_node, + "neighbor_indices", + ) + self.state.add_edge( + nsdfg_node, + "acc", + result_node, + None, + dace.Memlet(data=result_node.data, subset="0"), + ) + + def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: + assert isinstance(node.type, ts.ScalarType) + op_name, reduce_init, reduce_identity = get_reduce_params(node) + reduce_wcr = "lambda x, y: " + gtir_python_codegen.format_builtin(op_name, "x", "y") + + result, _ = self.subgraph_builder.add_temp_scalar(self.sdfg, reduce_identity.dc_dtype) + result_node = self.state.add_access(result) + + input_expr = self.visit(node.args[0]) + assert isinstance(input_expr, (MemletExpr, ValueExpr)) + assert ( + isinstance(input_expr.gt_dtype, ts.ListType) + and input_expr.gt_dtype.offset_type is not None + ) + offset_type = input_expr.gt_dtype.offset_type + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + + if offset_provider_type.has_skip_values: + self._make_reduce_with_skip_values( + input_expr, + offset_provider_type, + reduce_init, + reduce_identity, + reduce_wcr, + result_node, + ) + + else: + reduce_node = self.state.add_reduce(reduce_wcr, axes=None, identity=reduce_init.value) + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, reduce_node) + else: + self.state.add_nedge( + input_expr.dc_node, + reduce_node, + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + self.state.add_nedge(reduce_node, result_node, dace.Memlet(data=result, subset="0")) + + return ValueExpr(result_node, node.type) + + def _split_shift_args( + self, args: list[gtir.Expr] + ) -> tuple[tuple[gtir.Expr, gtir.Expr], Optional[list[gtir.Expr]]]: + """ + Splits the arguments to `shift` builtin function as pairs, each pair containing + the offset provider and the offset expression in one dimension. + """ + nargs = len(args) + assert nargs >= 2 and nargs % 2 == 0 + return (args[-2], args[-1]), args[: nargs - 2] if nargs > 2 else None + + def _visit_shift_multidim( + self, iterator: gtir.Expr, shift_args: list[gtir.Expr] + ) -> tuple[gtir.Expr, gtir.Expr, IteratorExpr]: + """Transforms a multi-dimensional shift into recursive shift calls, each in a single dimension.""" + (offset_provider_arg, offset_value_arg), tail = self._split_shift_args(shift_args) + if tail: + node = gtir.FunCall( + fun=gtir.FunCall(fun=gtir.SymRef(id="shift"), args=tail), + args=[iterator], + ) + it = self.visit(node) + else: + it = self.visit(iterator) + + assert isinstance(it, IteratorExpr) + return offset_provider_arg, offset_value_arg, it + + def _make_cartesian_shift( + self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: DataExpr + ) -> IteratorExpr: + """Implements cartesian shift along one dimension.""" + assert any(dim == offset_dim for dim, _ in it.field_domain) + new_index: SymbolExpr | ValueExpr + index_expr = it.indices[offset_dim] + if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): + # purely symbolic expression which can be interpreted at compile time + new_index = SymbolExpr( + index_expr.value + offset_expr.value, + index_expr.dc_dtype, + ) + else: + # the offset needs to be calculated by means of a tasklet (i.e. dynamic offset) + new_index_connector = "shifted_index" + if isinstance(index_expr, SymbolExpr): + dynamic_offset_tasklet = self._add_tasklet( + "dynamic_offset", + {"offset"}, + {new_index_connector}, + f"{new_index_connector} = {index_expr.value} + offset", + ) + elif isinstance(offset_expr, SymbolExpr): + dynamic_offset_tasklet = self._add_tasklet( + "dynamic_offset", + {"index"}, + {new_index_connector}, + f"{new_index_connector} = index + {offset_expr}", + ) + else: + dynamic_offset_tasklet = self._add_tasklet( + "dynamic_offset", + {"index", "offset"}, + {new_index_connector}, + f"{new_index_connector} = index + offset", + ) + for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge( + input_expr.dc_node, + input_expr.subset, + dynamic_offset_tasklet, + input_connector, + ) + elif isinstance(input_expr, ValueExpr): + self._add_edge( + input_expr.dc_node, + None, + dynamic_offset_tasklet, + input_connector, + dace.Memlet(data=input_expr.dc_node.data, subset="0"), + ) + + if isinstance(index_expr, SymbolExpr): + dc_dtype = index_expr.dc_dtype + else: + dc_dtype = index_expr.dc_node.desc(self.sdfg).dtype + + new_index = self._construct_tasklet_result( + dc_dtype, dynamic_offset_tasklet, new_index_connector + ) + + # a new iterator with a shifted index along one dimension + shifted_indices = { + dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items() + } + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) + + def _make_dynamic_neighbor_offset( + self, + offset_expr: MemletExpr | ValueExpr, + offset_table_node: dace.nodes.AccessNode, + origin_index: SymbolExpr, + ) -> ValueExpr: + """ + Implements access to neighbor connectivity table by means of a tasklet node. + + It requires a dynamic offset value, either obtained from a field/scalar argument (`MemletExpr`) + or computed by another tasklet (`DataExpr`). + """ + new_index_connector = "neighbor_index" + tasklet_node = self._add_tasklet( + "dynamic_neighbor_offset", + {"table", "offset"}, + {new_index_connector}, + f"{new_index_connector} = table[{origin_index.value}, offset]", + ) + self._add_input_data_edge( + offset_table_node, + dace_subsets.Range.from_array(offset_table_node.desc(self.sdfg)), + tasklet_node, + "table", + ) + if isinstance(offset_expr, MemletExpr): + self._add_input_data_edge( + offset_expr.dc_node, + offset_expr.subset, + tasklet_node, + "offset", + ) + else: + self._add_edge( + offset_expr.dc_node, + None, + tasklet_node, + "offset", + dace.Memlet(data=offset_expr.dc_node.data, subset="0"), + ) + + dc_dtype = offset_table_node.desc(self.sdfg).dtype + return self._construct_tasklet_result(dc_dtype, tasklet_node, new_index_connector) + + def _make_unstructured_shift( + self, + it: IteratorExpr, + connectivity: gtx_common.NeighborConnectivityType, + offset_table_node: dace.nodes.AccessNode, + offset_expr: DataExpr, + ) -> IteratorExpr: + """Implements shift in unstructured domain by means of a neighbor table.""" + assert any(dim == connectivity.codomain for dim, _ in it.field_domain) + neighbor_dim = connectivity.codomain + origin_dim = connectivity.source_dim + origin_index = it.indices[origin_dim] + assert isinstance(origin_index, SymbolExpr) + + shifted_indices = {dim: idx for dim, idx in it.indices.items() if dim != origin_dim} + if isinstance(offset_expr, SymbolExpr): + # use memlet to retrieve the neighbor index + shifted_indices[neighbor_dim] = MemletExpr( + dc_node=offset_table_node, + gt_dtype=it.gt_dtype, + subset=dace_subsets.Range.from_string(f"{origin_index.value}, {offset_expr.value}"), + ) + else: + # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node + shifted_indices[neighbor_dim] = self._make_dynamic_neighbor_offset( + offset_expr, offset_table_node, origin_index + ) + + return IteratorExpr(it.field, it.gt_dtype, it.field_domain, shifted_indices) + + def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: + # convert builtin-index type to dace type + IndexDType: Final = gtx_dace_utils.as_dace_type( + ts.ScalarType(kind=getattr(ts.ScalarKind, gtir_builtins.INTEGER_INDEX_BUILTIN.upper())) + ) + + assert isinstance(node.fun, gtir.FunCall) + # the iterator to be shifted is the node argument, while the shift arguments + # are provided by the nested function call; the shift arguments consist of + # the offset provider and the offset value in each dimension to be shifted + offset_provider_arg, offset_value_arg, it = self._visit_shift_multidim( + node.args[0], node.fun.args + ) + + # first argument of the shift node is the offset provider + assert isinstance(offset_provider_arg, gtir.OffsetLiteral) + offset = offset_provider_arg.value + assert isinstance(offset, str) + offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset) + # second argument should be the offset value, which could be a symbolic expression or a dynamic offset + offset_expr = ( + SymbolExpr(offset_value_arg.value, IndexDType) + if isinstance(offset_value_arg, gtir.OffsetLiteral) + else self.visit(offset_value_arg) + ) + + if isinstance(offset_provider_type, gtx_common.Dimension): + return self._make_cartesian_shift(it, offset_provider_type, offset_expr) + else: + # initially, the storage for the connectivity tables is created as transient; + # when the tables are used, the storage is changed to non-transient, + # so the corresponding arrays are supposed to be allocated by the SDFG caller + offset_table = gtx_dace_utils.connectivity_identifier(offset) + self.sdfg.arrays[offset_table].transient = False + offset_table_node = self.state.add_access(offset_table) + + return self._make_unstructured_shift( + it, offset_provider_type, offset_table_node, offset_expr + ) + + def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: + """ + Generic handler called by `visit_FunCall()` when it encounters + a builtin function that does not match any other specific handler. + """ + node_internals = [] + node_connections: dict[str, MemletExpr | ValueExpr] = {} + for i, arg in enumerate(node.args): + arg_expr = self.visit(arg) + if isinstance(arg_expr, MemletExpr | ValueExpr): + # the argument value is the result of a tasklet node or direct field access + connector = f"__arg{i}" + node_connections[connector] = arg_expr + node_internals.append(connector) + else: + assert isinstance(arg_expr, SymbolExpr) + # use the argument value without adding any connector + node_internals.append(arg_expr.value) + + assert isinstance(node.fun, gtir.SymRef) + builtin_name = str(node.fun.id) + # use tasklet connectors as expression arguments + code = gtir_python_codegen.format_builtin(builtin_name, *node_internals) + + out_connector = "result" + tasklet_node = self._add_tasklet( + builtin_name, + set(node_connections.keys()), + {out_connector}, + "{} = {}".format(out_connector, code), + ) + + for connector, arg_expr in node_connections.items(): + if isinstance(arg_expr, ValueExpr): + self._add_edge( + arg_expr.dc_node, + None, + tasklet_node, + connector, + dace.Memlet(data=arg_expr.dc_node.data, subset="0"), + ) + else: + self._add_input_data_edge( + arg_expr.dc_node, + arg_expr.subset, + tasklet_node, + connector, + ) + + if isinstance(node.type, ts.ListType): + # The only builtin function (so far) handled here that returns a list + # is 'make_const_list'. There are other builtin functions (map_, neighbors) + # that return a list but they are handled in specialized visit methods. + # This method (the generic visitor for builtin functions) always returns + # a single value. This is also the case of 'make_const_list' expression: + # it simply broadcasts a scalar on the local domain of another expression, + # for example 'map_(plus)(neighbors(V2Eₒ, it), make_const_list(1.0))'. + # Therefore we handle `ListType` as a single-element array with shape (1,) + # that will be accessed in a map expression on a local domain. + assert isinstance(node.type.element_type, ts.ScalarType) + dc_dtype = gtx_dace_utils.as_dace_type(node.type.element_type) + # In order to ease the lowring of the parent expression on local dimension, + # we represent the scalar value as a single-element 1D array. + use_array = True + else: + assert isinstance(node.type, ts.ScalarType) + dc_dtype = gtx_dace_utils.as_dace_type(node.type) + use_array = False + + return self._construct_tasklet_result(dc_dtype, tasklet_node, "result", use_array=use_array) + + def _visit_make_tuple(self, node: gtir.FunCall) -> tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "make_tuple") + return tuple(self.visit(arg) for arg in node.args) + + def _visit_tuple_get( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr]: + assert cpm.is_call_to(node, "tuple_get") + assert len(node.args) == 2 + + if not isinstance(node.args[0], gtir.Literal): + raise ValueError("Tuple can only be subscripted with compile-time constants.") + assert ti.is_integral(node.args[0].type) + index = int(node.args[0].value) + + tuple_fields = self.visit(node.args[1]) + return tuple_fields[index] + + def visit_FunCall( + self, node: gtir.FunCall + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: + if cpm.is_call_to(node, "deref"): + return self._visit_deref(node) + + elif cpm.is_call_to(node, "if_"): + return self._visit_if(node) + + elif cpm.is_call_to(node, "neighbors"): + return self._visit_neighbors(node) + + elif cpm.is_call_to(node, "list_get"): + return self._visit_list_get(node) + + elif cpm.is_call_to(node, "make_tuple"): + return self._visit_make_tuple(node) + + elif cpm.is_call_to(node, "tuple_get"): + return self._visit_tuple_get(node) + + elif cpm.is_applied_map(node): + return self._visit_map(node) + + elif cpm.is_applied_reduce(node): + return self._visit_reduce(node) + + elif cpm.is_applied_shift(node): + return self._visit_shift(node) + + elif isinstance(node.fun, gtir.Lambda): + # Lambda node should be visited with 'visit_let()' method. + raise ValueError(f"Unexpected lambda in 'FunCall' node: {node}.") + + elif isinstance(node.fun, gtir.SymRef): + return self._visit_generic_builtin(node) + + else: + raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") + + def visit_Lambda( + self, node: gtir.Lambda + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: + def _visit_Lambda_impl( + output_expr: DataflowOutputEdge | ValueExpr | MemletExpr | SymbolExpr, + ) -> DataflowOutputEdge: + if isinstance(output_expr, DataflowOutputEdge): + return output_expr + if isinstance(output_expr, ValueExpr): + return DataflowOutputEdge(self.state, output_expr) + + if isinstance(output_expr, MemletExpr): + # special case where the field operator is simply copying data from source to destination node + output_dtype = output_expr.dc_node.desc(self.sdfg).dtype + tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") + self._add_input_data_edge( + output_expr.dc_node, + output_expr.subset, + tasklet_node, + "__inp", + ) + else: + # even simpler case, where a constant value is written to destination node + output_dtype = output_expr.dc_dtype + tasklet_node = self._add_tasklet( + "write", {}, {"__out"}, f"__out = {output_expr.value}" + ) + + output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") + return DataflowOutputEdge(self.state, output_expr) + + result = self.visit(node.expr) + + return ( + gtx_utils.tree_map(_visit_Lambda_impl)(result) + if isinstance(result, tuple) + else _visit_Lambda_impl(result) + ) + + def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: + dc_dtype = gtx_dace_utils.as_dace_type(node.type) + return SymbolExpr(node.value, dc_dtype) + + def visit_SymRef( + self, node: gtir.SymRef + ) -> IteratorExpr | DataExpr | tuple[IteratorExpr | DataExpr | tuple[Any, ...], ...]: + param = str(node.id) + if param in self.symbol_map: + return self.symbol_map[param] + # if not in the lambda symbol map, this must be a symref to a builtin function + assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING + return SymbolExpr(param, dace.string) + + def visit_let( + self, + node: gtir.Lambda, + args: Sequence[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], + ) -> DataflowOutputEdge | tuple[DataflowOutputEdge | tuple[Any, ...], ...]: + """ + Maps lambda arguments to internal parameters. + + This method is responsible to recognize the usage of the `Lambda` node, + which can be either a let-statement or the stencil expression in local view. + The usage of a `Lambda` as let-statement corresponds to computing some results + and making them available inside the lambda scope, represented as a nested SDFG. + All let-statements, if any, are supposed to be encountered before the stencil + expression. In other words, the `Lambda` node representing the stencil expression + is always the innermost node. + Therefore, the lowering of let-statements results in recursive calls to + `visit_let()` until the stencil expression is found. At that point, it falls + back to the `visit()` function. + """ + + # lambda arguments are mapped to symbols defined in lambda scope. + for p, arg in zip(node.params, args, strict=True): + self.symbol_map[str(p.id)] = arg + + if cpm.is_let(node.expr): + let_node = node.expr + let_args = [self.visit(arg) for arg in let_node.args] + assert isinstance(let_node.fun, gtir.Lambda) + return self.visit_let(let_node.fun, args=let_args) + else: + # this lambda node is not a let-statement, but a stencil expression + return self.visit(node) + + +def translate_lambda_to_dataflow( + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.DataflowBuilder, + node: gtir.Lambda, + args: Sequence[ + IteratorExpr + | MemletExpr + | ValueExpr + | tuple[IteratorExpr | MemletExpr | ValueExpr | tuple[Any, ...], ...] + ], +) -> tuple[ + list[DataflowInputEdge], + tuple[DataflowOutputEdge | tuple[Any, ...], ...], +]: + """ + Entry point to visit a `Lambda` node and lower it to a dataflow graph, + that can be instantiated inside a map scope implementing the field operator. + + It calls `LambdaToDataflow.visit_let()` to map the lambda arguments to internal + parameters and visit the let-statements (if any), which always appear as outermost + nodes. Finally, the visitor returns the output edge of the dataflow. + + Args: + sdfg: The SDFG where the dataflow graph will be instantiated. + state: The SDFG state where the dataflow graph will be instantiated. + sdfg_builder: Helper class to build the dataflow inside the given SDFG. + node: Lambda node to visit. + args: Arguments passed to lambda node. + + Returns: + A tuple of two elements: + - List of connections for data inputs to the dataflow. + - Tree representation of output data connections. + """ + taskgen = LambdaToDataflow(sdfg, state, sdfg_builder) + lambda_output = taskgen.visit_let(node, args) + + if isinstance(lambda_output, DataflowOutputEdge): + return taskgen.input_edges, (lambda_output,) + else: + return taskgen.input_edges, lambda_output diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py similarity index 56% rename from src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py rename to src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py index f133a9224d..199783d893 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace/gtir_python_codegen.py @@ -14,12 +14,13 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt -from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator import builtins, ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm MATH_BUILTINS_MAPPING = { "abs": "abs({})", + "neg": "(- {})", "sin": "math.sin({})", "cos": "math.cos({})", "tan": "math.tan({})", @@ -65,27 +66,41 @@ "less_equal": "({} <= {})", "greater": "({} > {})", "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", + "and_": "({} and {})", + "or_": "({} or {})", + "xor_": "({} != {})", "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy + "not_": "(not {})", } -def builtin_cast(*args: Any) -> str: - val, target_type = args +def builtin_cast(val: str, target_type: str) -> str: + assert target_type in builtins.TYPE_BUILTINS return MATH_BUILTINS_MAPPING[target_type].format(val) -def builtin_if(*args: Any) -> str: - cond, true_val, false_val = args +def builtin_if(cond: str, true_val: str, false_val: str) -> str: return f"{true_val} if {cond} else {false_val}" -GENERAL_BUILTIN_MAPPING: dict[str, Callable[[Any], str]] = { +def builtin_tuple_get(index: str, tuple_name: str) -> str: + return f"{tuple_name}_{index}" + + +def make_const_list(arg: str) -> str: + """ + Takes a single scalar argument and broadcasts this value on the local dimension + of map expression. In a dataflow, we represent it as a tasklet that writes + a value to a scalar node. + """ + return arg + + +GENERAL_BUILTIN_MAPPING: dict[str, Callable[..., str]] = { "cast_": builtin_cast, "if_": builtin_if, + "make_const_list": make_const_list, + "tuple_get": builtin_tuple_get, } @@ -107,29 +122,41 @@ class PythonCodegen(codegen.TemplatedGenerator): as in the case of field domain definitions, for sybolic array shape and map range. """ - SymRef = as_fmt("{id}") Literal = as_fmt("{value}") - def _visit_deref(self, node: gtir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], gtir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def visit_FunCall(self, node: gtir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) + def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str: + if isinstance(node.fun, gtir.Lambda): + # update the mapping from lambda parameters to corresponding argument expressions + lambda_args_map = args_map | { + p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True) + } + return self.visit(node.fun.expr, args_map=lambda_args_map) + elif cpm.is_call_to(node, "deref"): + assert len(node.args) == 1 + if not isinstance(node.args[0], gtir.SymRef): + # shift expressions are not expected in this visitor context + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + return self.visit(node.args[0], args_map=args_map) elif isinstance(node.fun, gtir.SymRef): - args = self.visit(node.args) + args = self.visit(node.args, args_map=args_map) builtin_name = str(node.fun.id) return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str: + symbol = str(node.id) + if symbol in args_map: + return self.visit(args_map[symbol], args_map=args_map) + return symbol + -get_source = PythonCodegen.apply -""" -Specialized visit method for symbolic expressions. +def get_source(node: gtir.Node) -> str: + """ + Specialized visit method for symbolic expressions. -Returns: - A string containing the Python code corresponding to a symbolic expression -""" + The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions. + + Returns: + A string containing the Python code corresponding to a symbolic expression + """ + return PythonCodegen.apply(node, args_map={}) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py new file mode 100644 index 0000000000..da10d4bddd --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/gtir_scan_translator.py @@ -0,0 +1,673 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the lowering of scan field operator. + +This builtin translator implements the `PrimitiveTranslator` protocol as other +translators in `gtir_builtin_translators` module. This module implements the scan +translator, separately from the `gtir_builtin_translators` module, because the +parsing of input arguments as well as the construction of the map scope differ +from a regular field operator, which requires slightly different helper methods. +Besides, the function code is quite large, another reason to keep it separate +from other translators. + +The current GTIR representation of the scan operator is based on iterator view. +This is likely to change in the future, to enable GTIR optimizations for scan. +""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING, Any, Iterable + +import dace +from dace import subsets as dace_subsets + +from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next.program_processors.runners.dace import ( + gtir_builtin_translators as gtir_translators, + gtir_dataflow, + gtir_sdfg, + gtir_sdfg_utils, +) +from gt4py.next.type_system import type_info as ti, type_specifications as ts + + +if TYPE_CHECKING: + from gt4py.next.program_processors.runners.dace import gtir_sdfg + + +def _parse_scan_fieldop_arg( + node: gtir.Expr, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, + domain: gtir_translators.FieldopDomain, +) -> gtir_dataflow.MemletExpr | tuple[gtir_dataflow.MemletExpr | tuple[Any, ...], ...]: + """Helper method to visit an expression passed as argument to a scan field operator. + + On the innermost level, a scan operator is lowered to a loop region which computes + column elements in the vertical dimension. + + It differs from the helper method `gtir_builtin_translators` in that field arguments + are passed in full shape along the vertical dimension, rather than as iterator. + """ + + def _parse_fieldop_arg_impl( + arg: gtir_translators.FieldopData, + ) -> gtir_dataflow.MemletExpr: + arg_expr = arg.get_local_view(domain) + if isinstance(arg_expr, gtir_dataflow.MemletExpr): + return arg_expr + # In scan field operator, the arguments to the vertical stencil are passed by value. + # Therefore, the full field shape is passed as `MemletExpr` rather than `IteratorExpr`. + return gtir_dataflow.MemletExpr( + arg_expr.field, arg_expr.gt_dtype, arg_expr.get_memlet_subset(sdfg) + ) + + arg = sdfg_builder.visit(node, sdfg=sdfg, head_state=state) + + if isinstance(arg, gtir_translators.FieldopData): + return _parse_fieldop_arg_impl(arg) + else: + # handle tuples of fields + return gtx_utils.tree_map(lambda x: _parse_fieldop_arg_impl(x))(arg) + + +def _create_scan_field_operator_impl( + sdfg_builder: gtir_sdfg.SDFGBuilder, + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: gtir_translators.FieldopDomain, + output_edge: gtir_dataflow.DataflowOutputEdge, + output_type: ts.FieldType, + map_exit: dace.nodes.MapExit, +) -> gtir_translators.FieldopData: + """ + Helper method to allocate a temporary array that stores one field computed + by the scan field operator. + + This method is called by `_create_scan_field_operator()`. + + Similar to `gtir_builtin_translators._create_field_operator_impl()` but + for scan field operators. It differs in that the scan loop region produces + a field along the vertical dimension, rather than a single point. + Therefore, the memlet subset will write a slice into the result array, that + corresponds to the full vertical shape for each horizontal grid point. + + Refer to `gtir_builtin_translators._create_field_operator_impl()` for + the description of function arguments and return values. + """ + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) + assert isinstance(dataflow_output_desc, dace.data.Array) + + # the memory layout of the output field follows the field operator compute domain + field_dims, field_origin, field_shape = gtir_translators.get_field_layout(domain) + field_indices = gtir_translators.get_domain_indices(field_dims, field_origin) + field_subset = dace_subsets.Range.from_indices(field_indices) + + # the vertical dimension used as scan column is computed by the `LoopRegion` + # inside the map scope, therefore it is excluded from the map range + scan_dim_index = [sdfg_builder.is_column_axis(dim) for dim in field_dims].index(True) + + # the map scope writes the full-shape dimension corresponding to the scan column + field_subset = ( + dace_subsets.Range(field_subset[:scan_dim_index]) + + dace_subsets.Range.from_string(f"0:{dataflow_output_desc.shape[0]}") + + dace_subsets.Range(field_subset[scan_dim_index + 1 :]) + ) + + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert isinstance(output_type.dtype, ts.ScalarType) + if output_edge.result.gt_dtype != output_type.dtype: + raise TypeError( + f"Type mismatch, expected {output_type.dtype} got {output_edge.result.gt_dtype}." + ) + field_dtype = output_edge.result.gt_dtype + # the scan field operator computes a column of scalar values + assert len(dataflow_output_desc.shape) == 1 + else: + assert isinstance(output_type.dtype, ts.ListType) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type + if field_dtype != output_type.dtype.element_type: + raise TypeError( + f"Type mismatch, expected {output_type.dtype.element_type} got {field_dtype}." + ) + # the scan field operator computes a list of scalar values for each column level + # 1st dim: column level, 2nd dim: list of scalar values (e.g. `neighbors`) + assert len(dataflow_output_desc.shape) == 2 + # the lines below extend the array with the local dimension added by the field operator + assert output_edge.result.gt_dtype.offset_type is not None + field_shape = [*field_shape, dataflow_output_desc.shape[1]] + field_subset = field_subset + dace_subsets.Range.from_string( + f"0:{dataflow_output_desc.shape[1]}" + ) + + # allocate local temporary storage + field_name, field_desc = sdfg_builder.add_temp_array( + sdfg, field_shape, dataflow_output_desc.dtype + ) + # the inner and outer strides have to match + scan_output_stride = field_desc.strides[scan_dim_index] + # also consider the stride of the local dimension, in case the scan field operator computes a list + local_strides = field_desc.strides[len(field_dims) :] + assert len(local_strides) == (1 if isinstance(output_edge.result.gt_dtype, ts.ListType) else 0) + new_inner_strides = [scan_output_stride, *local_strides] + dataflow_output_desc.set_shape(dataflow_output_desc.shape, new_inner_strides) + + # and here the edge writing the dataflow result data through the map exit node + field_node = state.add_access(field_name) + output_edge.connect(map_exit, field_node, field_subset) + + return gtir_translators.FieldopData( + field_node, ts.FieldType(field_dims, output_edge.result.gt_dtype), tuple(field_origin) + ) + + +def _create_scan_field_operator( + sdfg: dace.SDFG, + state: dace.SDFGState, + domain: gtir_translators.FieldopDomain, + node_type: ts.FieldType | ts.TupleType, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Iterable[gtir_dataflow.DataflowInputEdge], + output_tree: gtir_dataflow.DataflowOutputEdge + | tuple[gtir_dataflow.DataflowOutputEdge | tuple[Any, ...], ...], +) -> gtir_translators.FieldopResult: + """ + Helper method to build the output of a field operator, which can consist of + a single field or a tuple of fields. + + Similar to `gtir_builtin_translators._create_field_operator()` but for scan + field operators. The main difference is that the scan vertical dimension is + excluded from the map range. This because the vertical dimension is traversed + by a loop region in a mapped nested SDFG. + + Refer to `gtir_builtin_translators._create_field_operator()` for the + description of function arguments and return values. + """ + domain_dims, _, _ = gtir_translators.get_field_layout(domain) + + # create a map scope to execute the `LoopRegion` over the horizontal domain + if len(domain_dims) == 1: + # We construct the scan field operator on the horizontal domain, while the + # vertical dimension (the column axis) is computed by the loop region. + # If the field operator computes only the column axis (a 1d scan field operator), + # there is no horizontal domain, therefore the map scope is not needed. + # This case currently produces wrong CUDA code because of a DaCe issue + # (see https://github.com/GridTools/gt4py/issues/1136). + # The corresponding GT4Py tests are disabled (pytest marker `uses_scan_1d_field`). + map_entry, map_exit = (None, None) + else: + # create map range corresponding to the field operator domain + map_entry, map_exit = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + gtir_sdfg_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + if not sdfg_builder.is_column_axis(dim) + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(map_entry) + + if isinstance(node_type, ts.FieldType): + assert isinstance(output_tree, gtir_dataflow.DataflowOutputEdge) + return _create_scan_field_operator_impl( + sdfg_builder, sdfg, state, domain, output_tree, node_type, map_exit + ) + else: + # handle tuples of fields + # the symbol name 'x' in the call below is not used, we only need + # the tree structure of the `TupleType` definition to pass to `tree_map()` + output_symbol_tree = gtir_sdfg_utils.make_symbol_tree("x", node_type) + return gtx_utils.tree_map( + lambda output_edge, output_sym: ( + _create_scan_field_operator_impl( + sdfg_builder, + sdfg, + state, + domain, + output_edge, + output_sym.type, + map_exit, + ) + ) + )(output_tree, output_symbol_tree) + + +def _scan_input_name(input_name: str) -> str: + """ + Helper function to make naming of input connectors in the scan nested SDFG + consistent throughut this module scope. + """ + return f"__gtir_scan_input_{input_name}" + + +def _scan_output_name(input_name: str) -> str: + """ + Same as above, but for the output connecters in the scan nested SDFG. + """ + return f"__gtir_scan_output_{input_name}" + + +def _lower_lambda_to_nested_sdfg( + lambda_node: gtir.Lambda, + sdfg: dace.SDFG, + sdfg_builder: gtir_sdfg.SDFGBuilder, + domain: gtir_translators.FieldopDomain, + init_data: gtir_translators.FieldopResult, + lambda_symbols: dict[str, ts.DataType], + scan_forward: bool, + scan_carry_symbol: gtir.Sym, +) -> tuple[dace.SDFG, gtir_translators.FieldopResult]: + """ + Helper method to lower the lambda node representing the scan stencil dataflow + inside a separate SDFG. + + In regular field operators, where the computation of a grid point is independent + from other points, therefore the stencil can be lowered to a mapped tasklet + dataflow, and the map range is defined on the full domain. + The scan field operator has to carry an intermediate result while the stencil + is applied on vertical levels, which is input to the computation of next level + (an accumulator function, for example). Therefore, the points on the vertical + dimension are computed inside a `LoopRegion` construct. + This function creates the `LoopRegion` inside a nested SDFG, which will be + mapped by the caller to the horizontal domain in the field operator context. + + Args: + lambda_node: The lambda representing the stencil expression on the horizontal level. + sdfg: The SDFG where the scan field operator is translated. + sdfg_builder: The SDFG builder object to access the field operator context. + domain: The field operator domain, with all horizontal and vertical dimensions. + init_data: The data produced in the field operator context that is used + to initialize the scan carry value. + lambda_symbols: List of symbols used as parameters of the stencil expressions. + scan_forward: When True, the loop should range starting from the origin; + when False, traverse towards origin. + scan_carry_symbol: The symbol used in the stencil expression to carry the + intermediate result along the vertical dimension. + + Returns: + A tuple of two elements: + - An SDFG containing the `LoopRegion` computation along the vertical + dimension, to be instantied as a nested SDFG in the field operator context. + - The inner fields, that is 1d arrays with vertical shape containing + the output of the stencil computation. These fields will have to be + mapped to outer arrays by the caller. The caller is responsible to ensure + that inner and outer arrays use the same strides. + """ + + # the lambda expression, i.e. body of the scan, will be created inside a nested SDFG. + nsdfg = dace.SDFG(sdfg_builder.unique_nsdfg_name(sdfg, "scan")) + nsdfg.debuginfo = gtir_sdfg_utils.debug_info(lambda_node, default=sdfg.debuginfo) + # We set `using_explicit_control_flow=True` because the vertical scan is lowered to a `LoopRegion`. + # This property is used by pattern matching in SDFG transformation framework + # to skip those transformations that do not yet support control flow blocks. + nsdfg.using_explicit_control_flow = True + lambda_translator = sdfg_builder.setup_nested_context(lambda_node, nsdfg, lambda_symbols) + + # use the vertical dimension in the domain as scan dimension + scan_domain = [ + (dim, lower_bound, upper_bound) + for dim, lower_bound, upper_bound in domain + if sdfg_builder.is_column_axis(dim) + ] + assert len(scan_domain) == 1 + scan_dim, scan_lower_bound, scan_upper_bound = scan_domain[0] + + # extract the scan loop range + scan_loop_var = gtir_sdfg_utils.get_map_variable(scan_dim) + + # in case the scan operator computes a list (not a scalar), we need to add an extra dimension + def get_scan_output_shape( + scan_init_data: gtir_translators.FieldopData, + ) -> list[dace.symbolic.SymExpr]: + scan_column_size = scan_upper_bound - scan_lower_bound + if isinstance(scan_init_data.gt_type, ts.ScalarType): + return [scan_column_size] + assert isinstance(scan_init_data.gt_type, ts.ListType) + assert scan_init_data.gt_type.offset_type + offset_type = scan_init_data.gt_type.offset_type + offset_provider_type = sdfg_builder.get_offset_provider_type(offset_type.value) + assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + list_size = offset_provider_type.max_neighbors + return [scan_column_size, dace.symbolic.SymExpr(list_size)] + + if isinstance(init_data, tuple): + lambda_result_shape = gtx_utils.tree_map(get_scan_output_shape)(init_data) + else: + lambda_result_shape = get_scan_output_shape(init_data) + + # Create the body of the initialization state + # This dataflow will write the initial value of the scan carry variable. + init_state = nsdfg.add_state("scan_init", is_start_block=True) + scan_carry_input = ( + gtir_sdfg_utils.make_symbol_tree(scan_carry_symbol.id, scan_carry_symbol.type) + if isinstance(scan_carry_symbol.type, ts.TupleType) + else scan_carry_symbol + ) + + def init_scan_carry(sym: gtir.Sym) -> None: + scan_carry_dataname = str(sym.id) + scan_carry_desc = nsdfg.data(scan_carry_dataname) + input_scan_carry_dataname = _scan_input_name(scan_carry_dataname) + input_scan_carry_desc = scan_carry_desc.clone() + nsdfg.add_datadesc(input_scan_carry_dataname, input_scan_carry_desc) + scan_carry_desc.transient = True + init_state.add_nedge( + init_state.add_access(input_scan_carry_dataname), + init_state.add_access(scan_carry_dataname), + nsdfg.make_array_memlet(input_scan_carry_dataname), + ) + + if isinstance(scan_carry_input, tuple): + gtx_utils.tree_map(init_scan_carry)(scan_carry_input) + else: + init_scan_carry(scan_carry_input) + + # Create a loop region over the vertical dimension corresponding to the scan column + if scan_forward: + scan_loop = dace.sdfg.state.LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} < {scan_upper_bound}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_lower_bound}", + update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", + inverted=False, + ) + else: + scan_loop = dace.sdfg.state.LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} >= {scan_lower_bound}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_upper_bound} - 1", + update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", + inverted=False, + ) + nsdfg.add_node(scan_loop) + nsdfg.add_edge(init_state, scan_loop, dace.InterstateEdge()) + + # Inside the loop region, create a 'compute' and an 'update' state. + # The body of the 'compute' state implements the stencil expression for one vertical level. + # The 'update' state writes the value computed by the stencil into the scan carry variable, + # in order to make it available to the next vertical level. + compute_state = scan_loop.add_state("scan_compute") + update_state = scan_loop.add_state_after(compute_state, "scan_update") + + # inside the 'compute' state, visit the list of arguments to be passed to the stencil + stencil_args = [ + _parse_scan_fieldop_arg(im.ref(p.id), nsdfg, compute_state, lambda_translator, domain) + for p in lambda_node.params + ] + # stil inside the 'compute' state, generate the dataflow representing the stencil + # to be applied on the horizontal domain + lambda_input_edges, lambda_result = gtir_dataflow.translate_lambda_to_dataflow( + nsdfg, compute_state, lambda_translator, lambda_node, stencil_args + ) + # connect the dataflow input directly to the source data nodes, without passing through a map node; + # the reason is that the map for horizontal domain is outside the scan loop region + for edge in lambda_input_edges: + edge.connect(map_entry=None) + # connect the dataflow output nodes, called 'scan_result' below, to a global field called 'output' + output_column_index = dace.symbolic.pystr_to_symbolic(scan_loop_var) - scan_lower_bound + + def connect_scan_output( + scan_output_edge: gtir_dataflow.DataflowOutputEdge, + scan_output_shape: list[dace.symbolic.SymExpr], + scan_carry_sym: gtir.Sym, + ) -> gtir_translators.FieldopData: + scan_result = scan_output_edge.result + if isinstance(scan_result.gt_dtype, ts.ScalarType): + assert scan_result.gt_dtype == scan_carry_sym.type + # the scan field operator computes a column of scalar values + assert len(scan_output_shape) == 1 + output_subset = dace_subsets.Range.from_string(str(output_column_index)) + else: + assert isinstance(scan_carry_sym.type, ts.ListType) + assert scan_result.gt_dtype.element_type == scan_carry_sym.type.element_type + # the scan field operator computes a list of scalar values for each column level + assert len(scan_output_shape) == 2 + output_subset = dace_subsets.Range.from_string( + f"{output_column_index}, 0:{scan_output_shape[1]}" + ) + scan_result_data = scan_result.dc_node.data + scan_result_desc = scan_result.dc_node.desc(nsdfg) + + # `sym` represents the global output data, that is the nested-SDFG output connector + scan_carry_data = str(scan_carry_sym.id) + output = _scan_output_name(scan_carry_data) + nsdfg.add_array(output, scan_output_shape, scan_result_desc.dtype) + output_node = compute_state.add_access(output) + + # in the 'compute' state, we write the current vertical level data to the output field + # (the output field is mapped to an external array) + compute_state.add_nedge( + scan_result.dc_node, output_node, dace.Memlet(data=output, subset=output_subset) + ) + + # in the 'update' state, the value of the current vertical level is written + # to the scan carry variable for the next loop iteration + update_state.add_nedge( + update_state.add_access(scan_result_data), + update_state.add_access(scan_carry_data), + dace.Memlet.from_array(scan_result_data, scan_result_desc), + ) + + output_type = ts.FieldType(dims=[scan_dim], dtype=scan_result.gt_dtype) + return gtir_translators.FieldopData(output_node, output_type, origin=(scan_lower_bound,)) + + # write the stencil result (value on one vertical level) into a 1D field + # with full vertical shape representing one column + if isinstance(scan_carry_input, tuple): + assert isinstance(lambda_result_shape, tuple) + lambda_output = gtx_utils.tree_map(connect_scan_output)( + lambda_result, lambda_result_shape, scan_carry_input + ) + else: + assert isinstance(lambda_result[0], gtir_dataflow.DataflowOutputEdge) + assert isinstance(lambda_result_shape, list) + lambda_output = connect_scan_output(lambda_result[0], lambda_result_shape, scan_carry_input) + + # in case tuples are passed as argument, isolated access nodes might be left in the state, + # because not all tuple fields are necessarily accessed inside the lambda scope + for data_node in compute_state.data_nodes(): + data_desc = data_node.desc(nsdfg) + if compute_state.degree(data_node) == 0: + # By construction there should never be isolated transient nodes. + # Therefore, the assert below implements a sanity check, that allows + # the exceptional case (encountered in one GT4Py test) where the carry + # variable is not used, so not a scan indeed because no data dependency. + assert (not data_desc.transient) or data_node.data.startswith(scan_carry_symbol.id) + compute_state.remove_node(data_node) + + return nsdfg, lambda_output + + +def _connect_nested_sdfg_output_to_temporaries( + sdfg: dace.SDFG, + nsdfg: dace.SDFG, + nsdfg_node: dace.nodes.NestedSDFG, + outer_state: dace.SDFGState, + inner_data: gtir_translators.FieldopData, +) -> gtir_dataflow.DataflowOutputEdge: + """ + Helper function to create the edges to write output data from the nested SDFG + to temporary arrays in the parent SDFG, denoted as outer context. + + Args: + sdfg: The SDFG representing the outer context, where the field operator is translated. + nsdfg: The SDFG where the scan `LoopRegion` is translated. + nsdfg_node: The nested SDFG node in the outer context. + outer_state: The state in outer context where the field operator is translated. + inner_data: The data produced by the scan `LoopRegion` in the inner context. + + Returns: + An object representing the output data connection of this field operator. + """ + assert isinstance(inner_data.gt_type, ts.FieldType) + inner_dataname = inner_data.dc_node.data + inner_desc = nsdfg.data(inner_dataname) + outer_dataname, outer_desc = sdfg.add_temp_transient_like(inner_desc) + outer_node = outer_state.add_access(outer_dataname) + outer_state.add_edge( + nsdfg_node, + inner_dataname, + outer_node, + None, + dace.Memlet.from_array(outer_dataname, outer_desc), + ) + output_expr = gtir_dataflow.ValueExpr(outer_node, inner_data.gt_type.dtype) + return gtir_dataflow.DataflowOutputEdge(outer_state, output_expr) + + +def translate_scan( + node: gtir.Node, + sdfg: dace.SDFG, + state: dace.SDFGState, + sdfg_builder: gtir_sdfg.SDFGBuilder, +) -> gtir_translators.FieldopResult: + """ + Generates the dataflow subgraph for the `as_fieldop` builtin with a scan operator. + + It differs from `translate_as_fieldop()` in that the horizontal domain is lowered + to a map scope, while the scan column computation is lowered to a `LoopRegion` + on the vertical dimension, that is inside the horizontal map. + The current design choice is to keep the map scope on the outer level, and + the `LoopRegion` inside. This choice follows the GTIR representation where + the `scan` operator is called inside the `as_fieldop` node. + + Implements the `PrimitiveTranslator` protocol. + """ + assert isinstance(node, gtir.FunCall) + assert cpm.is_call_to(node.fun, "as_fieldop") + assert isinstance(node.type, (ts.FieldType, ts.TupleType)) + + fun_node = node.fun + assert len(fun_node.args) == 2 + scan_expr, domain_expr = fun_node.args + assert cpm.is_call_to(scan_expr, "scan") + + # parse the domain of the scan field operator + domain = gtir_translators.extract_domain(domain_expr) + + # parse scan parameters + assert len(scan_expr.args) == 3 + stencil_expr = scan_expr.args[0] + assert isinstance(stencil_expr, gtir.Lambda) + + # params[0]: the lambda parameter to propagate the scan carry on the vertical dimension + scan_carry = str(stencil_expr.params[0].id) + + # params[1]: boolean flag for forward/backward scan + assert isinstance(scan_expr.args[1], gtir.Literal) and ti.is_logical(scan_expr.args[1].type) + scan_forward = scan_expr.args[1].value == "True" + + # params[2]: the expression that computes the value for scan initialization + init_expr = scan_expr.args[2] + # visit the initialization value of the scan expression + init_data = sdfg_builder.visit(init_expr, sdfg=sdfg, head_state=state) + # extract type definition of the scan carry + scan_carry_type = ( + init_data.gt_type + if isinstance(init_data, gtir_translators.FieldopData) + else gtir_translators.get_tuple_type(init_data) + ) + + # define the set of symbols available in the lambda context, which consists of + # the carry argument and all lambda function arguments + lambda_arg_types = [scan_carry_type] + [ + arg.type for arg in node.args if isinstance(arg.type, ts.DataType) + ] + lambda_symbols = { + str(p.id): arg_type + for p, arg_type in zip(stencil_expr.params, lambda_arg_types, strict=True) + } + + # lower the scan stencil expression in a separate SDFG context + nsdfg, lambda_output = _lower_lambda_to_nested_sdfg( + stencil_expr, + sdfg, + sdfg_builder, + domain, + init_data, + lambda_symbols, + scan_forward, + im.sym(scan_carry, scan_carry_type), + ) + + # visit the arguments to be passed to the lambda expression + # this must be executed before visiting the lambda expression, in order to populate + # the data descriptor with the correct field domain offsets for field arguments + lambda_args = [sdfg_builder.visit(arg, sdfg=sdfg, head_state=state) for arg in node.args] + lambda_args_mapping = [ + (im.sym(_scan_input_name(scan_carry), scan_carry_type), init_data), + ] + [ + (im.sym(param.id, arg.gt_type), arg) + for param, arg in zip(stencil_expr.params[1:], lambda_args, strict=True) + ] + + lambda_arg_nodes = dict( + itertools.chain( + *[gtir_translators.flatten_tuples(psym.id, arg) for psym, arg in lambda_args_mapping] + ) + ) + + # parse the dataflow output symbols + if isinstance(scan_carry_type, ts.TupleType): + lambda_flat_outs = { + str(sym.id): sym.type + for sym in gtir_sdfg_utils.flatten_tuple_fields( + _scan_output_name(scan_carry), scan_carry_type + ) + } + else: + lambda_flat_outs = {_scan_output_name(scan_carry): scan_carry_type} + + # build the mapping of symbols from nested SDFG to field operator context + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + for psym, arg in lambda_args_mapping: + nsdfg_symbols_mapping |= gtir_translators.get_arg_symbol_mapping(psym.id, arg, sdfg) + + # the scan nested SDFG is ready: it is instantiated in the field operator context + # where the map scope over the horizontal domain lives + nsdfg_node = state.add_nested_sdfg( + nsdfg, + sdfg, + inputs=set(lambda_arg_nodes.keys()), + outputs=set(lambda_flat_outs.keys()), + symbol_mapping=nsdfg_symbols_mapping, + ) + + lambda_input_edges = [] + for input_connector, outer_arg in lambda_arg_nodes.items(): + arg_desc = outer_arg.dc_node.desc(sdfg) + input_subset = dace_subsets.Range.from_array(arg_desc) + input_edge = gtir_dataflow.MemletInputEdge( + state, outer_arg.dc_node, input_subset, nsdfg_node, input_connector + ) + lambda_input_edges.append(input_edge) + + # for output connections, we create temporary arrays that contain the computation + # results of a column slice for each point in the horizontal domain + lambda_output_tree = gtx_utils.tree_map( + lambda lambda_output_data: _connect_nested_sdfg_output_to_temporaries( + sdfg, nsdfg, nsdfg_node, state, lambda_output_data + ) + )(lambda_output) + + # we call a helper method to create a map scope that will compute the entire field + return _create_scan_field_operator( + sdfg, state, domain, node.type, sdfg_builder, lambda_input_edges, lambda_output_tree + ) diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py new file mode 100644 index 0000000000..a4c0194849 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg.py @@ -0,0 +1,1004 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +""" +Contains visitors to lower GTIR to DaCe SDFG. + +Note: this module covers the fieldview flavour of GTIR. +""" + +from __future__ import annotations + +import abc +import dataclasses +import itertools +import operator +from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union + +import dace + +from gt4py import eve +from gt4py.eve import concepts +from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm +from gt4py.next.iterator.transforms import prune_casts as ir_prune_casts, symbol_ref_utils +from gt4py.next.iterator.type_system import inference as gtir_type_inference +from gt4py.next.program_processors.runners.dace import ( + gtir_builtin_translators, + gtir_sdfg_utils, + transformations as gtx_transformations, + utils as gtx_dace_utils, +) +from gt4py.next.type_system import type_specifications as ts, type_translation as tt + + +class DataflowBuilder(Protocol): + """Visitor interface to build a dataflow subgraph.""" + + @abc.abstractmethod + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: ... + + @abc.abstractmethod + def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: ... + + @abc.abstractmethod + def unique_map_name(self, name: str) -> str: ... + + @abc.abstractmethod + def unique_tasklet_name(self, name: str) -> str: ... + + def add_temp_array( + self, sdfg: dace.SDFG, shape: Sequence[Any], dtype: dace.dtypes.typeclass + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary array to the SDFG.""" + return sdfg.add_temp_transient(shape, dtype) + + def add_temp_array_like( + self, sdfg: dace.SDFG, datadesc: dace.data.Array + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary array to the SDFG.""" + return sdfg.add_temp_transient_like(datadesc) + + def add_temp_scalar( + self, sdfg: dace.SDFG, dtype: dace.dtypes.typeclass + ) -> tuple[str, dace.data.Scalar]: + """Add a temporary scalar to the SDFG.""" + temp_name = sdfg.temp_data_name() + return sdfg.add_scalar(temp_name, dtype, transient=True) + + def add_map( + self, + name: str, + state: dace.SDFGState, + ndrange: Union[ + Dict[str, Union[str, dace.subsets.Subset]], + List[Tuple[str, Union[str, dace.subsets.Subset]]], + ], + **kwargs: Any, + ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: + """Wrapper of `dace.SDFGState.add_map` that assigns unique name.""" + unique_name = self.unique_map_name(name) + return state.add_map(unique_name, ndrange, **kwargs) + + def add_tasklet( + self, + name: str, + state: dace.SDFGState, + inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], + code: str, + **kwargs: Any, + ) -> dace.nodes.Tasklet: + """Wrapper of `dace.SDFGState.add_tasklet` that assigns unique name.""" + unique_name = self.unique_tasklet_name(name) + return state.add_tasklet(unique_name, inputs, outputs, code, **kwargs) + + def add_mapped_tasklet( + self, + name: str, + state: dace.SDFGState, + map_ranges: Dict[str, str | dace.subsets.Subset] + | List[Tuple[str, str | dace.subsets.Subset]], + inputs: Dict[str, dace.Memlet], + code: str, + outputs: Dict[str, dace.Memlet], + **kwargs: Any, + ) -> tuple[dace.nodes.Tasklet, dace.nodes.MapEntry, dace.nodes.MapExit]: + """Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns unique name.""" + unique_name = self.unique_tasklet_name(name) + return state.add_mapped_tasklet(unique_name, map_ranges, inputs, code, outputs, **kwargs) + + +class SDFGBuilder(DataflowBuilder, Protocol): + """Visitor interface available to GTIR-primitive translators.""" + + @abc.abstractmethod + def make_field( + self, + data_node: dace.nodes.AccessNode, + data_type: ts.FieldType | ts.ScalarType, + ) -> gtir_builtin_translators.FieldopData: + """Retrieve the field data descriptor including the domain offset information.""" + ... + + @abc.abstractmethod + def get_symbol_type(self, symbol_name: str) -> ts.DataType: + """Retrieve the GT4Py type of a symbol used in the SDFG.""" + ... + + @abc.abstractmethod + def is_column_axis(self, dim: gtx_common.Dimension) -> bool: + """Check if the given dimension is the column axis.""" + ... + + @abc.abstractmethod + def setup_nested_context( + self, + expr: gtir.Expr, + sdfg: dace.SDFG, + global_symbols: dict[str, ts.DataType], + ) -> SDFGBuilder: + """ + Create an SDFG context to translate a nested expression, indipendent + from the current context where the parent expression is being translated. + + This method will setup the global symbols, that correspond to the parameters + of the expression to be lowered, as well as the set of symbolic arguments, + that is scalar values used in internal domain expressions. + + Args: + expr: The nested expresson to be lowered. + sdfg: The SDFG where to lower the nested expression. + global_symbols: Mapping from symbol name to GTIR data type. + + Returns: + A visitor object implementing the `SDFGBuilder` protocol. + """ + ... + + @abc.abstractmethod + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + """Visit a node of the GT4Py IR.""" + ... + + +def _collect_symbols_in_domain_expressions( + ir: gtir.Node, ir_params: Sequence[gtir.Sym] +) -> set[str]: + """ + Collect symbols accessed in domain expressions that also appear in the paremeter list. + + This function is used to identify all parameters that are accessed in domain + expressions. They have to be passed to the SDFG call as DaCe symbols (instead + of scalars) such that they can be used as bounds in map ranges. + + Args: + ir: GTIR node to be traversed and where to search for domain expressions. + ir_params: List of parameters to search for in domain expressions. + + Returns: + A set of names corresponding to the parameters found in domain expressions. + """ + params = {str(sym.id) for sym in ir_params} + return set( + eve.walk_values(ir) + .filter(lambda node: cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain"))) + .map( + lambda domain: eve.walk_values(domain) + .if_isinstance(gtir.SymRef) + .map(lambda symref: str(symref.id)) + .filter(lambda sym: sym in params) + .to_list() + ) + .reduce(operator.add, init=[]) + ) + + +def _make_access_index_for_field( + domain: gtir_builtin_translators.FieldopDomain, data: gtir_builtin_translators.FieldopData +) -> dace.subsets.Range: + """Helper method to build a memlet subset of a field over the given domain.""" + # convert domain expression to dictionary to ease access to the dimensions, + # since the access indices have to follow the order of dimensions in field domain + if isinstance(data.gt_type, ts.FieldType) and len(data.gt_type.dims) != 0: + assert data.origin is not None + domain_ranges = {dim: (lb, ub) for dim, lb, ub in domain} + return dace.subsets.Range( + (domain_ranges[dim][0] - origin, domain_ranges[dim][1] - origin - 1, 1) + for dim, origin in zip(data.gt_type.dims, data.origin, strict=True) + ) + else: + assert len(domain) == 0 + return dace.subsets.Range.from_string("0") + + +@dataclasses.dataclass(frozen=True) +class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): + """Provides translation capability from a GTIR program to a DaCe SDFG. + + This class is responsible for translation of `ir.Program`, that is the top level representation + of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions. + Each statement is translated to a taskgraph inside a separate state. Statement states are chained + one after the other: concurrency between states should be extracted by means of SDFG analysis. + The translator will extend the SDFG while preserving the property of single exit state: + branching is allowed within the context of one statement, but in that case the statement should + terminate with a join state; the join state will represent the head state for next statement, + from where to continue building the SDFG. + """ + + offset_provider_type: gtx_common.OffsetProviderType + column_axis: Optional[gtx_common.Dimension] + global_symbols: dict[str, ts.DataType] + map_uids: eve.utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") + ) + tasklet_uids: eve.utils.UIDGenerator = dataclasses.field( + init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") + ) + + def get_offset_provider_type(self, offset: str) -> gtx_common.OffsetProviderTypeElem: + return self.offset_provider_type[offset] + + def make_field( + self, + data_node: dace.nodes.AccessNode, + data_type: ts.FieldType | ts.ScalarType, + ) -> gtir_builtin_translators.FieldopData: + """ + Helper method to build the field data type associated with a data access node. + + In case of `ScalarType` data, the `FieldopData` is constructed with `origin=None`. + In case of `FieldType` data, the field origin is added to the data descriptor. + Besides, if the `FieldType` contains a local dimension, the descriptor is converted + to a canonical form where the field domain consists of all global dimensions + (the grid axes) and the field data type is `ListType`, with `offset_type` equal + to the field local dimension. + + TODO(edoapo): consider refactoring this method and moving it to a type module + close to the `FieldopData` type declaration. + + Args: + data_node: The access node to the SDFG data storage. + data_type: The GT4Py data descriptor, which can either come from a field parameter + of an expression node, or from an intermediate field in a previous expression. + + Returns: + The descriptor associated with the SDFG data storage, filled with field origin. + """ + if isinstance(data_type, ts.ScalarType): + return gtir_builtin_translators.FieldopData(data_node, data_type, origin=()) + local_dims = [dim for dim in data_type.dims if dim.kind == gtx_common.DimensionKind.LOCAL] + if len(local_dims) == 0: + # do nothing: the field domain consists of all global dimensions + field_type = data_type + elif len(local_dims) == 1: + local_dim = local_dims[0] + local_dim_index = data_type.dims.index(local_dim) + # the local dimension is converted into `ListType` data element + if not isinstance(data_type.dtype, ts.ScalarType): + raise ValueError(f"Invalid field type {data_type}.") + if local_dim_index != (len(data_type.dims) - 1): + raise ValueError( + f"Invalid field domain: expected the local dimension to be at the end, found at position {local_dim_index}." + ) + if local_dim.value not in self.offset_provider_type: + raise ValueError( + f"The provided local dimension {local_dim} does not match any offset provider type." + ) + local_type = ts.ListType(element_type=data_type.dtype, offset_type=local_dim) + field_type = ts.FieldType(dims=data_type.dims[:local_dim_index], dtype=local_type) + else: + raise NotImplementedError( + "Fields with more than one local dimension are not supported." + ) + field_origin = tuple( + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.range_start_symbol(data_node.data, axis)) + for axis in range(len(field_type.dims)) + ) + return gtir_builtin_translators.FieldopData(data_node, field_type, field_origin) + + def get_symbol_type(self, symbol_name: str) -> ts.DataType: + return self.global_symbols[symbol_name] + + def is_column_axis(self, dim: gtx_common.Dimension) -> bool: + assert self.column_axis + return dim == self.column_axis + + def setup_nested_context( + self, + expr: gtir.Expr, + sdfg: dace.SDFG, + global_symbols: dict[str, ts.DataType], + ) -> SDFGBuilder: + nsdfg_builder = GTIRToSDFG(self.offset_provider_type, self.column_axis, global_symbols) + nsdfg_params = [ + gtir.Sym(id=p_name, type=p_type) for p_name, p_type in global_symbols.items() + ] + domain_symbols = _collect_symbols_in_domain_expressions(expr, nsdfg_params) + nsdfg_builder._add_sdfg_params( + sdfg, node_params=nsdfg_params, symbolic_arguments=domain_symbols + ) + return nsdfg_builder + + def unique_nsdfg_name(self, sdfg: dace.SDFG, prefix: str) -> str: + nsdfg_list = [ + nsdfg.label for nsdfg in sdfg.all_sdfgs_recursive() if nsdfg.label.startswith(prefix) + ] + return f"{prefix}_{len(nsdfg_list)}" + + def unique_map_name(self, name: str) -> str: + return f"{self.map_uids.sequential_id()}_{name}" + + def unique_tasklet_name(self, name: str) -> str: + return f"{self.tasklet_uids.sequential_id()}_{name}" + + def _make_array_shape_and_strides( + self, name: str, dims: Sequence[gtx_common.Dimension] + ) -> tuple[list[dace.symbolic.SymbolicType], list[dace.symbolic.SymbolicType]]: + """ + Parse field dimensions and allocate symbols for array shape and strides. + + For local dimensions, the size is known at compile-time and therefore + the corresponding array shape dimension is set to an integer literal value. + + This method is only called for non-transient arrays, which require symbolic + memory layout. The memory layout of transient arrays, used for temporary + fields, is left to the DaCe default (row major, not necessarily the optimal + one) and might be changed during optimization. + + Returns: + Two lists of symbols, one for the shape and the other for the strides of the array. + """ + neighbor_table_types = gtx_dace_utils.filter_connectivity_types(self.offset_provider_type) + shape = [] + for i, dim in enumerate(dims): + if dim.kind == gtx_common.DimensionKind.LOCAL: + # for local dimension, the size is taken from the associated connectivity type + shape.append(neighbor_table_types[dim.value].max_neighbors) + elif gtx_dace_utils.is_connectivity_identifier(name, self.offset_provider_type): + # we use symbolic size for the global dimension of a connectivity + shape.append( + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_size_symbol_name(name, i)) + ) + else: + # the size of global dimensions for a regular field is the symbolic + # expression of domain range 'stop - start' + shape.append( + dace.symbolic.pystr_to_symbolic( + "{} - {}".format( + gtx_dace_utils.range_stop_symbol(name, i), + gtx_dace_utils.range_start_symbol(name, i), + ) + ) + ) + strides = [ + dace.symbolic.pystr_to_symbolic(gtx_dace_utils.field_stride_symbol_name(name, i)) + for i in range(len(dims)) + ] + return shape, strides + + def _add_storage( + self, + sdfg: dace.SDFG, + symbolic_arguments: set[str], + name: str, + gt_type: ts.DataType, + transient: bool = True, + ) -> list[tuple[str, ts.DataType]]: + """ + Add storage in the SDFG for a given GT4Py data symbol. + + GT4Py fields are allocated as DaCe arrays. GT4Py scalars are represented + as DaCe scalar objects in the SDFG; the exception are the symbols passed as + `symbolic_arguments`, e.g. symbols used in domain expressions, and those used + for symbolic array shape and strides. + + The fields used as temporary arrays, when `transient = True`, are allocated + and exist only within the SDFG; when `transient = False`, the fields have + to be allocated outside and have to be passed as arguments to the SDFG call. + + Args: + sdfg: The SDFG where storage needs to be allocated. + symbolic_arguments: Set of GT4Py scalars that must be represented as SDFG symbols. + name: Symbol Name to be allocated. + gt_type: GT4Py symbol type. + transient: True when the data symbol has to be allocated as internal storage. + + Returns: + List of tuples '(data_name, gt_type)' where 'data_name' is the name of + the data container used as storage in the SDFG and 'gt_type' is the + corresponding GT4Py type. In case the storage has to be allocated for + a tuple symbol the list contains a flattened version of the tuple, + otherwise the list will contain a single entry. + """ + if isinstance(gt_type, ts.TupleType): + tuple_fields = [] + for sym in gtir_sdfg_utils.flatten_tuple_fields(name, gt_type): + assert isinstance(sym.type, ts.DataType) + tuple_fields.extend( + self._add_storage(sdfg, symbolic_arguments, sym.id, sym.type, transient) + ) + return tuple_fields + + elif isinstance(gt_type, ts.FieldType): + if len(gt_type.dims) == 0: + # represent zero-dimensional fields as scalar arguments + return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) + if not isinstance(gt_type.dtype, ts.ScalarType): + raise ValueError(f"Field type '{gt_type.dtype}' not supported.") + # handle default case: field with one or more dimensions + dc_dtype = gtx_dace_utils.as_dace_type(gt_type.dtype) + # Use symbolic shape, which allows to invoke the program with fields of different size; + # and symbolic strides, which enables decoupling the memory layout from generated code. + sym_shape, sym_strides = self._make_array_shape_and_strides(name, gt_type.dims) + sdfg.add_array(name, sym_shape, dc_dtype, strides=sym_strides, transient=transient) + return [(name, gt_type)] + + elif isinstance(gt_type, ts.ScalarType): + dc_dtype = gtx_dace_utils.as_dace_type(gt_type) + if gtx_dace_utils.is_field_symbol(name) or name in symbolic_arguments: + if name in sdfg.symbols: + # Sometimes, when the field domain is implicitly derived from the + # field domain, the gt4py lowering adds the field size as a scalar + # argument to the program IR. Suppose a field '__sym', then gt4py + # will add '__sym_size_0'. + # Therefore, here we check whether the shape symbol was already + # created by `_make_array_shape_and_strides()`, when allocating + # storage for field arguments. We assume that the scalar argument + # for field size, if present, always follows the field argument. + assert gtx_dace_utils.is_field_symbol(name) + if sdfg.symbols[name].dtype != dc_dtype: + raise ValueError( + f"Type mismatch on argument {name}: got {dc_dtype}, expected {sdfg.symbols[name].dtype}." + ) + else: + sdfg.add_symbol(name, dc_dtype) + else: + sdfg.add_scalar(name, dc_dtype, transient=transient) + + return [(name, gt_type)] + + raise RuntimeError(f"Data type '{type(gt_type)}' not supported.") + + def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]: + """ + Add temporary storage (aka transient) for data containers used as GTIR temporaries. + + Assume all temporaries to be fields, therefore represented as dace arrays. + """ + raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") + + def _visit_expression( + self, node: gtir.Expr, sdfg: dace.SDFG, head_state: dace.SDFGState, use_temp: bool = True + ) -> list[gtir_builtin_translators.FieldopData]: + """ + Specialized visit method for fieldview expressions. + + This method represents the entry point to visit `ir.Stmt` expressions. + As such, it must preserve the property of single exit state in the SDFG. + + Returns: + A list of array nodes containing the result fields. + """ + result = self.visit(node, sdfg=sdfg, head_state=head_state) + + # sanity check: each statement should preserve the property of single exit state (aka head state), + # i.e. eventually only introduce internal branches, and keep the same head state + sink_states = sdfg.sink_nodes() + assert len(sink_states) == 1 + assert sink_states[0] == head_state + + def make_temps( + field: gtir_builtin_translators.FieldopData, + ) -> gtir_builtin_translators.FieldopData: + desc = sdfg.arrays[field.dc_node.data] + if desc.transient or not use_temp: + return field + else: + temp, _ = self.add_temp_array_like(sdfg, desc) + temp_node = head_state.add_access(temp) + head_state.add_nedge( + field.dc_node, temp_node, sdfg.make_array_memlet(field.dc_node.data) + ) + return gtir_builtin_translators.FieldopData(temp_node, field.gt_type, field.origin) + + temp_result = gtx_utils.tree_map(make_temps)(result) + return list(gtx_utils.flatten_nested_tuple((temp_result,))) + + def _add_sdfg_params( + self, + sdfg: dace.SDFG, + node_params: Sequence[gtir.Sym], + symbolic_arguments: set[str], + ) -> list[str]: + """ + Helper function to add storage for node parameters and connectivity tables. + + GT4Py field arguments will be translated to `dace.data.Array` objects. + GT4Py scalar arguments will be translated to `dace.data.Scalar` objects, + except when they are listed in 'symbolic_arguments', in which case they + will be represented in the SDFG as DaCe symbols. + """ + + # add non-transient arrays and/or SDFG symbols for the program arguments + sdfg_args = [] + for param in node_params: + pname = str(param.id) + assert isinstance(param.type, (ts.DataType)) + sdfg_args += self._add_storage( + sdfg, symbolic_arguments, pname, param.type, transient=False + ) + + # add SDFG storage for connectivity tables + for offset, connectivity_type in gtx_dace_utils.filter_connectivity_types( + self.offset_provider_type + ).items(): + scalar_type = tt.from_dtype(connectivity_type.dtype) + gt_type = ts.FieldType( + [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type + ) + # We store all connectivity tables as transient arrays here; later, while building + # the field operator expressions, we change to non-transient (i.e. allocated externally) + # the tables that are actually used. This way, we avoid adding SDFG arguments for + # the connectivity tables that are not used. The remaining unused transient arrays + # are removed by the dace simplify pass. + self._add_storage( + sdfg, + symbolic_arguments, + gtx_dace_utils.connectivity_identifier(offset), + gt_type, + ) + + # the list of all sdfg arguments (aka non-transient arrays) which include tuple-element fields + return [arg_name for arg_name, _ in sdfg_args] + + def visit_Program(self, node: gtir.Program) -> dace.SDFG: + """Translates `ir.Program` to `dace.SDFG`. + + First, it will allocate field and scalar storage for global data. The storage + represents global data, available everywhere in the SDFG, either containing + external data (aka non-transient data) or temporary data (aka transient data). + The temporary data is global, therefore available everywhere in the SDFG + but not outside. Then, all statements are translated, one after the other. + """ + sdfg = dace.SDFG(node.id) + sdfg.debuginfo = gtir_sdfg_utils.debug_info(node) + + # start block of the stateful graph + entry_state = sdfg.add_state("program_entry", is_start_block=True) + + # declarations of temporaries result in transient array definitions in the SDFG + if node.declarations: + temp_symbols: dict[str, str] = {} + for decl in node.declarations: + temp_symbols |= self._add_storage_for_temporary(decl) + + # define symbols for shape and offsets of temporary arrays as interstate edge symbols + head_state = sdfg.add_state_after(entry_state, "init_temps", assignments=temp_symbols) + else: + head_state = entry_state + + domain_symbols = _collect_symbols_in_domain_expressions(node, node.params) + sdfg_arg_names = self._add_sdfg_params(sdfg, node.params, symbolic_arguments=domain_symbols) + + # visit one statement at a time and expand the SDFG from the current head state + for i, stmt in enumerate(node.body): + # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy + head_state = sdfg.add_state_after(head_state, f"stmt_{i}") + head_state._debuginfo = gtir_sdfg_utils.debug_info(stmt, default=sdfg.debuginfo) + head_state = self.visit(stmt, sdfg=sdfg, state=head_state) + + # remove unused connectivity tables (by design, arrays are marked as non-transient when they are used) + for nsdfg in sdfg.all_sdfgs_recursive(): + unused_connectivities = [ + data + for data, datadesc in nsdfg.arrays.items() + if gtx_dace_utils.is_connectivity_identifier(data, self.offset_provider_type) + and datadesc.transient + ] + for data in unused_connectivities: + assert isinstance(nsdfg.arrays[data], dace.data.Array) + nsdfg.arrays.pop(data) + + # Create the call signature for the SDFG. + # Only the arguments required by the GT4Py program, i.e. `node.params`, are added + # as positional arguments. The implicit arguments, such as the offset providers or + # the arguments created by the translation process, must be passed as keyword arguments. + sdfg.arg_names = sdfg_arg_names + + sdfg.validate() + return sdfg + + def visit_SetAt( + self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState + ) -> dace.SDFGState: + """Visits a `SetAt` statement expression and writes the local result to some external storage. + + Each statement expression results in some sort of dataflow gragh writing to temporary storage. + The translation of `SetAt` ensures that the result is written back to the target external storage. + + Returns: + The SDFG head state, eventually updated if the target write requires a new state. + """ + + source_fields = self._visit_expression(stmt.expr, sdfg, state) + + # the target expression could be a `SymRef` to an output node or a `make_tuple` expression + # in case the statement returns more than one field + target_fields = self._visit_expression(stmt.target, sdfg, state, use_temp=False) + + # visit the domain expression + domain = gtir_builtin_translators.extract_domain(stmt.domain) + + expr_input_args = { + sym_id + for sym in eve.walk_values(stmt.expr).if_isinstance(gtir.SymRef) + if (sym_id := str(sym.id)) in sdfg.arrays + } + state_input_data = { + node.data + for node in state.data_nodes() + if node.data in expr_input_args and state.degree(node) != 0 + } + + target_state: Optional[dace.SDFGState] = None + for source, target in zip(source_fields, target_fields, strict=True): + target_desc = sdfg.arrays[target.dc_node.data] + assert not target_desc.transient + + assert source.gt_type == target.gt_type + source_subset = _make_access_index_for_field(domain, source) + target_subset = _make_access_index_for_field(domain, target) + + if target.dc_node.data in state_input_data: + # if inout argument, write the result in separate next state + # this is needed to avoid undefined behavior for expressions like: X, Y = X + 1, X + if not target_state: + target_state = sdfg.add_state_after(state, f"post_{state.label}") + # create new access nodes in the target state + target_state.add_nedge( + target_state.add_access(source.dc_node.data), + target_state.add_access(target.dc_node.data), + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), + ) + # remove isolated access node + state.remove_node(target.dc_node) + else: + state.add_nedge( + source.dc_node, + target.dc_node, + dace.Memlet( + data=target.dc_node.data, subset=target_subset, other_subset=source_subset + ), + ) + + return target_state or state + + def visit_FunCall( + self, + node: gtir.FunCall, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> gtir_builtin_translators.FieldopResult: + # use specialized dataflow builder classes for each builtin function + if cpm.is_call_to(node, "if_"): + return gtir_builtin_translators.translate_if(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "index"): + return gtir_builtin_translators.translate_index(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "make_tuple"): + return gtir_builtin_translators.translate_make_tuple(node, sdfg, head_state, self) + elif cpm.is_call_to(node, "tuple_get"): + return gtir_builtin_translators.translate_tuple_get(node, sdfg, head_state, self) + elif cpm.is_applied_as_fieldop(node): + return gtir_builtin_translators.translate_as_fieldop(node, sdfg, head_state, self) + elif isinstance(node.fun, gtir.Lambda): + lambda_args = [ + self.visit( + arg, + sdfg=sdfg, + head_state=head_state, + ) + for arg in node.args + ] + + return self.visit( + node.fun, + sdfg=sdfg, + head_state=head_state, + args=lambda_args, + ) + elif isinstance(node.type, ts.ScalarType): + return gtir_builtin_translators.translate_scalar_expr(node, sdfg, head_state, self) + else: + raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") + + def visit_Lambda( + self, + node: gtir.Lambda, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + args: Sequence[gtir_builtin_translators.FieldopResult], + ) -> gtir_builtin_translators.FieldopResult: + """ + Translates a `Lambda` node to a nested SDFG in the current state. + + All arguments to lambda functions are fields (i.e. `as_fieldop`, field or scalar `gtir.SymRef`, + nested let-lambdas thereof). The reason for creating a nested SDFG is to define local symbols + (the lambda paremeters) that map to parent fields, either program arguments or temporary fields. + + If the lambda has a parameter whose name is already present in `GTIRToSDFG.global_symbols`, + i.e. a lambda parameter with the same name as a symbol in scope, the parameter will shadow + the previous symbol during traversal of the lambda expression. + """ + lambda_arg_nodes = dict( + itertools.chain( + *[ + gtir_builtin_translators.flatten_tuples(psym.id, arg) + for psym, arg in zip(node.params, args, strict=True) + ] + ) + ) + + # inherit symbols from parent scope but eventually override with local symbols + lambda_symbols = { + sym: self.global_symbols[sym] + for sym in symbol_ref_utils.collect_symbol_refs(node.expr, self.global_symbols.keys()) + } | { + psym.id: gtir_builtin_translators.get_tuple_type(arg) + if isinstance(arg, tuple) + else arg.gt_type + for psym, arg in zip(node.params, args, strict=True) + } + + # lower let-statement lambda node as a nested SDFG + nsdfg = dace.SDFG(name=self.unique_nsdfg_name(sdfg, "lambda")) + nsdfg.debuginfo = gtir_sdfg_utils.debug_info(node, default=sdfg.debuginfo) + lambda_translator = self.setup_nested_context(node.expr, nsdfg, lambda_symbols) + + nstate = nsdfg.add_state("lambda") + lambda_result = lambda_translator.visit( + node.expr, + sdfg=nsdfg, + head_state=nstate, + ) + + # Process lambda inputs + # + # All input arguments are passed as parameters to the nested SDFG, therefore + # we they are stored as non-transient array and scalar objects. + # + connectivity_arrays = { + gtx_dace_utils.connectivity_identifier(offset) + for offset in gtx_dace_utils.filter_connectivity_types(self.offset_provider_type) + } + + input_memlets = {} + for nsdfg_dataname, nsdfg_datadesc in nsdfg.arrays.items(): + if nsdfg_datadesc.transient: + continue + + if nsdfg_dataname in lambda_arg_nodes: + src_node = lambda_arg_nodes[nsdfg_dataname].dc_node + dataname = src_node.data + datadesc = src_node.desc(sdfg) + else: + dataname = nsdfg_dataname + datadesc = sdfg.arrays[nsdfg_dataname] + + # ensure that connectivity tables are non-transient arrays in parent SDFG + if dataname in connectivity_arrays: + datadesc.transient = False + + input_memlets[nsdfg_dataname] = sdfg.make_array_memlet(dataname) + + # Process lambda outputs + # + # The output arguments do not really exist, so they are not allocated before + # visiting the lambda expression. Therefore, the result appears inside the + # nested SDFG as transient array/scalar storage. The exception is given by + # input arguments that are just passed through and returned by the lambda, + # e.g. when the lambda is constructing a tuple: in this case, the result + # data is non-transient, because it corresponds to an input node. + # The transient storage of the lambda result in nested-SDFG is corrected + # below by the call to `make_temps()`: this function ensures that the result + # transient nodes are changed to non-transient and the corresponding output + # connecters on the nested SDFG are connected to new data nodes in parent SDFG. + # + lambda_output_data: Iterable[gtir_builtin_translators.FieldopData] = ( + gtx_utils.flatten_nested_tuple(lambda_result) + ) + # The output connectors only need to be setup for the actual result of the + # internal dataflow that writes to transient nodes. + # We filter out the non-transient nodes because they are already available + # in the current context. Later these nodes will eventually be removed + # from the nested SDFG because they are isolated (see `make_temps()`). + lambda_outputs = { + output_data.dc_node.data + for output_data in lambda_output_data + if output_data.dc_node.desc(nsdfg).transient + } + + # map free symbols to parent SDFG + nsdfg_symbols_mapping = {str(sym): sym for sym in nsdfg.free_symbols} + for sym, arg in zip(node.params, args, strict=True): + nsdfg_symbols_mapping |= gtir_builtin_translators.get_arg_symbol_mapping( + sym.id, arg, sdfg + ) + + nsdfg_node = head_state.add_nested_sdfg( + nsdfg, + parent=sdfg, + inputs=set(input_memlets.keys()), + outputs=lambda_outputs, + symbol_mapping=nsdfg_symbols_mapping, + debuginfo=gtir_sdfg_utils.debug_info(node, default=sdfg.debuginfo), + ) + + for connector, memlet in input_memlets.items(): + if connector in lambda_arg_nodes: + src_node = lambda_arg_nodes[connector].dc_node + else: + src_node = head_state.add_access(memlet.data) + + head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) + + def construct_output_for_nested_sdfg( + inner_data: gtir_builtin_translators.FieldopData, + ) -> gtir_builtin_translators.FieldopData: + """ + This function makes a data container that lives inside a nested SDFG, denoted by `inner_data`, + available in the parent SDFG. + In order to achieve this, the data container inside the nested SDFG is marked as non-transient + (in other words, externally allocated - a requirement of the SDFG IR) and a new data container + is created within the parent SDFG, with the same properties (shape, stride, etc.) of `inner_data` + but appropriatly remapped using the symbol mapping table. + For lambda arguments that are simply returned by the lambda, the `inner_data` was already mapped + to a parent SDFG data container, therefore it can be directly accessed in the parent SDFG. + The same happens to symbols available in the lambda context but not explicitly passed as lambda + arguments, that are simply returned by the lambda: it can be directly accessed in the parent SDFG. + """ + inner_desc = inner_data.dc_node.desc(nsdfg) + inner_dataname = inner_data.dc_node.data + if inner_desc.transient: + # Transient data nodes only exist within the nested SDFG. In order to return some result data, + # the corresponding data container inside the nested SDFG has to be changed to non-transient, + # that is externally allocated, as required by the SDFG IR. An output edge will write the result + # from the nested-SDFG to a new intermediate data container allocated in the parent SDFG. + outer_data = inner_data.map_to_parent_sdfg( + self, nsdfg, sdfg, head_state, nsdfg_symbols_mapping + ) + head_state.add_edge( + nsdfg_node, + inner_dataname, + outer_data.dc_node, + None, + sdfg.make_array_memlet(outer_data.dc_node.data), + ) + elif inner_dataname in lambda_arg_nodes: + # This if branch and the next one handle the non-transient result nodes. + # Non-transient nodes are just input nodes that are immediately returned + # by the lambda expression. Therefore, these nodes are already available + # in the parent context and can be directly accessed there. + outer_data = lambda_arg_nodes[inner_dataname] + else: + # This must be a symbol captured from the lambda parent scope. + outer_node = head_state.add_access(inner_dataname) + outer_data = gtir_builtin_translators.FieldopData( + outer_node, inner_data.gt_type, inner_data.origin + ) + # Isolated access node will make validation fail. + # Isolated access nodes can be found in the join-state of an if-expression + # or in lambda expressions that just construct tuples from input arguments. + if nstate.degree(inner_data.dc_node) == 0: + nstate.remove_node(inner_data.dc_node) + return outer_data + + return gtx_utils.tree_map(construct_output_for_nested_sdfg)(lambda_result) + + def visit_Literal( + self, + node: gtir.Literal, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> gtir_builtin_translators.FieldopResult: + return gtir_builtin_translators.translate_literal(node, sdfg, head_state, self) + + def visit_SymRef( + self, + node: gtir.SymRef, + sdfg: dace.SDFG, + head_state: dace.SDFGState, + ) -> gtir_builtin_translators.FieldopResult: + return gtir_builtin_translators.translate_symbol_ref(node, sdfg, head_state, self) + + +def _remove_field_origin_symbols(ir: gtir.Program, sdfg: dace.SDFG) -> None: + """ + Helper function to remove the origin symbols used in program field arguments, + that is only for non-transient data descriptors in the top-level SDFG. + The start symbol of field domain range is set to constant value 0, thus removing + the corresponding free symbol. These values are propagated to all nested SDFGs. + + This function is only used by `build_sdfg_from_gtir()` when the option flag + `disable_field_origin_on_program_arguments` is set to True. + """ + + # collect symbols used as range start for all program arguments + range_start_symbols: dict[str, dace.symbolic.SymExpr] = {} + for p in ir.params: + if isinstance(p.type, ts.TupleType): + psymbols = [ + sym + for sym in gtir_sdfg_utils.flatten_tuple_fields(p.id, p.type) + if isinstance(sym.type, ts.FieldType) + ] + elif isinstance(p.type, ts.FieldType): + psymbols = [p] + else: + psymbols = [] + for psymbol in psymbols: + assert isinstance(psymbol.type, ts.FieldType) + if len(psymbol.type.dims) == 0: + # zero-dimensional field + continue + dataname = str(psymbol.id) + # set all range start symbols to constant value 0 + range_start_symbols |= { + gtx_dace_utils.range_start_symbol(dataname, i): 0 + for i in range(len(psymbol.type.dims)) + } + # we set all range start symbols to 0 in the top-level SDFG and proagate them to nested SDFGs + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, range_start_symbols, validate=True) + + +def build_sdfg_from_gtir( + ir: gtir.Program, + offset_provider_type: gtx_common.OffsetProviderType, + column_axis: Optional[gtx_common.Dimension] = None, + disable_field_origin_on_program_arguments: bool = False, +) -> dace.SDFG: + """ + Receives a GTIR program and lowers it to a DaCe SDFG. + + The lowering to SDFG requires that the program node is type-annotated, therefore this function + runs type ineference as first step. + + Args: + ir: The GTIR program node to be lowered to SDFG + offset_provider_type: The definitions of offset providers used by the program node + column_axis: Vertical dimension used for column scan expressions. + disable_field_origin_on_program_arguments: When True, the field range in all dimensions is assumed to start from 0 + + Returns: + An SDFG in the DaCe canonical form (simplified) + """ + + if ir.function_definitions: + raise NotImplementedError("Functions expected to be inlined as lambda calls.") + + ir = gtir_type_inference.infer(ir, offset_provider_type=offset_provider_type) + ir = ir_prune_casts.PruneCasts().visit(ir) + + # DaCe requires C-compatible strings for the names of data containers, + # such as arrays and scalars. GT4Py uses a unicode symbols ('ᐞ') as name + # separator in the SSA pass, which generates invalid symbols for DaCe. + # Here we find new names for invalid symbols present in the IR. + ir = gtir_sdfg_utils.replace_invalid_symbols(ir) + + global_symbols = {str(p.id): p.type for p in ir.params if isinstance(p.type, ts.DataType)} + sdfg_genenerator = GTIRToSDFG(offset_provider_type, column_axis, global_symbols) + sdfg = sdfg_genenerator.visit(ir) + assert isinstance(sdfg, dace.SDFG) + + if disable_field_origin_on_program_arguments: + _remove_field_origin_symbols(ir, sdfg) + + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py new file mode 100644 index 0000000000..9a27cad21c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/gtir_sdfg_utils.py @@ -0,0 +1,129 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import Dict, Optional, TypeVar + +import dace + +from gt4py import eve +from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py.next.iterator import ir as gtir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.type_system import type_specifications as ts + + +def debug_info( + node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None +) -> Optional[dace.dtypes.DebugInfo]: + """Include the GT4Py node location as debug information in the corresponding SDFG nodes.""" + location = node.location + if location: + return dace.dtypes.DebugInfo( + start_line=location.line, + start_column=location.column if location.column else 0, + end_line=location.end_line if location.end_line else -1, + end_column=location.end_column if location.end_column else 0, + filename=location.filename, + ) + return default + + +def get_map_variable(dim: gtx_common.Dimension) -> str: + """ + Format map variable name based on the naming convention for application-specific SDFG transformations. + """ + suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" + return f"i_{dim.value}_gtx_{dim.kind}{suffix}" + + +def make_symbol_tree(tuple_name: str, tuple_type: ts.TupleType) -> tuple[gtir.Sym, ...]: + """ + Creates a tree representation of the symbols corresponding to the tuple fields. + The constructed tree preserves the nested nature of the tuple type, if any. + + Examples + -------- + >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) + >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) + >>> assert make_symbol_tree("a", t) == ( + ... im.sym("a_0", sty), + ... (im.sym("a_1_0", fty), im.sym("a_1_1", sty)), + ... ) + """ + assert all(isinstance(t, ts.DataType) for t in tuple_type.types) + fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] + return tuple( + make_symbol_tree(field_name, field_type) # type: ignore[misc] + if isinstance(field_type, ts.TupleType) + else im.sym(field_name, field_type) + for field_name, field_type in fields + ) + + +def flatten_tuple_fields(tuple_name: str, tuple_type: ts.TupleType) -> list[gtir.Sym]: + """ + Creates a list of symbols, annotated with the data type, for all elements of the given tuple. + + Examples + -------- + >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) + >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) + >>> assert flatten_tuple_fields("a", t) == [ + ... im.sym("a_0", sty), + ... im.sym("a_1_0", fty), + ... im.sym("a_1_1", sty), + ... ] + """ + symbol_tree = make_symbol_tree(tuple_name, tuple_type) + return list(gtx_utils.flatten_nested_tuple(symbol_tree)) + + +def replace_invalid_symbols(ir: gtir.Program) -> gtir.Program: + """ + Ensure that all symbols used in the program IR are valid strings (e.g. no unicode-strings). + + If any invalid symbol present, this function returns a copy of the input IR where + the invalid symbols have been replaced with new names. If all symbols are valid, + the input IR is returned without copying it. + """ + + class ReplaceSymbols(eve.PreserveLocationVisitor, eve.NodeTranslator): + T = TypeVar("T", gtir.Sym, gtir.SymRef) + + def _replace_sym(self, node: T, symtable: Dict[str, str]) -> T: + sym = str(node.id) + return type(node)(id=symtable.get(sym, sym), type=node.type) + + def visit_Sym(self, node: gtir.Sym, *, symtable: Dict[str, str]) -> gtir.Sym: + return self._replace_sym(node, symtable) + + def visit_SymRef(self, node: gtir.SymRef, *, symtable: Dict[str, str]) -> gtir.SymRef: + return self._replace_sym(node, symtable) + + # program arguments are checked separetely, because they cannot be replaced + if not all(dace.dtypes.validate_name(str(sym.id)) for sym in ir.params): + raise ValueError("Invalid symbol in program parameters.") + + ir_sym_ids = {str(sym.id) for sym in eve.walk_values(ir).if_isinstance(gtir.Sym).to_set()} + ir_ssa_uuid = eve.utils.UIDGenerator(prefix="gtir_tmp") + + invalid_symbols_mapping = { + sym_id: ir_ssa_uuid.sequential_id() + for sym_id in ir_sym_ids + if not dace.dtypes.validate_name(sym_id) + } + if len(invalid_symbols_mapping) == 0: + return ir + + # assert that the new symbol names are not used in the IR + assert ir_sym_ids.isdisjoint(invalid_symbols_mapping.values()) + return ReplaceSymbols().visit(ir, symtable=invalid_symbols_mapping) diff --git a/src/gt4py/next/program_processors/runners/dace/program.py b/src/gt4py/next/program_processors/runners/dace/program.py new file mode 100644 index 0000000000..a381346a1e --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/program.py @@ -0,0 +1,250 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import collections +import dataclasses +import itertools +import typing +from typing import Any, ClassVar, Optional, Sequence + +import dace +import numpy as np + +from gt4py.next import backend as next_backend, common +from gt4py.next.ffront import decorator +from gt4py.next.iterator import ir as itir, transforms as itir_transforms +from gt4py.next.iterator.transforms import extractors as extractors +from gt4py.next.otf import arguments, recipes, toolchain +from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils +from gt4py.next.type_system import type_specifications as ts + + +@dataclasses.dataclass(frozen=True) +class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): + """Extension of GT4Py Program implementing the SDFGConvertible interface via GTIR.""" + + sdfg_closure_cache: dict[str, Any] = dataclasses.field(default_factory=dict) + # Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, + # there is no name mangling of the connectivity tables used across the nested SDFGs + # since they share the same memory address. + connectivity_tables_data_descriptors: ClassVar[ + dict[str, dace.data.Array] + ] = {} # symbolically defined + + def __sdfg__(self, *args: Any, **kwargs: Any) -> dace.sdfg.sdfg.SDFG: + if (self.backend is None) or "dace" not in self.backend.name.lower(): + raise ValueError("The SDFG can be generated only for the DaCe backend.") + + offset_provider: common.OffsetProvider = { + **(self.connectivities or {}), + **self._implicit_offset_provider, + } + column_axis = kwargs.get("column_axis", None) + + # TODO(ricoh): connectivity tables required here for now. + gtir_stage = typing.cast(next_backend.Transforms, self.backend.transforms).past_to_itir( + toolchain.CompilableProgram( + data=self.past_stage, + args=arguments.CompileTimeArgs( + args=tuple(p.type for p in self.past_stage.past_node.params), + kwargs={}, + column_axis=column_axis, + offset_provider=offset_provider, + ), + ) + ) + program = gtir_stage.data + program = itir_transforms.apply_fieldview_transforms( # run the transforms separately because they require the runtime info + program, offset_provider=offset_provider + ) + object.__setattr__( + gtir_stage, + "data", + program, + ) + object.__setattr__( + gtir_stage.args, "offset_provider", gtir_stage.args.offset_provider_type + ) # TODO(ricoh): currently this is circumventing the frozenness of CompileTimeArgs + # in order to isolate DaCe from the runtime tables in connectivities.offset_provider. + # These are needed at the time of writing for mandatory GTIR passes. + # Remove this as soon as Program does not expect connectivity tables anymore. + + _crosscheck_dace_parsing( + dace_parsed_args=[*args, *kwargs.values()], + gt4py_program_args=[p.type for p in program.params], + ) + + compile_workflow = typing.cast( + recipes.OTFCompileWorkflow, + self.backend.executor + if not hasattr(self.backend.executor, "step") + else self.backend.executor.step, + ) # We know which backend we are using, but we don't know if the compile workflow is cached. + # TODO(ricoh): switch 'disable_itir_transforms=True' because we ran them separately previously + # and so we can ensure the SDFG does not know any runtime info it shouldn't know. Remove with + # the other parts of the workaround when possible. + sdfg = dace.SDFG.from_json( + compile_workflow.translation.replace( + disable_itir_transforms=True, disable_field_origin_on_program_arguments=True + )(gtir_stage).source_code + ) + + self.sdfg_closure_cache["arrays"] = sdfg.arrays + + # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, + # offset_providers_per_input_field. Add them as dynamic attributes to the SDFG + field_params = { + str(param.id): param for param in program.params if isinstance(param.type, ts.FieldType) + } + + def single_horizontal_dim_per_field( + fields: typing.Iterable[itir.Sym], + ) -> typing.Iterator[tuple[str, common.Dimension]]: + for field in fields: + assert isinstance(field.type, ts.FieldType) + horizontal_dims = [ + dim for dim in field.type.dims if dim.kind is common.DimensionKind.HORIZONTAL + ] + # do nothing for fields with multiple horizontal dimensions + # or without horizontal dimensions + # this is only meant for use with unstructured grids + if len(horizontal_dims) == 1: + yield str(field.id), horizontal_dims[0] + + input_fields = ( + field_params[name] for name in extractors.InputNamesExtractor.only_fields(program) + ) + sdfg.gt4py_program_input_fields = dict(single_horizontal_dim_per_field(input_fields)) + + output_fields = ( + field_params[name] for name in extractors.OutputNamesExtractor.only_fields(program) + ) + sdfg.gt4py_program_output_fields = dict(single_horizontal_dim_per_field(output_fields)) + + # TODO (ricoh): bring back sdfg.offset_providers_per_input_field. + # A starting point would be to use the "trace_shifts" pass on GTIR + # and associate the extracted shifts with each input field. + # Analogous to the version in `runners.dace_iterator.__init__`, which + # was removed when merging #1742. + + return sdfg + + def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: + """ + Return the closure arrays of the SDFG represented by this object + as a mapping between array name and the corresponding value. + + The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. + The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that + the offset providers are not part of GT4Py Program's arguments. + Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. + """ + closure_dict: dict[str, Any] = {} + + if self.connectivities: + symbols = {} + with_table = [ + name for name, conn in self.connectivities.items() if common.is_neighbor_table(conn) + ] + in_arrays_with_id = [ + (name, conn_id) + for name in with_table + if (conn_id := gtx_dace_utils.connectivity_identifier(name)) + in self.sdfg_closure_cache["arrays"] + ] + in_arrays = (name for name, _ in in_arrays_with_id) + name_axis = list(itertools.product(in_arrays, [0, 1])) + + def size_symbol_name(name: str, axis: int) -> str: + return gtx_dace_utils.field_size_symbol_name( + gtx_dace_utils.connectivity_identifier(name), axis + ) + + connectivity_tables_size_symbols = { + (sname := size_symbol_name(name, axis)): dace.symbol(sname) + for name, axis in name_axis + } + + def stride_symbol_name(name: str, axis: int) -> str: + return gtx_dace_utils.field_stride_symbol_name( + gtx_dace_utils.connectivity_identifier(name), axis + ) + + connectivity_table_stride_symbols = { + (sname := stride_symbol_name(name, axis)): dace.symbol(sname) + for name, axis in name_axis + } + + symbols = connectivity_tables_size_symbols | connectivity_table_stride_symbols + + # Define the storage location (e.g. CPU, GPU) of the connectivity tables + if "storage" not in self.connectivity_tables_data_descriptors: + for _, conn_id in in_arrays_with_id: + self.connectivity_tables_data_descriptors["storage"] = self.sdfg_closure_cache[ + "arrays" + ][conn_id].storage + break + + # Build the closure dictionary + for name, conn_id in in_arrays_with_id: + if conn_id not in self.connectivity_tables_data_descriptors: + conn = self.connectivities[name] + assert common.is_neighbor_table(conn) + self.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( + dtype=dace.dtypes.dtype_to_typeclass(conn.dtype.dtype.type), + shape=[ + symbols[gtx_dace_utils.field_size_symbol_name(conn_id, 0)], + symbols[gtx_dace_utils.field_size_symbol_name(conn_id, 1)], + ], + strides=[ + symbols[gtx_dace_utils.field_stride_symbol_name(conn_id, 0)], + symbols[gtx_dace_utils.field_stride_symbol_name(conn_id, 1)], + ], + storage=Program.connectivity_tables_data_descriptors["storage"], + ) + closure_dict[conn_id] = self.connectivity_tables_data_descriptors[conn_id] + + return closure_dict + + def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: + return [p.id for p in self.past_stage.past_node.params], [] + + +def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> None: + for dace_parsed_arg, gt4py_program_arg in zip( + dace_parsed_args, + gt4py_program_args, + strict=False, # dace does not see implicit size args + ): + match dace_parsed_arg: + case dace.data.Scalar(): + assert dace_parsed_arg.dtype == gtx_dace_utils.as_dace_type(gt4py_program_arg) + case bool() | np.bool_(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind == ts.ScalarKind.BOOL + case int() | np.integer(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] + case float() | np.floating(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] + case str() | np.str_(): + assert isinstance(gt4py_program_arg, ts.ScalarType) + assert gt4py_program_arg.kind == ts.ScalarKind.STRING + case dace.data.Array(): + assert isinstance(gt4py_program_arg, ts.FieldType) + assert isinstance(gt4py_program_arg.dtype, ts.ScalarType) + assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) + assert dace_parsed_arg.dtype == gtx_dace_utils.as_dace_type(gt4py_program_arg.dtype) + case dace.data.Structure() | dict() | collections.OrderedDict(): + # offset provider + pass + case _: + raise ValueError( + f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}" + ) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py similarity index 52% rename from src/gt4py/next/program_processors/runners/dace_common/dace_backend.py rename to src/gt4py/next/program_processors/runners/dace/sdfg_callable.py index 6039c82fdb..09720ddf3c 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace/sdfg_callable.py @@ -7,14 +7,15 @@ # SPDX-License-Identifier: BSD-3-Clause import warnings from collections.abc import Mapping, Sequence -from typing import Any, Iterable +from typing import Any, Optional import dace import numpy as np -from gt4py.next import common as gtx_common, utils as gtx_utils +from gt4py._core import definitions as core_defs +from gt4py.next import common as gtx_common -from . import utility as dace_utils +from . import utils as gtx_dace_utils try: @@ -23,47 +24,45 @@ cp = None -def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any: +def _convert_arg(arg: Any) -> tuple[Any, Optional[gtx_common.Domain]]: if not isinstance(arg, gtx_common.Field): - return arg - # field domain offsets are not supported - non_zero_offsets = [ - (dim, dim_range) - for dim, dim_range in zip(arg.domain.dims, arg.domain.ranges, strict=True) - if dim_range.start != 0 - ] - if non_zero_offsets: - dim, dim_range = non_zero_offsets[0] - raise RuntimeError( - f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}." - ) - if not use_field_canonical_representation: - return arg.ndarray - # the canonical representation requires alphabetical ordering of the dimensions in field domain definition - sorted_dims = dace_utils.get_sorted_dims(arg.domain.dims) - ndim = len(sorted_dims) - dim_indices = [dim_index for dim_index, _ in sorted_dims] - if isinstance(arg.ndarray, np.ndarray): - return np.moveaxis(arg.ndarray, range(ndim), dim_indices) - else: - assert cp is not None and isinstance(arg.ndarray, cp.ndarray) - return cp.moveaxis(arg.ndarray, range(ndim), dim_indices) - - -def _get_args( - sdfg: dace.SDFG, args: Sequence[Any], use_field_canonical_representation: bool -) -> dict[str, Any]: + return arg, None + if len(arg.domain.dims) == 0: + # Pass zero-dimensional fields as scalars. + return arg.as_scalar(), None + return arg.ndarray, arg.domain + + +def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names - flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) - return { - sdfg_param: _convert_arg(arg, sdfg_param, use_field_canonical_representation) - for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True) - } + sdfg_arguments = {} + range_symbols: dict[str, int] = {} + for sdfg_param, arg in zip(sdfg_params, args, strict=True): + sdfg_arg, domain = _convert_arg(arg) + sdfg_arguments[sdfg_param] = sdfg_arg + if domain: + assert gtx_common.Domain.is_finite(domain) + range_symbols |= { + gtx_dace_utils.range_start_symbol(sdfg_param, i): r.start + for i, r in enumerate(domain.ranges) + } + range_symbols |= { + gtx_dace_utils.range_stop_symbol(sdfg_param, i): r.stop + for i, r in enumerate(domain.ranges) + } + # sanity check in case range symbols are passed as explicit program arguments + for range_symbol, value in range_symbols.items(): + if (sdfg_arg := sdfg_arguments.get(range_symbol, None)) is not None: + if sdfg_arg != value: + raise ValueError( + f"Received program argument {range_symbol} with value {sdfg_arg}, expected {value}." + ) + return sdfg_arguments | range_symbols def _ensure_is_on_device( - connectivity_arg: np.typing.NDArray, device: dace.dtypes.DeviceType -) -> np.typing.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: dace.dtypes.DeviceType +) -> core_defs.NDArrayObject: if device == dace.dtypes.DeviceType.GPU: if not isinstance(connectivity_arg, cp.ndarray): warnings.warn( @@ -75,7 +74,7 @@ def _ensure_is_on_device( def _get_shape_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: shape_args: dict[str, int] = {} for name, value in args.items(): @@ -84,14 +83,14 @@ def _get_shape_args( assert sym.name not in shape_args shape_args[sym.name] = size elif sym != size: - raise RuntimeError( + raise ValueError( f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}." ) return shape_args def _get_stride_args( - arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray] + arrays: Mapping[str, dace.data.Array], args: Mapping[str, core_defs.NDArrayObject] ) -> dict[str, int]: stride_args = {} for name, value in args.items(): @@ -103,9 +102,9 @@ def _get_stride_args( ) if isinstance(sym, dace.symbol): assert sym.name not in stride_args - stride_args[str(sym)] = stride + stride_args[sym.name] = stride elif sym != stride: - raise RuntimeError( + raise ValueError( f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}." ) return stride_args @@ -115,7 +114,7 @@ def get_sdfg_conn_args( sdfg: dace.SDFG, offset_provider: gtx_common.OffsetProvider, on_gpu: bool, -) -> dict[str, np.typing.NDArray]: +) -> dict[str, core_defs.NDArrayObject]: """ Extracts the connectivity tables that are used in the sdfg and ensures that the memory buffers are allocated for the target device. @@ -123,46 +122,52 @@ def get_sdfg_conn_args( device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU connectivity_args = {} - for offset, connectivity in dace_utils.filter_connectivities(offset_provider).items(): - assert isinstance(connectivity, gtx_common.NeighborTable) - param = dace_utils.connectivity_identifier(offset) - if param in sdfg.arrays: - connectivity_args[param] = _ensure_is_on_device(connectivity.table, device) + for offset, connectivity in offset_provider.items(): + if gtx_common.is_neighbor_table(connectivity): + param = gtx_dace_utils.connectivity_identifier(offset) + if param in sdfg.arrays: + connectivity_args[param] = _ensure_is_on_device(connectivity.ndarray, device) return connectivity_args def get_sdfg_args( sdfg: dace.SDFG, + offset_provider: gtx_common.OffsetProvider, *args: Any, check_args: bool = False, on_gpu: bool = False, - use_field_canonical_representation: bool = True, - **kwargs: Any, ) -> dict[str, Any]: """Extracts the arguments needed to call the SDFG. - This function can handle the same arguments that are passed to dace runner. + This function can handle the arguments that are passed to the dace runner + and that end up in the decoration stage of the dace backend workflow. Args: sdfg: The SDFG for which we want to get the arguments. + offset_provider: The offset provider. + args: The list of arguments passed to the dace runner. + check_args: If True, return only the arguments that are expected + according to the SDFG signature. + on_gpu: If True, this method ensures that the arrays for the + connectivity tables are allocated in GPU memory. + + Returns: + A dictionary of keyword arguments to be passed in the SDFG call. """ - offset_provider = kwargs["offset_provider"] - dace_args = _get_args(sdfg, args, use_field_canonical_representation) + dace_args = _get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + dace_field_strides = _get_stride_args(sdfg.arrays, dace_field_args) dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) - dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = _get_shape_args(sdfg.arrays, dace_conn_args) - dace_strides = _get_stride_args(sdfg.arrays, dace_field_args) dace_conn_strides = _get_stride_args(sdfg.arrays, dace_conn_args) all_args = { **dace_args, **dace_conn_args, - **dace_shapes, **dace_conn_shapes, - **dace_strides, **dace_conn_strides, + **dace_field_strides, } if check_args: diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py new file mode 100644 index 0000000000..6157704857 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -0,0 +1,87 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Transformation and optimization pipeline for the DaCe backend in GT4Py. + +Please also see [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md) +that explains the general structure and requirements on the SDFGs. +""" + +from .auto_optimize import gt_auto_optimize +from .gpu_utils import ( + GPUSetBlockSize, + gt_gpu_transform_non_standard_memlet, + gt_gpu_transformation, + gt_set_gpu_blocksize, +) +from .local_double_buffering import gt_create_local_double_buffering +from .loop_blocking import LoopBlocking +from .map_fusion import MapFusion, MapFusionParallel, MapFusionSerial +from .map_orderer import MapIterationOrder, gt_set_iteration_order +from .map_promoter import SerialMapPromoter +from .redundant_array_removers import ( + CopyChainRemover, + MultiStateGlobalSelfCopyElimination, + SingleStateGlobalSelfCopyElimination, + gt_multi_state_global_self_copy_elimination, + gt_remove_copy_chain, +) +from .simplify import ( + GT_SIMPLIFY_DEFAULT_SKIP_SET, + GT4PyMapBufferElimination, + GT4PyMoveTaskletIntoMap, + gt_inline_nested_sdfg, + gt_reduce_distributed_buffering, + gt_simplify, + gt_substitute_compiletime_symbols, +) +from .strides import ( + gt_change_transient_strides, + gt_map_strides_to_dst_nested_sdfg, + gt_map_strides_to_src_nested_sdfg, + gt_propagate_strides_from_access_node, + gt_propagate_strides_of, +) +from .utils import gt_find_constant_arguments, gt_make_transients_persistent + + +__all__ = [ + "GT_SIMPLIFY_DEFAULT_SKIP_SET", + "CopyChainRemover", + "GPUSetBlockSize", + "GT4PyMapBufferElimination", + "GT4PyMoveTaskletIntoMap", + "LoopBlocking", + "MapFusion", + "MapFusionParallel", + "MapFusionSerial", + "MapIterationOrder", + "MultiStateGlobalSelfCopyElimination", + "SerialMapPromoter", + "SerialMapPromoterGPU", + "SingleStateGlobalSelfCopyElimination", + "gt_auto_optimize", + "gt_change_transient_strides", + "gt_create_local_double_buffering", + "gt_find_constant_arguments", + "gt_gpu_transform_non_standard_memlet", + "gt_gpu_transformation", + "gt_inline_nested_sdfg", + "gt_make_transients_persistent", + "gt_map_strides_to_dst_nested_sdfg", + "gt_map_strides_to_src_nested_sdfg", + "gt_multi_state_global_self_copy_elimination", + "gt_propagate_strides_from_access_node", + "gt_propagate_strides_of", + "gt_reduce_distributed_buffering", + "gt_remove_copy_chain", + "gt_set_gpu_blocksize", + "gt_set_iteration_order", + "gt_simplify", + "gt_substitute_compiletime_symbols", +] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py similarity index 65% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py rename to src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 37cc89aa2b..739fe39584 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -8,167 +8,35 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any, Final, Iterable, Optional, Sequence +from typing import Any, Optional, Sequence, Union import dace -from dace.transformation import dataflow as dace_dataflow, passes as dace_passes +from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize +from dace.transformation.passes import analysis as dace_analysis from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) - - -GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} -"""Set of simplify passes `gt_simplify()` skips by default. - -The following passes are included: -- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a - symbol or vice versa and at a later point to invert this again. However, this - pass has some problems with this pattern so for the time being it is disabled. -- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. -""" - - -def gt_simplify( - sdfg: dace.SDFG, - validate: bool = True, - validate_all: bool = False, - skip: Optional[Iterable[str]] = None, -) -> Any: - """Performs simplifications on the SDFG in place. - - Instead of calling `sdfg.simplify()` directly, you should use this function, - as it is specially tuned for GridTool based SDFGs. - - This function runs the DaCe simplification pass, but the following passes are - replaced: - - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. - - Furthermore, by default, or if `None` is passed fro `skip` the passes listed in - `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. - - Args: - sdfg: The SDFG to optimize. - validate: Perform validation after the pass has run. - validate_all: Perform extensive validation. - skip: List of simplify passes that should not be applied, defaults - to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. - """ - # Ensure that `skip` is a `set` - skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) - - if "InlineSDFGs" not in skip: - gt_inline_nested_sdfg( - sdfg=sdfg, - multistate=True, - permissive=False, - validate=validate, - validate_all=validate_all, - ) - - return dace_passes.SimplifyPass( - validate=validate, - validate_all=validate_all, - verbose=False, - skip=(skip | {"InlineSDFGs"}), - ).apply_pass(sdfg, {}) - - -def gt_set_iteration_order( - sdfg: dace.SDFG, - leading_dim: gtx_common.Dimension, - validate: bool = True, - validate_all: bool = False, -) -> Any: - """Set the iteration order of the Maps correctly. - - Modifies the order of the Map parameters such that `leading_dim` - is the fastest varying one, the order of the other dimensions in - a Map is unspecific. `leading_dim` should be the dimensions were - the stride is one. - - Args: - sdfg: The SDFG to process. - leading_dim: The leading dimensions. - validate: Perform validation during the steps. - validate_all: Perform extensive validation. - """ - return sdfg.apply_transformations_once_everywhere( - gtx_transformations.MapIterationOrder( - leading_dim=leading_dim, - ), - validate=validate, - validate_all=validate_all, - ) - - -def gt_inline_nested_sdfg( - sdfg: dace.SDFG, - multistate: bool = True, - permissive: bool = False, - validate: bool = True, - validate_all: bool = False, -) -> dace.SDFG: - """Perform inlining of nested SDFG into their parent SDFG. - - The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. - However, before the inline transformation is run the function will run some - cleaning passes that allows inlining nested SDFGs. - As a side effect, the function will split stages into more states. - - Args: - sdfg: The SDFG that should be processed, will be modified in place and returned. - multistate: Allow inlining of multistate nested SDFG, defaults to `True`. - permissive: Be less strict on the accepted SDFGs. - validate: Perform validation after the transformation has finished. - validate_all: Performs extensive validation. - """ - first_iteration = True - i = 0 - while True: - print(f"ITERATION: {i}") - nb_preproccess = sdfg.apply_transformations_repeated( - [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], - validate=False, - validate_all=validate_all, - ) - if (nb_preproccess == 0) and (not first_iteration): - break - - # Create and configure the inline pass - inline_sdfg = dace_passes.InlineSDFGs() - inline_sdfg.progress = False - inline_sdfg.permissive = permissive - inline_sdfg.multistate = multistate - - # Apply the inline pass - nb_inlines = inline_sdfg.apply_pass(sdfg, {}) - - # Check result, if needed and test if we can stop - if validate_all or validate: - sdfg.validate() - if nb_inlines == 0: - break - first_iteration = False - - return sdfg +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations def gt_auto_optimize( sdfg: dace.SDFG, gpu: bool, - leading_dim: Optional[gtx_common.Dimension] = None, + leading_dim: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, aggressive_fusion: bool = True, max_optimization_rounds_p2: int = 100, make_persistent: bool = True, gpu_block_size: Optional[Sequence[int | str] | str] = None, blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, + blocking_only_if_independent_nodes: Optional[bool] = None, reuse_transients: bool = False, gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, + constant_symbols: Optional[dict[str, Any]] = None, + assume_pointwise: bool = True, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -184,6 +52,9 @@ def gt_auto_optimize( different aspects of the SDFG. The initial SDFG is assumed to have a very large number of rather simple Maps. + Note, because of how `gt_auto_optimizer()` works it is not save to call + it twice on the same SDFG. + 1. Some general simplification transformations, beyond classical simplify, are applied to the SDFG. 2. Tries to create larger kernels by fusing smaller ones, see @@ -195,10 +66,11 @@ def gt_auto_optimize( one with stride one. 5. If requested the function will now apply loop blocking, on the dimension indicated by `leading_dim`. - 6. If requested the SDFG will be transformed to GPU. For this the + 6. The strides of temporaries are set to match the compute order. + 7. If requested the SDFG will be transformed to GPU. For this the `gt_gpu_transformation()` function is used, that might apply several other optimizations. - 7. Afterwards some general transformations to the SDFG are applied. + 8. Afterwards some general transformations to the SDFG are applied. This includes: - Use fast implementation for library nodes. - Move small transients to stack. @@ -219,24 +91,37 @@ def gt_auto_optimize( one for all. blocking_dim: On which dimension blocking should be applied. blocking_size: How many elements each block should process. + blocking_only_if_independent_nodes: If `True` only apply loop blocking if + there are independent nodes in the Map, see the `require_independent_nodes` + option of the `LoopBlocking` transformation. reuse_transients: Run the `TransientReuse` transformation, might reduce memory footprint. gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` for _all_ GPU Maps. + constant_symbols: Symbols listed in this `dict` will be replaced by the + respective value inside the SDFG. This might increase performance. + assume_pointwise: Assume that the SDFG has no risk for race condition in + global data access. See the `GT4PyMapBufferElimination` transformation for more. validate: Perform validation during the steps. validate_all: Perform extensive validation. + Note: + For identifying symbols that can be treated as compile time constants + `gt_find_constant_arguments()` function can be used. + Todo: - - Make sure that `SDFG.simplify()` is not called indirectly, by temporarily - overwriting it with `gt_simplify()`. + - Update the description. The Phases are nice, but they have lost their + link to reality a little bit. + - Improve the determination of the strides and iteration order of the + transients. + - Set padding of transients, i.e. alignment, the DaCe datadescriptor + can do that. + - Handle nested SDFGs better. - Specify arguments to set the size of GPU thread blocks depending on the dimensions. I.e. be able to use a different size for 1D than 2D Maps. - - Add a parallel version of Map fusion. - Implement some model to further guide to determine what we want to fuse. Something along the line "Fuse if operational intensity goes up, but not if we have too much internal space (register pressure). - - Create a custom array elimination pass that honors rule 1. - - Check if a pipeline could be used to speed up some computations. """ device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU @@ -249,17 +134,25 @@ def gt_auto_optimize( # to internal serial maps, such that they do not block fusion? # Phase 1: Initial Cleanup - gt_simplify( + gtx_transformations.gt_simplify( sdfg=sdfg, validate=validate, validate_all=validate_all, ) + gtx_transformations.gt_reduce_distributed_buffering(sdfg) + + if constant_symbols: + gtx_transformations.gt_substitute_compiletime_symbols( + sdfg=sdfg, + repl=constant_symbols, + validate=validate, + validate_all=validate_all, + ) + gtx_transformations.gt_simplify(sdfg) + sdfg.apply_transformations_repeated( [ dace_dataflow.TrivialMapElimination, - # TODO(phimuell): Investigate if these two are appropriate. - dace_dataflow.MapReduceFusion, - dace_dataflow.MapWCRFusion, ], validate=validate, validate_all=validate_all, @@ -275,40 +168,81 @@ def gt_auto_optimize( validate_all=validate_all, ) - # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. - # Currently this only applies fusion inside Maps. + # After we have created big kernels, we will perform some post cleanup. + gtx_transformations.gt_reduce_distributed_buffering(sdfg) sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_inner_maps=True, - ), + [ + gtx_transformations.GT4PyMoveTaskletIntoMap, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=assume_pointwise), + ], validate=validate, validate_all=validate_all, ) - gt_simplify(sdfg) + + # TODO(phimuell): The `MapReduceFusion` transformation is interesting as + # it moves the initialization of the accumulator at the top, which allows + # further fusing of the accumulator loop. However the transformation has + # a bug, so we can not use it. Furthermore, I have looked at the assembly + # and the compiler is already doing that. + # https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc + + # After we have created large kernels we run `dace_dataflow.MapReduceFusion`. + + # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. + # Currently this only applies fusion inside Maps. + gtx_transformations.gt_simplify(sdfg) + while True: + nb_applied = sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_inner_maps=True, + ), + gtx_transformations.MapFusionParallel( + only_inner_maps=True, + only_if_common_ancestor=False, # TODO(phimuell): Should we? + ), + ], + validate=validate, + validate_all=validate_all, + ) + if not nb_applied: + break + gtx_transformations.gt_simplify(sdfg) # Phase 4: Iteration Space # This essentially ensures that the stride 1 dimensions are handled # by the inner most loop nest (CPU) or x-block (GPU) if leading_dim is not None: - gt_set_iteration_order( + gtx_transformations.gt_set_iteration_order( sdfg=sdfg, leading_dim=leading_dim, validate=validate, validate_all=validate_all, ) + # We now ensure that point wise computations are properly double buffered. + # The main reason is to ensure that rule 3 of ADR18 is maintained. + gtx_transformations.gt_create_local_double_buffering(sdfg) + # Phase 5: Apply blocking if blocking_dim is not None: sdfg.apply_transformations_once_everywhere( gtx_transformations.LoopBlocking( blocking_size=blocking_size, blocking_parameter=blocking_dim, + require_independent_nodes=blocking_only_if_independent_nodes, ), validate=validate, validate_all=validate_all, ) - # Phase 6: Going to GPU + # Phase 6: Setting the strides of transients + # It is important that we set the strides before the GPU transformation. + # Because this transformation will also apply `CopyToMap` for the Memlets + # that the DaCe runtime can not handle. + gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) + + # Phase 7: Going to GPU if gpu: # TODO(phimuell): The GPU function might modify the map iteration order. # This is because how it is implemented (promotion and @@ -324,7 +258,7 @@ def gt_auto_optimize( try_removing_trivial_maps=True, ) - # Phase 7: General Optimizations + # Phase 8: General Optimizations # The following operations apply regardless if we have a GPU or CPU. # The DaCe auto optimizer also uses them. Note that the reuse transient # is not done by DaCe. @@ -339,9 +273,21 @@ def gt_auto_optimize( dace_aoptimize.set_fast_implementations(sdfg, device) # TODO(phimuell): Fix the bug, it uses the tile value and not the stack array value. dace_aoptimize.move_small_arrays_to_stack(sdfg) + if make_persistent: - # TODO(phimuell): Allow to also to set the lifetime to `SDFG`. - dace_aoptimize.make_transients_persistent(sdfg, device) + gtx_transformations.gt_make_transients_persistent(sdfg=sdfg, device=device) + + if device == dace.DeviceType.GPU: + # NOTE: For unknown reasons the counterpart of the + # `gt_make_transients_persistent()` function in DaCe, resets the + # `wcr_nonatomic` property of every memlet, i.e. makes it atomic. + # However, it does this only for edges on the top level and on GPU. + # For compatibility with DaCe (and until we found out why) the GT4Py + # auto optimizer will emulate this behaviour. + for state in sdfg.states(): + assert isinstance(state, dace.SDFGState) + for edge in state.edges(): + edge.data.wcr_nonatomic = False return sdfg @@ -387,14 +333,28 @@ def gt_auto_fuse_top_level_maps( # after the other, thus new opportunities might arise in the next round. # We use the hash of the SDFG to detect if we have reached a fix point. for _ in range(max_optimization_rounds): - # Use map fusion to reduce their number and to create big kernels # TODO(phimuell): Use a cost measurement to decide if fusion should be done. # TODO(phimuell): Add parallel fusion transformation. Should it run after # or with the serial one? + # TODO(phimuell): Switch to `FullMapFusion` once DaCe has parallel map fusion + # and [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved. + + # First we do scan the entire SDFG to figure out which data is only + # used once and can be deleted. MapFusion could do this on its own but + # it is more efficient to do it once and then reuse it. + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + fusion_transformation = gtx_transformations.MapFusion( + only_toplevel_maps=True, + allow_parallel_map_fusion=True, + allow_serial_map_fusion=True, + only_if_common_ancestor=False, + ) + fusion_transformation._single_use_data = single_use_data + sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), + fusion_transformation, validate=validate, validate_all=validate_all, ) @@ -434,7 +394,7 @@ def gt_auto_fuse_top_level_maps( # The SDFG was modified by the transformations above. The SDFG was # modified. Call Simplify and try again to further optimize. - gt_simplify(sdfg, validate=validate, validate_all=validate_all) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) else: raise RuntimeWarning("Optimization of the SDFG did not converge.") diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py new file mode 100644 index 0000000000..e1f105f0ef --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -0,0 +1,833 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Functions for turning an SDFG into a GPU SDFG.""" + +from __future__ import annotations + +import copy +from typing import Any, Callable, Final, Optional, Sequence, Union + +import dace +from dace import ( + dtypes as dace_dtypes, + properties as dace_properties, + transformation as dace_transformation, +) +from dace.codegen.targets import cpp as dace_cpp +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + + +def gt_gpu_transformation( + sdfg: dace.SDFG, + try_removing_trivial_maps: bool = True, + use_gpu_storage: bool = True, + gpu_block_size: Optional[Sequence[int | str] | str] = None, + gpu_launch_bounds: Optional[int | str] = None, + gpu_launch_factor: Optional[int] = None, + validate: bool = True, + validate_all: bool = False, + **kwargs: Any, +) -> dace.SDFG: + """Transform an SDFG into a GPU SDFG. + + The transformation expects a rather optimized SDFG and turn it into an SDFG + capable of running on the GPU. + The function performs the following steps: + - If requested, modify the storage location of the non transient arrays such + that they reside in GPU memory. + - Call the normal GPU transform function followed by simplify. + - If requested try to remove trivial kernels. + - If specified, set the `gpu_block_size` parameters of the Maps to the given value. + + Args: + sdfg: The SDFG that should be processed. + try_removing_trivial_maps: Try to get rid of trivial maps by incorporating them. + use_gpu_storage: Assume that the non global memory is already on the GPU. This + will avoid the data copy from host to GPU memory. + gpu_block_size: The size of a thread block on the GPU. + gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. + Will only take effect if `gpu_block_size` is specified. + gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` + Will only take effect if `gpu_block_size` is specified. + validate: Perform validation during the steps. + validate_all: Perform extensive validation. + + Notes: + The function might modify the order of the iteration variables of some + maps. + In addition it might fuse Maps together that should not be fused. To prevent + that you should set `try_removing_trivial_maps` to `False`. + + Todo: + - Solve the fusing problem. + - Currently only one block size for all maps is given, add more options. + """ + assert ( + len(kwargs) == 0 + ), f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" + + # Turn all global arrays (which we identify as input) into GPU memory. + # This way the GPU transformation will not create this copying stuff. + if use_gpu_storage: + for desc in sdfg.arrays.values(): + if isinstance(desc, dace.data.Array) and not desc.transient: + desc.storage = dace.dtypes.StorageType.GPU_Global + + # Now turn it into a GPU SDFG + sdfg.apply_gpu_transformations( + validate=validate, + validate_all=validate_all, + simplify=False, + ) + + # The documentation recommends to run simplify afterwards + gtx_transformations.gt_simplify(sdfg) + + if try_removing_trivial_maps: + gt_remove_trivial_gpu_maps( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) + + # TODO(phimuell): Fixing the stride problem. + sdfg = gt_gpu_transform_non_standard_memlet( + sdfg=sdfg, + map_postprocess=True, + validate=validate, + validate_all=validate_all, + ) + + # Set the GPU block size if it is known. + if gpu_block_size is not None: + gt_set_gpu_blocksize( + sdfg=sdfg, + block_size=gpu_block_size, + launch_bounds=gpu_launch_bounds, + launch_factor=gpu_launch_factor, + ) + + if validate_all or validate: + sdfg.validate() + + return sdfg + + +def gt_gpu_transform_non_standard_memlet( + sdfg: dace.SDFG, + map_postprocess: bool, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Transform some non standard Melets to Maps. + + The GPU code generator is not able to handle certain sets of Memlets. To + handle them, the code generator transforms them into copy Maps. The main + issue is that this transformation happens after the auto optimizer, thus + the copy-Maps will most likely have the wrong iteration order. + + This function allows to perform the preprocessing step before the actual + code generation. The function will perform the expansion. If + `map_postprocess` is `True` then the function will also apply MapFusion, + to these newly created copy-Maps and set their iteration order correctly. + + A user should not call this function directly, instead this function is + called by the `gt_gpu_transformation()` function. + + Args: + sdfg: The SDFG that we process. + map_postprocess: Enable post processing of the maps that are created. + See the Note section below. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + + Note: + - Currently the function applies some crude heuristic to determine the + correct loop order. + - This function should be called after `gt_set_iteration_order()` has run. + """ + + # Expand all non standard memlets and get the new MapEntries. + new_maps: set[dace_nodes.MapEntry] = _gt_expand_non_standard_memlets(sdfg) + + # If there are no Memlets that are translated to copy-Maps, then we have nothing to do. + if len(new_maps) == 0: + return sdfg + + # This function allows to restrict any fusion operation to the maps + # that we have just created. + def restrict_fusion_to_newly_created_maps( + self: gtx_transformations.MapFusion, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool, + ) -> bool: + return any(new_entry in new_maps for new_entry in [map_entry_1, map_entry_2]) + + # Using the callback to restrict the fusing + sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + apply_fusion_callback=restrict_fusion_to_newly_created_maps, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + apply_fusion_callback=restrict_fusion_to_newly_created_maps, + ), + ], + validate=validate, + validate_all=validate_all, + ) + + # Now we have to find the maps that were not fused. We rely here on the fact + # that at least one of the map that is involved in fusing still exists. + maps_to_modify: set[dace_nodes.MapEntry] = set() + for nsdfg in sdfg.all_sdfgs_recursive(): + for state in nsdfg.states(): + for map_entry in state.nodes(): + if not isinstance(map_entry, dace_nodes.MapEntry): + continue + if map_entry in new_maps: + maps_to_modify.add(map_entry) + assert 0 < len(maps_to_modify) <= len(new_maps) + + # This is a gross hack, but it is needed, for the following reasons: + # - The transients have C order while the non-transients have (most + # likely) FORTRAN order. So there is not an unique stride dimension. + # - The newly created maps have names that does not reflect GT4Py dimensions, + # thus we can not use `gt_set_iteration_order()`. + # For these reasons we do the simplest thing, which is assuming that the maps + # are created in C order and we must make them in FORTRAN order, which means + # just swapping the order of the map parameters. + # TODO(phimuell): Do it properly. + for me_to_modify in maps_to_modify: + map_to_modify: dace_nodes.Map = me_to_modify.map + map_to_modify.params = list(reversed(map_to_modify.params)) + map_to_modify.range = dace.subsets.Range( + (r1, r2, r3, t) + for (r1, r2, r3), t in zip( + reversed(map_to_modify.range.ranges), reversed(map_to_modify.range.tile_sizes) + ) + ) + + return sdfg + + +def _gt_expand_non_standard_memlets( + sdfg: dace.SDFG, +) -> set[dace_nodes.MapEntry]: + """Finds all non standard Memlet in the SDFG and expand them. + + The function is used by `gt_gpu_transform_non_standard_memlet()` and performs + the actual expansion of the Memlet, i.e. turning all Memlets that can not be + expressed as a `memcpy()` into a Map, copy kernel. + The function will return the MapEntries of all expanded. + + The function will process the SDFG recursively. + """ + new_maps: set[dace_nodes.MapEntry] = set() + for nsdfg in sdfg.all_sdfgs_recursive(): + new_maps.update(_gt_expand_non_standard_memlets_sdfg(nsdfg)) + return new_maps + + +def _gt_expand_non_standard_memlets_sdfg( + sdfg: dace.SDFG, +) -> set[dace_nodes.MapEntry]: + """Implementation of `_gt_expand_non_standard_memlets()` that process a single SDFG.""" + new_maps: set[dace_nodes.MapEntry] = set() + # The implementation is based on DaCe's code generator. + for state in sdfg.states(): + for e in state.edges(): + # We are only interested in edges that connects two access nodes of GPU memory. + if not ( + isinstance(e.src, dace_nodes.AccessNode) + and isinstance(e.dst, dace_nodes.AccessNode) + and e.src.desc(sdfg).storage == dace_dtypes.StorageType.GPU_Global + and e.dst.desc(sdfg).storage == dace_dtypes.StorageType.GPU_Global + ): + continue + + a: dace_nodes.AccessNode = e.src + b: dace_nodes.AccessNode = e.dst + copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides( + None, sdfg, state, e, a, b + ) + dims = len(copy_shape) + if dims == 1: + continue + elif dims == 2: + if src_strides[-1] != 1 or dst_strides[-1] != 1: + try: + is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] + is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] + except (TypeError, ValueError): + is_src_cont = False + is_dst_cont = False + if is_src_cont and is_dst_cont: + continue + else: + continue + elif dims > 2: + if not (src_strides[-1] != 1 or dst_strides[-1] != 1): + continue + + # For identifying the new map, we first store all neighbors of `a`. + old_neighbors_of_a: list[dace_nodes.AccessNode] = [ + edge.dst for edge in state.out_edges(a) + ] + + # Turn unsupported copy to a map + try: + dace_transformation.dataflow.CopyToMap.apply_to( + sdfg, + save=False, + annotate=False, + a=a, + b=b, + options={ + "ignore_strides": True + }, # apply 'CopyToMap' even if src/dst strides are different + ) + except ValueError: # If transformation doesn't match, continue normally + continue + + # We find the new map by comparing the new neighborhood of `a` with the old one. + new_nodes: set[dace_nodes.MapEntry] = { + edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a + } + assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes) + assert len(new_nodes) == 1 + new_maps.update(new_nodes) + return new_maps + + +def gt_set_gpu_blocksize( + sdfg: dace.SDFG, + block_size: Optional[Sequence[int | str] | str], + launch_bounds: Optional[int | str] = None, + launch_factor: Optional[int] = None, + **kwargs: Any, +) -> Any: + """Set the block size related properties of _all_ Maps. + + It supports the same arguments as `GPUSetBlockSize`, however it also has + versions without `_Xd`, these are used as default for the other maps. + If a version with `_Xd` is specified then it takes precedence. + + Args: + sdfg: The SDFG to process. + block_size: The size of a thread block on the GPU. + launch_bounds: The value for the launch bound that should be used. + launch_factor: If no `launch_bounds` was given use the number of threads + in a block multiplied by this number. + """ + for dim in [1, 2, 3]: + for arg, val in { + "block_size": block_size, + "launch_bounds": launch_bounds, + "launch_factor": launch_factor, + }.items(): + if f"{arg}_{dim}d" not in kwargs: + kwargs[f"{arg}_{dim}d"] = val + return sdfg.apply_transformations_once_everywhere(GPUSetBlockSize(**kwargs)) + + +def _make_gpu_block_parser_for( + dim: int, +) -> Callable[["GPUSetBlockSize", Any], None]: + """Generates a parser for GPU blocks for dimension `dim`. + + The returned function can be used as parser for the `GPUSetBlockSize.block_size_*d` + properties. + """ + + def _gpu_block_parser( + self: GPUSetBlockSize, + val: Any, + ) -> None: + """Used by the setter of `GPUSetBlockSize.block_size`.""" + org_val = val + if isinstance(val, (tuple | list)): + pass + elif isinstance(val, str): + val = tuple(x.strip() for x in val.split(",")) + elif isinstance(val, int): + val = (val,) + else: + raise TypeError( + f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." + ) + if len(val) < dim: + raise ValueError( + f"The passed block size only covers {len(val)} dimensions, but dimension was {dim}." + ) + if 0 < len(val) <= 3: + val = [*val, *([1] * (3 - len(val)))] + else: + raise ValueError(f"Can not parse block size '{org_val}': wrong length") + try: + val = [int(x) for x in val] + except ValueError: + raise TypeError( + f"Currently only block sizes convertible to int are supported, you passed '{val}'." + ) from None + + # Remove over specification. + for i in range(dim, 3): + val[i] = 1 + setattr(self, f"_block_size_{dim}d", tuple(val)) + + return _gpu_block_parser + + +def _make_gpu_block_getter_for( + dim: int, +) -> Callable[["GPUSetBlockSize"], tuple[int, int, int]]: + """Makes the getter for the block size of dimension `dim`.""" + + def _gpu_block_getter( + self: "GPUSetBlockSize", + ) -> tuple[int, int, int]: + """Used as getter in the `GPUSetBlockSize.block_size` property.""" + return getattr(self, f"_block_size_{dim}d") + + return _gpu_block_getter + + +def _gpu_launch_bound_parser( + block_size: tuple[int, int, int], + launch_bounds: int | str | None, + launch_factor: int | None = None, +) -> str | None: + """Used by the `GPUSetBlockSize.__init__()` method to parse the launch bounds.""" + if launch_bounds is None and launch_factor is None: + return None + elif launch_bounds is None and launch_factor is not None: + return str(int(launch_factor) * block_size[0] * block_size[1] * block_size[2]) + elif launch_bounds is not None and launch_factor is None: + assert isinstance(launch_bounds, (str, int)) + return str(launch_bounds) + else: + raise ValueError("Specified both `launch_bounds` and `launch_factor`.") + + +@dace_properties.make_properties +class GPUSetBlockSize(dace_transformation.SingleStateTransformation): + """Sets the GPU block size on GPU Maps. + + The `block_size` is either a sequence, of up to three integers or a string + of up to three numbers, separated by comma (`,`). The first number is the size + of the block in `x` direction, the second for the `y` direction and the third + for the `z` direction. Missing values will be filled with `1`. + + A different value for the GPU block size and launch bound can be specified for + maps of dimension 1, 2 or 3 (all maps with higher dimensions are considered + three dimensional). If no value is specified then the block size `(32, 1, 1)` + will be used an no launch bound will be be emitted. + + Args: + block_size_Xd: The size of a thread block on the GPU for `X` dimensional maps. + launch_bounds_Xd: The value for the launch bound that should be used for `X` + dimensional maps. + launch_factor_Xd: If no `launch_bounds` was given use the number of threads + in a block multiplied by this number, for maps of dimension `X`. + + Note: + - You should use the `gt_set_gpu_blocksize()` function. + - "Over specification" is ignored, i.e. if `(32, 3, 1)` is passed as block + size for 1 dimensional maps, then it is changed to `(32, 1, 1)`. + """ + + _block_size_default: Final[tuple[int, int, int]] = (32, 1, 1) + + block_size_1d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(1), + getter=_make_gpu_block_getter_for(1), + desc="Block size for 1 dimensional GPU maps.", + ) + launch_bounds_1d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 1 dimensional map.", + ) + block_size_2d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(2), + getter=_make_gpu_block_getter_for(2), + desc="Block size for 2 dimensional GPU maps.", + ) + launch_bounds_2d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 2 dimensional map.", + ) + block_size_3d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(3), + getter=_make_gpu_block_getter_for(3), + desc="Block size for 3 dimensional GPU maps.", + ) + launch_bounds_3d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 3 dimensional map.", + ) + + # Pattern matching + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + block_size_1d: Sequence[int | str] | str | None = None, + block_size_2d: Sequence[int | str] | str | None = None, + block_size_3d: Sequence[int | str] | str | None = None, + launch_bounds_1d: int | str | None = None, + launch_bounds_2d: int | str | None = None, + launch_bounds_3d: int | str | None = None, + launch_factor_1d: int | None = None, + launch_factor_2d: int | None = None, + launch_factor_3d: int | None = None, + ) -> None: + super().__init__() + if block_size_1d is not None: + self.block_size_1d = block_size_1d + if block_size_2d is not None: + self.block_size_2d = block_size_2d + if block_size_3d is not None: + self.block_size_3d = block_size_3d + self.launch_bounds_1d = _gpu_launch_bound_parser( + self.block_size_1d, launch_bounds_1d, launch_factor_1d + ) + self.launch_bounds_2d = _gpu_launch_bound_parser( + self.block_size_2d, launch_bounds_2d, launch_factor_2d + ) + self.launch_bounds_3d = _gpu_launch_bound_parser( + self.block_size_3d, launch_bounds_3d, launch_factor_3d + ) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Test if the block size can be set. + + The function tests: + - If the block size of the map is already set. + - If the map is at global scope. + - If if the schedule of the map is correct. + """ + scope = graph.scope_dict() + if scope[self.map_entry] is not None: + return False + if self.map_entry.map.schedule not in dace.dtypes.GPU_SCHEDULES: + return False + if self.map_entry.map.gpu_block_size is not None: + return False + return True + + def apply( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + """Modify the map as requested.""" + gpu_map: dace_nodes.Map = self.map_entry.map + if len(gpu_map.params) == 1: + block_size = self.block_size_1d + launch_bounds = self.launch_bounds_1d + elif len(gpu_map.params) == 2: + block_size = self.block_size_2d + launch_bounds = self.launch_bounds_2d + else: + block_size = self.block_size_3d + launch_bounds = self.launch_bounds_3d + gpu_map.gpu_block_size = block_size + if launch_bounds is not None: # Note: empty string has a meaning in DaCe + gpu_map.gpu_launch_bounds = launch_bounds + + +def gt_remove_trivial_gpu_maps( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Removes trivial maps that were created by the GPU transformation. + + The main problem is that a Tasklet outside of a Map cannot write into an + _array_ that is on GPU. `sdfg.apply_gpu_transformations()` will wrap such + Tasklets in a Map. The `GT4PyMoveTaskletIntoMap` pass, that runs before, + but only works if the tasklet is adjacent to a map. + + It first tries to promote them such that they can be fused in other non-trivial + maps, it will then also perform fusion on them, to reduce the number of kernel + calls. + + Args: + sdfg: The SDFG that we process. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + + # First we try to promote and fuse them with other non-trivial maps. + sdfg.apply_transformations_once_everywhere( + TrivialGPUMapElimination( + do_not_fuse=False, + only_gpu_maps=True, + ), + validate=False, + validate_all=False, + ) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) + + # Now we try to fuse them together, however, we restrict the fusion to trivial + # GPU map. + def restrict_to_trivial_gpu_maps( + self: gtx_transformations.MapFusion, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool, + ) -> bool: + for map_entry in [map_entry_1, map_entry_2]: + _map = map_entry.map + if len(_map.params) != 1: + return False + if _map.range[0][0] != _map.range[0][1]: + return False + if _map.schedule not in [ + dace.dtypes.ScheduleType.GPU_Device, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + return True + + sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + apply_fusion_callback=restrict_to_trivial_gpu_maps, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + apply_fusion_callback=restrict_to_trivial_gpu_maps, + ), + ], + validate=validate, + validate_all=validate_all, + ) + + return sdfg + + +@dace_properties.make_properties +class TrivialGPUMapElimination(dace_transformation.SingleStateTransformation): + """Eliminate certain kind of trivial GPU maps. + + A tasklet outside of map can not write to GPU memory, this can only be done + from within a map (a scalar is possible). For that reason DaCe's GPU + transformation wraps such tasklets in trivial maps. + Under certain condition the transformation will fuse the trivial tasklet with + a downstream (serial) map. + + Args: + do_not_fuse: If `True` then the maps are not fused together. + only_gpu_maps: Only apply to GPU maps; `True` by default. + + Note: + - This transformation should not be run on its own, instead it + is run within the context of `gt_gpu_transformation()`. + - This transformation must be run after the GPU Transformation. + """ + + only_gpu_maps = dace_properties.Property( + dtype=bool, + default=True, + desc="Only promote maps that are GPU maps (debug option).", + ) + do_not_fuse = dace_properties.Property( + dtype=bool, + default=False, + desc="Only perform the promotion, do not fuse.", + ) + + # Pattern Matching + trivial_map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + second_map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + do_not_fuse: Optional[bool] = None, + only_gpu_maps: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if only_gpu_maps is not None: + self.only_gpu_maps = only_gpu_maps + if do_not_fuse is not None: + self.do_not_fuse = do_not_fuse + + @classmethod + def expressions(cls) -> Any: + return [ + dace.sdfg.utils.node_path_graph( + cls.trivial_map_exit, cls.access_node, cls.second_map_entry + ) + ] + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the promotion is possible. + + The tests includes: + - Schedule of the maps. + - If the map is trivial. + - Tests if the maps can be fused. + """ + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit + trivial_map: dace_nodes.Map = trivial_map_exit.map + trivial_map_entry: dace_nodes.MapEntry = graph.entry_node(trivial_map_exit) + second_map: dace_nodes.Map = self.second_map_entry.map + + # The kind of maps we are interested only have one parameter. + if len(trivial_map.params) != 1: + return False + for rng in trivial_map.range.ranges: + if rng[0] != rng[1]: + return False + + # If we do not not fuse, then the second map can not be trivial. + # If we would not prevent that case then we would match these two + # maps again and again. + if self.do_not_fuse and len(second_map.params) <= 1: + for rng in second_map.range.ranges: + if rng[0] == rng[1]: + return False + + # We now check that the Memlets do not depend on the map parameter. + # This is important for the `can_be_applied_to()` check we do below + # because we can avoid calling the replace function. + scope = graph.scope_subgraph(trivial_map_entry) + trivial_map_param: str = trivial_map.params[0] + for edge in scope.edges(): + if trivial_map_param in edge.data.free_symbols: + return False + + # Check if only GPU maps are involved (this is more a testing debug feature). + if self.only_gpu_maps: + for map_to_check in [trivial_map, second_map]: + if map_to_check.schedule not in [ + dace.dtypes.ScheduleType.GPU_Device, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + + # Now we check if the two maps can be fused together. For that we have to + # do a temporary promotion, it is important that we do not perform the + # renaming. If the old symbol is still used, it is used inside a tasklet + # so it would show up (temporarily) as free symbol. + org_trivial_map_params = copy.deepcopy(trivial_map.params) + org_trivial_map_range = copy.deepcopy(trivial_map.range) + try: + self._promote_map(graph, replace_trivail_map_parameter=False) + if not gtx_transformations.MapFusionSerial.can_be_applied_to( + sdfg=sdfg, + first_map_exit=trivial_map_exit, + array=self.access_node, + second_map_entry=self.second_map_entry, + ): + return False + finally: + trivial_map.params = org_trivial_map_params + trivial_map.range = org_trivial_map_range + + return True + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the Map Promoting. + + The function will first perform the promotion of the trivial map and then + perform the merging of the two maps in one go. + """ + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit + second_map_entry: dace_nodes.MapEntry = self.second_map_entry + access_node: dace_nodes.AccessNode = self.access_node + + # Promote the maps. + self._promote_map(graph) + + # Perform the fusing if requested. + if not self.do_not_fuse: + gtx_transformations.MapFusionSerial.apply_to( + sdfg=sdfg, + first_map_exit=trivial_map_exit, + array=access_node, + second_map_entry=second_map_entry, + verify=True, + ) + + def _promote_map( + self, + state: dace.SDFGState, + replace_trivail_map_parameter: bool = True, + ) -> None: + """Performs the map promoting. + + Essentially this function will copy the parameters and the range from + the non trivial map (`self.second_map_entry.map`) to the trivial map + (`self.trivial_map_exit.map`). + + If `replace_trivail_map_parameter` is `True` (the default value), then the + function will also remove the trivial map parameter with its value. + """ + assert isinstance(self.trivial_map_exit, dace_nodes.MapExit) + assert isinstance(self.second_map_entry, dace_nodes.MapEntry) + assert isinstance(self.access_node, dace_nodes.AccessNode) + + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit + trivial_map: dace_nodes.Map = self.trivial_map_exit.map + trivial_map_entry: dace_nodes.MapEntry = state.entry_node(trivial_map_exit) + second_map: dace_nodes.Map = self.second_map_entry.map + + # If requested then replace the map variable with its value. + if replace_trivail_map_parameter: + scope = state.scope_subgraph(trivial_map_entry) + scope.replace(trivial_map.params[0], trivial_map.range[0][0]) + + # Now copy parameter and the ranges from the second to the trivial map. + trivial_map.params = copy.deepcopy(second_map.params) + trivial_map.range = copy.deepcopy(second_map.range) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py b/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py new file mode 100644 index 0000000000..5201748e12 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py @@ -0,0 +1,391 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +import copy + +import dace +from dace import ( + data as dace_data, + dtypes as dace_dtypes, + symbolic as dace_symbolic, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + + +def gt_create_local_double_buffering( + sdfg: dace.SDFG, +) -> int: + """Modifies the SDFG such that point wise data dependencies are stable. + + Rule 3 of the ADR18, guarantees that if data is input and output to a map, + then it must be a non transient array and it must only have point wise + dependency. This means that every index that is read is also written by + the same thread and no other thread reads or writes to the same location. + However, because the dataflow inside a map is partially asynchronous + it might happen if something is read multiple times, i.e. Tasklets, + the data might already be overwritten. + This function will scan the SDFG for potential cases and insert an + access node to cache this read. This is essentially a double buffer, but + it is not needed that the whole data is stored, but only the working set + of a single thread. + """ + processed_maps = 0 + for nsdfg in sdfg.all_sdfgs_recursive(): + processed_maps += _create_local_double_buffering_non_recursive(nsdfg) + return processed_maps + + +def _create_local_double_buffering_non_recursive( + sdfg: dace.SDFG, +) -> int: + """Implementation of the point wise transformation. + + This function does not handle nested SDFGs. + """ + # First we call `EdgeConsolidation`, because of that we know that + # every incoming edge of a `MapEntry` refers to distinct data. + # We do this to simplify our implementation. + edge_consolidation = dace_transformation.passes.ConsolidateEdges() + edge_consolidation.apply_pass(sdfg, None) + + processed_maps = 0 + for state in sdfg.states(): + assert isinstance(state, dace.SDFGState) + scope_dict = state.scope_dict() + for node in state.nodes(): + if not isinstance(node, dace_nodes.MapEntry): + continue + if scope_dict[node] is not None: + continue + inout_nodes = _check_if_map_must_be_handled( + map_entry=node, + state=state, + sdfg=sdfg, + ) + if inout_nodes is not None: + processed_maps += _add_local_double_buffering_to( + map_entry=node, + inout_nodes=inout_nodes, + state=state, + sdfg=sdfg, + ) + return processed_maps + + +def _add_local_double_buffering_to( + inout_nodes: dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]], + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> int: + """Adds the double buffering to `map_entry` for `inout_nodes`. + + The function assumes that there is only in incoming edge per data + descriptor at the map entry. If the data is needed multiple times, + then the distribution must be done inside the map. + + The function will now channel all reads to the data descriptor + through an access node, this ensures that the read happens + before the write. + """ + processed_maps = 0 + for inout_node in inout_nodes.values(): + _add_local_double_buffering_to_single_data( + map_entry=map_entry, + inout_node=inout_node, + state=state, + sdfg=sdfg, + ) + processed_maps += 1 + return processed_maps + + +def _add_local_double_buffering_to_single_data( + inout_node: tuple[dace_nodes.AccessNode, dace_nodes.AccessNode], + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None: + """Adds the local double buffering for a single data.""" + map_exit: dace_nodes.MapExit = state.exit_node(map_entry) + input_node, output_node = inout_node + input_edges = state.edges_between(input_node, map_entry) + output_edges = state.edges_between(map_exit, output_node) + assert len(input_edges) == 1 + assert len(output_edges) == 1 + inner_read_edges = _get_inner_edges(input_edges[0], map_entry, state, False) + inner_write_edges = _get_inner_edges(output_edges[0], map_exit, state, True) + + # For now we assume that all read the same, which is checked below. + new_double_inner_buff_shape_raw = dace_symbolic.overapproximate( + inner_read_edges[0].data.get_src_subset(inner_read_edges[0], state).size() + ) + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + squeezed_dims: list[int] = [] # These are the dimensions we removed. + new_double_inner_buff_shape: list[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_double_inner_buff_shape_raw, input_node.desc(sdfg).shape) + ): + if full_dim_size == 1: # Must be kept! + new_double_inner_buff_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_double_inner_buff_shape.append(proposed_dim_size) + + new_double_inner_buff_name: str = f"__inner_double_buffer_for_{input_node.data}" + # Now generate the intermediate data container. + if len(new_double_inner_buff_shape) == 0: + new_double_inner_buff_name, new_double_inner_buff_desc = sdfg.add_scalar( + new_double_inner_buff_name, + dtype=input_node.desc(sdfg).dtype, + transient=True, + storage=dace_dtypes.StorageType.Register, + find_new_name=True, + ) + else: + new_double_inner_buff_name, new_double_inner_buff_desc = sdfg.add_transient( + new_double_inner_buff_name, + shape=new_double_inner_buff_shape, + dtype=input_node.desc(sdfg).dtype, + find_new_name=True, + storage=dace_dtypes.StorageType.Register, + ) + new_double_inner_buff_node = state.add_access(new_double_inner_buff_name) + + # Now reroute the data flow through the new access node. + for old_inner_read_edge in inner_read_edges: + # To do handle the case the memlet is "fancy" + state.add_edge( + new_double_inner_buff_node, + None, + old_inner_read_edge.dst, + old_inner_read_edge.dst_conn, + dace.Memlet( + data=new_double_inner_buff_name, + subset=dace.subsets.Range.from_array(new_double_inner_buff_desc), + other_subset=copy.deepcopy( + old_inner_read_edge.data.get_dst_subset(old_inner_read_edge, state) + ), + ), + ) + state.remove_edge(old_inner_read_edge) + + # Now create a connection between the map entry and the intermediate node. + state.add_edge( + map_entry, + inner_read_edges[0].src_conn, + new_double_inner_buff_node, + None, + dace.Memlet( + data=input_node.data, + subset=copy.deepcopy( + inner_read_edges[0].data.get_src_subset(inner_read_edges[0], state) + ), + other_subset=dace.subsets.Range.from_array(new_double_inner_buff_desc), + ), + ) + + # To really ensure that a read happens before a write, we have to sequence + # the read first. We do this by connecting the double buffer node with + # empty Memlets to the last row of nodes that writes to the global buffer. + # This is needed to handle the case that some other data path performs the + # write. + # TODO(phimuell): Add a test that only performs this when it is really needed. + for inner_write_edge in inner_write_edges: + state.add_nedge( + new_double_inner_buff_node, + inner_write_edge.src, + dace.Memlet(), + ) + + +def _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node: dace_nodes.AccessNode, + sdfg: dace.SDFG, + known_nodes: dict[str, dace_nodes.AccessNode], +) -> bool: + """Internal function used by `_check_if_map_must_be_handled()` to classify nodes. + + If the function returns `True` it means that the input/output, does not + violates an internal constraint, i.e. can be handled by + `_ensure_that_map_is_pointwise()`. If appropriate the function will add the + node to `known_nodes`. I.e. in case of a transient the function will return + `True` but will not add it to `known_nodes`. + """ + + # This case is indicating that the `ConsolidateEdges` has not fully worked. + # Currently the transformation implementation assumes that this is the + # case, so we can not handle this case. + # TODO(phimuell): Implement this case. + if data_node.data in known_nodes: + return False + data_desc: dace_data.Data = data_node.desc(sdfg) + + # The conflict can only occur for global data, because transients + # are only written once. + if data_desc.transient: + return False + + # Currently we do not handle view, as they need to be traced. + # TODO(phimuell): Implement + if gtx_transformations.utils.is_view(data_desc, sdfg): + return False + + # TODO(phimuell): Check if there is a access node on the inner side, then we do not have to do it. + + # Now add the node to the list. + assert all(data_node is not known_node for known_node in known_nodes.values()) + known_nodes[data_node.data] = data_node + return True + + +def _get_inner_edges( + outer_edge: dace.sdfg.graph.MultiConnectorEdge, + scope_node: dace_nodes.MapExit | dace_nodes.MapEntry, + state: dace.SDFG, + outgoing_edge: bool, +) -> list[dace.sdfg.graph.MultiConnectorEdge]: + """Gets the edges on the inside of a map.""" + if outgoing_edge: + assert isinstance(scope_node, dace_nodes.MapExit) + conn_name = outer_edge.src_conn[4:] + return list(state.in_edges_by_connector(scope_node, connector="IN_" + conn_name)) + else: + assert isinstance(scope_node, dace_nodes.MapEntry) + conn_name = outer_edge.dst_conn[3:] + return list(state.out_edges_by_connector(scope_node, connector="OUT_" + conn_name)) + + +def _check_if_map_must_be_handled( + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None | dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]]: + """Check if the map should be processed to uphold rule 3. + + Essentially the function will check if there is a potential read-write + conflict. The function assumes that `ConsolidateEdges` has already run. + + If there is a possible data race the function will return a `dict`, that + maps the name of the data to the access nodes that are used as input and + output to the Map. + + Otherwise, the function returns `None`. It is, however, important that + `None` does not means that there is no possible race condition. It could + also means that the function that implements the buffering, i.e. + `_ensure_that_map_is_pointwise()`, is unable to handle this case. + + Todo: + Improve the function + """ + map_exit: dace_nodes.MapExit = state.exit_node(map_entry) + + # Find all the data that is accessed. Views are resolved. + input_datas: dict[str, dace_nodes.AccessNode] = {} + output_datas: dict[str, dace_nodes.AccessNode] = {} + + # Determine which nodes are possible conflicting. + for in_edge in state.in_edges(map_entry): + if in_edge.data.is_empty(): + continue + if not isinstance(in_edge.src, dace_nodes.AccessNode): + # TODO(phiumuell): Figuring out what this case means + continue + if in_edge.dst_conn and not in_edge.dst_conn.startswith("IN_"): + # TODO(phimuell): It is very unlikely that a Dynamic Map Range causes + # this particular data race, so we ignore it for the time being. + continue + if not _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node=in_edge.src, + sdfg=sdfg, + known_nodes=input_datas, + ): + continue + for out_edge in state.out_edges(map_exit): + if out_edge.data.is_empty(): + continue + if not isinstance(out_edge.dst, dace_nodes.AccessNode): + # TODO(phiumuell): Figuring out what this case means + continue + if not _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node=out_edge.dst, + sdfg=sdfg, + known_nodes=output_datas, + ): + continue + + # Double buffering is only needed if there inout arguments. + inout_datas: dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]] = { + dname: (input_datas[dname], output_datas[dname]) + for dname in input_datas + if dname in output_datas + } + if len(inout_datas) == 0: + return None + + # TODO(phimuell): What about the case that some data descriptor needs double + # buffering, but some do not? + for inout_data_name in list(inout_datas.keys()): + input_node, output_node = inout_datas[inout_data_name] + input_edges = state.edges_between(input_node, map_entry) + output_edges = state.edges_between(map_exit, output_node) + assert ( + len(input_edges) == 1 + ), f"Expected a single connection between input node and map entry, but found {len(input_edges)}." + assert ( + len(output_edges) == 1 + ), f"Expected a single connection between map exit and write back node, but found {len(output_edges)}." + + # If there is only one edge on the inside of the map, that goes into an + # AccessNode, then we assume it is double buffered. + inner_read_edges = _get_inner_edges(input_edges[0], map_entry, state, False) + if ( + len(inner_read_edges) == 1 + and isinstance(inner_read_edges[0].dst, dace_nodes.AccessNode) + and not gtx_transformations.utils.is_view(inner_read_edges[0].dst, sdfg) + ): + inout_datas.pop(inout_data_name) + continue + + inner_read_subsets = [ + inner_read_edge.data.get_src_subset(inner_read_edge, state) + for inner_read_edge in inner_read_edges + ] + assert all(inner_read_subset is not None for inner_read_subset in inner_read_subsets) + inner_write_subsets = [ + inner_write_edge.data.get_dst_subset(inner_write_edge, state) + for inner_write_edge in _get_inner_edges(output_edges[0], map_exit, state, True) + ] + # TODO(phimuell): Also implement a check that the volume equals the size of the subset. + assert all(inner_write_subset is not None for inner_write_subset in inner_write_subsets) + + # For being point wise the subsets must be compatible. The correct check would be: + # - The write sets are unique. + # - For every read subset there exists one matching write subset. It could + # be that there are many equivalent read subsets. + # - For every write subset there exists at least one matching read subset. + # The current implementation only checks if all are the same. + # TODO(phimuell): Implement the real check. + all_inner_subsets = inner_read_subsets + inner_write_subsets + if not all( + all_inner_subsets[0] == all_inner_subsets[i] for i in range(1, len(all_inner_subsets)) + ): + return None + + if len(inout_datas) == 0: + return None + + return inout_datas diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py similarity index 64% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py rename to src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index 7acd997a0d..826b5949f2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -19,7 +19,7 @@ from dace.transformation import helpers as dace_helpers from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util +from gt4py.next.program_processors.runners.dace import gtir_sdfg_utils @dace_properties.make_properties @@ -36,12 +36,16 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): What makes this transformation different from simple blocking, is that the inner map will not just be inserted right after the outer Map. Instead the transformation will first identify all nodes that does not depend - on the blocking parameter `I` and relocate them between the outer and inner map. - Thus these operations will only be performed once, per inner loop. + on the blocking parameter `I`, called independent nodes and relocate them + between the outer and inner map. Note that an independent node must be connected + to the MapEntry or another independent node. + Thus these operations will only be performed once, per outer loop iteration. Args: blocking_size: The size of the block, denoted as `B` above. blocking_parameter: On which parameter should we block. + require_independent_nodes: If `True` only apply loop blocking if the Map + actually contains independent nodes. Defaults to `False`. Todo: - Modify the inner map such that it always starts at zero. @@ -59,37 +63,35 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): desc="Name of the iteration variable on which to block (must be an exact match);" " 'I' in the above description.", ) - independent_nodes = dace_properties.Property( - dtype=set, - allow_none=True, - default=None, - optional=True, - optional_condition=lambda _: False, - desc="Set of nodes that are independent of the blocking parameter.", - ) - dependent_nodes = dace_properties.Property( - dtype=set, - allow_none=True, - default=None, - optional=True, - optional_condition=lambda _: False, - desc="Set of nodes that are dependent on the blocking parameter.", + require_independent_nodes = dace_properties.Property( + dtype=bool, + default=False, + desc="If 'True' then blocking is only applied if there are independent nodes.", ) - outer_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + # Set of nodes that are independent of the blocking parameter. + _independent_nodes: Optional[set[dace_nodes.AccessNode]] + _dependent_nodes: Optional[set[dace_nodes.AccessNode]] + + outer_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, blocking_size: Optional[int] = None, blocking_parameter: Optional[Union[gtx_common.Dimension, str]] = None, + require_independent_nodes: Optional[bool] = None, ) -> None: super().__init__() if isinstance(blocking_parameter, gtx_common.Dimension): - blocking_parameter = gtx_dace_fieldview_util.get_map_variable(blocking_parameter) + blocking_parameter = gtir_sdfg_utils.get_map_variable(blocking_parameter) if blocking_parameter is not None: self.blocking_parameter = blocking_parameter if blocking_size is not None: self.blocking_size = blocking_size + if require_independent_nodes is not None: + self.require_independent_nodes = require_independent_nodes + self._independent_nodes = None + self._dependent_nodes = None @classmethod def expressions(cls) -> Any: @@ -129,6 +131,8 @@ def can_be_applied( return False if not self.partition_map_output(graph, sdfg): return False + self._independent_nodes = None + self._dependent_nodes = None return True @@ -141,7 +145,6 @@ def apply( Performs the operation described in the doc string. """ - # Now compute the partitions of the nodes. self.partition_map_output(graph, sdfg) @@ -157,10 +160,8 @@ def apply( state=graph, sdfg=sdfg, ) - - # Clear the old partitions - self.independent_nodes = None - self.dependent_nodes = None + self._independent_nodes = None + self._dependent_nodes = None def _prepare_inner_outer_maps( self, @@ -203,9 +204,9 @@ def _prepare_inner_outer_maps( inner_label = f"inner_{outer_map.label}" inner_range = { self.blocking_parameter: dace_subsets.Range.from_string( - f"({coarse_block_var} * {self.blocking_size} + {rng_start})" + f"(({rng_start}) + ({coarse_block_var}) * ({self.blocking_size}))" + ":" - + f"min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)" + + f"min(({rng_start}) + ({coarse_block_var} + 1) * ({self.blocking_size}), ({rng_stop}) + 1)" ) } inner_entry, inner_exit = state.add_map( @@ -218,7 +219,7 @@ def _prepare_inner_outer_maps( # Now we modify the properties of the outer map. coarse_block_range = dace_subsets.Range.from_string( - f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" + f"0:int_ceil((({rng_stop}) + 1) - ({rng_start}), ({self.blocking_size}))" ).ranges[0] outer_map.params[blocking_parameter_dim] = coarse_block_var outer_map.range[blocking_parameter_dim] = coarse_block_range @@ -262,6 +263,9 @@ def partition_map_output( member variables are updated. If the partition does not exists the function will return `False` and the respective member variables will be `None`. + The function will honor `self.require_independent_nodes`. Thus if no independent + nodes were found the function behaves as if the partition does not exist. + Args: state: The state on which we operate. sdfg: The SDFG in which we operate on. @@ -273,8 +277,8 @@ def partition_map_output( """ # Clear the previous partition. - self.independent_nodes = set() - self.dependent_nodes = None + self._independent_nodes = set() + self._dependent_nodes = None while True: # Find all the nodes that we have to classify in this iteration. @@ -283,9 +287,9 @@ def partition_map_output( nodes_to_classify: set[dace_nodes.Node] = { edge.dst for edge in state.out_edges(self.outer_entry) } - for independent_node in self.independent_nodes: + for independent_node in self._independent_nodes: nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) - nodes_to_classify.difference_update(self.independent_nodes) + nodes_to_classify.difference_update(self._independent_nodes) # Now classify each node found_new_independent_node = False @@ -298,7 +302,7 @@ def partition_map_output( # Check if the partition exists. if class_res is None: - self.independent_nodes = None + self._independent_nodes = None return False if class_res is True: found_new_independent_node = True @@ -307,12 +311,16 @@ def partition_map_output( if not found_new_independent_node: break + if self.require_independent_nodes and len(self._independent_nodes) == 0: + self._independent_nodes = None + return False + # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not dependent. - self.dependent_nodes = { + self._dependent_nodes = { edge.dst for edge in state.out_edges(self.outer_entry) - if edge.dst not in self.independent_nodes + if edge.dst not in self._independent_nodes } return True @@ -337,7 +345,7 @@ def _classify_node( Returns: The function returns `True` if `node_to_classify` is considered independent. - In this case the function will add the node to `self.independent_nodes`. + In this case the function will add the node to `self._independent_nodes`. If the function returns `False` the node was classified as a dependent node. The function will return `None` if the node can not be classified, in this case the partition does not exist. @@ -347,23 +355,68 @@ def _classify_node( state: The state containing the map. sdfg: The SDFG that is processed. """ + assert self._independent_nodes is not None # silence MyPy outer_entry: dace_nodes.MapEntry = self.outer_entry # for caching. + outer_exit: dace_nodes.MapExit = state.exit_node(outer_entry) + + # The node needs to have an input and output. + if state.in_degree(node_to_classify) == 0 or state.out_degree(node_to_classify) == 0: + return None # We are only able to handle certain kind of nodes, so screening them. if isinstance(node_to_classify, dace_nodes.Tasklet): if node_to_classify.side_effects: - # TODO(phimuell): Think of handling it. return None + + # A Tasklet must write to an AccessNode, because otherwise there would + # be nothing that could be used to cache anything. Furthermore, this + # AccessNode must be outside of the inner loop, i.e. be independent. + # TODO: Make this check stronger to ensure that there is always an + # AccessNode that is independent. + if not all( + isinstance(out_edge.dst, dace_nodes.AccessNode) + for out_edge in state.out_edges(node_to_classify) + if not out_edge.data.is_empty() + ): + return False + + # Test if the body of the Tasklet depends on the block variable. + if self.blocking_parameter in node_to_classify.free_symbols: + return False + + elif isinstance(node_to_classify, dace.nodes.NestedSDFG): + # Same check as for Tasklets applies to the outputs of a nested SDFG node + if not all( + isinstance(out_edge.dst, dace_nodes.AccessNode) + for out_edge in state.out_edges(node_to_classify) + if not out_edge.data.is_empty() + ): + return False + + # Additionally, test if the symbol mapping depends on the block variable. + for v in node_to_classify.symbol_mapping.values(): + if self.blocking_parameter in v.free_symbols: + return False + elif isinstance(node_to_classify, dace_nodes.AccessNode): # AccessNodes need to have some special properties. node_desc: dace.data.Data = node_to_classify.desc(sdfg) - if isinstance(node_desc, dace.data.View): # Views are forbidden. return None - if node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: - # The access node has to life fully within the scope. + + # The access node inside either has scope lifetime or is a scalar. + if isinstance(node_desc, dace.data.Scalar): + pass + elif node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: return None + + elif isinstance(node_to_classify, dace_nodes.MapEntry): + # We classify `MapEntries` as dependent nodes, we could now start + # looking if the whole map is independent, but it is currently an + # overkill. + return False + else: # Any other node type we can not handle, so the partition can not exist. # TODO(phimuell): Try to handle certain kind of library nodes. @@ -380,44 +433,17 @@ def _classify_node( # for these classification to make sense the partition has to exist in the # first place. - # Either all incoming edges of a node are empty or none of them. If it has - # empty edges, they are only allowed to come from the map entry. - found_empty_edges, found_nonempty_edges = False, False - for in_edge in in_edges: - if in_edge.data.is_empty(): - found_empty_edges = True - if in_edge.src is not outer_entry: - # TODO(phimuell): Lift this restriction. - return None - else: - found_nonempty_edges = True - - # Test if we found a mixture of empty and nonempty edges. - if found_empty_edges and found_nonempty_edges: - return None - assert ( - found_empty_edges or found_nonempty_edges - ), f"Node '{node_to_classify}' inside '{outer_entry}' without an input connection." - - # Requiring that all output Memlets are non empty implies, because we are - # inside a scope, that there exists an output. - if any(out_edge.data.is_empty() for out_edge in state.out_edges(node_to_classify)): - return None - - # Now we have ensured that the partition exists, thus we will now evaluate - # if the node is independent or dependent. - - # Test if the body of the Tasklet depends on the block variable. - if ( - isinstance(node_to_classify, dace_nodes.Tasklet) - and self.blocking_parameter in node_to_classify.free_symbols - ): - return False + # There are some very small requirements that we impose on the output edges. + for out_edge in state.out_edges(node_to_classify): + # We consider nodes that are directly connected to the outer map exit as + # dependent. This is an implementation detail to avoid some hard cases. + if out_edge.dst is outer_exit: + return False # Now we have to look at incoming edges individually. # We will inspect the subset of the Memlet to see if they depend on the # block variable. If this loop ends normally, then we classify the node - # as independent and the node is added to `independent_nodes`. + # as independent and the node is added to `_independent_nodes`. for in_edge in in_edges: memlet: dace.Memlet = in_edge.data src_subset: dace_subsets.Subset | None = memlet.src_subset @@ -440,11 +466,11 @@ def _classify_node( # The edge must either originate from `outer_entry` or from an independent # node if not it is dependent. - if not (in_edge.src is outer_entry or in_edge.src in self.independent_nodes): + if not (in_edge.src is outer_entry or in_edge.src in self._independent_nodes): return False # Loop ended normally, thus we classify the node as independent. - self.independent_nodes.add(node_to_classify) + self._independent_nodes.add(node_to_classify) return True def _rewire_map_scope( @@ -471,116 +497,138 @@ def _rewire_map_scope( state: The state of the map. sdfg: The SDFG we operate on. """ + assert self._independent_nodes is not None and self._dependent_nodes is not None # Contains the nodes that are already have been handled. relocated_nodes: set[dace_nodes.Node] = set() # We now handle all independent nodes, this means that all of their - # _output_ edges have to go through the new inner map and the Memlets need - # modifications, because of the block parameter. - for independent_node in self.independent_nodes: - for out_edge in state.out_edges(independent_node): + # _output_ edges have to go through the new inner map and the Memlets + # need modifications, because of the block parameter. + for independent_node in self._independent_nodes: + for out_edge in list(state.out_edges(independent_node)): edge_dst: dace_nodes.Node = out_edge.dst relocated_nodes.add(edge_dst) # If destination of this edge is also independent we do not need # to handle it, because that node will also be before the new # inner serial map. - if edge_dst in self.independent_nodes: + if edge_dst in self._independent_nodes: continue # Now split `out_edge` such that it passes through the new inner entry. # We do not need to modify the subsets, i.e. replacing the variable # on which we block, because the node is independent and the outgoing # new inner map entry iterate over the blocked variable. - new_map_conn = inner_entry.next_connector() - dace_helpers.redirect_edge( - state=state, - edge=out_edge, - new_dst=inner_entry, - new_dst_conn="IN_" + new_map_conn, + if out_edge.data.is_empty(): + # `out_edge` is an empty Memlet that ensures its source, which is + # independent, is sequenced before its destination, which is + # dependent. We now have to split it into two. + # TODO(phimuell): Can we remove this edge? Is the map enough to + # ensure proper sequencing? + new_in_conn = None + new_out_conn = None + new_memlet_outside = dace.Memlet() + + elif not isinstance(independent_node, dace_nodes.AccessNode): + # For syntactical reasons there must be an access node on the + # outside of the (inner) scope, that acts as cache. The + # classification and this preconditions on SDFG should ensure + # that, but there are a few super hard edge cases. + # TODO(phimuell): Add an intermediate here in this case + raise NotImplementedError() + + else: + # NOTE: This creates more connections that are ultimately + # necessary. However, figuring out which one to use and if + # it would be valid, is very complicated, so we don't do it. + new_map_conn = inner_entry.next_connector(try_name=out_edge.data.data) + new_in_conn = "IN_" + new_map_conn + new_out_conn = "OUT_" + new_map_conn + new_memlet_outside = dace.Memlet.from_array( + out_edge.data.data, sdfg.arrays[out_edge.data.data] + ) + inner_entry.add_in_connector(new_in_conn) + inner_entry.add_out_connector(new_out_conn) + + state.add_edge( + out_edge.src, + out_edge.src_conn, + inner_entry, + new_in_conn, + new_memlet_outside, ) - # TODO(phimuell): Check if there might be a subset error. state.add_edge( inner_entry, - "OUT_" + new_map_conn, + new_out_conn, out_edge.dst, out_edge.dst_conn, copy.deepcopy(out_edge.data), ) - inner_entry.add_in_connector("IN_" + new_map_conn) - inner_entry.add_out_connector("OUT_" + new_map_conn) + state.remove_edge(out_edge) # Now we handle the dependent nodes, they differ from the independent nodes - # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. - for dependent_node in self.dependent_nodes: + # in that they _after_ the new inner map entry. Thus, we have to modify + # their incoming edges. + for dependent_node in self._dependent_nodes: for in_edge in state.in_edges(dependent_node): edge_src: dace_nodes.Node = in_edge.src - # Since the independent nodes were already processed, and they process - # their output we have to check for this. We do this by checking if - # the source of the edge is the new inner map entry. + # The incoming edge of a dependent node (before any processing) either + # starts at: + # - The outer map. + # - An other dependent node. + # - An independent node. + # The last case was already handled by the loop above. if edge_src is inner_entry: + # Edge originated originally at an independent node, but was + # already handled by the loop above. assert dependent_node in relocated_nodes - continue - # A dependent node has at least one connection to the outer map entry. - # And these are the only connections that we must handle, since other - # connections come from independent nodes, and were already handled - # or are inner nodes. - if edge_src is not outer_entry: - continue - - # If we encounter an empty Memlet we just just attach it to the - # new inner map entry. Note the partition function ensures that - # either all edges are empty or non. - if in_edge.data.is_empty(): - assert ( - edge_src is outer_entry - ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." + elif edge_src is not outer_entry: + # Edge originated at an other dependent node. There is nothing + # that we have to do. + # NOTE: We can not test if `edge_src` is in `self._dependent_nodes` + # because it only contains the dependent nodes that are directly + # connected to the map entry. + assert edge_src not in self._independent_nodes + + elif in_edge.data.is_empty(): + # The dependent node has an empty Memlet to the other map. + # Since the inner map is sequenced after the outer map, + # we will simply reconnect the edge to the inner map. + # TODO(phimuell): Are there situations where this makes problems. dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) - continue - # Because of the definition of a dependent node and the processing - # order, their incoming edges either point to the outer map or - # are already handled. - assert ( - edge_src is outer_entry - ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." - edge_conn: str = in_edge.src_conn[4:] - - # Must be before the handling of the modification below - # Note that this will remove the original edge from the SDFG. - dace_helpers.redirect_edge( - state=state, - edge=in_edge, - new_src=inner_entry, - new_src_conn="OUT_" + edge_conn, - ) - - # In a valid SDFG only one edge can go into an input connector of a Map. - if "IN_" + edge_conn in inner_entry.in_connectors: - # We have found this edge multiple times already. - # To ensure that there is no error, we will create a new - # Memlet that reads the whole array. - piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) - data_name = piping_edge.data.data - piping_edge.data = dace.Memlet.from_array( - data_name, sdfg.arrays[data_name], piping_edge.data.wcr + elif edge_src is outer_entry: + # This dependent node originated at the outer map. Thus we have to + # split the edge, such that it now passes through the inner map. + new_map_conn = inner_entry.next_connector(try_name=in_edge.src_conn[4:]) + new_in_conn = "IN_" + new_map_conn + new_out_conn = "OUT_" + new_map_conn + new_memlet_inner = dace.Memlet.from_array( + in_edge.data.data, sdfg.arrays[in_edge.data.data] + ) + state.add_edge( + in_edge.src, + in_edge.src_conn, + inner_entry, + new_in_conn, + new_memlet_inner, ) - - else: - # This is the first time we found this connection. - # so we just create the edge. state.add_edge( - outer_entry, - "OUT_" + edge_conn, inner_entry, - "IN_" + edge_conn, + new_out_conn, + in_edge.dst, + in_edge.dst_conn, copy.deepcopy(in_edge.data), ) - inner_entry.add_in_connector("IN_" + edge_conn) - inner_entry.add_out_connector("OUT_" + edge_conn) + inner_entry.add_in_connector(new_in_conn) + inner_entry.add_out_connector(new_out_conn) + state.remove_edge(in_edge) + + else: + raise NotImplementedError("Unknown node configuration.") # In certain cases it might happen that we need to create an empty # Memlet between the outer map entry and the inner one. @@ -597,7 +645,7 @@ def _rewire_map_scope( # This is simple reconnecting, there would be possibilities for improvements # but we do not use them for now. for in_edge in state.in_edges(outer_exit): - edge_conn = in_edge.dst_conn[3:] + edge_conn = inner_exit.next_connector(in_edge.dst_conn[3:]) dace_helpers.redirect_edge( state=state, edge=in_edge, @@ -614,5 +662,9 @@ def _rewire_map_scope( inner_exit.add_in_connector("IN_" + edge_conn) inner_exit.add_out_connector("OUT_" + edge_conn) + # There is an invalid cache state in the SDFG, that makes the memlet + # propagation fail, to clear the cache we call the hash function. + # See: https://github.com/spcl/dace/issues/1703 + _ = sdfg.hash_sdfg() # TODO(phimuell): Use a less expensive method. dace.sdfg.propagation.propagate_memlets_state(sdfg, state) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py new file mode 100644 index 0000000000..0f1dabf0d2 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion.py @@ -0,0 +1,189 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +"""An interface between DaCe's MapFusion and the one of GT4Py.""" + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Any, Callable, Optional, TypeAlias, TypeVar, Union + +import dace +from dace import nodes as dace_nodes, properties as dace_properties + +from gt4py.next.program_processors.runners.dace.transformations import ( + map_fusion_dace as dace_map_fusion, +) + + +_MapFusionType = TypeVar("_MapFusionType", bound="dace_map_fusion.MapFusion") + +FusionTestCallback: TypeAlias = Callable[ + [_MapFusionType, dace_nodes.MapEntry, dace_nodes.MapEntry, dace.SDFGState, dace.SDFG, int], bool +] +"""Callback for the map fusion transformation to check if a fusion should be performed. + +The callback returns `True` if the fusion should be performed and `False` if it +should be rejected. See also the description of GT4Py's MapFusion transformation for +more information. + +The arguments are as follows: +- The transformation object that is active. +- The MapEntry node of the first map; exact meaning depends on if parallel or + serial map fusion is performed. +- The MapEntry node of the second map; exact meaning depends on if parallel or + serial map fusion is performed. +- The SDFGState that that contains the data flow. +- The SDFG that is processed. +- The expression index, see `expr_index` in `can_be_applied()` it is `0` for + serial map fusion and `1` for parallel map fusion. +""" + + +@dace_properties.make_properties +class MapFusion(dace_map_fusion.MapFusion): + """GT4Py's MapFusion transformation. + + It is a wrapper that adds some functionality to the transformation that is not + present in the DaCe version of this transformation. + There are three important differences when compared with DaCe's MapFusion: + - In DaCe strict data flow is enabled by default, in GT4Py it is disabled by default. + - In DaCe `MapFusion` only performs the fusion of serial maps by default. In GT4Py + `MapFusion` will also perform parallel map fusion by default. + - GT4Py accepts an additional argument `apply_fusion_callback`. This is a + function that is called by the transformation, at the _beginning_ of + `self.can_be_applied()`, i.e. before the transformation does any check if + the maps can be fused. If this function returns `False`, `self.can_be_applied()` + ends and returns `False`. In case the callback returns `True` the transformation + will perform the usual steps to check if the transformation can apply or not. + For the signature see `FusionTestCallback`. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: Strict dataflow mode should be used, it is disabled by default. + assume_always_shared: Assume that all intermediates are shared. + allow_serial_map_fusion: Allow serial map fusion, by default `True`. + allow_parallel_fusion: Allow to merge parallel maps, by default `True`. + only_if_common_ancestor: In parallel map fusion mode, only fuse if both maps + have a common direct ancestor. + apply_fusion_callback: The callback function that is used. + + Todo: + Investigate ways of how to remove this intermediate layer. The main reason + why we need it is the callback functionality, but it is not needed often + and in these cases it might be solved differently. + """ + + _apply_fusion_callback: Optional[FusionTestCallback] + + def __init__( + self, + strict_dataflow: bool = False, + allow_serial_map_fusion: bool = True, + allow_parallel_map_fusion: bool = True, + apply_fusion_callback: Optional[FusionTestCallback] = None, + **kwargs: Any, + ) -> None: + self._apply_fusion_callback = None + super().__init__( + strict_dataflow=strict_dataflow, + allow_serial_map_fusion=allow_serial_map_fusion, + allow_parallel_map_fusion=allow_parallel_map_fusion, + **kwargs, + ) + if apply_fusion_callback is not None: + self._apply_fusion_callback = apply_fusion_callback + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Performs basic checks if the maps can be fused. + + Args: + map_entry_1: The entry of the first (in serial case the top) map. + map_exit_2: The entry of the second (in serial case the bottom) map. + graph: The SDFGState in which the maps are located. + sdfg: The SDFG itself. + permissive: Currently unused. + """ + assert expr_index in [0, 1] + + # If the call back is given then proceed with it. + if self._apply_fusion_callback is not None: + if expr_index == 0: # Serial MapFusion. + first_map_entry: dace_nodes.MapEntry = graph.entry_node(self.first_map_exit) + second_map_entry: dace_nodes.MapEntry = self.second_map_entry + elif expr_index == 1: # Parallel MapFusion + first_map_entry = self.first_parallel_map_entry + second_map_entry = self.second_parallel_map_entry + else: + raise NotImplementedError(f"Not implemented expression: {expr_index}") + + # Apply the call back. + if not self._apply_fusion_callback( + self, + first_map_entry, + second_map_entry, + graph, + sdfg, + expr_index, + ): + return False + + # Now forward to the underlying implementation. + return super().can_be_applied( + graph=graph, + expr_index=expr_index, + sdfg=sdfg, + permissive=permissive, + ) + + +@dace_properties.make_properties +class MapFusionSerial(MapFusion): + """Wrapper around `MapFusion` that only supports serial map fusion. + + Note: + This class exists only for the transition period. + """ + + def __init__( + self, + **kwargs: Any, + ) -> None: + assert "allow_serial_map_fusion" not in kwargs + assert "allow_parallel_map_fusion" not in kwargs + super().__init__( + allow_serial_map_fusion=True, + allow_parallel_map_fusion=False, + **kwargs, + ) + + +@dace_properties.make_properties +class MapFusionParallel(MapFusion): + """Wrapper around `MapFusion` that only supports parallel map fusion. + + Note: + This class exists only for the transition period. + """ + + def __init__( + self, + **kwargs: Any, + ) -> None: + assert "allow_serial_map_fusion" not in kwargs + assert "allow_parallel_map_fusion" not in kwargs + super().__init__( + allow_serial_map_fusion=False, + allow_parallel_map_fusion=True, + **kwargs, + ) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py new file mode 100644 index 0000000000..67fbe4182d --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_dace.py @@ -0,0 +1,2100 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements Helper functionaliyies for map fusion + +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. +""" + +import copy +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import dace +from dace import data, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes, validation +from dace.transformation import helpers + + +@properties.make_properties +class MapFusion(transformation.SingleStateTransformation): + """Implements the MapFusion transformation. + + From a high level perspective it will remove the MapExit node of the first and the MapEntry node of + the second Map. It will then rewire and modify the Memlets such that the data flow bypasses the + intermediate node. For this a new intermediate node will be created, which is much smaller because + it has no longer to store the whole output of the first map, but only the data that is produced by + a single iteration of the first map. The transformation will then remove the old intermediate. + Thus by merging the two Maps together the transformation will reduce the memory footprint. It is + important that it is not always possible to fully remove the intermediate node. For example the + data might be used somewhere else. In this case the intermediate will become an output of the Map. + + An example would be the following: + ```python + for i in range(N): + T[i] = foo(A[i]) + for j in range(N): + B[j] = bar(T[i]) + ``` + which would be translated into: + ```python + for i in range(N): + temp: scalar = foo(A[i]) + B[i] = bar(temp) + ``` + + The checks that two Maps can be fused are quite involved, however, they essentially check: + * If the two Maps cover the same iteration space, essentially have the same start, stop and + iteration , see `find_parameter_remapping()`. + * Furthermore, they verify if the new fused Map did not introduce read write conflict, + essentially it tests if the data is pointwise, i.e. what is read is also written, + see `has_read_write_dependency()`. + * Then it will examine the intermediate data. This will essentially test if the data that + is needed by a single iteration of the second Map is produced by a single iteration of + the first Map, see `partition_first_outputs()`. + + By default `strict_dataflow` is enabled. In this mode the transformation is more conservative. + The main difference is, that it will not adjust the subsets of the intermediate, i.e. turning + an array with shape `(1, 1, 1, 1)` into a scalar. Furthermore, shared intermediates, see + `partition_first_outputs()` will only be created if the data is not referred downstream in + the dataflow. + + In order to determine if an intermediate can be removed or has to be kept, it is in general + necessary to scan the whole SDFG, which is the default behaviour. There are two ways to + speed this up. The first way is to set `assume_always_shared` to `True`. In this case the + transformation will not perform the scan, but assume that the data is shared, i.e. used + somewhere else. This might lead to dead data flow. + The second way is to use the transformation inside a pipeline, which includes the + `FindSingleUseData` analysis pass. If the result of this pass is present then the + transformation will use it instead to determine if a intermediate can be removed. + Note that `assume_always_shared` takes precedence. + For this pattern the `FullMapFusion` pass is provided, that combines the analysis + pass and `MapFusion`. + + By default this transformation only handles the case where to maps are right after each other, + separated by an intermediate array. However, by setting `allow_parallel_map_fusion` to `True`, + the transformation will be _in addition_ also be able to handle the case where the Maps are + parallel (parallel here means that neither of the two Map can be reached from the other; see + `is_parallel()`). If you only want to perform parallel map fusion you also have to set + `allow_serial_map_fusion` to `False`. + + :param only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + :param only_toplevel_maps: Only consider Maps that are at the top. + :param strict_dataflow: Which dataflow mode should be used, see above. + :param assume_always_shared: Assume that all intermediates are shared. + :param allow_serial_map_fusion: Allow serial map fusion, by default `True`. + :param allow_parallel_map_fusion: Allow to merge parallel maps, by default `False`. + :param only_if_common_ancestor: In parallel map fusion mode, only fuse if both map + have a common direct ancestor. + + :note: This transformation modifies more nodes than it matches. + :note: If `assume_always_shared` is `True` then the transformation will assume that + all intermediates are shared. This avoids the problems mentioned above with + the cache at the expense of the creation of dead dataflow. + """ + + # Pattern Nodes: For the serial map fusion + # NOTE: Can only be accessed in the `can_serial_map_fusion_be_applied()` and the + # `apply_serial_map_fusion()` functions. + first_map_exit = transformation.transformation.PatternNode(nodes.MapExit) + array = transformation.transformation.PatternNode(nodes.AccessNode) + second_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + # Pattern Nodes: For the parallel map fusion + # NOTE: Can only be used in the `can_parallel_map_fusion_be_applied()` and the + # `apply_map_fusion_parallel()` functions. + first_parallel_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + second_parallel_map_entry = transformation.transformation.PatternNode(nodes.MapEntry) + + # Settings + only_toplevel_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are in the top level.", + ) + only_inner_maps = properties.Property( + dtype=bool, + default=False, + desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", + ) + strict_dataflow = properties.Property( + dtype=bool, + default=True, + desc="If `True` then the transformation will ensure a more stricter data flow.", + ) + assume_always_shared = properties.Property( + dtype=bool, + default=False, + desc="If `True` then all intermediates will be classified as shared.", + ) + + allow_serial_map_fusion = properties.Property( + dtype=bool, + default=True, + desc="If `True`, the default, then allow serial map fusion.", + ) + + allow_parallel_map_fusion = properties.Property( + dtype=bool, + default=False, + desc="If `True` then also perform parallel map fusion, disabled by default.", + ) + only_if_common_ancestor = properties.Property( + dtype=bool, + default=False, + desc="If `True` restrict parallel map fusion to maps that have a direct common ancestor.", + ) + + def __init__( + self, + only_inner_maps: Optional[bool] = None, + only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + assume_always_shared: Optional[bool] = None, + allow_serial_map_fusion: Optional[bool] = None, + allow_parallel_map_fusion: Optional[bool] = None, + only_if_common_ancestor: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if only_toplevel_maps is not None: + self.only_toplevel_maps = only_toplevel_maps + if only_inner_maps is not None: + self.only_inner_maps = only_inner_maps + if strict_dataflow is not None: + self.strict_dataflow = strict_dataflow + if assume_always_shared is not None: + self.assume_always_shared = assume_always_shared + if allow_serial_map_fusion is not None: + self.allow_serial_map_fusion = allow_serial_map_fusion + if allow_parallel_map_fusion is not None: + self.allow_parallel_map_fusion = allow_parallel_map_fusion + if only_if_common_ancestor is not None: + self.only_if_common_ancestor = only_if_common_ancestor + + # See comment in `is_shared_data()` for more information. + self._single_use_data: Optional[Dict[dace.SDFG, Set[str]]] = None + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The function returns a list of two expressions. + + The first, index `0`, is used by the serial map fusion. It consists of the + exit node of the first map, `first_map_exit`, the intermediate array, `array`, + and the map entry node of the second map, `second_map_entry`. An important note + is, that the transformation operates not just on the matched nodes, but more + or less on anything that has an incoming connection from the first Map or an + outgoing connection to the second Map entry. + + The second expression, index `1`, is used by parallel map fusion. It matches + any two maps entries, `first_parallel_map_entry` and `second_parallel_map_entry + in a state. + """ + map_fusion_serial_match = dace.sdfg.utils.node_path_graph( + cls.first_map_exit, cls.array, cls.second_map_entry + ) + + map_fusion_parallel_match = graph.OrderedMultiDiConnectorGraph() + map_fusion_parallel_match.add_nodes_from( + [cls.first_parallel_map_entry, cls.second_parallel_map_entry] + ) + + return [map_fusion_serial_match, map_fusion_parallel_match] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Checks if the map fusion can be applied. + + Depending on the value of `expr_index` the function will dispatch the call + either to `can_serial_map_fusion_be_applied()` or + `can_parallel_map_fusion_be_applied()`, see there for more information. + """ + # Perform some checks of the deferred configuration data. + if not (self.allow_parallel_map_fusion or self.allow_serial_map_fusion): + raise ValueError("Disabled serial and parallel map fusion.") + assert expr_index == self.expr_index + assert self.expr_index in [0, 1], f"Found invalid 'expr_index' {self.expr_index}" + + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + + # Now perform the dispatch. + if self.allow_serial_map_fusion and expr_index == 0: + return self.can_serial_map_fusion_be_applied( + graph=graph, + sdfg=sdfg, + ) + + elif self.allow_parallel_map_fusion and expr_index == 1: + return self.can_parallel_map_fusion_be_applied( + graph=graph, + sdfg=sdfg, + ) + + # Non of the cases applied + return False + + def apply( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + """Apply the map fusion. + + Depending on the settings the function will either dispatch to + `apply_serial_map_fusion()` or to `apply_parallel_map_fusion()`. + """ + # Perform some checks of the deferred configuration data. + if not (self.allow_parallel_map_fusion or self.allow_serial_map_fusion): + raise ValueError("Disabled serial and parallel map fusion.") + assert self.expr_index in [0, 1] + + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1708](https://github.com/spcl/dace/issues/1703) + for edge in graph.edges(): + edge.data.try_initialize(sdfg, graph, edge) + + # Now perform the dispatch. + if self.expr_index == 0: + assert self.allow_serial_map_fusion + return self.apply_serial_map_fusion( + graph=graph, + sdfg=sdfg, + ) + + elif self.expr_index == 1: + assert self.allow_parallel_map_fusion + return self.apply_parallel_map_fusion( + graph=graph, + sdfg=sdfg, + ) + + else: + raise NotImplementedError(f"Encountered unknown expression index {self.expr_index}") + + def can_parallel_map_fusion_be_applied( + self, + graph: Union[SDFGState, SDFG], + sdfg: dace.SDFG, + ) -> bool: + """Check if the matched Maps can be fused in parallel.""" + + # NOTE: The after this point it is not legal to access the matched nodes + first_map_entry: nodes.MapEntry = self.first_parallel_map_entry + second_map_entry: nodes.MapEntry = self.second_parallel_map_entry + + assert self.expr_index == 1 + assert isinstance(first_map_entry, nodes.MapEntry) + assert isinstance(second_map_entry, nodes.MapEntry) + + # We will now check if the two maps are parallel. + if not self.is_parallel(graph=graph, node1=first_map_entry, node2=second_map_entry): + return False + + # Check the structural properties of the Maps. The function will return + # the `dict` that describes how the parameters must be renamed (for caching) + # or `None` if the maps can not be structurally fused. + param_repl = self.can_topologically_be_fused( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + graph=graph, + sdfg=sdfg, + ) + if param_repl is None: + return False + + # Test if they have they share a node as direct ancestor. + if self.only_if_common_ancestor: + # TODO(phimuell): Improve this such that different AccessNode that refer + # to the same data are also considered the same; Probably an overkill. + first_ancestors: Set[nodes.Node] = {e1.src for e1 in graph.in_edges(first_map_entry)} + if not any(e2.src in first_ancestors for e2 in graph.in_edges(second_map_entry)): + return False + + return True + + def can_serial_map_fusion_be_applied( + self, + graph: Union[SDFGState, SDFG], + sdfg: dace.SDFG, + ) -> bool: + """Tests if the matched Maps can be merged serially. + + The two Maps are mergeable iff: + * Checks general requirements, see `can_topologically_be_fused()`. + * Tests if there are read write dependencies. + * Tests if the decomposition exists. + """ + # NOTE: The after this point it is not legal to access the matched nodes + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + + assert self.expr_index == 0 + assert isinstance(first_map_exit, nodes.MapExit) + assert isinstance(second_map_entry, nodes.MapEntry) + assert isinstance(self.array, nodes.AccessNode) + + # Check the structural properties of the Maps. The function will return + # the `dict` that describes how the parameters must be renamed (for caching) + # or `None` if the maps can not be structurally fused. + param_repl = self.can_topologically_be_fused( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + graph=graph, + sdfg=sdfg, + ) + if param_repl is None: + return False + + # Tests if there are read write dependencies that are caused by the bodies + # of the Maps, such as referring to the same data. Note that this tests are + # different from the ones performed by `has_read_write_dependency()`, which + # only checks the data dependencies that go through the scope nodes. + if self.has_inner_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + state=graph, + sdfg=sdfg, + ): + return False + + # Tests for read write conflicts of the two maps, this is only checking + # the data that goes through the scope nodes. `has_inner_read_write_dependency()` + # if used to check if there are internal dependencies. + if self.has_read_write_dependency( + first_map_entry=first_map_entry, + second_map_entry=second_map_entry, + param_repl=param_repl, + state=graph, + sdfg=sdfg, + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=param_repl, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + + return True + + def apply_parallel_map_fusion( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + """Performs parallel map fusion. + + Essentially this function will move all input connectors from one map, + i.e. its MapEntry and MapExit nodes, to the other map. + """ + + # NOTE: The after this point it is not legal to access the matched nodes + first_map_entry: nodes.MapEntry = self.first_parallel_map_entry + first_map_exit: nodes.MapExit = graph.exit_node(first_map_entry) + second_map_entry: nodes.MapEntry = self.second_parallel_map_entry + second_map_exit: nodes.MapExit = graph.exit_node(second_map_entry) + + # Before we do anything we perform the renaming, i.e. we will rename the + # parameters of the second map such that they match the one of the first map. + self.rename_map_parameters( + first_map=first_map_entry.map, + second_map=second_map_entry.map, + second_map_entry=second_map_entry, + state=graph, + ) + + # Now we relocate all connectors from the second to the first map and remove + # the respective node of the second map. + for to_node, from_node in [ + (first_map_entry, second_map_entry), + (first_map_exit, second_map_exit), + ]: + self.relocate_nodes( + from_node=from_node, + to_node=to_node, + state=graph, + sdfg=sdfg, + ) + # The relocate function does not remove the node, so we must do it. + graph.remove_node(from_node) + + def apply_serial_map_fusion( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + :param graph: The SDFG state we are operating on. + :param sdfg: The SDFG we are operating on. + """ + assert self.expr_index == 0 + + # NOTE: The after this point it is not legal to access the matched nodes + first_map_exit: nodes.MapExit = self.first_map_exit + second_map_entry: nodes.MapEntry = self.second_map_entry + second_map_exit: nodes.MapExit = graph.exit_node(self.second_map_entry) + first_map_entry: nodes.MapEntry = graph.entry_node(self.first_map_exit) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=first_map_exit.map, + second_map=second_map_entry.map, + second_map_entry=second_map_entry, + state=graph, + ) + + # Now compute the partition. Because we have already renamed the parameters + # of the second Map, there is no need to perform any renaming, thus we can + # pass an empty `dict`. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + param_repl=dict(), + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + # Now perform the actual rewiring, we handle each partition separately. + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + first_map_exit=first_map_exit, + second_map_entry=second_map_entry, + second_map_exit=second_map_exit, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(first_map_exit)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=first_map_exit, + to_node=second_map_exit, + state=graph, + sdfg=sdfg, + ) + + # Now move the input of the second map, that has no connection to the first + # map, to the first map. This is needed because we will later delete the + # exit of the first map (which we have essentially handled above). Now + # we must handle the input of the second map (that has no connection to the + # first map) to the input of the first map. + self.relocate_nodes( + from_node=second_map_entry, + to_node=first_map_entry, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [first_map_exit, second_map_entry]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + second_map_exit.map = first_map_entry.map + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `first_map_exit` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + * Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + * Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + * Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + :return: If such a decomposition exists the function will return the three sets + mentioned above in the same order. In case the decomposition does not exist, + i.e. the maps can not be fused the function returns `None`. + + :param state: The in which the two maps are located. + :param sdfg: The full SDFG in whcih we operate. + :param first_map_exit: The exit node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Use this map to rename the parameter of the second Map, such + that they match the one of the first map. + """ + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(first_map_exit): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=second_map_entry, + ): + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + intermediate_desc: dace.data.Data = intermediate_node.desc(sdfg) + if self.is_view(intermediate_desc, sdfg): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) + ) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Melets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view( + producer_edge.src, sdfg + ): + return None + if producer_edge.data.dynamic: + # TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely. + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None + else: + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not second_map_entry: + if self.is_node_reachable_from( + graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry + ): + return None + continue + + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE1: The subset still uses the old iteration variables. + # NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets. + # This is different compared to the producer Memlet. The reason is + # because in a consumer the data is conditionally read, so the data + # has to exists anyway. + for inner_consumer_edge in state.out_edges_by_connector( + second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:] + ): + if inner_consumer_edge.data.src_subset is None: + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert ( + found_second_map + ), f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if param_repl: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace( + mapping=param_repl, replace_callback=consumer_subset.replace + ) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum( + producer_subset.covers(consumer_subset) for producer_subset in producer_subsets + ) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg): + # The intermediate data is used somewhere else, either in this or another state. + # NOTE: If the intermediate is shared, then we will turn it into a + # sink node attached to the combined map exit. Technically this + # should be enough, even if the same data appears again in the + # dataflow down streams. However, some DaCe transformations, + # I am looking at you `auto_optimizer()` do not like that. Thus + # if the intermediate is used further down in the same datadflow, + # then we consider that the maps can not be fused. But we only + # do this in the strict data flow mode. + if self.strict_dataflow: + if self._is_data_accessed_downstream( + data=intermediate_node.data, + graph=state, + begin=intermediate_node, # is ignored itself. + ): + return None + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) + + assert len(processed_inter_nodes) == sum( + len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] + ) + return (pure_outputs, exclusive_outputs, shared_outputs) + + def relocate_nodes( + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + ) -> None: + """Move the connectors and edges from `from_node` to `to_nodes` node. + + This function will only rewire the edges, it does not remove the nodes + themselves. Furthermore, this function should be called twice per Map, + once for the entry and then for the exit. + While it does not remove the node themselves if guarantees that the + `from_node` has degree zero. + The function assumes that the parameter renaming was already done. + + :param from_node: Node from which the edges should be removed. + :param to_node: Node to which the edges should reconnect. + :param state: The state in which the operation happens. + :param sdfg: The SDFG that is modified. + """ + + # Now we relocate empty Memlets, from the `from_node` to the `to_node` + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) + + # We now ensure that there is only one empty Memlet from the `to_node` to any other node. + # Although it is allowed, we try to prevent it. + empty_targets: Set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): + if empty_edge.dst in empty_targets: + state.remove_edge(empty_edge) + empty_targets.add(empty_edge.dst) + + # We now determine which edges we have to migrate, for this we are looking at + # the incoming edges, because this allows us also to detect dynamic map ranges. + # TODO(phimuell): If there is already a connection to the node, reuse this. + for edge_to_move in list(state.in_edges(from_node)): + assert isinstance(edge_to_move.dst_conn, str) + + if not edge_to_move.dst_conn.startswith("IN_"): + # Dynamic Map Range + # The connector name simply defines a variable name that is used, + # inside the Map scope to define a variable. We handle it directly. + dmr_symbol = edge_to_move.dst_conn + + # TODO(phimuell): Check if the symbol is really unused in the target scope. + if dmr_symbol in to_node.in_connectors: + raise NotImplementedError( + f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" + f" to '{to_node}', but the symbol is already known there, but the" + " renaming is not implemented." + ) + if not to_node.add_in_connector(dmr_symbol, force=False): + raise RuntimeError( # Might fail because of out connectors. + f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." + ) + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + from_node.remove_in_connector(dmr_symbol) + + else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) + + # Check if we succeeded. + if state.out_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + if state.in_degree(from_node) != 0: + raise validation.InvalidSDFGError( + f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", + sdfg, + sdfg.node_id(state), + ) + assert len(from_node.in_connectors) == 0 + assert len(from_node.out_connectors) == 0 + + def handle_intermediate_set( + self, + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + first_map_exit: nodes.MapExit, + second_map_entry: nodes.MapEntry, + second_map_exit: nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. + + :param intermediate_outputs: The set of outputs, that should be processed. + :param state: The state in which the map is processed. + :param sdfg: The SDFG that should be optimized. + :param first_map_exit: The exit of the first/top map. + :param second_map_entry: The entry of the second map. + :param second_map_exit: The exit of the second map. + :param is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + :note: Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + """ + + map_params = first_map_exit.map.params.copy() + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list( + state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:]) + ) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + + (new_inter_shape_raw, new_inter_shape, squeezed_dims) = ( + self.compute_reduced_intermediate( + producer_subset=pre_exit_edge.data.dst_subset, + inter_desc=inter_desc, + ) + ) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{self.state_id}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all( + x == 1 for x in new_inter_shape + ) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=None, + ) + + # Memlets have a lot of additional informations, to ensure that we get + # all of them, we have to do it this way. The main reason for this is + # to handle the case were the "Memlet reverse direction", i.e. `data` + # refers to the other end of the connection than before. + assert pre_exit_edge.data.dst_subset is not None + new_pre_exit_memlet_src_subset = copy.deepcopy(pre_exit_edge.data.src_subset) + new_pre_exit_memlet_dst_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We can update `{src, dst}_subset` only after we have inserted the + # edge, this is because the direction of the Memlet might change. + new_pre_exit_edge.data.src_subset = new_pre_exit_memlet_src_subset + new_pre_exit_edge.data.dst_subset = new_pre_exit_memlet_dst_subset + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( + include_self=False + ): + producer_edge = producer_tree.edge + + # In order to preserve the intrinsic direction of Memlets we only have to change + # the `.data` attribute of the producer Memlet if it refers to the old intermediate. + # If it refers to something different we keep it. Note that this case can only + # occur if the producer is an AccessNode. + if producer_edge.data.data == inter_name: + producer_edge.data.data = new_inter_name + + # Regardless of the intrinsic direction of the Memlet, the subset we care about + # is always `dst_subset`. + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(producer_offset, negative=True) + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: Set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == second_map_entry: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(second_map_entry, out_conn_name): + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) + + # Now create the memlet for the new consumer. To make sure that we get all attributes + # of the Memlet we make a deep copy of it. There is a tricky part here, we have to + # access `src_subset` however, this is only correctly set once it is put inside the + # SDFG. Furthermore, we have to make sure that the Memlet does not change its direction. + # i.e. that the association of `subset` and `other_subset` does not change. For this + # reason we only modify `.data` attribute of the Memlet if its name refers to the old + # intermediate. Furthermore, to play it safe, we only access the subset, `src_subset` + # after we have inserted it to the SDFG. + new_inner_memlet = copy.deepcopy(inner_edge.data) + if inner_edge.data.data == inter_name: + new_inner_memlet.data = new_inter_name + + # Now we replace the edge from the SDFG. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. + if is_scalar: + new_inner_memlet.subset = "0" + elif new_inner_memlet.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( + include_self=False + ): + consumer_edge = consumer_tree.edge + + # We only modify the data if the Memlet refers to the old intermediate data. + # We can not do this unconditionally, because it might change the intrinsic + # direction of a Memlet and then `src_subset` would at the next `try_initialize` + # be wrong. Note that this case only occurs if the destination is an AccessNode. + if consumer_edge.data.data == inter_name: + consumer_edge.data.data = new_inter_name + + # Now we have to adapt the subsets. + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.src_subset is not None: + # TODO(phimuell): Figuring out if `src_subset` is None is an error. + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. + for edge in list(state.in_edges_by_connector(second_map_entry, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + second_map_entry.remove_in_connector(in_conn_name) + second_map_entry.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + first_map_exit.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # TODO(phimuell): Lift this restriction + assert pre_exit_edge.data.data == inter_name + + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + first_map_exit.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_conn = second_map_exit.next_connector() + state.add_edge( + new_inter_node, + None, + second_map_exit, + "IN_" + new_pre_exit_conn, + final_pre_exit_memlet, + ) + state.add_edge( + second_map_exit, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + second_map_exit.add_in_connector("IN_" + new_pre_exit_conn) + second_map_exit.add_out_connector("OUT_" + new_pre_exit_conn) + + first_map_exit.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) + + def compute_reduced_intermediate( + self, + producer_subset: subsets.Range, + inter_desc: dace.data.Data, + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], List[int]]: + """Compute the size of the new (reduced) intermediate. + + `MapFusion` does not only fuses map, but, depending on the situation, also + eliminates intermediate arrays between the two maps. To transmit data between + the two maps a new, but much smaller intermediate is needed. + + :return: The function returns a tuple with three values with the following meaning: + * The raw shape of the reduced intermediate. + * The cleared shape of the reduced intermediate, essentially the raw shape + with all shape 1 dimensions removed. + * Which dimensions of the raw shape have been removed to get the cleared shape. + + :param producer_subset: The subset that was used to write into the intermediate. + :param inter_desc: The data descriptor for the intermediate. + """ + assert producer_subset is not None + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + new_inter_shape_raw = symbolic.overapproximate(producer_subset.size()) + inter_shape = inter_desc.shape + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + return (tuple(new_inter_shape_raw), tuple(new_inter_shape), squeezed_dims) + + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Union[subsets.Range, None], + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + This is the value that must be substracted from the memlets to adjust, i.e + (`memlet_to_adjust(correction, negative=True)`). If `producer_offset` is + `None` then the function computes the correction that should be applied to + the producer memlets, i.e. the memlets of the tree converging at + `intermediate_node`. If `producer_offset` is given, it should be the output + of the previous call to this function, with `producer_offset=None`. In this + case the function computes the correction for the consumer side, i.e. the + memlet tree that originates at `intermediate_desc`. + + :param original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + :param intermediate_desc: The original intermediate data descriptor. + :param map_params: The parameter of the final map. + :param producer_offset: The correction that was applied to the producer side. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + # If the intermediate was a scalar, then it will remain a scalar. + # Thus there is no correction that we must apply. + return subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError( + f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'." + ) + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, + ) + return final_offset + + def can_topologically_be_fused( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool = False, + ) -> Optional[Dict[str, str]]: + """Performs basic checks if the maps can be fused. + + This function only checks constrains that are common between serial and + parallel map fusion process, which includes: + * The scope of the maps. + * The scheduling of the maps. + * The map parameters. + + :return: If the maps can not be topologically fused the function returns `None`. + If they can be fused the function returns `dict` that describes parameter + replacement, see `find_parameter_remapping()` for more. + + :param first_map_entry: The entry of the first (in serial case the top) map. + :param second_map_exit: The entry of the second (in serial case the bottom) map. + :param graph: The SDFGState in which the maps are located. + :param sdfg: The SDFG itself. + :param permissive: Currently unused. + """ + if self.only_inner_maps and self.only_toplevel_maps: + raise ValueError( + "Only one of `only_inner_maps` and `only_toplevel_maps` is allowed per MapFusion instance." + ) + + # Ensure that both have the same schedule + if first_map_entry.map.schedule != second_map_entry.map.schedule: + return None + + # Fusing is only possible if the two entries are in the same scope. + scope = graph.scope_dict() + if scope[first_map_entry] != scope[second_map_entry]: + return None + elif self.only_inner_maps: + if scope[first_map_entry] is None: + return None + elif self.only_toplevel_maps: + if scope[first_map_entry] is not None: + return None + + # We will now check if we can rename the Map parameter of the second Map such that they + # match the one of the first Map. + param_repl = self.find_parameter_remapping( + first_map=first_map_entry.map, second_map=second_map_entry.map + ) + return param_repl + + def is_parallel( + self, + graph: SDFGState, + node1: nodes.Node, + node2: nodes.Node, + ) -> bool: + """Tests if `node1` and `node2` are parallel in the data flow graph. + + The function considers two nodes parallel in the data flow graph, if `node2` + can not be reached from `node1` and vice versa. + + :param graph: The state on which we operate. + :param node1: The first node to check. + :param node2: The second node to check. + """ + # In order to be parallel they must be in the same scope. + scope = graph.scope_dict() + if scope[node1] != scope[node2]: + return False + + # The `all_nodes_between()` function traverse the graph and returns `None` if + # `end` was not found. We have to call it twice, because we do not know + # which node is upstream if they are not parallel. + if self.is_node_reachable_from(graph=graph, begin=node1, end=node2): + return False + elif self.is_node_reachable_from(graph=graph, begin=node2, end=node1): + return False + return True + + def has_inner_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """This function tests if there are dependency inside the Maps. + + The function will scan and anaysize the body of the two Maps and look for + inconsistencies. To detect them the function will scan the body of the maps + and examine the all AccessNodes and apply the following rules: + * If an AccessNode refers to a View, it is ignored. Because the source is + either on the outside, in which case `has_read_write_dependency()` + takes care of it, or the data source is inside the Map body itself. + * An inconsistency is detected, if in each bodies there exists an AccessNode + that refer to the same data. + * An inconsistency is detected, if there exists an AccessNode that refers + to non transient data. This is an implementation detail and could be + lifted. + + Note that some of the restrictions of this function could be relaxed by + performing more analysis. + + :return: The function returns `True` if an inconsistency has been found. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_body = state.scope_subgraph(first_map_entry, False, False) + second_map_body = state.scope_subgraph(second_map_entry, False, False) + + # Find the data that is internally referenced. Because of the first rule above, + # we filter all views above. + first_map_body_data, second_map_body_data = [ + { + dnode.data + for dnode in map_body.nodes() + if isinstance(dnode, nodes.AccessNode) and not self.is_view(dnode, sdfg) + } + for map_body in [first_map_body, second_map_body] + ] + + # If there is data that is referenced in both, then we consider this as an error + # this is the second rule above. + if not first_map_body_data.isdisjoint(second_map_body_data): + return True + + # We consider it as a problem if any map refers to non-transient data. + # This is an implementation detail and could be dropped if we do further + # analysis. + if any( + not sdfg.arrays[data].transient + for data in first_map_body_data.union(second_map_body_data) + ): + return True + + return False + + def has_read_write_dependency( + self, + first_map_entry: nodes.MapEntry, + second_map_entry: nodes.MapEntry, + param_repl: Dict[str, str], + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps to be fused. + + The function checks three different things. + * The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets of the inputs of the MapEntry of the first and the + outputs of the MapExit node of the second map. + * The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as output of the + combined map. Note that it is allowed that an intermediate node is also + an input to the first map. + * In case an intermediate node, is also used as input node of the first map, + it is forbidden that the data is used as output of the second map, the + function will do additional checks. This is needed as the partition function + only checks the data consumption of the second map can be satisfied by the + data production of the first map, it ignores any potential reads made by + the first map's MapEntry. + + :return: `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` is returned. + + :param first_map_entry: The entry node of the first map. + :param second_map_entry: The entry node of the second map. + :param param_repl: Dict that describes how to rename the parameters of the second Map. + :param state: The state on which we operate. + :param sdfg: The SDFG on which we operate. + """ + first_map_exit: nodes.MapExit = state.exit_node(first_map_entry) + second_map_exit: nodes.MapExit = state.exit_node(second_map_entry) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [first_map_entry, first_map_exit, second_map_entry, second_map_exit]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) + access_sets.append({node.data: node for node in access_set}) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve them. + # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append( + { + self.track_view(node, state, sdfg).data + if self.is_view(node, sdfg) + else node.data + for node in unresolved_set.values() + } + ) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): + return False + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We do not allow that the first and second map each write to the same data. + # This essentially ensures that an intermediate can not be used as output of + # the second map at the same time. It is actually stronger as it does not + # take their role into account. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection( + read_map_2.values() + ) + + # If the number are different then a data is accessed through different + # AccessNodes. We could analyse this, but we will consider this as a data race. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. + # NOTE: This set is not required to be empty. It might look as this would + # create a data race, but it is save. The reason is because all data has + # to pass through the intermediate we create, this will separate the reads + # from the writes. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # If a data container is used as input and output then it can not be a view (simplicity) + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # A data container can not be used as output (of the second as well as the + # combined map) and as intermediate. If we would allow that the map would + # have two output nodes one the original one and the second is the created + # node that is created because the intermediate is shared. + # TODO(phimuell): Handle this case. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True + + # While it is forbidden that a data container, used as intermediate, is also + # used as output of the second map. It is allowed that the data container + # is used as intermediate and as input of the first map. The partition only + # checks that the data dependencies are mean, i.e. what is read by the second + # map is also computed (written to the intermediate) it does not take into + # account the first map's read to the data container. + # To make an example: The partition function will make sure that if the + # second map reads index `i` from the intermediate that the first map writes + # to that index. But it will not care if the first map reads (through its + # MapEntry) index `i + 1`. In order to be valid me must ensure that the first + # map's reads and writes to the intermediate are pointwise. + # Note that we only have to make this check if it is also an intermediate node. + # Because if it is not read by the second map it is not a problem as the node + # will end up as an pure output node anyway. + read_write_map_1 = set(read_map_1.keys()).intersection(write_map_1.keys()) + datas_to_inspect = read_write_map_1.intersection(exchange_names) + for data_to_inspect in datas_to_inspect: + # Now get all subsets of the data container that the first map reads + # from or writes to and check if they are pointwise. + all_subsets: List[subsets.Subset] = [] + all_subsets.extend( + self.find_subsets( + node=read_map_1[data_to_inspect], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + ) + ) + all_subsets.extend( + self.find_subsets( + node=write_map_1[data_to_inspect], + scope_node=first_map_exit, + state=state, + sdfg=sdfg, + param_repl=None, + ) + ) + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + del all_subsets + + # If there is no intersection between the input and output data, then we can + # we have nothing to check. + if len(fused_inout_data_names) == 0: + return False + + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: + all_subsets = [] + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self.find_subsets( + node=read_map_1[inout_data_name], + scope_node=first_map_entry, + state=state, + sdfg=sdfg, + param_repl=None, + ) + ) + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self.find_subsets( + node=write_map_2[inout_data_name], + scope_node=second_map_exit, + state=state, + sdfg=sdfg, + param_repl=param_repl, + ) + ) + # Now we can test if these subsets are point wise + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + del all_subsets + + # No read write dependency was found. + return False + + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: + """Point wise means that they are all the same. + + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + If the subsets originates from different maps, then they must have been + renamed. + + :param subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): + return False + else: + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True + + def is_shared_data( + self, + data: nodes.AccessNode, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Tests if `data` is shared data, i.e. it can not be removed from the SDFG. + + Depending on the situation, the function will not perform a scan of the whole SDFG: + 1) If `assume_always_shared` was set to `True`, the function will return `True` unconditionally. + 2) If `data` is non transient then the function will return `True`, as non transient data + must be reconstructed always. + 3) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge + it is classified as shared. + 2) If `FindSingleUseData` is in the pipeline it will be used and no scan will be performed. + 3) The function will perform a scan. + + :param data: The transient that should be checked. + :param state: The state in which the fusion is performed. + :param sdfg: The SDFG in which we want to perform the fusing. + + """ + # `assume_always_shared` takes precedence. + if self.assume_always_shared: + return True + + # If `data` is non transient then return `True` as the intermediate can not be removed. + if not data.desc(sdfg).transient: + return True + + # This means the data is consumed by multiple Maps, through the same AccessNode, in this state + # Note currently multiple incoming edges are not handled, but in the spirit of this function + # we consider such AccessNodes as shared, because we can not remove the intermediate. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + # NOTE: Actually, if this transformation is run through the `FullMapFusion` pass, it should + # read the results from `FindSingelUseData`, that was computed because it is a dependent + # pass through the `self._pipeline_results` which is set by the `SingleStateTransformation`. + # However, this member is only set during when `apply()` is called, but not during + # `can_be_applied()`, see [issue#1911](https://github.com/spcl/dace/issues/1911). + # Because, the whole goal of this separation of scanning and fusion was to make the + # transformation stateless, the member `_single_use_data` was introduced. If it is set + # then we use it otherwise we use the scanner. + # This value is set for example by the `FullMapFusion` pass. + # TODO(phimuell): Change this once the issue is resolved. + if self._single_use_data is not None: + assert ( + sdfg in self._single_use_data + ), f"`_single_use_data` was set, but does not contain information about the SDFG '{sdfg.name}'." + single_use_data: Set[str] = self._single_use_data[sdfg] + return data.data not in single_use_data + + # We have to perform the full scan of the SDFG. + return self._scan_sdfg_if_data_is_shared(data=data, state=state, sdfg=sdfg) + + def _scan_sdfg_if_data_is_shared( + self, + data: nodes.AccessNode, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Scans `sdfg` to determine if `data` is shared. + + Essentially, this function determine, if the intermediate AccessNode `data` is + can be removed or if it has to be restored as output of the Map. + A data descriptor is classified as shared if any of the following is true: + - `data` is non transient data. + - `data` has at most one incoming and/or outgoing edge. + - There are other AccessNodes beside `data` that refer to the same data. + - The data is accessed on an interstate edge. + + This function should not be called directly. Instead it is called indirectly + by `is_shared_data()` if there is no short cut. + + :param data: The AccessNode that should checked if it is shared. + :param sdfg: The SDFG for which the set of shared data should be computed. + """ + if not data.desc(sdfg).transient: + return True + + # See description in `is_shared_data()` for more. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + data_name: str = data.data + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode is data: + # We have found the `data` AccessNode, which we must ignore. + continue + if dnode.data == data_name: + # We found a different AccessNode that refers to the same data + # as `data`. Thus `data` is shared. + return True + + # Test if the data is referenced in the interstate edges. + for edge in sdfg.edges(): + if data_name in edge.data.free_symbols: + # The data is used in the inter state edges. So it is shared. + return True + + # Test if the data is referenced inside a control flow, such as a conditional + # block or loop condition. + for cfr in sdfg.all_control_flow_regions(): + if data_name in cfr.used_symbols(all_symbols=True, with_contents=False): + return True + + # The `data` is not used anywhere else, thus `data` is not shared. + return False + + def find_parameter_remapping( + self, first_map: nodes.Map, second_map: nodes.Map + ) -> Optional[Dict[str, str]]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at is _needed_, i.e. all parameter have the same name and range, + then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. + + :param first_map: The first map (these parameters will be replaced). + :param second_map: The second map, these parameters acts as source. + + :note: This function currently fails if the renaming is not unique. Consider the + case were the first map has the structure `for i, j in map[0:20, 0:20]` and it + writes `T[i, j]`, while the second map is equivalent to + `for l, k in map[0:20, 0:20]` which reads `T[l, k]`. For this case we have + the following valid remappings `{l: i, k: j}` and `{l: j, k: i}` but + only the first one allows to fuse the map. This is because if the second + one is used the second map will read `T[j, i]` which leads to a data + dependency that can not be satisfied. + To avoid this issue the renaming algorithm will process them in order, i.e. + assuming that the order of the parameters in the map matches. But this is + not perfect, the only way to really solve this is by trying possible + remappings. At least the algorithm used here is deterministic. + """ + + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) # noqa: E731 [lambda-assignment] + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) for param, rng in zip(first_params, first_map.range) + } + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) + } + + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and the parameters of the first map that are still free. + # That we use a `list` instead of a `set` is intentional, because it counter + # acts the issue that is described in the doc string. Using a list ensures + # that they indexes are matched in order. This assume that in real world + # code the order of the loop is not arbitrary but kind of matches. + unmapped_second_params: List[str] = list(second_params) + unused_first_params: List[str] = list(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we identify the parameters that already have the correct name. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[param] + second_rng = second_rngs[param] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.remove(param) + unused_first_params.remove(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[unmapped_second_param] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in list(unused_first_params): + candidate_rng = first_rngs[candidate_param] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + unused_first_params.remove(candidate_param) + break + else: + # We did not find a candidate, so the remapping does not exist + return None + + assert len(unused_first_params) == 0 + assert len(final_mapping) == len(unmapped_second_params) + return final_mapping + + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + :param first_map: The first map (these are the final parameter). + :param second_map: The second map, this map will be replaced. + :param second_map_entry: The entry node of the second map. + :param state: The SDFGState on which we operate. + """ + # Compute the replacement dict. + repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] # Guaranteed to be not `None`. + first_map=first_map, second_map=second_map + ) + + if repl_dict is None: + raise RuntimeError("The replacement does not exist") + if len(repl_dict) == 0: + return + + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) + + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) + + def is_node_reachable_from( + self, + graph: dace.SDFGState, + begin: nodes.Node, + end: nodes.Node, + ) -> bool: + """Test if the node `end` can be reached from `begin`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end` the function returns `True`. If the node is never found `False` is + returned. + + :param graph: The graph to operate on. + :param begin: The start of the DFS. + :param end: The node that should be located. + """ + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() + + while len(to_visit) > 0: + node: nodes.Node = to_visit.pop() + if node == end: + return True + elif node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) + + # We never found `end` + return False + + def _is_data_accessed_downstream( + self, + data: str, + graph: dace.SDFGState, + begin: nodes.Node, + ) -> bool: + """Tests if there is an AccessNode for `data` downstream of `begin`. + + Essentially, this function starts a DFS at `begin` and checks every + AccessNode that is reachable from it. If it finds such a node it will + check if it refers to `data` and if so, it will return `True`. + If no such node is found it will return `False`. + Note that the node `begin` will be ignored. + + :param data: The name of the data to look for. + :param graph: The graph to explore. + :param begin: The node to start exploration; The node itself is ignored. + """ + + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) + + # Dataflow graph is acyclic, so we do not need to keep a list of + # what we have visited. + to_visit: List[nodes.Node] = list(next_nodes(begin)) + while len(to_visit) > 0: + node = to_visit.pop() + if isinstance(node, nodes.AccessNode) and node.data == data: + return True + to_visit.extend(next_nodes(node)) + + return False + + def get_access_set( + self, + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. + + :param scope_node: The scope node that should be evaluated. + :param state: The state in which we operate. + """ + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) # noqa: E731 [lambda-assignment] + other_node = lambda e: e.src # noqa: E731 [lambda-assignment] + else: + get_edges = lambda node: state.out_edges(node) # noqa: E731 [lambda-assignment] + other_node = lambda e: e.dst # noqa: E731 [lambda-assignment] + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) + if isinstance(node, nodes.AccessNode) + } + + return access_set + + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + param_repl: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. + + :param node: The access node that should be examined. + :param scope_node: We are only interested in data that flows through this node. + :param state: The state in which we operate. + :param sdfg: The SDFG object. + :param param_repl: `dict` that describes the parameter renaming that should be + performed. Can be `None` to skip the processing. + """ + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset # noqa: E731 [lambda-assignment] + get_inner_edges = ( # noqa: E731 [lambda-assignment] + lambda e: state.out_edges_by_connector(scope_node, "OUT_" + e.dst_conn[3:]) + ) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset # noqa: E731 [lambda-assignment] + get_inner_edges = ( # noqa: E731 [lambda-assignment] + lambda e: state.in_edges_by_connector(scope_node, "IN_" + e.src_conn[4:]) + ) + + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, "Could not find any subsets." + assert not any(subset is None for subset in found_subsets) + + found_subsets = copy.deepcopy(found_subsets) + if param_repl: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(param_repl, subset.replace) + + return found_subsets + + def is_view( + self, + node: Union[nodes.AccessNode, data.Data], + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node if isinstance(node, data.Data) else node.desc(sdfg) + return isinstance(node_desc, data.View) + + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + :param view: The view that should be traced. + :param state: The state in which we operate. + :param sdfg: The SDFG on which we operate. + """ + + # Test if it is a view at all, if not return the passed node as source. + if not self.is_view(view, sdfg): + return view + + # This is the node that defines the view. + defining_node = dace.sdfg.utils.get_last_view_node(state, view) + assert isinstance(defining_node, nodes.AccessNode) + assert not self.is_view(defining_node, sdfg) + return defining_node diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py new file mode 100644 index 0000000000..23dcbf8ef7 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_orderer.py @@ -0,0 +1,201 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional, Sequence, Union + +import dace +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import nodes as dace_nodes + +from gt4py.next import common as gtx_common +from gt4py.next.program_processors.runners.dace import gtir_sdfg_utils + + +def gt_set_iteration_order( + sdfg: dace.SDFG, + leading_dim: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, + validate: bool = True, + validate_all: bool = False, +) -> Any: + """Set the iteration order of the Maps correctly. + + Modifies the order of the Map parameters such that `leading_dim` + is the fastest varying one, the order of the other dimensions in + a Map is unspecific. `leading_dim` should be the dimensions were + the stride is one. + + Args: + sdfg: The SDFG to process. + leading_dim: The leading dimensions. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + return sdfg.apply_transformations_once_everywhere( + MapIterationOrder( + leading_dims=leading_dim, + ), + validate=validate, + validate_all=validate_all, + ) + + +@dace_properties.make_properties +class MapIterationOrder(dace_transformation.SingleStateTransformation): + """Modify the order of the iteration variables. + + The iteration order, while irrelevant from an SDFG point of view, is highly + relevant in code and the fastest varying index ("inner most loop" in CPU or + "x block dimension" in GPU) should be associated with the stride 1 dimension + of the array. + This transformation will reorder the map indexes such that this is the case. + + While the place of the leading dimension is clearly defined, the order of the + other loop indexes, after this transformation is unspecified. + + The transformation accepts either a single dimension or a list of dimensions. + In case a list is passed this is interpreted as priorities. + Assuming we have the `leading_dim=[EdgeDim, VertexDim]`, then we have the + following: + - `Map[EdgeDim, KDim, VertexDim]` -> `Map[KDim, VertexDim, EdgeDim]`. + - `Map[VertexDim, KDim]` -> `Map[KDim, VertexDim]`. + - `Map[EdgeDim, KDim]` -> `Map[KDim, EdgeDim]`. + - `Map[CellDim, KDim]` -> `Map[CellDim, KDim]` (no modification). + + Args: + leading_dim: GT4Py dimensions that are associated with the dimension that is + supposed to have stride 1. If it is a list it is used as a ranking. + + Note: + The transformation does follow the rules outlines in + [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md) + especially rule 11, regarding the names. + + Todo: + - Extend that different dimensions can be specified to be leading + dimensions, with some priority mechanism. + - Maybe also process the parameters to bring them in a canonical order. + """ + + leading_dims = dace_properties.ListProperty( + element_type=str, + allow_none=True, + default=None, + desc="Dimensions that should become the leading dimension.", + ) + + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + leading_dims: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if isinstance(leading_dims, (gtx_common.Dimension, str)): + leading_dims = [leading_dims] + if isinstance(leading_dims, list): + self.leading_dims = [ + leading_dim + if isinstance(leading_dim, str) + else gtir_sdfg_utils.get_map_variable(leading_dim) + for leading_dim in leading_dims + ] + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_entry)] + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Test if the map can be reordered. + + Essentially the function checks if the selected dimension is inside the map, + and if so, if it is on the right place. + """ + if self.leading_dims is None: + return False + map_entry: dace_nodes.MapEntry = self.map_entry + map_params: Sequence[str] = map_entry.map.params + processed_dims: set[str] = set(self.leading_dims) + + if not any(map_param in processed_dims for map_param in map_params): + return False + if self.compute_map_param_order() is None: + return False + return True + + def apply( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + """Performs the actual parameter reordering. + + The function will make the map variable, that corresponds to + `self.leading_dim` the last map variable (this is given by the structure of + DaCe's code generator). + """ + map_object: dace_nodes.Map = self.map_entry.map + new_map_params_order: list[int] = self.compute_map_param_order() # type: ignore[assignment] # Guaranteed to be not `None`. + + def reorder(what: list[Any]) -> list[Any]: + assert isinstance(what, list) + return [what[new_pos] for new_pos in new_map_params_order] + + map_object.params = reorder(map_object.params) + map_object.range.ranges = reorder(map_object.range.ranges) + map_object.range.tile_sizes = reorder(map_object.range.tile_sizes) + + def compute_map_param_order(self) -> Optional[list[int]]: + """Computes the new iteration order of the matched map. + + The function returns a list, the value at index `i` indicates the old dimension + that should be put at the new location. If the order is already correct then + `None` is returned. + """ + map_entry: dace_nodes.MapEntry = self.map_entry + map_params: list[str] = map_entry.map.params + org_mapping: dict[str, int] = {map_param: i for i, map_param in enumerate(map_params)} + leading_dims: list[str] = self.leading_dims + + # We divide the map parameters into two groups, the one we care and the others. + map_params_to_order: set[str] = { + map_param for map_param in map_params if map_param in leading_dims + } + + # If there is nothing to order, then we are done. + if not map_params_to_order: + return None + + # We start with all parameters that we ignore/do not care about. + new_map_params: list[str] = [ + map_param for map_param in map_params if map_param not in leading_dims + ] + + # Because how code generation works the leading dimension must be the most + # left one. Because this is also `self.leading_dims[0]` we have to process + # then in reverse order. + for map_param_to_check in reversed(leading_dims): + if map_param_to_check in map_params_to_order: + new_map_params.append(map_param_to_check) + assert len(map_params) == len(new_map_params) + + if map_params == new_map_params: + return None + + return [org_mapping[new_map_param] for new_map_param in new_map_params] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py similarity index 91% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py rename to src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py index 19818fd3d1..14f5f56689 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_promoter.py @@ -17,9 +17,7 @@ ) from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations @dace_properties.make_properties @@ -299,9 +297,9 @@ class SerialMapPromoter(BaseMapPromoter): """ # Pattern Matching - exit_first_map = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - entry_second_map = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + exit_first_map = dace_transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + entry_second_map = dace_transformation.PatternNode(dace_nodes.MapEntry) @classmethod def expressions(cls) -> Any: @@ -346,17 +344,11 @@ def _test_if_promoted_maps_can_be_fused( ) -> bool: """This function checks if the promoted maps can be fused by map fusion. - This function assumes that `self.can_be_applied()` returned `True`. + This function assumes that `super().self.can_be_applied()` returned `True`. Args: state: The state in which we operate. sdfg: The SDFG we process. - - Note: - The current implementation uses a very hacky way to test this. - - Todo: - Find a better way of doing it. """ first_map_exit: dace_nodes.MapExit = self.exit_first_map access_node: dace_nodes.AccessNode = self.access_node @@ -373,23 +365,17 @@ def _test_if_promoted_maps_can_be_fused( # This will lead to a promotion of the map, this is needed that # Map fusion can actually inspect them. self.apply(graph=state, sdfg=sdfg) - - # Now create the map fusion object that we can then use to check if - # the fusion is possible or not. - serial_fuser = gtx_transformations.SerialMapFusion( - only_inner_maps=self.only_inner_maps, - only_toplevel_maps=self.only_toplevel_maps, - ) - candidate = { - type(serial_fuser).map_exit1: first_map_exit, - type(serial_fuser).access_node: access_node, - type(serial_fuser).map_entry2: second_map_entry, - } - state_id = sdfg.node_id(state) - serial_fuser.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) - - # Now use the serial fuser to see if fusion would succeed - if not serial_fuser.can_be_applied(state, 0, sdfg): + if not gtx_transformations.MapFusionSerial.can_be_applied_to( + sdfg=sdfg, + expr_index=0, + options={ + "only_inner_maps": self.only_inner_maps, + "only_toplevel_maps": self.only_toplevel_maps, + }, + first_map_exit=first_map_exit, + array=access_node, + second_map_entry=second_map_entry, + ): return False finally: diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py new file mode 100644 index 0000000000..8919d2bc0f --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/redundant_array_removers.py @@ -0,0 +1,1067 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional, Sequence + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_sbs, + symbolic as dace_sym, + transformation as dace_transformation, +) +from dace.sdfg import graph as dace_graph, nodes as dace_nodes +from dace.transformation import pass_pipeline as dace_ppl +from dace.transformation.passes import analysis as dace_analysis + +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + + +def gt_multi_state_global_self_copy_elimination( + sdfg: dace.SDFG, + validate: bool = False, +) -> Optional[dict[dace.SDFG, set[str]]]: + """Runs `MultiStateGlobalSelfCopyElimination` on the SDFG recursively. + + For the return value see `MultiStateGlobalSelfCopyElimination.apply_pass()`. + """ + pipeline = dace_ppl.Pipeline([gtx_transformations.MultiStateGlobalSelfCopyElimination()]) + res = pipeline.apply_pass(sdfg, {}) + + if validate: + sdfg.validate() + + if "MultiStateGlobalSelfCopyElimination" not in res: + return None + return res["MultiStateGlobalSelfCopyElimination"][sdfg] + + +def gt_remove_copy_chain( + sdfg: dace.SDFG, + validate: bool = False, + validate_all: bool = False, + single_use_data: Optional[dict[dace.SDFG, set[str]]] = None, +) -> Optional[int]: + """Applies the `CopyChainRemover` transformation to the SDFG. + + The transformation returns the number of removed data containers or `None` + if nothing was done. + + Args: + sdfg: The SDFG to process. + validate: Perform validation after the pass has run. + validate_all: Perform extensive validation. + single_use_data: Which data descriptors are used only once. + If not passed the function will run `FindSingleUseData`. + """ + + # To ensures that the `{src,dst}_subset` are properly set, run initialization. + # See [issue 1703](https://github.com/spcl/dace/issues/1703) + for state in sdfg.states(): + for edge in state.edges(): + edge.data.try_initialize(sdfg, state, edge) + + if single_use_data is None: + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + result: int = sdfg.apply_transformations_repeated( + CopyChainRemover(single_use_data=single_use_data), + validate=validate, + validate_all=validate_all, + ) + return result if result != 0 else None + + +@dace_properties.make_properties +class MultiStateGlobalSelfCopyElimination(dace_transformation.Pass): + """Removes self copying across different states. + + This transformation is very similar to `SingleStateGlobalSelfCopyElimination`, but + addresses a slightly different case. Assume we have the pattern `(G) -> (T)` + in one state, i.e. the global data `G` is copied into a transient. In another + state, we have the pattern `(T) -> (G)`, i.e. the data is written back. + + If the following conditions are satisfied, this transformation will remove all + writes to `G`: + - The only write access to `G` happens in the `(T) -> (G)` pattern. ADR-18 + guarantees, that if `G` is used as an input and output it must be pointwise. + Thus there is no weird shifting. + + If the only usage of `T` is to write into `G` then the transient `T` will be + removed. + + Note that this transformation does not consider the subsets of the writes from + `T` to `G` because ADR-18 guarantees to us, that _if_ `G` is a genuine input + and output, then the `G` read and write subsets have the exact same range. + If `G` is not an output then any mutating changes to `G` would be invalid. + + Todo: + - Implement the pattern `(G) -> (T) -> (G)` which is handled currently by + `SingleStateGlobalSelfCopyElimination`, see `_classify_candidate()` and + `_remove_writes_to_global()` for more. + - Make it more efficient such that the SDFG is not scanned multiple times. + """ + + def modifies(self) -> dace_ppl.Modifies: + return dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes + + def should_reapply(self, modified: dace_ppl.Modifies) -> bool: + return modified & (dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes) + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return { + dace_transformation.passes.FindAccessStates, + } + + def apply_pass( + self, sdfg: dace.SDFG, pipeline_results: dict[str, Any] + ) -> Optional[dict[dace.SDFG, set[str]]]: + """Applies the pass. + + The function will return a `dict` that contains for every SDFG, the name + of the processed data descriptors. If a name refers to a global memory, + then it means that all write backs, i.e. `(T) -> (G)` patterns, have + been removed for that `G`. If the name refers to a data descriptor that no + longer exists, then it means that the write `(G) -> (T)` was also eliminated. + Currently there is no possibility to identify which transient name belonged + to a global name. + """ + assert "FindAccessStates" in pipeline_results + + result: dict[dace.SDFG, set[str]] = dict() + for nsdfg in sdfg.all_sdfgs_recursive(): + single_level_res: set[str] = self._process_sdfg(nsdfg, pipeline_results) + if single_level_res: + result[nsdfg] = single_level_res + + return result if result else None + + def _process_sdfg( + self, + sdfg: dace.SDFG, + pipeline_results: dict[str, Any], + ) -> set[str]: + """Apply the pass to a single level of an SDFG, i.e. do not handle nested SDFG. + + The return value of this function is the same as for `apply_pass()`, but + only for the SDFG that was passed. + """ + t_mapping = self._find_candidates(sdfg, pipeline_results) + if len(t_mapping) == 0: + return set() + self._remove_writes_to_globals(sdfg, t_mapping, pipeline_results) + removed_transients = self._remove_transient_buffers_if_possible( + sdfg, t_mapping, pipeline_results + ) + + return removed_transients | t_mapping.keys() + + def _find_candidates( + self, + sdfg: dace.SDFG, + pipeline_results: dict[str, Any], + ) -> dict[str, set[str]]: + """The function searches for all candidates of that must be processed. + + The function returns a `dict` that maps the name of a global memory, `G` in + the above pattern, to the name of the buffer transient, `T` in the above + pattern. + """ + access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][ + sdfg.cfg_id + ] + global_data = [ + aname + for aname, desc in sdfg.arrays.items() + if not desc.transient + and isinstance(desc, dace_data.Array) + and not isinstance(desc, dace_data.View) + ] + + candidates: dict[str, set[str]] = dict() + for gname in global_data: + candidate_tnames = self._classify_candidate(sdfg, gname, access_states) + if candidate_tnames is not None: + assert len(candidate_tnames) > 0 + candidates[gname] = candidate_tnames + + return candidates + + def _classify_candidate( + self, + sdfg: dace.SDFG, + gname: str, + access_states: dict[str, set[dace.SDFGState]], + ) -> Optional[set[str]]: + """The function tests if the global data `gname` can be handled. + + It essentially checks all conditions above, which is that the global is + only written through transients that are fully defined by the data itself. + writes to it are through transients that are fully defined by the data + itself. + + The function returns `None` if `gname` can not be handled by the function. + If `gname` can be handled the function returns a set of all data descriptors + that act as distributed buffers. + """ + # The set of access nodes that reads from the global, i.e. `gname`, essentially + # the set of candidates of `T` defined through the way it is defined. + # And the same set, but this time defined through who writes into the global. + reads_from_g: set[str] = set() + writes_to_g: set[str] = set() + + # In a first step we will identify the possible `T` only from the angle of + # how they interact with `G`. At a later point we will look at the `T` again, + # because in case of branches there might be multiple definitions of `T`. + for state in access_states[gname]: + for dnode in state.data_nodes(): + if dnode.data != gname: + continue + + # Note that we allow that `G` can be written to by multiple `T` at + # once. However, we require that all this data, is fully defined by + # a read to `G` itself. + for iedge in state.in_edges(dnode): + possible_t = iedge.src + + # If `G` is a pseudo output, see definition above, then it is only + # allowed that access nodes writes to them. Note, that here we + # will only collect which nodes writes to `G`, if these are + # valid `T`s will be checked later, after we cllected all of them. + if not isinstance(possible_t, dace_nodes.AccessNode): + return None + + possible_t_desc = possible_t.desc(sdfg) + if not possible_t_desc.transient: + return None # we must write into a transient. + if isinstance(possible_t_desc, dace_data.View): + return None # The global data must be written to from an array + if not isinstance(possible_t_desc, dace_data.Array): + return None + writes_to_g.add(possible_t.data) + + # Let's look who reads from `g` this will contribute to the `reads_from_g` set. + for oedge in state.out_edges(dnode): + possible_t = oedge.dst + # `T` must be an AccessNode. Note that it is not important + # what also reads from `G`. We just have to find everything that + # can act as `T`. + if not isinstance(possible_t, dace_nodes.AccessNode): + continue + + # It is important that only `G` defines `T`, so it must have + # an incoming degree of one, since we have SSA. + if state.in_degree(possible_t) != 1: + continue + + # `T` must fulfil some condition, like that it is transient. + possible_t_desc = possible_t.desc(sdfg) + if not possible_t_desc.transient: + continue # we must write into a transient. + if isinstance(possible_t_desc, dace_data.View): + continue # We must write into an array and not a view. + if not isinstance(possible_t_desc, dace_data.Array): + continue + + # Currently we do not handle the pattern `(T) -> (G) -> (T)`, + # see `_remove_writes_to_global()` for more, thus we filter + # this pattern here. + if any( + tnode_oedge.dst.data == gname + for tnode_oedge in state.out_edges(possible_t) + if isinstance(tnode_oedge.dst, dace_nodes.AccessNode) + ): + return None + + # Now add the data to the list of data that reads from `G`. + reads_from_g.add(possible_t.data) + + if len(writes_to_g) == 0: + return None + + # Now every write to `G` necessarily comes from an access node that was created + # by a direct read from `G`. We ensure this by checking that `writes_to_g` is + # a subset of `reads_to_g`. + # Note that the `T` nodes might not be unique, which happens in case + # of separate memlets for different subsets. + # of different subsets, are contained in ` + if not writes_to_g.issubset(reads_from_g): + return None + + # If we have branches, it might be that different data is written to `T` depending + # on which branch is selected, i.e. `T = G if cond else foo(A)`. For that + # reason we must now check that `G` is the only data source of `T`, but this + # time we must do the check on `T`. Note we only have to remove the particular access node + # to `T` where `G` is the only data source, while we keep the other access nodes. + # `T`. + for tname in list(writes_to_g): + for state in access_states[tname]: + for dnode in state.data_nodes(): + if dnode.data != tname: + continue + if state.in_degree(dnode) == 0: + continue # We are only interested at definitions. + + # Now ensures that only `gname` defines `T`. + for iedge in state.in_edges(dnode): + t_def_node = iedge.src + if not isinstance(t_def_node, dace_nodes.AccessNode): + writes_to_g.discard(tname) + break + if t_def_node.data != gname: + writes_to_g.discard(tname) + break + if tname not in writes_to_g: + break + + return None if len(writes_to_g) == 0 else writes_to_g + + def _remove_writes_to_globals( + self, + sdfg: dace.SDFG, + t_mapping: dict[str, set[str]], + pipeline_results: dict[str, Any], + ) -> None: + """Remove all writes to the global data defined through `t_mapping`. + + The function does not handle reads from the global to the transients. + + Args: + sdfg: The SDFG on which we should process. + t_mapping: Maps the name of the global data to the transient data. + This set was computed by the `_find_candidates()` function. + pipeline_results: The results of the pipeline. + """ + access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][ + sdfg.cfg_id + ] + for gname, tnames in t_mapping.items(): + self._remove_writes_to_global( + sdfg=sdfg, gname=gname, tnames=tnames, access_states=access_states + ) + + def _remove_writes_to_global( + self, + sdfg: dace.SDFG, + gname: str, + tnames: set[str], + access_states: dict[str, set[dace.SDFGState]], + ) -> None: + """Remove writes to the global data `gname`. + + The function is the same as `_remove_writes_to_globals()` but only processes + one global data descriptor. + """ + # Here we delete the `T` node that writes into `G`, this might turn the `G` + # node into an isolated node. + # It is important that this code does not handle the `(G) -> (T) -> (G)` + # pattern, which is difficult to handle. The issue is that by removing `(T)`, + # what this function does, it also removes the definition `(T)`. However, + # it can only do that if it ensures that `T` is not used anywhere else. + # This is currently handle by the `SingleStateGlobalSelfCopyElimination` pass + # and the classifier rejects this pattern. + for state in access_states[gname]: + for dnode in list(state.data_nodes()): + if dnode.data != gname: + continue + for iedge in list(state.in_edges(dnode)): + tnode = iedge.src + if not isinstance(tnode, dace_nodes.AccessNode): + continue + if tnode.data in tnames: + state.remove_node(tnode) + + # It might be that the `dnode` has become isolated so remove it. + if state.degree(dnode) == 0: + state.remove_node(dnode) + + def _remove_transient_buffers_if_possible( + self, + sdfg: dace.SDFG, + t_mapping: dict[str, set[str]], + pipeline_results: dict[str, Any], + ) -> set[str]: + """Remove the transient data if it is possible, listed in `t_mapping`. + + Essentially the function will look if there is a read to any data that is + mentioned in `tnames`. If there isn't it will remove the write to it and + remove it from the registry. + The function must run after `_remove_writes_to_globals()`. + + The function returns the list of transients that were eliminated. + """ + access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][ + sdfg.cfg_id + ] + result: set[str] = set() + for gname, tnames in t_mapping.items(): + result.update( + self._remove_transient_buffer_if_possible( + sdfg=sdfg, + gname=gname, + tnames=tnames, + access_states=access_states, + ) + ) + return result + + def _remove_transient_buffer_if_possible( + self, + sdfg: dace.SDFG, + gname: str, + tnames: set[str], + access_states: dict[str, set[dace.SDFGState]], + ) -> set[str]: + obsolete_ts: set[str] = set() + for tname in tnames: + # We can remove the (defining) write to `T` only if it is not read + # anywhere else. + if self._has_read_access_for(sdfg, tname, access_states): + continue + # Now we look for all writes to `tname` and remove them, since there + # are no reads. + for state in access_states[tname]: + neighbourhood: set[dace_nodes.Node] = set() + for dnode in list(state.data_nodes()): + if dnode.data == tname: + # We have to store potential sources nodes, which is `G`. + # This is because the local `G` node could become isolated. + # We do not need to consider the outgoing edges, because + # they are reads which we have handled above. + for iedge in state.in_edges(dnode): + assert ( + isinstance(iedge.src, dace_nodes.AccessNode) + and iedge.src.data == gname + ) + neighbourhood.add(iedge.src) + state.remove_node(dnode) + obsolete_ts.add(dnode.data) + + # We now have to check if an node has become isolated. + for nh_node in neighbourhood: + if state.degree(nh_node) == 0: + state.remove_node(nh_node) + + for tname in obsolete_ts: + sdfg.remove_data(tname, validate=False) + + return obsolete_ts + + def _has_read_access_for( + self, + sdfg: dace.SDFG, + dname: str, + access_states: dict[str, set[dace.SDFGState]], + ) -> bool: + """Checks if there is a read access on `dname`.""" + for state in access_states[dname]: + for dnode in state.data_nodes(): + if state.out_degree(dnode) == 0: + continue # We are only interested in read accesses + if dnode.data == dname: + return True + return False + + +@dace_properties.make_properties +class SingleStateGlobalSelfCopyElimination(dace_transformation.SingleStateTransformation): + """Remove global self copy. + + This transformation matches the following case `(G) -> (T) -> (G)`, i.e. `G` + is read from and written too at the same time, however, in between is `T` + used as a buffer. In the example above `G` is a global memory and `T` is a + temporary. This situation is generated by the lowering if the data node is + not needed (because the computation on it is only conditional). + + In case `G` refers to global memory rule 3 of ADR-18 guarantees that we can + only have a point wise dependency of the output on the input. + This transformation will remove the write into `G`, i.e. we thus only have + `(G) -> (T)`. The read of `G` and the definition of `T`, will only be removed + if `T` is not used downstream. If it is used `T` will be maintained. + """ + + node_read_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + node_tmp = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + node_write_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.node_read_g, cls.node_tmp, cls.node_write_g)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + read_g = self.node_read_g + write_g = self.node_write_g + tmp_node = self.node_tmp + g_desc = read_g.desc(sdfg) + tmp_desc = tmp_node.desc(sdfg) + + # NOTE: We do not check if `G` is read downstream. + if read_g.data != write_g.data: + return False + if g_desc.transient: + return False + if not tmp_desc.transient: + return False + if graph.in_degree(read_g) != 0: + return False + if graph.out_degree(read_g) != 1: + return False + if graph.degree(tmp_node) != 2: + return False + if graph.in_degree(write_g) != 1: + return False + if graph.out_degree(write_g) != 0: + return False + if graph.scope_dict()[read_g] is not None: + return False + + return True + + def _is_read_downstream( + self, + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + ) -> bool: + """Scans for reads to `data_to_look`. + + The function will go through states that are reachable from `start_state` + (including) and test if there is a read to the data container `data_to_look`. + It will return `True` the first time it finds such a node. + It is important that the matched nodes, i.e. `self.node_{read_g, write_g, tmp}` + are ignored. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + + Todo: + Port this function to use DaCe pass pipeline. + """ + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + # TODO(phimuell): Run the `StateReachability` pass in a pipeline and use + # the `_pipeline_results` member to access the data. + return gtx_transformations.utils.is_accessed_downstream( + start_state=start_state, + sdfg=sdfg, + reachable_states=None, + data_to_look=data_to_look, + nodes_to_ignore={read_g, write_g, tmp_node}, + ) + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + # We first check if `T`, the intermediate is not used downstream. In this + # case we can remove the read to `G` and `T` itself from the SDFG. + # We have to do this check before, because the matching is not fully stable. + is_tmp_used_downstream = self._is_read_downstream( + start_state=graph, sdfg=sdfg, data_to_look=tmp_node.data + ) + + # The write to `G` can always be removed. + graph.remove_node(write_g) + + # Also remove the read to `G` and `T` from the SDFG if possible. + if not is_tmp_used_downstream: + graph.remove_node(read_g) + graph.remove_node(tmp_node) + # It could still be used in a parallel branch. + try: + sdfg.remove_data(tmp_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_node.data}:"): + raise + + +@dace_properties.make_properties +class CopyChainRemover(dace_transformation.SingleStateTransformation): + """Removes chain of redundant copies, mostly related to `concat_where`. + + `concat_where`, especially when nested, will build "chains" of AccessNodes, + this transformation will remove them. It should be called repeatedly until a + fix point is reached and should be seen as an addition to the array removal passes + that ship with DaCe. + The transformation will look for the pattern `(A1) -> (A2)`, i.e. a data container + is copied into another one. The transformation will then remove `A1` and rewire + the edges such that they now refer to `A2`. Another, and probably better way, is to + consider the transformation as fusion transformation for AccessNodes. + + The transformation builds on ADR-18 and imposes the following additional + requirements before it can be applied: + - Through the merging of `A1` and `A2` no cycles are created. + - `A1` can not be used anywhere else. + - `A1` is fully read by `A2`. + - `A1` is a transient and must have the same dimensionality than `A2`. + + Notes: + - The transformation assumes that the domain inference adjusted the ranges of + the maps such that, in case they write into a transient, the full shape of the transient array is written. + has the same size, i.e. there is not padding, or data that is not written + to. + + Args: + single_use_data: List of data containers that are used only at one place. + Will be stored internally and not updated. + + Todo: + - Extend such that not the full array must be read. + - Try to allow more than one connection between `A1` and `A2`. + - Modify it such that also `A2` can be removed. + """ + + node_a1 = dace_transformation.PatternNode(dace_nodes.AccessNode) + node_a2 = dace_transformation.PatternNode(dace_nodes.AccessNode) + + # Name of all data that is used at only one place. Is computed by the + # `FindSingleUseData` pass and be passed at construction time. Needed until + # [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved. + _single_use_data: dict[dace.SDFG, set[str]] + + def __init__( + self, + *args: Any, + single_use_data: dict[dace.SDFG, set[str]], + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._single_use_data = single_use_data + + @classmethod + def expressions(cls) -> Any: + return [ + dace.sdfg.utils.node_path_graph( + cls.node_a1, + cls.node_a2, + ) + ] + + def can_be_applied( + self, + graph: dace.SDFGState, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + a1: dace_nodes.AccessNode = self.node_a1 + a2: dace_nodes.AccessNode = self.node_a2 + + a1_desc = a1.desc(sdfg) + a2_desc = a2.desc(sdfg) + + # We remove `a1` so it must be a transient and used only once. + if not a1_desc.transient: + return False + if not self.is_single_use_data(sdfg, a1): + return False + + # This avoids that we have to modify the subsets in a fancy way. + if len(a1_desc.shape) != len(a2_desc.shape): + return False + + # For simplicity we assume that neither of `a1` nor `a2` are views. + # TODO(phimuell): Implement some of the cases. + if gtx_transformations.utils.is_view(a1_desc, None): + return False + if gtx_transformations.utils.is_view(a2_desc, None): + return False + + # We only allow that we operate on the top level scope. + if graph.scope_dict()[a1] is not None: + return False + + # TODO(phimuell): Relax this to only prevent host-device copies. + if a1_desc.storage != a2_desc.storage: + return False + + # There shall only be one edge connecting `a1` and `a2`. + # We even strengthen this requirement by not checking for the node `a2`, + # but for the data. + connecting_edges = [ + oedge + for oedge in graph.out_edges(a1) + if isinstance(oedge.dst, dace_nodes.AccessNode) and (oedge.dst.data == a2.data) + ] + if len(connecting_edges) != 1: + return False + + # The full array `a1` is copied into `a2`. Note that it is allowed, that + # `a2` is bigger than `a1`, it is just important that everything that was + # written into `a1` is also accessed. + connecting_edge = connecting_edges[0] + assert connecting_edge.dst is a2 + connecting_memlet = connecting_edge.data + + # If the destination or the source subset of the connection is not fully + # specified, we do not apply. + src_subset = connecting_memlet.get_src_subset(connecting_edge, graph) + if src_subset is None: + return False + dst_subset = connecting_memlet.get_dst_subset(connecting_edge, graph) + if dst_subset is None: + return False + + # NOTE: The main benefit of requiring that the whole array is read is + # that we do not have to adjust maps. + a1_range = dace_sbs.Range.from_array(a1_desc) + if not src_subset.covers(a1_range): + return False + + # We have to ensure that no cycle is created through the removal of `a1`. + # For this we have to ensure that there is no connection, beside the direct + # one between `a1` and `a2`. + # NOTE: We only check the outgoing edges of `a1`, it is not needed to also + # check the incoming edges, because this will not create a cycle. + if gtx_transformations.utils.is_reachable( + start=[oedge.dst for oedge in graph.out_edges(a1) if oedge.dst is not a2], + target=a2, + state=graph, + ): + return False + + # NOTE: In case `a2` is a non transient we do not have to check if it is read + # or written to somewhere else in this state. The reason is that ADR18 + # guarantees us that everything is point wise, therefore `a1` is never + # used as double buffer. + return True + + def is_single_use_data( + self, + sdfg: dace.SDFG, + data: str | dace_nodes.AccessNode, + ) -> bool: + """Checks if `data` is a single use data.""" + assert sdfg in self._single_use_data + if isinstance(data, dace_nodes.AccessNode): + data = data.data + return data in self._single_use_data[sdfg] + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + a1: dace_nodes.AccessNode = self.node_a1 + a2: dace_nodes.AccessNode = self.node_a2 + a1_to_a2_edge: dace_graph.MultiConnectorEdge = next( + oedge for oedge in graph.out_edges(a1) if oedge.dst is a2 + ) + a1_to_a2_memlet: dace.Memlet = a1_to_a2_edge.data + a1_to_a2_dst_subset: dace_sbs.Range = a1_to_a2_memlet.get_dst_subset(a1_to_a2_edge, graph) + + # Note that it is possible that `a1` is connected to the same node multiple + # times, although through different edges. We have to modify the data + # flow there, since the offsets and the data have changed. However, we must + # do this only once. Note that only matching the node is not enough, a + # counter example would be a Map with different connector names. + reconfigured_neighbour: set[tuple[dace_nodes.Node, Optional[str]]] = set() + + # Now we compose the new subset. + # We build on the fact that we have ensured that the whole array `a1` is + # copied into `a2`. Thus the destination of the original source, i.e. + # whatever write into `a1`, is just offset by the beginning of the range + # `a1` writes into `a2`. + # (s1) ------[c:d]-> (A1) -[0:N]------[a:b]-> (A2) + # (s1) ---------[(a + c):(a + c + (d - c))]-> (A2) + # Thus the offset is simply given by `a`, the start where `a1` is written into + # `a2`. + # NOTE: If we ever allow the that `a1` is not fully read, then we would have + # to modify this computation slightly. + a2_offsets: Sequence[dace_sym.SymExpr] = a1_to_a2_dst_subset.min_element() + + # Handle the producer side of things. + for producer_edge in list(graph.in_edges(a1)): + producer: dace_nodes.Node = producer_edge.src + producer_conn = producer_edge.src_conn + new_producer_edge = self._reroute_edge( + is_producer_edge=True, + current_edge=producer_edge, + a2_offsets=a2_offsets, + state=graph, + sdfg=sdfg, + a1=a1, + a2=a2, + ) + if (producer, producer_conn) not in reconfigured_neighbour: + self._reconfigure_dataflow( + is_producer_edge=True, + new_edge=new_producer_edge, + sdfg=sdfg, + state=graph, + a2_offsets=a2_offsets, + a1=a1, + a2=a2, + ) + reconfigured_neighbour.add((producer, producer_conn)) + + # Handle the consumer side of things, as they now have to read from `a2`. + # It is important that the offset is still the same. + for consumer_edge in list(graph.out_edges(a1)): + consumer: dace_nodes.Node = consumer_edge.dst + consumer_conn = consumer_edge.dst_conn + if consumer is a2: + assert consumer_edge is a1_to_a2_edge + continue + new_consumer_edge = self._reroute_edge( + is_producer_edge=False, + current_edge=consumer_edge, + a2_offsets=a2_offsets, + state=graph, + sdfg=sdfg, + a1=a1, + a2=a2, + ) + if (consumer, consumer_conn) not in reconfigured_neighbour: + self._reconfigure_dataflow( + is_producer_edge=False, + new_edge=new_consumer_edge, + sdfg=sdfg, + state=graph, + a2_offsets=a2_offsets, + a1=a1, + a2=a2, + ) + reconfigured_neighbour.add((consumer, consumer_conn)) + + # After the rerouting we have to delete the `a1` data node and descriptor, + # this will also remove all the old edges. + graph.remove_node(a1) + sdfg.remove_data(a1.data, validate=False) + + # We will now propagate the strides starting from the access nodes `a2`. + # Essentially, this will replace the strides from `a1` with the ones of + # `a2`. We do it outside to make sure that we do not forget a case and + # that we propagate the change into every NestedSDFG only once. + gtx_transformations.gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=graph, + outer_node=a2, + ) + + def _reroute_edge( + self, + is_producer_edge: bool, + current_edge: dace_graph.MultiConnectorEdge, + a2_offsets: Sequence[dace_sym.SymExpr], + state: dace.SDFGState, + sdfg: dace.SDFG, + a1: dace_nodes.AccessNode, + a2: dace_nodes.AccessNode, + ) -> dace_graph.MultiConnectorEdge: + """Performs the rerouting of the edge. + + Essentially the function creates new edges that account for the fact that + `a1` will be replaced with `a2`. Depending on the value of `is_producer_edge` + the behaviour is slightly different. + + If `is_producer_edge` is `True` then the function assumes that `current_edge` + ends at `a1`. It will then create a new edge that has the same start and a + similar Memlet but ends at `a2`. + If `is_producer_edge` is `False` then the function assumes that `current_edge` + starts at `a1`. It will then create a new edge that starts at `a2` but has the + same destination and a similar Memlet. + In both cases the Memlet and the corresponding subset, will be modified such + that they account that `a1` was replaced with `a2`. + + It is important that the function will **not** do the following things: + - Remove the old edge, i.e. `producer_edge`. + - Modify the data flow at the other side of the edge. + + The function returns the new edge. + + Args: + is_producer_edge: Indicates how to interpret `current_edge`. + current_edge: The current edge that should be replaced. + a2_offsets: Offset that describes how much to shift writes and reads, + that were previously associated with `a1`. + state: The state in which we operate. + sdfg: The SDFG on which we operate on. + a1: The `a1` node. + a2: The `a2` node. + + """ + current_memlet: dace.Memlet = current_edge.data + if is_producer_edge: + # NOTE: See the note in `_reconfigure_dataflow()` why it is not save to + # use the `get_{dst, src}_subset()` function, although it would be more + # appropriate. + current_subset: dace_sbs.Range = current_memlet.dst_subset + new_src = current_edge.src + new_src_conn = current_edge._src_conn + new_dst = a2 + new_dst_conn = None + assert current_edge.dst_conn is None + else: + current_subset = current_memlet.src_subset + new_src = a2 + new_src_conn = None + new_dst = current_edge.dst + new_dst_conn = current_edge.dst_conn + assert current_edge.src_conn is None + + # If the subset we care about, which is always on the `a1` side, was not + # specified we assume that the whole `a1` has been written. + # TODO(edopao): Fix lowering that this does not happens, it happens for example + # in `tests/next_tests/integration_tests/feature_tests/ffront_tests/ + # test_execution.py::test_docstring`. + if current_subset is None: + current_subset = dace_sbs.Range.from_array(a1.desc(sdfg)) + + # This is the new Memlet, that we will use. We copy it from the original + # Memlet and modify it later. + new_memlet: dace.Memlet = dace.Memlet.from_memlet(current_memlet) + + # Because we operate on the `subset` and `other_subset` properties directly + # we do not need to distinguish between the different directions. Also + # in both cases the offset is the same. + if new_memlet.data == a1.data: + new_memlet.data = a2.data + new_subset = current_subset.offset_new(a2_offsets, negative=False) + new_memlet.subset = new_subset + else: + new_subset = current_subset.offset_new(a2_offsets, negative=False) + new_memlet.other_subset = new_subset + + new_edge = state.add_edge( + new_src, + new_src_conn, + new_dst, + new_dst_conn, + new_memlet, + ) + assert ( # Ensure that the edge has the right direction. + new_subset is new_edge.data.dst_subset + if is_producer_edge + else new_subset is new_edge.data.src_subset + ) + return new_edge + + def _reconfigure_dataflow( + self, + is_producer_edge: bool, + new_edge: dace_graph.MultiConnectorEdge, + a2_offsets: Sequence[dace_sym.SymExpr], + state: dace.SDFGState, + sdfg: dace.SDFG, + a1: dace_nodes.AccessNode, + a2: dace_nodes.AccessNode, + ) -> None: + """Modify the data flow associated to `new_edge`. + + The `_reroute_edge()` function creates a new edge, but it does not modify + the data flow at the other side, of the connection, this is done by this + function. + + Depending on the value of `is_producer_edge` the function will either modify + the source of `new_edge` (`True`) or it will modify the data flow associated + to the destination of `new_edge` (`False`). + Furthermore, the specific actions depends on what kind of node is on the other + side. However, essentially the function will modify it to account for the + change from `a1` to `a2`. + + It is important that it is the caller's responsibility to ensure that this + function is not called multiple times on the same producer target. + + It is important that this function will not propagate the new strides. This + must be done from the outside. + + Args: + is_producer_edge: If `True` then the source of `new_edge` is processed, + if `False` then the destination part of `new_edge` is processed. + new_edge: The newly created edge, essentially the return value of + `self._reroute__edge()`. + a2_offsets: Offset that describes how much to shift subsets associated + to `a1` to account that they are now associated to `a2`. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + a1: The `a1` node. + a2: The `a2` node. + """ + other_node = new_edge.src if is_producer_edge else new_edge.dst + + if isinstance(other_node, dace_nodes.AccessNode): + # There is nothing here to do. + pass + + elif isinstance(other_node, dace_nodes.Tasklet): + # A very obscure case, but I think it might happen, but as in the AccessNode + # case there is nothing to do here. + pass + + elif isinstance(other_node, (dace_nodes.MapExit | dace_nodes.MapEntry)): + # Essentially, we have to propagate the change that everything that + # refers to `a1` should now refer to `a2`, In addition we also have to + # modify the subsets, depending on the direction of the new edge either + # the source or destination subset. + # NOTE: Because we assume that `a1` is read fully into `a2` we do not + # have to adjust the ranges of the Map. If we would drop this assumption + # then we would have to modify the ranges such that only the ranges we + # need are computed. + # NOTE: Also for this case we have to propagate the strides, for the case + # that a NestedSDFG is inside the map, but this is done externally. + assert ( + isinstance(other_node, dace_nodes.MapExit) + if is_producer_edge + else isinstance(other_node, dace_nodes.MapEntry) + ) + for memlet_tree in state.memlet_tree(new_edge).traverse_children(include_self=False): + edge_to_adjust = memlet_tree.edge + memlet_to_adjust = edge_to_adjust.data + + # NOTE: Actually we should use the `get_{src, dst}_subset()` functions, + # see https://github.com/spcl/dace/issues/1703. However, we can not + # do that because the SDFG is currently in an invalid state. So + # we have to call the properties and hope that it works. + subset_to_adjust = ( + memlet_to_adjust.dst_subset if is_producer_edge else memlet_to_adjust.src_subset + ) + + # If needed modify the association of the Memlet. + if memlet_to_adjust.data == a1.data: + memlet_to_adjust.data = a2.data + + assert subset_to_adjust is not None + subset_to_adjust.offset(a2_offsets, negative=False) + + elif isinstance(other_node, dace_nodes.NestedSDFG): + # We have obviously to adjust the strides, however, this is done outside + # this function. + # TODO(phimuell): Look into the implication that we not necessarily pass + # the full array, but essentially slice a bit. + pass + + else: + # As we encounter them we should handle them case by case. + raise NotImplementedError( + f"The case for '{type(other_node).__name__}' has not been implemented." + ) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py new file mode 100644 index 0000000000..c2c5acf05f --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/simplify.py @@ -0,0 +1,1140 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""The GT4Py specific simplification pass.""" + +import collections +import copy +import uuid +from typing import Any, Final, Iterable, Optional, TypeAlias + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes +from dace.transformation import ( + dataflow as dace_dataflow, + pass_pipeline as dace_ppl, + passes as dace_passes, +) + +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + + +GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} +"""Set of simplify passes `gt_simplify()` skips by default. + +The following passes are included: +- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a + symbol or vice versa and at a later point to invert this again. However, this + pass has some problems with this pattern so for the time being it is disabled. +- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. +""" + + +def gt_simplify( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, + skip: Optional[Iterable[str]] = None, +) -> Optional[dict[str, Any]]: + """Performs simplifications on the SDFG in place. + + Instead of calling `sdfg.simplify()` directly, you should use this function, + as it is specially tuned for GridTool based SDFGs. + + This function runs the DaCe simplification pass, but the following passes are + replaced: + - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. + + Further, the function will run the following passes in addition to DaCe simplify: + - `SingleStateGlobalSelfCopyElimination`: Special copy pattern that in the context + of GT4Py based SDFG behaves as a no op, i.e. `(G) -> (T) -> (G)`. + - `MultiStateGlobalSelfCopyElimination`: Very similar to + `SingleStateGlobalSelfCopyElimination`, with the exception that the write to + `T`, i.e. `(G) -> (T)` and the write back to `G`, i.e. `(T) -> (G)` might be + in different states. + - `CopyChainRemover`: Which removes some chains that are introduced by the + `concat_where` built-in function. + + Furthermore, by default, or if `None` is passed for `skip` the passes listed in + `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. + + Args: + sdfg: The SDFG to optimize. + validate: Perform validation after the pass has run. + validate_all: Perform extensive validation. + skip: List of simplify passes that should not be applied, defaults + to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. + + Note: + Currently DaCe does not provide a way to inject or exchange sub passes in + simplify. The custom inline pass is run at the beginning and the array + elimination at the end. The whole process is run inside a loop that ensures + that `gt_simplify()` results in a fix point. + """ + # Ensure that `skip` is a `set` + skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) + + result: Optional[dict[str, Any]] = None + + at_least_one_xtrans_run = True + + while at_least_one_xtrans_run: + at_least_one_xtrans_run = False + + # NOTE: See comment in `gt_inline_nested_sdfg()` for more. + sdfg.reset_cfg_list() + + # To mitigate DaCe issue 1959, we run the chain removal transformation here. + # TODO(phimuell): Remove as soon as we have a true solution. + if "CopyChainRemover" not in skip: + copy_chain_remover_result = gtx_transformations.gt_remove_copy_chain( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + if copy_chain_remover_result is not None: + at_least_one_xtrans_run = True + result = result or {} + if "CopyChainRemover" not in result: + result["CopyChainRemover"] = 0 + result["CopyChainRemover"] += copy_chain_remover_result + + if "InlineSDFGs" not in skip: + inline_res = gt_inline_nested_sdfg( + sdfg=sdfg, + multistate=True, + permissive=False, + validate=validate, + validate_all=validate_all, + ) + if inline_res is not None: + at_least_one_xtrans_run = True + result = result or {} + result.update(inline_res) + + simplify_res = dace_passes.SimplifyPass( + validate=validate, + validate_all=validate_all, + verbose=False, + skip=(skip | {"InlineSDFGs"}), + ).apply_pass(sdfg, {}) + + if simplify_res is not None: + at_least_one_xtrans_run = True + result = result or {} + result.update(simplify_res) + + # This is the place were we actually want to apply the chain removal. + if "CopyChainRemover" not in skip: + copy_chain_remover_result = gtx_transformations.gt_remove_copy_chain( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + if copy_chain_remover_result is not None: + at_least_one_xtrans_run = True + result = result or {} + if "CopyChainRemover" not in result: + result["CopyChainRemover"] = 0 + result["CopyChainRemover"] += copy_chain_remover_result + + if "SingleStateGlobalSelfCopyElimination" not in skip: + self_copy_removal_result = sdfg.apply_transformations_repeated( + gtx_transformations.SingleStateGlobalSelfCopyElimination(), + validate=validate, + validate_all=validate_all, + ) + if self_copy_removal_result > 0: + at_least_one_xtrans_run = True + result = result or {} + if "SingleStateGlobalSelfCopyElimination" not in result: + result["SingleStateGlobalSelfCopyElimination"] = 0 + result["SingleStateGlobalSelfCopyElimination"] += self_copy_removal_result + + if "MultiStateGlobalSelfCopyElimination" not in skip: + distributed_self_copy_result = ( + gtx_transformations.gt_multi_state_global_self_copy_elimination( + sdfg, validate=validate_all + ) + ) + if distributed_self_copy_result is not None: + at_least_one_xtrans_run = True + result = result or {} + if "MultiStateGlobalSelfCopyElimination" not in result: + result["MultiStateGlobalSelfCopyElimination"] = set() + result["MultiStateGlobalSelfCopyElimination"].update(distributed_self_copy_result) + + return result + + +def gt_inline_nested_sdfg( + sdfg: dace.SDFG, + multistate: bool = True, + permissive: bool = False, + validate: bool = True, + validate_all: bool = False, +) -> Optional[dict[str, int]]: + """Perform inlining of nested SDFG into their parent SDFG. + + The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. + However, before the inline transformation is run the function will run some + cleaning passes that allows inlining nested SDFGs. + As a side effect, the function will split stages into more states. + + Args: + sdfg: The SDFG that should be processed, will be modified in place and returned. + multistate: Allow inlining of multistate nested SDFG, defaults to `True`. + permissive: Be less strict on the accepted SDFGs. + validate: Perform validation after the transformation has finished. + validate_all: Performs extensive validation. + """ + first_iteration = True + nb_preproccess_total = 0 + nb_inlines_total = 0 + while True: + # TODO(edopao): we call `reset_cfg_list()` as temporary workaround for a + # dace issue with pattern matching. Any time the SDFG's CFG-tree is modified, + # i.e. a loop is added/removed or something similar, the CFG list needs + # to be updated accordingly. Otherwise, all ID-based accesses are not going + # to work (which is what pattern matching attempts to do). + sdfg.reset_cfg_list() + nb_preproccess = sdfg.apply_transformations_repeated( + [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], + validate=False, + validate_all=validate_all, + ) + nb_preproccess_total += nb_preproccess + if (nb_preproccess == 0) and (not first_iteration): + break + + # Create and configure the inline pass + inline_sdfg = dace_passes.InlineSDFGs() + inline_sdfg.progress = False + inline_sdfg.permissive = permissive + inline_sdfg.multistate = multistate + + # Apply the inline pass + # The pass returns `None` no indicate "nothing was done" + nb_inlines = inline_sdfg.apply_pass(sdfg, {}) or 0 + nb_inlines_total += nb_inlines + + # Check result, if needed and test if we can stop + if validate_all or validate: + sdfg.validate() + if nb_inlines == 0: + break + first_iteration = False + + result: dict[str, int] = {} + if nb_inlines_total != 0: + result["InlineSDFGs"] = nb_inlines_total + if nb_preproccess_total != 0: + result["PruneSymbols|PruneConnectors"] = nb_preproccess_total + return result if result else None + + +def gt_substitute_compiletime_symbols( + sdfg: dace.SDFG, + repl: dict[str, Any], + validate: bool = False, + validate_all: bool = False, +) -> None: + """Substitutes symbols that are known at compile time with their value. + + Some symbols are known to have a constant value. This function will remove these + symbols from the SDFG and replace them with the value. + An example where this makes sense are strides that are known to be one. + + Args: + sdfg: The SDFG to process. + repl: Maps the name of the symbol to the value it should be replaced with. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + + Todo: This function needs improvement. + """ + + # We will use the `replace` function of the top SDFG, however, lower levels + # are handled using ConstantPropagation. + sdfg.replace_dict(repl) + + # TODO(phimuell): Get rid of the `ConstantPropagation` + const_prop = dace_passes.ConstantPropagation() + const_prop.recursive = True + const_prop.progress = False + + const_prop.apply_pass( + sdfg=sdfg, + initial_symbols=repl, + _=None, + ) + gt_simplify( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + dace.sdfg.propagation.propagate_memlets_sdfg(sdfg) + + +def gt_reduce_distributed_buffering( + sdfg: dace.SDFG, +) -> Optional[dict[dace.SDFG, dict[dace.SDFGState, set[str]]]]: + """Removes distributed write back buffers.""" + pipeline = dace_ppl.Pipeline([DistributedBufferRelocator()]) + all_result = {} + + for rsdfg in sdfg.all_sdfgs_recursive(): + ret = pipeline.apply_pass(sdfg, {}) + if ret is not None: + all_result[rsdfg] = ret + + if len(all_result) == 0: + return None + + return all_result + + +AccessLocation: TypeAlias = tuple[dace_nodes.AccessNode, dace.SDFGState] +"""Describes an access node and the state in which it is located. +""" + + +@dace_properties.make_properties +class DistributedBufferRelocator(dace_transformation.Pass): + """Moves the final write back of the results to where it is needed. + + In certain cases, especially in case where we have `if` the result is computed + in each branch and then in the join state written back. Thus there is some + additional storage needed. + The transformation will look for the following situation: + - A transient data container, called `temp_storage`, is written into another + container, called `dest_storage`, which is not transient. + - The access node of `temp_storage` has an in degree of zero and an out degree of one. + - The access node of `dest_storage` has an in degree of of one and an + out degree of zero (this might be lifted). + - `temp_storage` is not used afterwards. + - `dest_storage` is only used to implement the buffering. + + The function will relocate the writing of `dest_storage` to where `temp_storage` is + written, which might be multiple locations. + It will also remove the writing back. + It is advised that after this transformation simplify is run again. + + The relocation will not take place if it might create data race. A necessary + but not sufficient condition for a data race is if `dest_storage` is present + in the state where `temp_storage` is defined. In addition at least one of the + following conditions has to be met: + - There are accesses to `dest_storage` that are not predecessor to the node where + the data is stored inside `temp_storage`. This check will ignore empty Memlets. + - There is a `dest_storage` access node, that has an output degree larger + than one. + + Note: + - Essentially this transformation removes the double buffering of + `dest_storage`. Because we ensure that that `dest_storage` is non + transient this is okay, as our rule guarantees this. + + Todo: + - Allow that `dest_storage` can also be transient. + - Allow that `dest_storage` does not need to be a sink node, this is most + likely most relevant if it is transient. + - Check if `dest_storage` is used between where we want to place it and + where it is currently used. + """ + + def modifies(self) -> dace_ppl.Modifies: + return dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes + + def should_reapply(self, modified: dace_ppl.Modifies) -> bool: + return modified & (dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes) + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return { + dace_transformation.passes.StateReachability, + dace_transformation.passes.FindAccessStates, + } + + def apply_pass( + self, sdfg: dace.SDFG, pipeline_results: dict[str, Any] + ) -> Optional[dict[dace.SDFGState, set[str]]]: + # NOTE: We can not use `AccessSets` because this pass operates on + # `ControlFlowBlock`s, which might consists of multiple states. Thus we are + # using `FindAccessStates` which has this `SDFGState` granularity. The downside + # is, however, that we have to determine if the access in that state is a + # write or not, which means we have to find it first. + access_states: dict[str, set[dace.SDFGState]] = pipeline_results["FindAccessStates"][ + sdfg.cfg_id + ] + + # For speeding up the `is_accessed_downstream()` calls. + reachable: dict[dace.SDFGState, set[dace.SDFGState]] = pipeline_results[ + "StateReachability" + ][sdfg.cfg_id] + + result: dict[dace.SDFGState, set[str]] = collections.defaultdict(set) + + to_relocate = self._find_candidates(sdfg, reachable, access_states) + if len(to_relocate) == 0: + return None + self._relocate_write_backs(sdfg, to_relocate) + + for (wb_an, wb_state), _ in to_relocate: + result[wb_state].add(wb_an.data) + + return result + + def _relocate_write_backs( + self, + sdfg: dace.SDFG, + to_relocate: list[tuple[AccessLocation, list[AccessLocation]]], + ) -> None: + """Perform the actual relocation.""" + for (wb_an, wb_state), def_locations in to_relocate: + # Get the memlet that we have to replicate. + wb_edge = next(iter(wb_state.out_edges(wb_an))) + wb_memlet: dace.Memlet = wb_edge.data + final_dest_name: str = wb_edge.dst.data + + for def_an, def_state in def_locations: + def_state.add_edge( + def_an, + wb_edge.src_conn, + def_state.add_access(final_dest_name), + wb_edge.dst_conn, + copy.deepcopy(wb_memlet), + ) + + # Now remove the old node and if the old target become isolated + # remove that as well. + old_dst = wb_edge.dst + wb_state.remove_node(wb_an) + if wb_state.degree(old_dst) == 0: + wb_state.remove_node(old_dst) + + def _find_candidates( + self, + sdfg: dace.SDFG, + reachable: dict[dace.SDFGState, set[dace.SDFGState]], + access_states: dict[str, set[dace.SDFGState]], + ) -> list[tuple[AccessLocation, list[AccessLocation]]]: + """Determines all temporaries that have to be relocated. + + Returns: + A list of tuples. The first element element of the tuple is an + `AccessLocation` that describes where the temporary is read. + The second element is a list of `AccessLocation`s that describes + where the temporary is defined. + """ + # All nodes that are used as distributed buffers. + candidate_temp_storage: list[AccessLocation] = [] + + # Which `temp_storage` access node is written back to which global memory. + temp_storage_to_global: dict[dace_nodes.AccessNode, str] = {} + + for state in sdfg.states(): + # These are the possible targets we want to write into. + candidate_dst_nodes: set[dace_nodes.AccessNode] = { + node + for node in state.sink_nodes() + if ( + isinstance(node, dace_nodes.AccessNode) + and state.in_degree(node) == 1 + and (not node.desc(sdfg).transient) + ) + } + if len(candidate_dst_nodes) == 0: + continue + + for temp_storage in state.data_nodes(): + if not temp_storage.desc(sdfg).transient: + continue + if state.out_degree(temp_storage) != 1: + continue + dst_candidate: dace_nodes.AccessNode = next( + iter(edge.dst for edge in state.out_edges(temp_storage)) + ) + if dst_candidate not in candidate_dst_nodes: + continue + candidate_temp_storage.append((temp_storage, state)) + temp_storage_to_global[temp_storage] = dst_candidate.data + + if len(candidate_temp_storage) == 0: + return [] + + # Now we have to find the places where the temporary sources are defined. + # I.e. This is also the location where the temporary source was initialized. + result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = [] + + def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: + return { + src_state + for src_state in sdfg.states() + if dst_state in reachable[src_state] and dst_state is not src_state + } + + for temp_storage in candidate_temp_storage: + temp_storage_node, temp_storage_state = temp_storage + def_locations: list[AccessLocation] = [] + for upstream_state in find_upstream_states(temp_storage_state): + if self._is_written_to_in_state( + data=temp_storage_node.data, + state=upstream_state, + access_states=access_states, + ): + # NOTE: We do not impose any restriction on `temp_storage`. Thus + # It could be that we do read from it (we can never write to it) + # in this state or any other state later. + # TODO(phimuell): Should we require that `temp_storage` is a sink + # node? It might prevent or allow other optimizations. + new_locations = [ + (data_node, upstream_state) + for data_node in upstream_state.data_nodes() + if data_node.data == temp_storage_node.data + ] + def_locations.extend(new_locations) + if len(def_locations) != 0: + result_candidates.append((temp_storage, def_locations)) + + # This transformation removes `temp_storage` by writing its content directly + # to `dest_storage`, at the point where it is defined. + # For this transformation to be valid the following conditions have to be met: + # - Between the definition of `temp_storage` and the write back to `dest_storage`, + # `dest_storage` can not be accessed. + # - Between the definitions of `temp_storage` and the point where it is written + # back, `temp_storage` can only be accessed in the range that is written back. + # - After the write back point, `temp_storage` shall not be accessed. This + # restriction could be lifted. + # + # To keep the implementation simple, we use the conditions: + # - `temp_storage` is only accessed were it is defined and at the write back + # point. + # - Between the definitions of `temp_storage` and the write back point, + # `dest_storage` is not used. + + result: list[tuple[AccessLocation, list[AccessLocation]]] = [] + + for wb_location, def_locations in result_candidates: + # Get the state and the location where the temporary is written back + # into the global data container. + wb_node, wb_state = wb_location + + for def_node, def_state in def_locations: + # Test if `temp_storage` is only accessed where it is defined and + # where it is written back. + if gtx_transformations.utils.is_accessed_downstream( + start_state=def_state, + sdfg=sdfg, + reachable_states=reachable, + data_to_look=wb_node.data, + nodes_to_ignore={def_node, wb_node}, + ): + break + + # Check if the global data is not used between the definition of + # `dest_storage` and where its written back. However, we ignore + # the state were `temp_storage` is defined. The checks if these + # checks are performed by the `_check_read_write_dependency()` + # function. + global_data_name = temp_storage_to_global[wb_node] + global_nodes_in_def_state = { + dnode for dnode in def_state.data_nodes() if dnode.data == global_data_name + } + + # The `is_accessed_downstream()` function has some odd behaviour + # regarding `states_to_ignore`. Because of the special SDFGs we have + # this should not be an issue. + if gtx_transformations.utils.is_accessed_downstream( + start_state=def_state, + sdfg=sdfg, + reachable_states=reachable, + data_to_look=global_data_name, + nodes_to_ignore=global_nodes_in_def_state, + states_to_ignore={wb_state}, + ): + break + if self._check_read_write_dependency(sdfg, wb_location, def_locations): + break + else: + result.append((wb_location, def_locations)) + + return result + + def _is_written_to_in_state( + self, + data: str, + state: dace.SDFGState, + access_states: dict[str, set[dace.SDFGState]], + ) -> bool: + """This function determines if there is a write to data `data` in state `state`. + + Args: + data: Name of the data descriptor that should be tested. + state: The state that should be examined. + access_states: The set of state that writes to a specific data. + """ + assert data in access_states, f"Did not found '{data}' in 'access_states'." + + # According to `access_states` `data` is not accessed inside `state`. + # Therefore there is no write. + if state not in access_states[data]: + return False + + # There is an AccessNode for `data` inside `state`. Now we have to find the + # node and determine if it is a write or not. + for dnode in state.data_nodes(): + if dnode.data != data: + continue + if state.in_degree(dnode) > 0: + return True + + return False + + def _check_read_write_dependency( + self, + sdfg: dace.SDFG, + write_back_location: AccessLocation, + target_locations: list[AccessLocation], + ) -> bool: + """Tests if read-write conflicts would be created. + + This function ensures that the substitution of `write_back_location` into + `target_locations` will not create a read-write conflict. + The rules that are used for this are outlined in the class description. + + Args: + sdfg: The SDFG on which we operate. + write_back_location: Where currently the write back occurs. + target_locations: List of the locations where we would like to perform + the write back instead. + + Returns: + If a read-write dependency is detected then the function will return + `True` and if none was detected `False` will be returned. + """ + for target_location in target_locations: + if self._check_read_write_dependency_impl(sdfg, write_back_location, target_location): + return True + return False + + def _check_read_write_dependency_impl( + self, + sdfg: dace.SDFG, + write_back_location: AccessLocation, + target_location: AccessLocation, + ) -> bool: + """Tests if read-write conflict would be created for a single location. + + Args: + sdfg: The SDFG on which we operate. + write_back_location: Where currently the write back occurs. + target_locations: Location where the new write back should be performed. + + Todo: + Refine these checks later. + + Returns: + If a read-write dependency is detected then the function will return + `True` and if none was detected `False` will be returned. + """ + assert write_back_location[0].data == target_location[0].data + + # Get the state and the location where the temporary is written back + # into the global data container. Because `write_back_node` refers to + # the temporary we must query the graph to find the global node. + write_back_node, write_back_state = write_back_location + write_back_edge = next(iter(write_back_state.out_edges(write_back_node))) + global_data_name = write_back_edge.dst.data + assert not sdfg.arrays[global_data_name].transient + assert write_back_state.out_degree(write_back_node) == 1 + assert write_back_state.in_degree(write_back_node) == 0 + + # Get the location and the state where the temporary is originally defined. + def_location_of_intermediate, state_to_inspect = target_location + + # These are all access nodes that refers to the global data, that we want + # to move into the state `state_to_inspect`. We need them to do the + # second test. + accesses_to_global_data: set[dace_nodes.AccessNode] = set() + + # In the first check we look for an access node, to the global data, that + # has an output degree larger than one. However, for this we ignore all + # empty Memlets. This is done because such Memlets are used to induce a + # schedule or order in the dataflow graph. + # As a byproduct, for the second test, we also collect all of these nodes. + # TODO(phimuell): Refine this such that it takes the location of the data + # into account. + for dnode in state_to_inspect.data_nodes(): + if dnode.data != global_data_name: + continue + dnode_degree = sum( + (1 for oedge in state_to_inspect.out_edges(dnode) if not oedge.data.is_empty()) + ) + if dnode_degree > 1: + return True + # TODO(phimuell): Maybe AccessNodes with zero input degree should be ignored. + accesses_to_global_data.add(dnode) + + # There is no reference to the global data, so no need to do more tests. + if len(accesses_to_global_data) == 0: + return False + + # For the second test we will explore the dataflow graph, in reverse order, + # starting from the definition of the temporary node. If we find an access + # to the global data we remove it from the `accesses_to_global_data` list. + # If the list has not become empty, then we know that there is some sind + # branch (or concurrent dataflow) in this state that accesses the global + # data and we will have read-write conflicts. + # It is however, important to realize that passing this check does not + # imply that there are no read-write. We assume here that all accesses to + # the global data that was made before the write back were constructed in + # a correct way. + to_process: list[dace_nodes.Node] = [def_location_of_intermediate] + seen: set[dace_nodes.Node] = set() + while len(to_process) != 0: + node = to_process.pop() + seen.add(node) + + if isinstance(node, dace_nodes.AccessNode): + if node.data == global_data_name: + accesses_to_global_data.discard(node) + if len(accesses_to_global_data) == 0: + return False + + # Note that we only explore the ingoing edges, thus we will not necessarily + # explore the whole graph. However, this is fine, because we will see the + # relevant parts. To see that assume that we would also have to check the + # outgoing edges, this would mean that there was some branching point, + # which is a serialization point, so the dataflow would have been invalid + # before. + to_process.extend( + iedge.src for iedge in state_to_inspect.in_edges(node) if iedge.src not in seen + ) + + assert len(accesses_to_global_data) > 0 + return True + + +@dace_properties.make_properties +class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation): + """Moves a Tasklet, with no input into a map. + + Tasklets without inputs, are mostly used to generate constants. + However, if they are outside a Map, then this constant value is an + argument to the kernel, and can not be used by the compiler. + + This transformation moves such Tasklets into a Map scope. + """ + + tasklet = dace_transformation.PatternNode(dace_nodes.Tasklet) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.tasklet, cls.access_node, cls.map_entry)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + tasklet: dace_nodes.Tasklet = self.tasklet + access_node: dace_nodes.AccessNode = self.access_node + access_desc: dace_data.Data = access_node.desc(sdfg) + map_entry: dace_nodes.MapEntry = self.map_entry + + if graph.in_degree(tasklet) != 0: + return False + if graph.out_degree(tasklet) != 1: + return False + if tasklet.has_side_effects(sdfg): + return False + if tasklet.code_init.as_string: + return False + if tasklet.code_exit.as_string: + return False + if tasklet.code_global.as_string: + return False + if tasklet.state_fields: + return False + if not isinstance(access_desc, dace_data.Scalar): + return False + if not access_desc.transient: + return False + if not any( + edge.dst_conn and edge.dst_conn.startswith("IN_") + for edge in graph.out_edges(access_node) + if edge.dst is map_entry + ): + return False + # NOTE: We allow that the access node is used in multiple places. + + return True + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + tasklet: dace_nodes.Tasklet = self.tasklet + access_node: dace_nodes.AccessNode = self.access_node + access_desc: dace_data.Scalar = access_node.desc(sdfg) + map_entry: dace_nodes.MapEntry = self.map_entry + + # Find _a_ connection that leads from the access node to the map. + edge_to_map = next( + iter( + edge + for edge in graph.out_edges(access_node) + if edge.dst is map_entry and edge.dst_conn.startswith("IN_") + ) + ) + connector_name: str = edge_to_map.dst_conn[3:] + + # This is the tasklet that we will put inside the map, note we have to do it + # this way to avoid some name clash stuff. + inner_tasklet: dace_nodes.Tasklet = graph.add_tasklet( + name=f"{tasklet.label}__clone_{str(uuid.uuid1()).replace('-', '_')}", + outputs=tasklet.out_connectors.keys(), + inputs=set(), + code=tasklet.code, + language=tasklet.language, + debuginfo=tasklet.debuginfo, + ) + inner_desc: dace_data.Scalar = access_desc.clone() + inner_data_name: str = sdfg.add_datadesc(access_node.data, inner_desc, find_new_name=True) + inner_an: dace_nodes.AccessNode = graph.add_access(inner_data_name) + + # Connect the tasklet with the map entry and the access node. + graph.add_nedge(map_entry, inner_tasklet, dace.Memlet()) + graph.add_edge( + inner_tasklet, + next(iter(inner_tasklet.out_connectors.keys())), + inner_an, + None, + dace.Memlet(f"{inner_data_name}[0]"), + ) + + # Now we will reroute the edges went through the inner map, through the + # inner access node instead. + for old_inner_edge in list( + graph.out_edges_by_connector(map_entry, "OUT_" + connector_name) + ): + # We now modify the downstream data. This is because we no longer refer + # to the data outside but the one inside. + self._modify_downstream_memlets( + state=graph, + edge=old_inner_edge, + old_data=access_node.data, + new_data=inner_data_name, + ) + + # After we have changed the properties of the MemletTree of `edge` + # we will now reroute it, such that the inner access node is used. + graph.add_edge( + inner_an, + None, + old_inner_edge.dst, + old_inner_edge.dst_conn, + old_inner_edge.data, + ) + graph.remove_edge(old_inner_edge) + map_entry.remove_in_connector("IN_" + connector_name) + map_entry.remove_out_connector("OUT_" + connector_name) + + # Now we can remove the map connection between the outer/old access + # node and the map. + graph.remove_edge(edge_to_map) + + # The data is no longer referenced in this state, so we can potentially + # remove + if graph.out_degree(access_node) == 0: + # TODO(phimuell): Use the pipeline to run `StateReachability` once. + if not gtx_transformations.utils.is_accessed_downstream( + start_state=graph, + sdfg=sdfg, + reachable_states=None, + data_to_look=access_node.data, + nodes_to_ignore={access_node}, + ): + graph.remove_nodes_from([tasklet, access_node]) + # Needed if data is accessed in a parallel branch. + try: + sdfg.remove_data(access_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {access_node.data}:"): + raise + + def _modify_downstream_memlets( + self, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge, + old_data: str, + new_data: str, + ) -> None: + """Replaces the data along on the tree defined by `edge`. + + The function will traverse the MemletTree defined by `edge`. + Any Memlet that refers to `old_data` will be replaced with + `new_data`. + + Args: + state: The sate in which we operate. + edge: The edge defining the MemletTree. + old_data: The name of the data that should be replaced. + new_data: The name of the new data the Memlet should refer to. + """ + mtree: dace.memlet.MemletTree = state.memlet_tree(edge) + for tedge in mtree.traverse_children(True): + # Because we only change the name of the data, we do not change the + # direction of the Memlet, so `{src, dst}_subset` will remain the same. + if tedge.edge.data.data == old_data: + tedge.edge.data.data = new_data + + +@dace_properties.make_properties +class GT4PyMapBufferElimination(dace_transformation.SingleStateTransformation): + """Allows to remove unneeded buffering at map output. + + The transformation matches the case `MapExit -> (T) -> (G)`, where `T` is an + AccessNode referring to a transient and `G` an AccessNode that refers to non + transient memory. + If the following conditions are met then `T` is removed. + - `T` is not used to filter computations, i.e. what is written into `G` + is covered by what is written into `T`. + - `T` is not used anywhere else. + - `G` is not also an input to the map, except there is only a pointwise + dependency in `G`, see the note below. + - Everything needs to be at top scope. + + Notes: + - Rule 3 of ADR18 should guarantee that any valid GT4Py program meets the + point wise dependency in `G`, for that reason it is possible to disable + this test by specifying `assume_pointwise`. + + Todo: + - Implement a real pointwise test. + - Run this inside a pipeline. + """ + + map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) + tmp_ac = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + glob_ac = dace_transformation.PatternNode(dace_nodes.AccessNode) + + assume_pointwise = dace_properties.Property( + dtype=bool, + default=False, + desc="Dimensions that should become the leading dimension.", + ) + + def __init__( + self, + assume_pointwise: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if assume_pointwise is not None: + self.assume_pointwise = assume_pointwise + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_exit, cls.tmp_ac, cls.glob_ac)] + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return {dace_transformation.passes.ConsolidateEdges} + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + tmp_ac: dace_nodes.AccessNode = self.tmp_ac + glob_ac: dace_nodes.AccessNode = self.glob_ac + tmp_desc: dace_data.Data = tmp_ac.desc(sdfg) + glob_desc: dace_data.Data = glob_ac.desc(sdfg) + + if not tmp_desc.transient: + return False + if glob_desc.transient: + return False + if graph.in_degree(tmp_ac) != 1: + return False + if any(gtx_transformations.utils.is_view(ac, sdfg) for ac in [tmp_ac, glob_ac]): + return False + if len(glob_desc.shape) != len(tmp_desc.shape): + return False + + # Test if we are on the top scope (it is likely). + if graph.scope_dict()[glob_ac] is not None: + return False + + # Now perform if we are point wise + if not self._perform_pointwise_test(graph, sdfg): + return False + + # Test if `tmp` is only anywhere else, this is important for removing it. + if graph.out_degree(tmp_ac) != 1: + return False + # TODO(phimuell): Use the pipeline system to run the `StateReachability` pass + # only once. Taking care of DaCe issue 1911. + if gtx_transformations.utils.is_accessed_downstream( + start_state=graph, + sdfg=sdfg, + reachable_states=None, + data_to_look=tmp_ac.data, + nodes_to_ignore={tmp_ac}, + ): + return False + + # Now we ensure that `tmp` is not used to filter out some computations. + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + + tmp_in_subset = map_to_tmp_edge.data.get_dst_subset(map_to_tmp_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + if tmp_in_subset is None: + tmp_in_subset = dace_subsets.Range.from_array(tmp_desc) + if tmp_out_subset is None: + tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) + if glob_in_subset is None: + return False + + # TODO(phimuell): Do we need simplify in the check. + # TODO(phimuell): Restrict this to having the same size. + if tmp_out_subset != tmp_in_subset: + return False + return True + + def _perform_pointwise_test( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Test if `G` is only point wise accessed. + + This function will also consider the `assume_pointwise` property. + """ + map_exit: dace_nodes.MapExit = self.map_exit + map_entry: dace_nodes.MapEntry = state.entry_node(map_exit) + glob_ac: dace_nodes.AccessNode = self.glob_ac + glob_data: str = glob_ac.data + + # First we check if `G` is also an input to this map. + conflicting_inputs: set[dace_nodes.AccessNode] = set() + for in_edge in state.in_edges(map_entry): + if not isinstance(in_edge.src, dace_nodes.AccessNode): + continue + + # Find the source of this data, if it is a view we trace it to + # its origin. + src_node: dace_nodes.AccessNode = gtx_transformations.utils.track_view( + in_edge.src, state, sdfg + ) + + # Test if there is a conflict; We do not store the source but the + # actual node that is adjacent. + if src_node.data == glob_data: + conflicting_inputs.add(in_edge.src) + + # If there are no conflicting inputs, then we are point wise. + # This is an implementation detail that make life simpler. + if len(conflicting_inputs) == 0: + return True + + # If we can assume pointwise computations, then we do not have to do + # anything. + if self.assume_pointwise: + return True + + # Currently the only test that we do is, if we have a view, then we + # are not point wise. + # TODO(phimuell): Improve/implement this. + return any(gtx_transformations.utils.is_view(node, sdfg) for node in conflicting_inputs) + + def apply( + self, + graph: dace.SDFGState, + sdfg: dace.SDFG, + ) -> None: + # Removal + # Propagation ofthe shift. + map_exit: dace_nodes.MapExit = self.map_exit + tmp_ac: dace_nodes.AccessNode = self.tmp_ac + tmp_desc: dace_data.Data = tmp_ac.desc(sdfg) + tmp_data = tmp_ac.data + glob_ac: dace_nodes.AccessNode = self.glob_ac + glob_data = glob_ac.data + + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + if tmp_out_subset is None: + tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) + assert glob_in_subset is not None + + # Recursively visit the nested SDFGs for mapping of strides from inner to outer array + gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) + + # We now remove the `tmp` node, and create a new connection between + # the global node and the map exit. + new_map_to_glob_edge = graph.add_edge( + map_exit, + map_to_tmp_edge.src_conn, + glob_ac, + tmp_to_glob_edge.dst_conn, + dace.Memlet( + data=glob_ac.data, + subset=copy.deepcopy(glob_in_subset), + ), + ) + graph.remove_edge(map_to_tmp_edge) + graph.remove_edge(tmp_to_glob_edge) + graph.remove_node(tmp_ac) + + # We can not unconditionally remove the data `tmp` refers to, because + # it could be that in a parallel branch the `tmp` is also defined. + try: + sdfg.remove_data(tmp_ac.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_ac.data}:"): + raise + + # Now we must modify the memlets inside the map scope, because + # they now write into `G` instead of `tmp`, which has a different + # offset. + # NOTE: Assumes that `tmp_out_subset` and `tmp_in_subset` are the same. + correcting_offset = glob_in_subset.offset_new(tmp_out_subset, negative=True) + mtree = graph.memlet_tree(new_map_to_glob_edge) + for tree in mtree.traverse_children(include_self=False): + curr_edge = tree.edge + curr_dst_subset = curr_edge.data.get_dst_subset(curr_edge, graph) + if curr_edge.data.data == tmp_data: + curr_edge.data.data = glob_data + if curr_dst_subset is not None: + curr_dst_subset.offset(correcting_offset, negative=False) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py new file mode 100644 index 0000000000..c037535124 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py @@ -0,0 +1,684 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Optional, TypeAlias + +import dace +from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + + +PropagatedStrideRecord: TypeAlias = tuple[str, dace_nodes.NestedSDFG] +"""Record of a stride that has been propagated into a NestedSDFG. + +The type combines the NestedSDFG into which the strides were already propagated +and the data within that NestedSDFG to which we have propagated the strides, +which is the connector name on the NestedSDFG. +We need the NestedSDFG because we have to know what was already processed, +however, we also need the inner array name because of aliasing, i.e. a data +descriptor on the outside could be mapped to multiple data descriptors +inside the NestedSDFG. +""" + + +def gt_change_transient_strides( + sdfg: dace.SDFG, + gpu: bool, +) -> dace.SDFG: + """Modifies the strides of transients. + + The function will analyse the access patterns and set the strides of + transients in the optimal way. + The function should run after all maps have been created. + + After the strides have been adjusted the function will also propagate + the strides into nested SDFG. This propagation will happen with + `ignore_symbol_mapping` set to `True`, see `gt_propagate_strides_of()` + for more. + + Args: + sdfg: The SDFG to process. + gpu: If the SDFG is supposed to run on the GPU. + + Note: + Currently the function will not scan the access pattern. Instead it will + either use FORTRAN order for GPU or C order (which is assumed to be the + default, so it is a no ops). + + Todo: + - Implement the estimation correctly. + """ + # TODO(phimeull): Implement this function correctly. + + # We assume that by default we have C order which is already correct, + # so in this case we have a no ops + if not gpu: + return sdfg + + for nsdfg in sdfg.all_sdfgs_recursive(): + _gt_change_transient_strides_non_recursive_impl(nsdfg) + + +def _gt_change_transient_strides_non_recursive_impl( + sdfg: dace.SDFG, +) -> None: + """Set optimal strides of all transients in the SDFG. + + The function will look for all top level transients, see `_gt_find_toplevel_data_accesses()` + and set their strides such that the access is optimal, see Note. The function + will also run `gt_propagate_strides_of()` to propagate the strides into nested SDFGs. + + This function should never be called directly but always through + `gt_change_transient_strides()`! + + Note: + Currently the function just reverses the strides of the data descriptor + it processes. Since DaCe generates `C` order by default this lead to + FORTRAN order, which is (for now) sufficient to optimize the memory + layout to GPU. + + Todo: + Make this function more intelligent to analyse the access pattern and then + figuring out the best order. + """ + # NOTE: Processing the transient here is enough. If we are inside a + # NestedSDFG then they were handled before on the level above us. + top_level_transients_and_their_accesses = _gt_find_toplevel_data_accesses( + sdfg=sdfg, + only_transients=True, + only_arrays=True, + ) + for top_level_transient, accesses in top_level_transients_and_their_accesses.items(): + desc: dace_data.Array = sdfg.arrays[top_level_transient] + + # Setting the strides only make sense if we have more than one dimensions + ndim = len(desc.shape) + if ndim <= 1: + continue + + # We assume that everything is in C order initially, to get FORTRAN order + # we simply have to reverse the order. + # TODO(phimuell): Improve this. + new_stride_order = list(range(ndim)) + desc.set_strides_from_layout(*new_stride_order) + + # Now we have to propagate the changed strides. Because we already have + # collected all the AccessNodes we are using the + # `gt_propagate_strides_from_access_node()` function, but we have to + # create `processed_nsdfg` set already outside here. + # Furthermore, the same comment as above applies here, we do not have to + # propagate the non-transients, because they either come from outside, + # or they were already handled in the levels above, where they were + # defined and then propagated down. + # TODO(phimuell): Updated the functions such that only one scan is needed. + processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + for state, access_node in accesses: + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=access_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=True, + ) + + +def gt_propagate_strides_of( + sdfg: dace.SDFG, + data_name: str, + ignore_symbol_mapping: bool = True, +) -> None: + """Propagates the strides of `data_name` within the whole SDFG. + + This function will call `gt_propagate_strides_from_access_node()` for every + AccessNode that refers to `data_name`. It will also make sure that a descriptor + inside a NestedSDFG is only processed once. + + Args: + sdfg: The SDFG on which we operate. + data_name: Name of the data descriptor that should be handled. + ignore_symbol_mapping: If `False` (default is `True`) try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + """ + + # Defining it here ensures that we will not enter an NestedSDFG multiple times. + processed_nsdfgs: set[PropagatedStrideRecord] = set() + + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode.data != data_name: + continue + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=dnode, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_propagate_strides_from_access_node( + sdfg: dace.SDFG, + state: dace.SDFGState, + outer_node: dace_nodes.AccessNode, + ignore_symbol_mapping: bool = True, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the stride of `outer_node` to any adjacent NestedSDFG. + + The function will propagate the strides of the data descriptor `outer_node` + refers to along all adjacent edges of `outer_node`. If one of these edges + leads to a NestedSDFG then the function will modify the strides of data + descriptor within to match the strides on the outside. The function will then + recursively process NestedSDFG. + + It is important that this function will only handle the NestedSDFGs that are + reachable from `outer_node`. To fully propagate the strides the + `gt_propagate_strides_of()` should be used. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False` (default is `True`), try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + """ + assert isinstance(state, dace.SDFGState) + + if processed_nsdfgs is None: + # For preventing the case that nested SDFGs are handled multiple time. + processed_nsdfgs = set() + + for in_edge in state.in_edges(outer_node): + gt_map_strides_to_src_nested_sdfg( + sdfg=sdfg, + state=state, + edge=in_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + for out_edge in state.out_edges(outer_node): + gt_map_strides_to_dst_nested_sdfg( + sdfg=sdfg, + state=state, + edge=out_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_map_strides_to_dst_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = True, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` in the dataflow direction. + + In this context "along the dataflow direction" means that `edge` is an outgoing + edge of `outer_node` and the strides are propagated into all NestedSDFGs that + are downstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when + you know what your are doing. + """ + assert edge.src is outer_node + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=True, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def gt_map_strides_to_src_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = False, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` in the opposite direction of the dataflow + + In this context "in the opposite direction of the dataflow" means that `edge` + is an incoming edge of `outer_node` and the strides are propagated into all + NestedSDFGs that are upstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when + you know what your are doing. + """ + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=False, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def _gt_map_strides_to_nested_sdfg_src_dst( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]], + propagate_along_dataflow: bool, + ignore_symbol_mapping: bool = False, +) -> None: + """Propagates the stride of `outer_node` along `edge`. + + The function will follow `edge`, the direction depends on the value of + `propagate_along_dataflow` and propagate the strides of `outer_node` + into every NestedSDFG that is reachable by following `edge`. + + When the function encounters a NestedSDFG it will determine what data + the `outer_node` is mapped to on the inside of the NestedSDFG. + It will then replace the stride of the inner descriptor with the ones + of the outside. Afterwards it will recursively propagate the strides + inside the NestedSDFG. + During this propagation the function will follow any edges. + + If the function reaches a NestedSDFG that is listed inside `processed_nsdfgs` + then it will be skipped. NestedSDFGs that have been processed will be added + to the `processed_nsdfgs`. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + propagate_along_dataflow: Determine the direction of propagation. If `True` the + function follows the dataflow. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. + + Note: + A user should not use this function directly, instead `gt_propagate_strides_of()`, + `gt_map_strides_to_src_nested_sdfg()` (`propagate_along_dataflow == `False`) + or `gt_map_strides_to_dst_nested_sdfg()` (`propagate_along_dataflow == `True`) + should be used. + + Todo: + Try using `MemletTree` for the propagation. + """ + # If `processed_nsdfg` is `None` then this is the first call. We will now + # allocate the `set` and pass it as argument to all recursive calls, this + # ensures that the `set` is the same everywhere. + if processed_nsdfgs is None: + processed_nsdfgs = set() + + if propagate_along_dataflow: + # Propagate along the dataflow or forward, so we are interested at the `dst` of the edge. + ScopeNode = dace_nodes.MapEntry + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.dst + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.dst_conn + + def get_subset( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.get_src_subset(edge, state) + + def next_edges_by_connector( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + if edge.dst_conn is None or not edge.dst_conn.startswith("IN_"): + return [] + return list(state.out_edges_by_connector(edge.dst, "OUT_" + edge.dst_conn[3:])) + + else: + # Propagate against the dataflow or backward, so we are interested at the `src` of the edge. + ScopeNode = dace_nodes.MapExit + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.src + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.src_conn + + def get_subset( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.get_dst_subset(edge, state) + + def next_edges_by_connector( + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + return list(state.in_edges_by_connector(edge.src, "IN_" + edge.src_conn[4:])) + + if isinstance(get_node(edge), ScopeNode): + for next_edge in next_edges_by_connector(state, edge): + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=next_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=propagate_along_dataflow, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + elif isinstance(get_node(edge), dace.nodes.NestedSDFG): + nsdfg_node = get_node(edge) + inner_data = get_inner_data(edge) + process_record = (inner_data, nsdfg_node) + + if process_record in processed_nsdfgs: + # We already handled this NestedSDFG and the inner data. + return + + # Mark this nested SDFG as processed. + processed_nsdfgs.add(process_record) + + # Now set the stride of the data descriptor inside the nested SDFG to + # the ones it has outside. + _gt_map_strides_into_nested_sdfg( + sdfg=sdfg, + nsdfg_node=nsdfg_node, + inner_data=inner_data, + outer_subset=get_subset(state, edge), + outer_desc=outer_node.desc(sdfg), + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + # Since the function call above is not recursive we have now to propagate + # the change into the NestedSDFGs. Using `_gt_find_toplevel_data_accesses()` + # is a bit overkill, but allows for a more uniform processing. + # TODO(phimuell): Instead of scanning every level for every data we modify + # we should scan the whole SDFG once and then reuse this information. + accesses_in_nested_sdfg = _gt_find_toplevel_data_accesses( + sdfg=nsdfg_node.sdfg, + only_transients=False, # Because on the nested levels they are globals. + only_arrays=True, + ) + for nested_state, nested_access in accesses_in_nested_sdfg.get(inner_data, list()): + # We have to use `gt_propagate_strides_from_access_node()` here because we + # have to handle its entirety. We could wait until the other branch processes + # the nested SDFG, but this might not work, so let's do it fully now. + gt_propagate_strides_from_access_node( + sdfg=nsdfg_node.sdfg, + state=nested_state, + outer_node=nested_access, + processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, + ) + + +def _gt_map_strides_into_nested_sdfg( + sdfg: dace.SDFG, + nsdfg_node: dace.nodes.NestedSDFG, + inner_data: str, + outer_subset: dace.subsets.Subset, + outer_desc: dace_data.Data, + ignore_symbol_mapping: bool, +) -> None: + """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. + + `inner_data` is the name of a data descriptor inside the NestedSDFG. + The function will then modify the strides of `inner_data`, assuming this + is an array, to match the ones of `outer_desc`. + + Args: + sdfg: The SDFG containing the NestedSDFG. + nsdfg_node: The node in the parent SDFG that contains the NestedSDFG. + inner_data: The name of the data descriptor that should be processed + inside the NestedSDFG (by construction also a connector name). + outer_subset: The subset that describes what part of the outer data is + mapped into the NestedSDFG. + outer_desc: The data descriptor of the data on the outside. + ignore_symbol_mapping: If possible the function will perform the renaming + through the `symbol_mapping` of the nested SDFG. If `True` then + the function will always perform the renaming. + Note that setting this value to `False` might have negative side effects. + + Todo: + - Handle explicit dimensions of size 1. + - What should we do if the stride symbol is used somewhere else, creating an + alias is probably not the right thing? + - Handle the case if the outer stride symbol is already used in another + context inside the Neste SDFG. + """ + # We need to compute the new strides. In the following we assume that the + # relative order of the dimensions does not change, but we support the case + # where some dimensions of the outer data descriptor are not present on the + # inside. For example this happens for the Memlet `a[__i0, 0:__a_size1]`. We + # detect this case by checking if the Memlet subset in that dimension has size 1. + # TODO(phimuell): Handle the case were some additional size 1 dimensions are added. + inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data] + inner_shape = inner_desc.shape + inner_strides_init = inner_desc.strides + + outer_shape = outer_desc.shape + outer_strides = outer_desc.strides + outer_inflow = outer_subset.size() + + if isinstance(inner_desc, dace_data.Scalar): + # A scalar does not have a stride that must be propagated. + return + + # Now determine the new stride that is needed on the inside. + new_strides: list = [] + if len(outer_shape) == len(inner_shape): + # The inner and the outer descriptor have the same dimensionality. + # We now have to decide if we should take the stride from the outside, + # which happens for example in case of `A[0:N, 0:M] -> B[N, M]`, or if we + # must take 1, which happens if we do `A[0:N, i] -> B[N, 1]`, we detect that + # based on the volume that flows in. + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + new_strides.append(1 if dim_oinflow == 1 else dim_ostride) + + elif len(inner_shape) < len(outer_shape): + # There are less dimensions on the inside than on the outside. This means + # that some were sliced away. We detect this case by checking if the Memlet + # subset in that dimension has size 1. + # NOTE: That this is not always correct as it might be possible that there + # are some explicit size 1 dimensions at several places. + new_strides = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + if dim_oinflow == 1: + pass + else: + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + else: + # The case that we have more dimensions on the inside than on the outside. + # This is currently not supported. + raise NotImplementedError("NestedSDFGs can not be used to increase the rank.") + + if len(new_strides) != len(inner_shape): + raise ValueError("Failed to compute the inner strides.") + + # Now we actually replace the strides, there are two ways of doing it. + # The first is to create an alias in the `symbol_mapping`, however, + # this is only possible if the current strides are singular symbols, + # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` + # or literal values. Furthermore, this would change the meaning of the + # old stride symbol in any context and not only in the one of the stride + # of a single and isolated data descriptor. + # The second way would be to replace `strides` attribute of the + # inner data descriptor. In case the new stride consists of expressions + # such as `value1 - value2` we have to make them available inside the + # NestedSDFG. However, it could be that the strides is used somewhere else. + # We will do the following, if `ignore_symbol_mapping` is `False` and + # the strides of the inner descriptors are symbols, we will use the + # symbol mapping. Otherwise, we will replace the `strides` attribute + # of the inner descriptor, in addition we will install a remapping, + # for those values that were a symbol. + if (not ignore_symbol_mapping) and all( + isinstance(inner_stride, dace.symbol) for inner_stride in inner_strides_init + ): + # Use the symbol + for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): + nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride + else: + # We have to replace the `strides` attribute of the inner descriptor. + inner_desc.set_shape(inner_desc.shape, new_strides) + + # Now find the free symbols that the new strides need. + # Note that usually `free_symbols` returns `set[str]`, but here, because + # we fall back on SymPy, we get back symbols. We will keep them, because + # then we can use them to extract the type form them, which we need later. + new_strides_symbols: list[dace.symbol] = [] + for new_stride_dim in new_strides: + if dace.symbolic.issymbolic(new_stride_dim): + new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) + else: + # It is not already a symbol, so we turn it into a symbol. + # However, we only add it, if it is also a symbol, for example `1`. + # should not be added. + new_stride_symbol = dace.symbolic.pystr_to_symbolic(new_stride_dim) + if new_stride_symbol.is_symbol: + new_strides_symbols.append(new_stride_symbol) + + # Now we determine the set of symbols that should be mapped inside the NestedSDFG. + # We will exclude all that are already inside the `symbol_mapping` (we do not + # check if they map to the same value, we just hope it). Furthermore, + # we will exclude all symbols that are listed in the `symbols` property + # of the SDFG that is nested, and hope that it has the same meaning. + # TODO(phimuell): Add better checks to avoid overwriting. + missing_symbol_mappings: set[dace.symbol] = { + sym + for sym in new_strides_symbols + if not (sym.name in nsdfg_node.sdfg.symbols or sym.name in nsdfg_node.symbol_mapping) + } + + # Now propagate the symbols from the parent SDFG to the NestedSDFG. + for sym in missing_symbol_mappings: + assert sym.name in sdfg.symbols, f"Expected that '{sym}' is defined in the parent SDFG." + nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) + nsdfg_node.symbol_mapping[sym.name] = sym + + +def _gt_find_toplevel_data_accesses( + sdfg: dace.SDFG, + only_transients: bool, + only_arrays: bool = False, +) -> dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]]: + """Find all data that is accessed on the top level. + + The function will scan the SDFG, ignoring nested one, and return the + name of all data that only have AccessNodes on the top level. In data + is found that has an AccessNode on both the top level and in a nested + scope and error is generated. + By default the function will return transient and non transient data, + however, if `only_transients` is `True` then only transient data will + be returned. + Furthermore, the function will ignore an access in the following cases: + - The AccessNode refers to data that is a register. + - The AccessNode refers to a View. + + Args: + sdfg: The SDFG to process. + only_transients: If `True` only include transients. + only_arrays: If `True`, defaults to `False`, only arrays are returned. + + Returns: + A `dict` that maps the name of a data container, to a list of tuples + containing the state where the AccessNode was found and the AccessNode. + """ + # List of data that is accessed on the top level and all its access node. + top_level_data: dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]] = dict() + + # List of all data that were found not on top level. + not_top_level_data: set[str] = set() + + for state in sdfg.states(): + assert isinstance(state, dace.SDFGState) + scope_dict = state.scope_dict() + for dnode in state.data_nodes(): + data: str = dnode.data + if scope_dict[dnode] is not None: + # The node was not found on the top level. So we can ignore it. + # We also check if it was ever found on the top level, this should + # not happen, as everything should go through Maps. But some strange + # DaCe transformation might do it. + assert ( + data not in top_level_data + ), f"Found {data} on the top level and inside a scope." + not_top_level_data.add(data) + continue + + elif data in top_level_data: + # The data is already known to be in top level data, so we must add the + # AccessNode to the list of known nodes. But nothing else. + top_level_data[data].append((state, dnode)) + continue + + elif gtx_transformations.utils.is_view(dnode, sdfg): + # The AccessNode refers to a View so we ignore it anyway. + continue + + # We have found a new data node that is on the top node and is unknown. + assert ( + data not in not_top_level_data + ), f"Found {data} on the top level and inside a scope." + desc: dace_data.Data = dnode.desc(sdfg) + + # Check if we only accept arrays + if only_arrays and not isinstance(desc, dace_data.Array): + continue + + # For now we ignore registers. + # We do this because register are allocated on the stack, so the compiler + # has all information and should organize the best thing possible. + # TODO(phimuell): verify this. + elif desc.storage is dace.StorageType.Register: + continue + + # We are only interested in transients + if only_transients and (not desc.transient): + continue + + # Now create the new entry in the list and record the AccessNode. + top_level_data[data] = [(state, dnode)] + return top_level_data diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py new file mode 100644 index 0000000000..7afb93d5c2 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -0,0 +1,320 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Common functionality for the transformations/optimization pipeline.""" + +from typing import Any, Container, Optional, Sequence, Union + +import dace +from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes +from dace.transformation.passes import analysis as dace_analysis + +from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils + + +def gt_make_transients_persistent( + sdfg: dace.SDFG, + device: dace.DeviceType, +) -> dict[int, set[str]]: + """ + Changes the lifetime of certain transients to `Persistent`. + + A persistent lifetime means that the transient is allocated only the very first + time the SDFG is executed and only deallocated if the underlying `CompiledSDFG` + object goes out of scope. The main advantage is, that memory must not be + allocated every time the SDFG is run. The downside is that the SDFG can not be + called by different threads. + + Args: + sdfg: The SDFG to process. + device: The device type. + + Returns: + A `dict` mapping SDFG IDs to a set of transient arrays that + were made persistent. + + Note: + This function is based on a similar function in DaCe. However, the DaCe + function does, for unknown reasons, also reset the `wcr_nonatomic` property, + but only for GPU. + """ + result: dict[int, set[str]] = {} + for nsdfg in sdfg.all_sdfgs_recursive(): + fsyms: set[str] = nsdfg.free_symbols + modify_lifetime: set[str] = set() + not_modify_lifetime: set[str] = set() + + for state in nsdfg.states(): + scope_dict = state.scope_dict() + for dnode in state.data_nodes(): + if dnode.data in not_modify_lifetime: + continue + + if dnode.data in nsdfg.constants_prop: + not_modify_lifetime.add(dnode.data) + continue + + desc = dnode.desc(nsdfg) + if not desc.transient or type(desc) not in {dace.data.Array, dace.data.Scalar}: + not_modify_lifetime.add(dnode.data) + continue + if desc.storage == dace.StorageType.Register: + not_modify_lifetime.add(dnode.data) + continue + + if desc.lifetime == dace.AllocationLifetime.External: + not_modify_lifetime.add(dnode.data) + continue + + # If the data is referenced inside a scope, such as a map, it might be possible + # that it is only used inside that scope. If we would make it persistent, then + # it would essentially be allocated outside and be shared among the different + # map iterations. So we can not make it persistent. + # The downside is, that we might have to perform dynamic allocation. + if scope_dict[dnode] is not None: + not_modify_lifetime.add(dnode.data) + continue + + try: + # The symbols describing the total size must be a subset of the + # free symbols of the SDFG (symbols passed as argument). + # NOTE: This ignores the renaming of symbols through the + # `symbol_mapping` property of nested SDFGs. + if not set(map(str, desc.total_size.free_symbols)).issubset(fsyms): + not_modify_lifetime.add(dnode.data) + continue + except AttributeError: # total_size is an integer / has no free symbols + pass + + # Make it persistent. + modify_lifetime.add(dnode.data) + + # Now setting the lifetime. + result[nsdfg.cfg_id] = modify_lifetime - not_modify_lifetime + for aname in result[nsdfg.cfg_id]: + nsdfg.arrays[aname].lifetime = dace.AllocationLifetime.Persistent + + return result + + +def gt_find_constant_arguments( + call_args: dict[str, Any], + include: Optional[Container[str]] = None, +) -> dict[str, Any]: + """Scans the calling arguments for compile time constants. + + The output of this function can be used as input to + `gt_substitute_compiletime_symbols()`, which then removes these symbols. + + By specifying `include` it is possible to force the function to include + additional arguments, that would not be matched otherwise. Importantly, + their value is not checked. + + Args: + call_args: The full list of arguments that will be passed to the SDFG. + include: List of arguments that should be included. + """ + if include is None: + include = set() + ret_value: dict[str, Any] = {} + + for name, value in call_args.items(): + if name in include or (gtx_dace_utils.is_field_symbol(name) and value == 1): + ret_value[name] = value + + return ret_value + + +def is_accessed_downstream( + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + reachable_states: Optional[dict[dace.SDFGState, set[dace.SDFGState]]], + nodes_to_ignore: Optional[set[dace_nodes.AccessNode]] = None, + states_to_ignore: Optional[set[dace.SDFGState]] = None, +) -> bool: + """Scans for accesses to the data container `data_to_look`. + + The function will go through states that are reachable from `start_state` + (included) and test if there is an AccessNode that _reads_ from `data_to_look`. + It will return `True` the first time it finds such a node. + + The function will ignore all nodes that are listed in `nodes_to_ignore`. + Furthermore, states listed in `states_to_ignore` will be ignored, i.e. + handled as they did not exist. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + reachable_states: Maps an `SDFGState` to all `SDFGState`s that can be reached. + If `None` it will be computed, but this is not recommended. + nodes_to_ignore: Ignore these nodes. + states_to_ignore: Ignore these states. + + Note: + Currently, the function will not only ignore the states that are listed in + `states_to_ignore`, but all that are reachable from any of these states. + Thus care must be taken when this option is used. Furthermore, this behaviour + is not intended and will change in further versions. + `reachable_states` can be computed by using the `StateReachability` analysis + pass from DaCe. + + Todo: + - Modify the function such that it is no longer necessary to pass the + `reachable_states` argument. + - Fix the behaviour for `states_to_ignore`. + """ + # After DaCe 1 switched to a hierarchical version of the state machine. Thus + # it is no longer possible in a simple way to traverse the SDFG. As a temporary + # solution we use the `StateReachability` pass. However, this has some issues, + # see the note about `states_to_ignore`. + if reachable_states is None: + state_reachability_pass = dace_analysis.StateReachability() + reachable_states = state_reachability_pass.apply_pass(sdfg, None)[sdfg.cfg_id] + else: + # Ensures that the externally generated result was passed properly. + assert all( + isinstance(state, dace.SDFGState) and state.sdfg is sdfg for state in reachable_states + ) + + ign_dnodes: set[dace_nodes.AccessNode] = nodes_to_ignore or set() + ign_states: set[dace.SDFGState] = states_to_ignore or set() + + # NOTE: We have to include `start_state`, however, we must also consider the + # data in `reachable_states` as immutable, so we have to do it this way. + # TODO(phimuell): Go back to a trivial scan of the graph. + if start_state not in reachable_states: + # This can mean different things, either there was only one state to begin + # with or `start_state` is the last one. In this case the `states_to_scan` + # set consists only of the `start_state` because we have to process it. + states_to_scan = {start_state} + else: + # Ensure that `start_state` is scanned. + states_to_scan = reachable_states[start_state].union([start_state]) + + # In the first version we explored the state machine and if we encountered a + # state in the ignore set we simply ignored it. This is no longer possible. + # Instead we will remove all states from the `states_to_scan` that are reachable + # from an ignored state. However, this is not the same as if we would explore + # the state machine (as we did before). Consider the following case: + # + # (STATE_1) ------------> (STATE_2) + # | /\ + # V | + # (STATE_3) ------------------+ + # + # Assume that `STATE_1` is the starting state and `STATE_3` is ignored. + # If we would explore the state machine, we would still scan `STATE_2`. + # However, because `STATE_2` is also reachable from `STATE_3` it will now be + # ignored. In most cases this should be fine, but we have to handle it. + states_to_scan.difference_update(ign_states) + for ign_state in ign_states: + states_to_scan.difference_update(reachable_states.get(ign_state, set())) + assert start_state in states_to_scan + + for downstream_state in states_to_scan: + if downstream_state in ign_states: + continue + for dnode in downstream_state.data_nodes(): + if dnode.data != data_to_look: + continue + if dnode in ign_dnodes: + continue + if downstream_state.out_degree(dnode) != 0: + return True # There is a read operation + return False + + +def is_reachable( + start: Union[dace_nodes.Node, Sequence[dace_nodes.Node]], + target: Union[dace_nodes.Node, Sequence[dace_nodes.Node]], + state: dace.SDFGState, +) -> bool: + """Explores the graph from `start` and checks if `target` is reachable. + + The exploration of the graph is done in a way that ignores the connector names. + It is possible to pass multiple start nodes and targets. In case of multiple target nodes, the function returns True if any of them is reachable. + + Args: + start: The node from where to start. + target: The node to look for. + state: The SDFG state on which we operate. + """ + to_visit: list[dace_nodes.Node] = [start] if isinstance(start, dace_nodes.Node) else list(start) + targets: set[dace_nodes.Node] = {target} if isinstance(target, dace_nodes.Node) else set(target) + seen: set[dace_nodes.Node] = set() + + while to_visit: + node = to_visit.pop() + if node in targets: + return True + seen.add(node) + to_visit.extend(oedge.dst for oedge in state.out_edges(node) if oedge.dst not in seen) + + return False + + +def is_view( + node: Union[dace_nodes.AccessNode, dace_data.Data], + sdfg: Optional[dace.SDFG] = None, +) -> bool: + """Tests if `node` points to a view or not.""" + if isinstance(node, dace_nodes.AccessNode): + assert sdfg is not None + node_desc = node.desc(sdfg) + else: + assert isinstance(node, dace_data.Data) + node_desc = node + return isinstance(node_desc, dace_data.View) + + +def track_view( + view: dace_nodes.AccessNode, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> dace_nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ + + # Test if it is a view at all, if not return the passed node as source. + if not is_view(view, sdfg): + return view + + # First determine if the view is used for reading or writing. + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "views": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src # noqa: E731 + elif curr_edge.src_conn == "views": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst # noqa: E731 + else: + raise RuntimeError(f"Failed to determine the direction of the view '{view}' | {curr_edge}.") + + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/src/gt4py/next/program_processors/runners/dace/utils.py b/src/gt4py/next/program_processors/runners/dace/utils.py new file mode 100644 index 0000000000..5fdace73a9 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/utils.py @@ -0,0 +1,128 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import re +from typing import Final, Literal, Mapping, Union + +import dace + +from gt4py.next import common as gtx_common +from gt4py.next.type_system import type_specifications as ts + + +# arrays for connectivity tables use the following prefix +CONNECTIVITY_INDENTIFIER_PREFIX: Final[str] = "connectivity_" +CONNECTIVITY_INDENTIFIER_RE: Final[re.Pattern] = re.compile(r"^connectivity_(.+)$") + + +# regex to match the symbols for field shape and strides +FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"^__.+_((\d+_range_[01])|((size|stride)_\d+))$") + + +def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: + """Converts GT4Py scalar type to corresponding DaCe type.""" + + match type_.kind: + case ts.ScalarKind.BOOL: + return dace.bool_ + case ts.ScalarKind(): + return getattr(dace, type_.kind.name.lower()) + case _: + raise ValueError(f"Scalar type '{type_}' not supported.") + + +def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType: + """Get GT4Py scalar representation of a DaCe type.""" + type_name = str(dtype.as_numpy_dtype()) + try: + kind = getattr(ts.ScalarKind, type_name.upper()) + except AttributeError as ex: + raise ValueError(f"Data type {type_name} not supported.") from ex + return ts.ScalarType(kind) + + +def connectivity_identifier(name: str) -> str: + return f"{CONNECTIVITY_INDENTIFIER_PREFIX}{name}" + + +def is_connectivity_identifier( + name: str, offset_provider_type: gtx_common.OffsetProviderType +) -> bool: + m = CONNECTIVITY_INDENTIFIER_RE.match(name) + if m is None: + return False + return m[1] in offset_provider_type + + +def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str: + return f"__{field_name}_{sym}_{axis}" + + +def field_size_symbol_name(field_name: str, axis: int) -> str: + return field_symbol_name(field_name, axis, "size") + + +def field_stride_symbol_name(field_name: str, axis: int) -> str: + return field_symbol_name(field_name, axis, "stride") + + +def range_start_symbol(field_name: str, axis: int) -> str: + """Format name of start symbol for domain range, as expected by GTIR.""" + return f"__{field_name}_{axis}_range_0" + + +def range_stop_symbol(field_name: str, axis: int) -> str: + """Format name of stop symbol for domain range, as expected by GTIR.""" + return f"__{field_name}_{axis}_range_1" + + +def is_field_symbol(name: str) -> bool: + return FIELD_SYMBOL_RE.match(name) is not None + + +def filter_connectivity_types( + offset_provider_type: gtx_common.OffsetProviderType, +) -> dict[str, gtx_common.NeighborConnectivityType]: + """ + Filter offset provider types of type `NeighborConnectivityType`. + + In other words, filter out the cartesian offset providers. + """ + return { + offset: conn + for offset, conn in offset_provider_type.items() + if isinstance(conn, gtx_common.NeighborConnectivityType) + } + + +def safe_replace_symbolic( + val: dace.symbolic.SymbolicType, + symbol_mapping: Mapping[ + Union[dace.symbolic.SymbolicType, str], Union[dace.symbolic.SymbolicType, str] + ], +) -> dace.symbolic.SymbolicType: + """ + Replace free symbols in a dace symbolic expression, using `safe_replace()` + in order to avoid clashes in case the new symbol value is also a free symbol + in the original exoression. + + Args: + val: The symbolic expression where to apply the replacement. + symbol_mapping: The mapping table for symbol replacement. + + Returns: + A new symbolic expression as result of symbol replacement. + """ + # The list `x` is needed because `subs()` returns a new object and can not handle + # replacement dicts of the form `{'x': 'y', 'y': 'x'}`. + # The utility `safe_replace()` will call `subs()` twice in case of such dicts. + x = [val] + dace.symbolic.safe_replace(symbol_mapping, lambda m, xx=x: xx.append(xx[-1].subs(m))) + return x[-1] diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py new file mode 100644 index 0000000000..4d825c0c9b --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/__init__.py @@ -0,0 +1,20 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Implements the On-The-Fly (OTF) compilation workflow for the GTIR-DaCe backend. + +The main module is `backend`, that exports the backends for CPU and GPU devices. +The `backend` module uses `factory` to define a workflow that implements the +`OTFCompileWorkflow` recipe. The different stages are implemeted in separate modules: +- `translation` for lowering of GTIR to SDFG and applying SDFG transformations +- `compilation` for compiling the SDFG into a program +- `decoration` to parse the program arguments and pass them to the program call + +The GTIR-DaCe backend factory extends `CachedBackendFactory`, thus it provides +caching of the GTIR program. +""" diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py new file mode 100644 index 0000000000..fb93e3df79 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -0,0 +1,61 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import factory + +import gt4py.next.allocators as next_allocators +from gt4py._core import definitions as core_defs +from gt4py.next import backend +from gt4py.next.otf import stages, workflow +from gt4py.next.program_processors.runners.dace.workflow.factory import DaCeWorkflowFactory + + +class DaCeBackendFactory(factory.Factory): + class Meta: + model = backend.Backend + + class Params: + name_device = "cpu" + name_cached = "" + name_postfix = "" + gpu = factory.Trait( + allocator=next_allocators.StandardGPUFieldBufferAllocator(), + device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, + name_device="gpu", + ) + cached = factory.Trait( + executor=factory.LazyAttribute( + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + ), + name_cached="_cached", + ) + device_type = core_defs.DeviceType.CPU + hash_function = stages.compilation_hash + otf_workflow = factory.SubFactory( + DaCeWorkflowFactory, + device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), + ) + auto_optimize = factory.Trait(name_postfix="_opt") + + name = factory.LazyAttribute( + lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}" + ) + + executor = factory.LazyAttribute(lambda o: o.otf_workflow) + allocator = next_allocators.StandardCPUFieldBufferAllocator() + transforms = backend.DEFAULT_TRANSFORMS + + +run_dace_cpu = DaCeBackendFactory(cached=True, auto_optimize=True) +run_dace_cpu_noopt = DaCeBackendFactory(cached=True, auto_optimize=False) + +run_dace_gpu = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=True) +run_dace_gpu_noopt = DaCeBackendFactory(gpu=True, cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py new file mode 100644 index 0000000000..c0d1c74c7a --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/compilation.py @@ -0,0 +1,94 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import dataclasses +from typing import Any + +import dace +import factory + +from gt4py._core import definitions as core_defs +from gt4py.next import config +from gt4py.next.otf import languages, stages, step_types, workflow +from gt4py.next.otf.compilation import cache + + +class CompiledDaceProgram(stages.ExtendedCompiledProgram): + sdfg_program: dace.CompiledSDFG + + # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; + # scalar arguments that are not used in the SDFG will not be present. + sdfg_arglist: list[tuple[str, dace.dtypes.Data]] + + def __init__(self, program: dace.CompiledSDFG, implicit_domain: bool): + self.sdfg_program = program + self.implicit_domain = implicit_domain + # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument + # name to its data type, in the same order as arguments appear in the program ABI. + # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. + self.sdfg_arglist = [ + (arg_name, arg_type) for arg_name, arg_type in program.sdfg.arglist().items() + ] + + def __call__(self, *args: Any, **kwargs: Any) -> None: + result = self.sdfg_program(*args, **kwargs) + assert result is None + + def fast_call(self) -> None: + result = self.sdfg_program.fast_call(*self.sdfg_program._lastargs) + assert result is None + + +@dataclasses.dataclass(frozen=True) +class DaCeCompiler( + workflow.ChainableWorkflowMixin[ + stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + CompiledDaceProgram, + ], + workflow.ReplaceEnabledWorkflowMixin[ + stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + CompiledDaceProgram, + ], + step_types.CompilationStep[languages.SDFG, languages.LanguageSettings, languages.Python], +): + """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" + + cache_lifetime: config.BuildCacheLifetime + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG + + def __call__( + self, + inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], + ) -> CompiledDaceProgram: + sdfg = dace.SDFG.from_json(inp.program_source.source_code) + + src_dir = cache.get_cache_folder(inp, self.cache_lifetime) + sdfg.build_folder = src_dir / ".dacecache" + + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "build_type", value=self.cmake_build_type.value) + if self.device_type == core_defs.DeviceType.CPU: + compiler_args = dace.config.Config.get("compiler", "cpu", "args") + # disable finite-math-only in order to support isfinite/isinf/isnan builtins + if "-ffast-math" in compiler_args: + compiler_args += " -fno-finite-math-only" + if "-ffinite-math-only" in compiler_args: + compiler_args.replace("-ffinite-math-only", "") + + dace.config.Config.set("compiler", "cpu", "args", value=compiler_args) + sdfg_program = sdfg.compile(validate=False) + + return CompiledDaceProgram(sdfg_program, inp.program_source.implicit_domain) + + +class DaCeCompilationStepFactory(factory.Factory): + class Meta: + model = DaCeCompiler diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py new file mode 100644 index 0000000000..9648ac9e04 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -0,0 +1,95 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import ctypes +from typing import Any, Sequence + +import dace +from dace.codegen.compiled_sdfg import _array_interface_ptr as get_array_interface_ptr + +from gt4py._core import definitions as core_defs +from gt4py.next import common, utils as gtx_utils +from gt4py.next.otf import arguments, stages +from gt4py.next.program_processors.runners.dace import ( + sdfg_callable, + utils as gtx_dace_utils, + workflow as dace_worflow, +) + + +def convert_args( + inp: dace_worflow.compilation.CompiledDaceProgram, + device: core_defs.DeviceType = core_defs.DeviceType.CPU, + use_field_canonical_representation: bool = False, +) -> stages.CompiledProgram: + sdfg_program = inp.sdfg_program + sdfg = sdfg_program.sdfg + on_gpu = True if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] else False + + def decorated_program( + *args: Any, + offset_provider: common.OffsetProvider, + out: Any = None, + ) -> None: + if out is not None: + args = (*args, out) + flat_args: Sequence[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) + if inp.implicit_domain: + # generate implicit domain size arguments only if necessary + size_args = arguments.iter_size_args(args) + flat_size_args: Sequence[int] = gtx_utils.flatten_nested_tuple(tuple(size_args)) + flat_args = (*flat_args, *flat_size_args) + + if sdfg_program._lastargs: + kwargs = dict(zip(sdfg.arg_names, flat_args, strict=True)) + kwargs.update(sdfg_callable.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) + + use_fast_call = True + last_call_args = sdfg_program._lastargs[0] + # The scalar arguments should be overridden with the new value; for field arguments, + # the data pointer should remain the same otherwise fast_call cannot be used and + # the arguments list has to be reconstructed. + for i, (arg_name, arg_type) in enumerate(inp.sdfg_arglist): + if isinstance(arg_type, dace.data.Array): + assert arg_name in kwargs, f"argument '{arg_name}' not found." + data_ptr = get_array_interface_ptr(kwargs[arg_name], arg_type.storage) + assert isinstance(last_call_args[i], ctypes.c_void_p) + if last_call_args[i].value != data_ptr: + use_fast_call = False + break + else: + assert isinstance(arg_type, dace.data.Scalar) + assert isinstance(last_call_args[i], ctypes._SimpleCData) + if arg_name in kwargs: + # override the scalar value used in previous program call + actype = arg_type.dtype.as_ctypes() + last_call_args[i] = actype(kwargs[arg_name]) + else: + # shape and strides of arrays are supposed not to change, and can therefore be omitted + assert gtx_dace_utils.is_field_symbol( + arg_name + ), f"argument '{arg_name}' not found." + + if use_fast_call: + return inp.fast_call() + + sdfg_args = sdfg_callable.get_sdfg_args( + sdfg, + offset_provider, + *flat_args, + check_args=False, + on_gpu=on_gpu, + ) + + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "allow_view_arguments", value=True) + return inp(**sdfg_args) + + return decorated_program diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py new file mode 100644 index 0000000000..02a089c88c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -0,0 +1,58 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import functools + +import factory + +from gt4py._core import definitions as core_defs +from gt4py.next import config +from gt4py.next.otf import recipes, stages +from gt4py.next.program_processors.runners.dace.workflow import decoration as decoration_step +from gt4py.next.program_processors.runners.dace.workflow.compilation import ( + DaCeCompilationStepFactory, +) +from gt4py.next.program_processors.runners.dace.workflow.translation import ( + DaCeTranslationStepFactory, +) + + +def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: + return stages.CompilableSource(program_source=inp, binding_source=None) + + +class DaCeWorkflowFactory(factory.Factory): + class Meta: + model = recipes.OTFCompileWorkflow + + class Params: + device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + cmake_build_type: config.CMakeBuildType = factory.LazyFunction( + lambda: config.CMAKE_BUILD_TYPE + ) + auto_optimize: bool = False + + translation = factory.SubFactory( + DaCeTranslationStepFactory, + device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), + ) + bindings = _no_bindings + compilation = factory.SubFactory( + DaCeCompilationStepFactory, + cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), + cmake_build_type=factory.SelfAttribute("..cmake_build_type"), + ) + decoration = factory.LazyAttribute( + lambda o: functools.partial( + decoration_step.convert_args, + device=o.device_type, + ) + ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py similarity index 51% rename from src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py rename to src/gt4py/next/program_processors/runners/dace/workflow/translation.py index ffc33a9f25..e31e4ea741 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/translation.py @@ -9,21 +9,21 @@ from __future__ import annotations import dataclasses -import functools from typing import Optional import dace import factory from gt4py._core import definitions as core_defs -from gt4py.next import common, config -from gt4py.next.iterator import ir as itir -from gt4py.next.otf import languages, recipes, stages, step_types, workflow +from gt4py.next import common +from gt4py.next.iterator import ir as itir, transforms as itir_transforms +from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import interface from gt4py.next.otf.languages import LanguageSettings -from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg -from gt4py.next.type_system import type_translation as tt +from gt4py.next.program_processors.runners.dace import ( + gtir_sdfg, + transformations as gtx_transformations, +) @dataclasses.dataclass(frozen=True) @@ -33,7 +33,10 @@ class DaCeTranslator( ], step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], ): - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + device_type: core_defs.DeviceType + auto_optimize: bool + disable_itir_transforms: bool = False + disable_field_origin_on_program_arguments: bool = False def _language_settings(self) -> languages.LanguageSettings: return languages.LanguageSettings( @@ -45,28 +48,48 @@ def generate_sdfg( ir: itir.Program, offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], + auto_opt: bool, + on_gpu: bool, ) -> dace.SDFG: - # TODO(edopao): Call IR transformations and domain inference, finally lower IR to SDFG - raise NotImplementedError + if not self.disable_itir_transforms: + ir = itir_transforms.apply_fieldview_transforms(ir, offset_provider=offset_provider) + sdfg = gtir_sdfg.build_sdfg_from_gtir( + ir, + common.offset_provider_to_type(offset_provider), + column_axis, + disable_field_origin_on_program_arguments=self.disable_field_origin_on_program_arguments, + ) + + if auto_opt: + gtx_transformations.gt_auto_optimize(sdfg, gpu=on_gpu) + elif on_gpu: + # We run simplify to bring the SDFG into a canonical form that the gpu transformations + # can handle. This is a workaround for an issue with scalar expressions that are + # promoted to symbolic expressions and computed on the host (CPU), but the intermediate + # result is written to a GPU global variable (https://github.com/spcl/dace/issues/1773). + gtx_transformations.gt_simplify(sdfg) + gtx_transformations.gt_gpu_transformation(sdfg, try_removing_trivial_maps=True) - return gtir_sdfg.build_sdfg_from_gtir(program=ir, offset_provider=offset_provider) + return sdfg def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the GTIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data assert isinstance(program, itir.Program) sdfg = self.generate_sdfg( program, - inp.args.offset_provider, + inp.args.offset_provider, # TODO(havogt): should be offset_provider_type once the transformation don't require run-time info inp.args.column_axis, + auto_opt=self.auto_optimize, + on_gpu=(self.device_type == core_defs.CUPY_DEVICE_TYPE), ) param_types = tuple( - interface.Parameter(param, tt.from_value(arg)) - for param, arg in zip(sdfg.arg_names, inp.args.args) + interface.Parameter(param, arg_type) + for param, arg_type in zip(sdfg.arg_names, inp.args.args) ) module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( @@ -85,35 +108,3 @@ def __call__( class DaCeTranslationStepFactory(factory.Factory): class Meta: model = DaCeTranslator - - -def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: - return stages.CompilableSource(program_source=inp, binding_source=None) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( - lambda: config.CMAKE_BUILD_TYPE - ) - - translation = factory.SubFactory( - DaCeTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - ) - bindings = _no_bindings - compilation = factory.SubFactory( - dace_workflow.DaCeCompilationStepFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - dace_workflow.convert_args, - device=o.device_type, - ) - ) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py deleted file mode 100644 index dec34ecbac..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ /dev/null @@ -1,101 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import re -from typing import Final, Optional, Sequence - -import dace - -from gt4py.next import common as gtx_common -from gt4py.next.iterator import ir as gtir -from gt4py.next.type_system import type_specifications as ts - - -# regex to match the symbols for field shape and strides -FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile("__.+_(size|stride)_\d+") - - -def as_dace_type(type_: ts.ScalarType) -> dace.typeclass: - """Converts GT4Py scalar type to corresponding DaCe type.""" - if type_.kind == ts.ScalarKind.BOOL: - return dace.bool_ - elif type_.kind == ts.ScalarKind.INT32: - return dace.int32 - elif type_.kind == ts.ScalarKind.INT64: - return dace.int64 - elif type_.kind == ts.ScalarKind.FLOAT32: - return dace.float32 - elif type_.kind == ts.ScalarKind.FLOAT64: - return dace.float64 - raise ValueError(f"Scalar type '{type_}' not supported.") - - -def as_scalar_type(typestr: str) -> ts.ScalarType: - """Obtain GT4Py scalar type from generic numpy string representation.""" - try: - kind = getattr(ts.ScalarKind, typestr.upper()) - except AttributeError as ex: - raise ValueError(f"Data type {typestr} not supported.") from ex - return ts.ScalarType(kind) - - -def connectivity_identifier(name: str) -> str: - return f"connectivity_{name}" - - -def field_size_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_size_{axis}" - - -def field_stride_symbol_name(field_name: str, axis: int) -> str: - return f"__{field_name}_stride_{axis}" - - -def is_field_symbol(name: str) -> bool: - return FIELD_SYMBOL_RE.match(name) is not None - - -def debug_info( - node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None -) -> Optional[dace.dtypes.DebugInfo]: - """Include the GT4Py node location as debug information in the corresponding SDFG nodes.""" - location = node.location - if location: - return dace.dtypes.DebugInfo( - start_line=location.line, - start_column=location.column if location.column else 0, - end_line=location.end_line if location.end_line else -1, - end_column=location.end_column if location.end_column else 0, - filename=location.filename, - ) - return default - - -def filter_connectivities( - offset_provider: gtx_common.OffsetProvider, -) -> dict[str, gtx_common.Connectivity]: - """ - Filter offset providers of type `Connectivity`. - - In other words, filter out the cartesian offset providers. - Returns a new dictionary containing only `Connectivity` values. - """ - return { - offset: table - for offset, table in offset_provider.items() - if isinstance(table, gtx_common.Connectivity) - } - - -def get_sorted_dims( - dims: Sequence[gtx_common.Dimension], -) -> Sequence[tuple[int, gtx_common.Dimension]]: - """Sort list of dimensions in alphabetical order.""" - return sorted(enumerate(dims), key=lambda v: v[1].value) diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py deleted file mode 100644 index ae0a24605d..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ /dev/null @@ -1,164 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import ctypes -import dataclasses -from typing import Any - -import dace -import factory -from dace.codegen.compiled_sdfg import _array_interface_ptr as get_array_interface_ptr - -from gt4py._core import definitions as core_defs -from gt4py.next import common, config -from gt4py.next.otf import arguments, languages, stages, step_types, workflow -from gt4py.next.otf.compilation import cache -from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils - - -class CompiledDaceProgram(stages.CompiledProgram): - sdfg_program: dace.CompiledSDFG - - # Sorted list of SDFG arguments as they appear in program ABI and corresponding data type; - # scalar arguments that are not used in the SDFG will not be present. - sdfg_arglist: list[tuple[str, dace.dtypes.Data]] - - def __init__(self, program: dace.CompiledSDFG): - self.sdfg_program = program - # `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument - # name to its data type, in the same order as arguments appear in the program ABI. - # This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`. - self.sdfg_arglist = [ - (arg_name, arg_type) for arg_name, arg_type in program.sdfg.arglist().items() - ] - - def __call__(self, *args: Any, **kwargs: Any) -> None: - result = self.sdfg_program(*args, **kwargs) - assert result is None - - def fast_call(self) -> None: - result = self.sdfg_program.fast_call(*self.sdfg_program._lastargs) - assert result is None - - -@dataclasses.dataclass(frozen=True) -class DaCeCompiler( - workflow.ChainableWorkflowMixin[ - stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], - CompiledDaceProgram, - ], - workflow.ReplaceEnabledWorkflowMixin[ - stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], - CompiledDaceProgram, - ], - step_types.CompilationStep[languages.SDFG, languages.LanguageSettings, languages.Python], -): - """Use the dace build system to compile a GT4Py program to a ``gt4py.next.otf.stages.CompiledProgram``.""" - - cache_lifetime: config.BuildCacheLifetime - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = config.CMakeBuildType.DEBUG - - def __call__( - self, - inp: stages.CompilableSource[languages.SDFG, languages.LanguageSettings, languages.Python], - ) -> CompiledDaceProgram: - sdfg = dace.SDFG.from_json(inp.program_source.source_code) - - src_dir = cache.get_cache_folder(inp, self.cache_lifetime) - sdfg.build_folder = src_dir / ".dacecache" - - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "build_type", value=self.cmake_build_type.value) - if self.device_type == core_defs.DeviceType.CPU: - compiler_args = dace.config.Config.get("compiler", "cpu", "args") - # disable finite-math-only in order to support isfinite/isinf/isnan builtins - if "-ffast-math" in compiler_args: - compiler_args += " -fno-finite-math-only" - if "-ffinite-math-only" in compiler_args: - compiler_args.replace("-ffinite-math-only", "") - - dace.config.Config.set("compiler", "cpu", "args", value=compiler_args) - sdfg_program = sdfg.compile(validate=False) - - return CompiledDaceProgram(sdfg_program) - - -class DaCeCompilationStepFactory(factory.Factory): - class Meta: - model = DaCeCompiler - - -def convert_args( - inp: CompiledDaceProgram, - device: core_defs.DeviceType = core_defs.DeviceType.CPU, - use_field_canonical_representation: bool = False, -) -> stages.CompiledProgram: - sdfg_program = inp.sdfg_program - sdfg = sdfg_program.sdfg - on_gpu = True if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] else False - - def decorated_program( - *args: Any, - offset_provider: common.OffsetProvider, - out: Any = None, - ) -> None: - if out is not None: - args = (*args, out) - if len(sdfg.arg_names) > len(args): - args = (*args, *arguments.iter_size_args(args)) - - if sdfg_program._lastargs: - kwargs = dict(zip(sdfg.arg_names, args, strict=True)) - kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu)) - - use_fast_call = True - last_call_args = sdfg_program._lastargs[0] - # The scalar arguments should be overridden with the new value; for field arguments, - # the data pointer should remain the same otherwise fast_call cannot be used and - # the arguments list has to be reconstructed. - for i, (arg_name, arg_type) in enumerate(inp.sdfg_arglist): - if isinstance(arg_type, dace.data.Array): - assert arg_name in kwargs, f"argument '{arg_name}' not found." - data_ptr = get_array_interface_ptr(kwargs[arg_name], arg_type.storage) - assert isinstance(last_call_args[i], ctypes.c_void_p) - if last_call_args[i].value != data_ptr: - use_fast_call = False - break - else: - assert isinstance(arg_type, dace.data.Scalar) - assert isinstance(last_call_args[i], ctypes._SimpleCData) - if arg_name in kwargs: - # override the scalar value used in previous program call - actype = arg_type.dtype.as_ctypes() - last_call_args[i] = actype(kwargs[arg_name]) - else: - # shape and strides of arrays are supposed not to change, and can therefore be omitted - assert dace_utils.is_field_symbol( - arg_name - ), f"argument '{arg_name}' not found." - - if use_fast_call: - return inp.fast_call() - - sdfg_args = dace_backend.get_sdfg_args( - sdfg, - *args, - check_args=False, - offset_provider=offset_provider, - on_gpu=on_gpu, - use_field_canonical_representation=use_field_canonical_representation, - ) - - with dace.config.temporary_config(): - dace.config.Config.set("compiler", "allow_view_arguments", value=True) - return inp(**sdfg_args) - - return decorated_program diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py deleted file mode 100644 index 602453fc5a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - - -from gt4py.next.program_processors.runners.dace_common.dace_backend import get_sdfg_args -from gt4py.next.program_processors.runners.dace_fieldview.gtir_sdfg import build_sdfg_from_gtir - - -__all__ = [ - "build_sdfg_from_gtir", - "get_sdfg_args", -] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py deleted file mode 100644 index e91bd880c6..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ /dev/null @@ -1,576 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import abc -import dataclasses -from typing import TYPE_CHECKING, Iterable, Optional, Protocol, TypeAlias - -import dace -import dace.subsets as sbs - -from gt4py.next import common as gtx_common, utils as gtx_utils -from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.type_system import type_specifications as itir_ts -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.program_processors.runners.dace_fieldview import ( - gtir_dataflow, - gtir_python_codegen, - utility as dace_gtir_utils, -) -from gt4py.next.type_system import type_specifications as ts - - -if TYPE_CHECKING: - from gt4py.next.program_processors.runners.dace_fieldview import gtir_sdfg - - -IteratorIndexDType: TypeAlias = dace.int32 # type of iterator indexes - - -@dataclasses.dataclass(frozen=True) -class Field: - data_node: dace.nodes.AccessNode - data_type: ts.FieldType | ts.ScalarType - - -FieldopResult: TypeAlias = Field | tuple[Field | tuple, ...] - - -class PrimitiveTranslator(Protocol): - @abc.abstractmethod - def __call__( - self, - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], - ) -> FieldopResult: - """Creates the dataflow subgraph representing a GTIR primitive function. - - This method is used by derived classes to build a specialized subgraph - for a specific GTIR primitive function. - - Arguments: - node: The GTIR node describing the primitive to be lowered - sdfg: The SDFG where the primitive subgraph should be instantiated - state: The SDFG state where the result of the primitive function should be made available - sdfg_builder: The object responsible for visiting child nodes of the primitive node. - reduce_identity: The value of the reduction identity, in case the primitive node - is visited in the context of a reduction expression. This value is used - by the `neighbors` primitive to provide the default value of skip neighbors. - - Returns: - A list of data access nodes and the associated GT4Py data type, which provide - access to the result of the primitive subgraph. The GT4Py data type is useful - in the case the returned data is an array, because the type provdes the domain - information (e.g. order of dimensions, dimension types). - """ - - -def _parse_fieldop_arg( - node: gtir.Expr, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], - reduce_identity: Optional[gtir_dataflow.SymbolExpr], -) -> gtir_dataflow.IteratorExpr | gtir_dataflow.MemletExpr: - arg = sdfg_builder.visit( - node, - sdfg=sdfg, - head_state=state, - reduce_identity=reduce_identity, - ) - - # arguments passed to field operator should be plain fields, not tuples of fields - if not isinstance(arg, Field): - raise ValueError(f"Received {node} as argument to field operator, expected a field.") - - if isinstance(arg.data_type, ts.ScalarType): - return gtir_dataflow.MemletExpr(arg.data_node, sbs.Indices([0])) - elif isinstance(arg.data_type, ts.FieldType): - indices: dict[gtx_common.Dimension, gtir_dataflow.ValueExpr] = { - dim: gtir_dataflow.SymbolExpr( - dace_gtir_utils.get_map_variable(dim), - IteratorIndexDType, - ) - for dim, _, _ in domain - } - dims = arg.data_type.dims + ( - # we add an extra anonymous dimension in the iterator definition to enable - # dereferencing elements in `ListType` - [gtx_common.Dimension("")] if isinstance(arg.data_type.dtype, itir_ts.ListType) else [] - ) - return gtir_dataflow.IteratorExpr(arg.data_node, dims, indices) - else: - raise NotImplementedError(f"Node type {type(arg.data_type)} not supported.") - - -def _create_temporary_field( - sdfg: dace.SDFG, - state: dace.SDFGState, - domain: list[ - tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType] - ], - node_type: ts.FieldType, - output_desc: dace.data.Data, -) -> Field: - domain_dims, _, domain_ubs = zip(*domain) - field_dims = list(domain_dims) - # It should be enough to allocate an array with shape (upper_bound - lower_bound) - # but this would require to use array offset for compensate for the start index. - # Suppose that a field operator executes on domain [2,N-2], the dace array to store - # the result only needs size (N-4), but this would require to compensate all array - # accesses with offset -2 (which corresponds to -lower_bound). Instead, we choose - # to allocate (N-2), leaving positions [0:2] unused. The reason is that array offset - # is known to cause issues to SDFG inlining. Besides, map fusion will in any case - # eliminate most of transient arrays. - field_shape = list(domain_ubs) - - if isinstance(output_desc, dace.data.Array): - assert isinstance(node_type.dtype, itir_ts.ListType) - assert isinstance(node_type.dtype.element_type, ts.ScalarType) - field_dtype = node_type.dtype.element_type - # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - field_shape.extend(output_desc.shape) - elif isinstance(output_desc, dace.data.Scalar): - field_dtype = node_type.dtype - else: - raise ValueError(f"Cannot create field for dace type {output_desc}.") - - # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, dace_utils.as_dace_type(field_dtype)) - field_node = state.add_access(temp_name) - field_type = ts.FieldType(field_dims, node_type.dtype) - - return Field(field_node, field_type) - - -def translate_as_field_op( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], -) -> FieldopResult: - """ - Generates the dataflow subgraph for the `as_fieldop` builtin function. - - Expects a `FunCall` node with two arguments: - 1. a lambda function representing the stencil, which is lowered to a dataflow subgraph - 2. the domain of the field operator, which is used as map range - - The dataflow can be as simple as a single tasklet, or implement a local computation - as a composition of tasklets and even include a map to range on local dimensions (e.g. - neighbors and map builtins). - The stencil dataflow is instantiated inside a map scope, which apply the stencil over - the field domain. - """ - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - assert isinstance(stencil_expr, gtir.Lambda) - assert isinstance(domain_expr, gtir.FunCall) - - # parse the domain of the field operator - domain = dace_gtir_utils.get_domain(domain_expr) - - if cpm.is_applied_reduce(stencil_expr.expr): - if reduce_identity is not None: - raise NotImplementedError("nested reductions not supported.") - - # the reduce identity value is used to fill the skip values in neighbors list - _, _, reduce_identity = gtir_dataflow.get_reduce_params(stencil_expr.expr) - - # visit the list of arguments to be passed to the lambda expression - stencil_args = [ - _parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain, reduce_identity) - for arg in node.args - ] - - # represent the field operator as a mapped tasklet graph, which will range over the field domain - taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder, reduce_identity) - input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.expr.node.desc(sdfg) - - domain_index = sbs.Indices([dace_gtir_utils.get_map_variable(dim) for dim, _, _ in domain]) - if isinstance(node.type.dtype, itir_ts.ListType): - assert isinstance(output_desc, dace.data.Array) - assert set(output_desc.offset) == {0} - # additional local dimension for neighbors - # TODO(phimuell): Investigate if we should swap the two. - output_subset = sbs.Range.from_indices(domain_index) + sbs.Range.from_array(output_desc) - else: - assert isinstance(output_desc, dace.data.Scalar) - output_subset = sbs.Range.from_indices(domain_index) - - # create map range corresponding to the field operator domain - map_ranges = {dace_gtir_utils.get_map_variable(dim): f"{lb}:{ub}" for dim, lb, ub in domain} - me, mx = sdfg_builder.add_map("field_op", state, map_ranges) - - # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output_desc) - - # here we setup the edges from the map entry node - for edge in input_edges: - edge.connect(me) - - # and here the edge writing the result data through the map exit node - output.connect(mx, result_field.data_node, output_subset) - - return result_field - - -def translate_if( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], -) -> FieldopResult: - """Generates the dataflow subgraph for the `if_` builtin function.""" - assert cpm.is_call_to(node, "if_") - assert len(node.args) == 3 - cond_expr, true_expr, false_expr = node.args - - # expect condition as first argument - if_stmt = gtir_python_codegen.get_source(cond_expr) - - # use current head state to terminate the dataflow, and add a entry state - # to connect the true/false branch states as follows: - # - # ------------ - # === | cond | === - # || ------------ || - # \/ \/ - # ------------ ------------- - # | true | | false | - # ------------ ------------- - # || || - # || ------------ || - # ==> | head | <== - # ------------ - # - cond_state = sdfg.add_state_before(state, state.label + "_cond") - sdfg.remove_edge(sdfg.out_edges(cond_state)[0]) - - # expect true branch as second argument - true_state = sdfg.add_state(state.label + "_true_branch") - sdfg.add_edge(cond_state, true_state, dace.InterstateEdge(condition=f"bool({if_stmt})")) - sdfg.add_edge(true_state, state, dace.InterstateEdge()) - - # and false branch as third argument - false_state = sdfg.add_state(state.label + "_false_branch") - sdfg.add_edge(cond_state, false_state, dace.InterstateEdge(condition=(f"not bool({if_stmt})"))) - sdfg.add_edge(false_state, state, dace.InterstateEdge()) - - true_br_args = sdfg_builder.visit( - true_expr, - sdfg=sdfg, - head_state=true_state, - reduce_identity=reduce_identity, - ) - false_br_args = sdfg_builder.visit( - false_expr, - sdfg=sdfg, - head_state=false_state, - reduce_identity=reduce_identity, - ) - - def make_temps(x: Field) -> Field: - desc = x.data_node.desc(sdfg) - data_name, _ = sdfg.add_temp_transient_like(desc) - data_node = state.add_access(data_name) - - return Field(data_node, x.data_type) - - result_temps = gtx_utils.tree_map(make_temps)(true_br_args) - - fields: Iterable[tuple[Field, Field, Field]] = zip( - gtx_utils.flatten_nested_tuple((true_br_args,)), - gtx_utils.flatten_nested_tuple((false_br_args,)), - gtx_utils.flatten_nested_tuple((result_temps,)), - strict=True, - ) - - for true_br, false_br, temp in fields: - assert true_br.data_type == false_br.data_type - true_br_node = true_br.data_node - false_br_node = false_br.data_node - - temp_name = temp.data_node.data - true_br_output_node = true_state.add_access(temp_name) - true_state.add_nedge( - true_br_node, - true_br_output_node, - sdfg.make_array_memlet(temp_name), - ) - - false_br_output_node = false_state.add_access(temp_name) - false_state.add_nedge( - false_br_node, - false_br_output_node, - sdfg.make_array_memlet(temp_name), - ) - - return result_temps - - -def _get_data_nodes( - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - sym_name: str, - sym_type: ts.DataType, -) -> FieldopResult: - if isinstance(sym_type, ts.FieldType): - sym_node = state.add_access(sym_name) - return Field(sym_node, sym_type) - elif isinstance(sym_type, ts.ScalarType): - if sym_name in sdfg.arrays: - # access the existing scalar container - sym_node = state.add_access(sym_name) - else: - sym_node = _get_symbolic_value( - sdfg, state, sdfg_builder, sym_name, sym_type, temp_name=f"__{sym_name}" - ) - return Field(sym_node, sym_type) - elif isinstance(sym_type, ts.TupleType): - tuple_fields = dace_gtir_utils.get_tuple_fields(sym_name, sym_type) - return tuple( - _get_data_nodes(sdfg, state, sdfg_builder, fname, ftype) - for fname, ftype in tuple_fields - ) - else: - raise NotImplementedError(f"Symbol type {type(sym_type)} not supported.") - - -def _get_symbolic_value( - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - symbolic_expr: dace.symbolic.SymExpr, - scalar_type: ts.ScalarType, - temp_name: Optional[str] = None, -) -> dace.nodes.AccessNode: - tasklet_node = sdfg_builder.add_tasklet( - "get_value", - state, - {}, - {"__out"}, - f"__out = {symbolic_expr}", - ) - temp_name, _ = sdfg.add_scalar( - temp_name or sdfg.temp_data_name(), - dace_utils.as_dace_type(scalar_type), - find_new_name=True, - transient=True, - ) - data_node = state.add_access(temp_name) - state.add_edge( - tasklet_node, - "__out", - data_node, - None, - dace.Memlet(data=temp_name, subset="0"), - ) - return data_node - - -def translate_literal( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], -) -> FieldopResult: - """Generates the dataflow subgraph for a `ir.Literal` node.""" - assert isinstance(node, gtir.Literal) - - data_type = node.type - data_node = _get_symbolic_value(sdfg, state, sdfg_builder, node.value, data_type) - - return Field(data_node, data_type) - - -def translate_make_tuple( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], -) -> FieldopResult: - assert cpm.is_call_to(node, "make_tuple") - return tuple( - sdfg_builder.visit( - arg, - sdfg=sdfg, - head_state=state, - reduce_identity=reduce_identity, - ) - for arg in node.args - ) - - -def translate_tuple_get( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], -) -> FieldopResult: - assert cpm.is_call_to(node, "tuple_get") - assert len(node.args) == 2 - - if not isinstance(node.args[0], gtir.Literal): - raise ValueError("Tuple can only be subscripted with compile-time constants.") - assert node.args[0].type == dace_utils.as_scalar_type(gtir.INTEGER_INDEX_BUILTIN) - index = int(node.args[0].value) - - data_nodes = sdfg_builder.visit( - node.args[1], - sdfg=sdfg, - head_state=state, - reduce_identity=reduce_identity, - ) - if isinstance(data_nodes, Field): - raise ValueError(f"Invalid tuple expression {node}") - unused_arg_nodes: Iterable[Field] = gtx_utils.flatten_nested_tuple( - tuple(arg for i, arg in enumerate(data_nodes) if i != index) - ) - state.remove_nodes_from( - [arg.data_node for arg in unused_arg_nodes if state.degree(arg.data_node) == 0] - ) - return data_nodes[index] - - -def translate_scalar_expr( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], -) -> FieldopResult: - assert isinstance(node, gtir.FunCall) - assert isinstance(node.type, ts.ScalarType) - - args = [] - connectors = [] - scalar_expr_args = [] - - for arg_expr in node.args: - visit_expr = True - if isinstance(arg_expr, gtir.SymRef): - try: - # `gt_symbol` refers to symbols defined in the GT4Py program - gt_symbol_type = sdfg_builder.get_symbol_type(arg_expr.id) - if not isinstance(gt_symbol_type, ts.ScalarType): - raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") - except KeyError: - # this is the case of non-variable argument, e.g. target type such as `float64`, - # used in a casting expression like `cast_(variable, float64)` - visit_expr = False - - if visit_expr: - # we visit the argument expression and obtain the access node to - # a scalar data container, which will be connected to the tasklet - arg = sdfg_builder.visit( - arg_expr, - sdfg=sdfg, - head_state=state, - reduce_identity=reduce_identity, - ) - if not (isinstance(arg, Field) and isinstance(arg.data_type, ts.ScalarType)): - raise ValueError(f"Invalid argument to scalar expression {arg_expr}.") - param = f"__in_{arg.data_node.data}" - args.append(arg.data_node) - connectors.append(param) - scalar_expr_args.append(gtir.SymRef(id=param)) - else: - assert isinstance(arg_expr, gtir.SymRef) - scalar_expr_args.append(arg_expr) - - # we visit the scalar expression replacing the input arguments with the corresponding data connectors - scalar_node = gtir.FunCall(fun=node.fun, args=scalar_expr_args) - python_code = gtir_python_codegen.get_source(scalar_node) - tasklet_node = sdfg_builder.add_tasklet( - name="scalar_expr", - state=state, - inputs=set(connectors), - outputs={"__out"}, - code=f"__out = {python_code}", - ) - # create edges for the input data connectors - for arg_node, conn in zip(args, connectors, strict=True): - state.add_edge( - arg_node, - None, - tasklet_node, - conn, - dace.Memlet(data=arg_node.data, subset="0"), - ) - # finally, create temporary for the result value - temp_name, _ = sdfg.add_scalar( - sdfg.temp_data_name(), - dace_utils.as_dace_type(node.type), - find_new_name=True, - transient=True, - ) - temp_node = state.add_access(temp_name) - state.add_edge( - tasklet_node, - "__out", - temp_node, - None, - dace.Memlet(data=temp_name, subset="0"), - ) - - return Field(temp_node, node.type) - - -def translate_symbol_ref( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], -) -> FieldopResult: - """Generates the dataflow subgraph for a `ir.SymRef` node.""" - assert isinstance(node, gtir.SymRef) - - symbol_name = str(node.id) - # we retrieve the type of the symbol in the GT4Py prgram - gt_symbol_type = sdfg_builder.get_symbol_type(symbol_name) - - # Create new access node in current state. It is possible that multiple - # access nodes are created in one state for the same data container. - # We rely on the dace simplify pass to remove duplicated access nodes. - return _get_data_nodes(sdfg, state, sdfg_builder, symbol_name, gt_symbol_type) - - -if TYPE_CHECKING: - # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol - __primitive_translators: list[PrimitiveTranslator] = [ - translate_as_field_op, - translate_if, - translate_literal, - translate_make_tuple, - translate_tuple_get, - translate_scalar_expr, - translate_symbol_ref, - ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py deleted file mode 100644 index 9739d7927a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ /dev/null @@ -1,925 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import abc -import dataclasses -from typing import Any, Dict, Final, List, Optional, Protocol, Set, Tuple, TypeAlias, Union - -import dace -import dace.subsets as sbs - -from gt4py import eve -from gt4py.next import common as gtx_common -from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.type_system import type_specifications as itir_ts -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.program_processors.runners.dace_fieldview import ( - gtir_python_codegen, - gtir_sdfg, - utility as dace_gtir_utils, -) -from gt4py.next.type_system import type_specifications as ts - - -@dataclasses.dataclass(frozen=True) -class DataExpr: - """Local storage for the computation result returned by a tasklet node.""" - - node: dace.nodes.AccessNode - dtype: itir_ts.ListType | ts.ScalarType - - -@dataclasses.dataclass(frozen=True) -class MemletExpr: - """Scalar or array data access through a memlet.""" - - node: dace.nodes.AccessNode - subset: sbs.Indices | sbs.Range - - -@dataclasses.dataclass(frozen=True) -class SymbolExpr: - """Any symbolic expression that is constant in the context of current SDFG.""" - - value: dace.symbolic.SymExpr - dtype: dace.typeclass - - -ValueExpr: TypeAlias = DataExpr | MemletExpr | SymbolExpr - - -@dataclasses.dataclass(frozen=True) -class IteratorExpr: - """ - Iterator for field access to be consumed by `deref` or `shift` builtin functions. - - Args: - field: The field this iterator operates on. - dimensions: Field domain represented as a sorted list of dimensions. - In order to dereference an element in the field, we need index values - for all the dimensions in the right order. - indices: Maps each dimension to an index value, which could be either a symbolic value - or the result of a tasklet computation like neighbors connectivity or dynamic offset. - - """ - - field: dace.nodes.AccessNode - dimensions: list[gtx_common.Dimension] - indices: dict[gtx_common.Dimension, ValueExpr] - - -class DataflowInputEdge(Protocol): - """ - This protocol represents an open connection into the dataflow. - - It provides the `connect` method to setup an input edge from an external data source. - Since the dataflow represents a stencil, we instantiate the dataflow inside a map scope - and connect its inputs and outputs to external data nodes by means of memlets that - traverse the map entry and exit nodes. - """ - - @abc.abstractmethod - def connect(self, me: dace.nodes.MapEntry) -> None: ... - - -@dataclasses.dataclass(frozen=True) -class MemletInputEdge(DataflowInputEdge): - """ - Allows to setup an input memlet through a map entry node. - - The edge source has to be a data access node, while the destination node can either - be a tasklet, in which case the connector name is also required, or an access node. - """ - - state: dace.SDFGState - source: dace.nodes.AccessNode - subset: sbs.Range - dest: dace.nodes.AccessNode | dace.nodes.Tasklet - dest_conn: Optional[str] - - def connect(self, me: dace.nodes.MapEntry) -> None: - memlet = dace.Memlet(data=self.source.data, subset=self.subset) - self.state.add_memlet_path( - self.source, - me, - self.dest, - dst_conn=self.dest_conn, - memlet=memlet, - ) - - -@dataclasses.dataclass(frozen=True) -class EmptyInputEdge(DataflowInputEdge): - """ - Allows to setup an edge from a map entry node to a tasklet with no arguements. - - The reason behind this kind of connection is that all nodes inside a map scope - must have an in/out path that traverses the entry and exit nodes. - """ - - state: dace.SDFGState - node: dace.nodes.Tasklet - - def connect(self, me: dace.nodes.MapEntry) -> None: - self.state.add_nedge(me, self.node, dace.Memlet()) - - -@dataclasses.dataclass(frozen=True) -class DataflowOutputEdge: - """ - Allows to setup an output memlet through a map exit node. - - The result of a dataflow subgraph needs to be written to an external data node. - Since the dataflow represents a stencil and the dataflow is computed over - a field domain, the dataflow is instatiated inside a map scope. The `connect` - method creates a memlet that writes the dataflow result to the external array - passing through the map exit node. - """ - - state: dace.SDFGState - expr: DataExpr - - def connect( - self, - mx: dace.nodes.MapExit, - result_node: dace.nodes.AccessNode, - subset: sbs.Range, - ) -> None: - # retrieve the node which writes the result - last_node = self.state.in_edges(self.expr.node)[0].src - if isinstance(last_node, dace.nodes.Tasklet): - # the last transient node can be deleted - last_node_connector = self.state.in_edges(self.expr.node)[0].src_conn - self.state.remove_node(self.expr.node) - else: - last_node = self.expr.node - last_node_connector = None - - self.state.add_memlet_path( - last_node, - mx, - result_node, - src_conn=last_node_connector, - memlet=dace.Memlet(data=result_node.data, subset=subset), - ) - - -DACE_REDUCTION_MAPPING: dict[str, dace.dtypes.ReductionType] = { - "minimum": dace.dtypes.ReductionType.Min, - "maximum": dace.dtypes.ReductionType.Max, - "plus": dace.dtypes.ReductionType.Sum, - "multiplies": dace.dtypes.ReductionType.Product, - "and_": dace.dtypes.ReductionType.Logical_And, - "or_": dace.dtypes.ReductionType.Logical_Or, - "xor_": dace.dtypes.ReductionType.Logical_Xor, - "minus": dace.dtypes.ReductionType.Sub, - "divides": dace.dtypes.ReductionType.Div, -} - - -def get_reduce_params(node: gtir.FunCall) -> tuple[str, SymbolExpr, SymbolExpr]: - assert isinstance(node.type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type) - - assert isinstance(node.fun, gtir.FunCall) - assert len(node.fun.args) == 2 - assert isinstance(node.fun.args[0], gtir.SymRef) - op_name = str(node.fun.args[0]) - assert isinstance(node.fun.args[1], gtir.Literal) - assert node.fun.args[1].type == node.type - reduce_init = SymbolExpr(node.fun.args[1].value, dtype) - - if op_name not in DACE_REDUCTION_MAPPING: - raise RuntimeError(f"Reduction operation '{op_name}' not supported.") - identity_value = dace.dtypes.reduction_identity(dtype, DACE_REDUCTION_MAPPING[op_name]) - reduce_identity = SymbolExpr(identity_value, dtype) - - return op_name, reduce_init, reduce_identity - - -class LambdaToDataflow(eve.NodeVisitor): - """ - Translates an `ir.Lambda` expression to a dataflow graph. - - The dataflow graph generated here typically represents the stencil function - of a field operator. It only computes single elements or pure local fields, - in case of neighbor values. In case of local fields, the dataflow contains - inner maps with fixed literal size (max number of neighbors). - Once the lambda expression has been lowered to a dataflow, the dataflow graph - needs to be instantiated, that is we have to connect all in/out edges to - external source/destination data nodes. Since the lambda expression is used - in GTIR as argument to a field operator, the dataflow is instatiated inside - a map scope and applied on the field domain. Therefore, all in/out edges - must traverse the entry/exit map nodes. - """ - - sdfg: dace.SDFG - state: dace.SDFGState - subgraph_builder: gtir_sdfg.DataflowBuilder - reduce_identity: Optional[SymbolExpr] - input_edges: list[DataflowInputEdge] - symbol_map: dict[str, IteratorExpr | MemletExpr | SymbolExpr] - - def __init__( - self, - sdfg: dace.SDFG, - state: dace.SDFGState, - subgraph_builder: gtir_sdfg.DataflowBuilder, - reduce_identity: Optional[SymbolExpr], - ): - self.sdfg = sdfg - self.state = state - self.subgraph_builder = subgraph_builder - self.reduce_identity = reduce_identity - self.input_edges = [] - self.symbol_map = {} - - def _add_input_data_edge( - self, - src: dace.nodes.AccessNode, - src_subset: sbs.Range, - dst_node: dace.nodes.Node, - dst_conn: Optional[str] = None, - ) -> None: - edge = MemletInputEdge(self.state, src, src_subset, dst_node, dst_conn) - self.input_edges.append(edge) - - def _add_edge( - self, - src_node: dace.Node, - src_node_connector: Optional[str], - dst_node: dace.Node, - dst_node_connector: Optional[str], - memlet: dace.Memlet, - ) -> None: - """Helper method to add an edge in current state.""" - self.state.add_edge(src_node, src_node_connector, dst_node, dst_node_connector, memlet) - - def _add_map( - self, - name: str, - ndrange: Union[ - Dict[str, Union[str, dace.subsets.Subset]], - List[Tuple[str, Union[str, dace.subsets.Subset]]], - ], - **kwargs: Any, - ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: - """Helper method to add a map with unique name in current state.""" - return self.subgraph_builder.add_map(name, self.state, ndrange, **kwargs) - - def _add_tasklet( - self, - name: str, - inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], - outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], - code: str, - **kwargs: Any, - ) -> dace.nodes.Tasklet: - """Helper method to add a tasklet with unique name in current state.""" - tasklet_node = self.subgraph_builder.add_tasklet( - name, self.state, inputs, outputs, code, **kwargs - ) - if len(inputs) == 0: - # All nodes inside a map scope must have an in/out path that traverses - # the entry and exit nodes. Therefore, a tasklet node with no arguments - # still needs an (empty) input edge from map entry node. - edge = EmptyInputEdge(self.state, tasklet_node) - self.input_edges.append(edge) - return tasklet_node - - def _construct_tasklet_result( - self, - dtype: dace.typeclass, - src_node: dace.nodes.Tasklet, - src_connector: str, - ) -> DataExpr: - temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, dtype, transient=True) - data_type = dace_utils.as_scalar_type(str(dtype.as_numpy_dtype())) - temp_node = self.state.add_access(temp_name) - self._add_edge( - src_node, - src_connector, - temp_node, - None, - dace.Memlet(data=temp_name, subset="0"), - ) - return DataExpr(temp_node, data_type) - - def _visit_deref(self, node: gtir.FunCall) -> ValueExpr: - """ - Visit a `deref` node, which represents dereferencing of an iterator. - The iterator is the argument of this node. - - The iterator contains the information for accessing a field, that is the - sorted list of dimensions in the field domain and the index values for - each dimension. The index values can be either symbol values, that is - literal values or scalar arguments which are constant in the SDFG scope; - or they can be the result of some expression, that computes a dynamic - index offset or gets an neighbor index from a connectivity table. - In case all indexes are symbol values, the `deref` node is lowered to a - memlet; otherwise dereferencing is a runtime operation represented in - the SDFG as a tasklet node. - """ - # format used for field index tasklet connector - IndexConnectorFmt: Final = "__index_{dim}" - - assert len(node.args) == 1 - arg_expr = self.visit(node.args[0]) - - if isinstance(arg_expr, IteratorExpr): - field_desc = arg_expr.field.desc(self.sdfg) - assert len(field_desc.shape) == len(arg_expr.dimensions) - if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): - # when all indices are symblic expressions, we can perform direct field access through a memlet - field_subset = sbs.Range( - (arg_expr.indices[dim].value, arg_expr.indices[dim].value, 1) # type: ignore[union-attr] - if dim in arg_expr.indices - else (0, size - 1, 1) - for dim, size in zip(arg_expr.dimensions, field_desc.shape) - ) - return MemletExpr(arg_expr.field, field_subset) - - else: - # we use a tasklet to dereference an iterator when one or more indices are the result of some computation, - # either indirection through connectivity table or dynamic cartesian offset. - assert all(dim in arg_expr.indices for dim in arg_expr.dimensions) - field_indices = [(dim, arg_expr.indices[dim]) for dim in arg_expr.dimensions] - index_connectors = [ - IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - if not isinstance(index, SymbolExpr) - ] - # here `internals` refer to the names used as index in the tasklet code string: - # an index can be either a connector name (for dynamic/indirect indices) - # or a symbol value (for literal values and scalar arguments). - index_internals = ",".join( - str(index.value) - if isinstance(index, SymbolExpr) - else IndexConnectorFmt.format(dim=dim.value) - for dim, index in field_indices - ) - deref_node = self._add_tasklet( - "runtime_deref", - {"field"} | set(index_connectors), - {"val"}, - code=f"val = field[{index_internals}]", - ) - # add new termination point for the field parameter - self._add_input_data_edge( - arg_expr.field, - sbs.Range.from_array(field_desc), - deref_node, - "field", - ) - - for dim, index_expr in field_indices: - # add termination points for the dynamic iterator indices - deref_connector = IndexConnectorFmt.format(dim=dim.value) - if isinstance(index_expr, MemletExpr): - self._add_input_data_edge( - index_expr.node, - index_expr.subset, - deref_node, - deref_connector, - ) - - elif isinstance(index_expr, DataExpr): - self._add_edge( - index_expr.node, - None, - deref_node, - deref_connector, - dace.Memlet(data=index_expr.node.data, subset="0"), - ) - else: - assert isinstance(index_expr, SymbolExpr) - - dtype = arg_expr.field.desc(self.sdfg).dtype - return self._construct_tasklet_result(dtype, deref_node, "val") - - else: - # dereferencing a scalar or a literal node results in the node itself - return arg_expr - - def _visit_neighbors(self, node: gtir.FunCall) -> DataExpr: - assert len(node.args) == 2 - - assert isinstance(node.args[0], gtir.OffsetLiteral) - offset = node.args[0].value - assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) - assert isinstance(offset_provider, gtx_common.Connectivity) - - it = self.visit(node.args[1]) - assert isinstance(it, IteratorExpr) - assert offset_provider.neighbor_axis in it.dimensions - neighbor_dim_index = it.dimensions.index(offset_provider.neighbor_axis) - assert offset_provider.neighbor_axis not in it.indices - assert offset_provider.origin_axis not in it.dimensions - assert offset_provider.origin_axis in it.indices - origin_index = it.indices[offset_provider.origin_axis] - assert isinstance(origin_index, SymbolExpr) - assert all(isinstance(index, SymbolExpr) for index in it.indices.values()) - - field_desc = it.field.desc(self.sdfg) - connectivity = dace_utils.connectivity_identifier(offset) - # initially, the storage for the connectivty tables is created as transient; - # when the tables are used, the storage is changed to non-transient, - # as the corresponding arrays are supposed to be allocated by the SDFG caller - connectivity_desc = self.sdfg.arrays[connectivity] - connectivity_desc.transient = False - - # The visitor is constructing a list of input connections that will be handled - # by `translate_as_fieldop` (the primitive translator), that is responsible - # of creating the map for the field domain. For each input connection, it will - # create a memlet that will write to a node specified by the third attribute - # in the `InputConnection` tuple (either a tasklet, or a view node, or a library - # node). For the specific case of `neighbors` we need to nest the neighbors map - # inside the field map and the memlets will traverse the external map and write - # to the view nodes. The simplify pass will remove the redundant access nodes. - field_slice_view, field_slice_desc = self.sdfg.add_view( - f"{offset_provider.neighbor_axis.value}_view", - (field_desc.shape[neighbor_dim_index],), - field_desc.dtype, - strides=(field_desc.strides[neighbor_dim_index],), - find_new_name=True, - ) - field_slice_node = self.state.add_access(field_slice_view) - field_subset = ",".join( - it.indices[dim].value # type: ignore[union-attr] - if dim != offset_provider.neighbor_axis - else f"0:{size}" - for dim, size in zip(it.dimensions, field_desc.shape, strict=True) - ) - self._add_input_data_edge( - it.field, - sbs.Range.from_string(field_subset), - field_slice_node, - ) - - connectivity_slice_view, _ = self.sdfg.add_view( - "neighbors_view", - (offset_provider.max_neighbors,), - connectivity_desc.dtype, - strides=(connectivity_desc.strides[1],), - find_new_name=True, - ) - connectivity_slice_node = self.state.add_access(connectivity_slice_view) - self._add_input_data_edge( - self.state.add_access(connectivity), - sbs.Range.from_string(f"{origin_index.value}, 0:{offset_provider.max_neighbors}"), - connectivity_slice_node, - ) - - neighbors_temp, _ = self.sdfg.add_temp_transient( - (offset_provider.max_neighbors,), field_desc.dtype - ) - neighbors_node = self.state.add_access(neighbors_temp) - - offset_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) - neighbor_idx = dace_gtir_utils.get_map_variable(offset_dim) - me, mx = self._add_map( - f"{offset}_neighbors", - { - neighbor_idx: f"0:{offset_provider.max_neighbors}", - }, - ) - index_connector = "__index" - if offset_provider.has_skip_values: - assert self.reduce_identity is not None - assert self.reduce_identity.dtype == field_desc.dtype - # TODO: Investigate if a NestedSDFG brings benefits - tasklet_node = self._add_tasklet( - "gather_neighbors_with_skip_values", - {"__field", index_connector}, - {"__val"}, - f"__val = __field[{index_connector}] if {index_connector} != {gtx_common._DEFAULT_SKIP_VALUE} else {self.reduce_identity.dtype}({self.reduce_identity.value})", - ) - - else: - tasklet_node = self._add_tasklet( - "gather_neighbors", - {"__field", index_connector}, - {"__val"}, - f"__val = __field[{index_connector}]", - ) - - self.state.add_memlet_path( - field_slice_node, - me, - tasklet_node, - dst_conn="__field", - memlet=dace.Memlet.from_array(field_slice_view, field_slice_desc), - ) - self.state.add_memlet_path( - connectivity_slice_node, - me, - tasklet_node, - dst_conn=index_connector, - memlet=dace.Memlet(data=connectivity_slice_view, subset=neighbor_idx), - ) - self.state.add_memlet_path( - tasklet_node, - mx, - neighbors_node, - src_conn="__val", - memlet=dace.Memlet(data=neighbors_temp, subset=neighbor_idx), - ) - - assert isinstance(node.type, itir_ts.ListType) - return DataExpr(neighbors_node, node.type) - - def _visit_reduce(self, node: gtir.FunCall) -> DataExpr: - op_name, reduce_init, reduce_identity = get_reduce_params(node) - dtype = reduce_identity.dtype - - # We store the value of reduce identity in the visitor context while visiting - # the input to reduction; this value will be use by the `neighbors` visitor - # to fill the skip values in the neighbors list. - prev_reduce_identity = self.reduce_identity - self.reduce_identity = reduce_identity - - try: - input_expr = self.visit(node.args[0]) - finally: - # ensure that we leave the visitor in the same state as we entered - self.reduce_identity = prev_reduce_identity - - assert isinstance(input_expr, MemletExpr | DataExpr) - input_desc = input_expr.node.desc(self.sdfg) - assert isinstance(input_desc, dace.data.Array) - - if len(input_desc.shape) > 1: - assert isinstance(input_expr, MemletExpr) - ndims = len(input_desc.shape) - 1 - # the axis to be reduced is always the last one, because `reduce` is supposed - # to operate on `ListType` - assert set(input_expr.subset.size()[0:ndims]) == {1} - reduce_axes = [ndims] - else: - reduce_axes = None - - reduce_wcr = "lambda x, y: " + gtir_python_codegen.format_builtin(op_name, "x", "y") - reduce_node = self.state.add_reduce(reduce_wcr, reduce_axes, reduce_init.value) - - if isinstance(input_expr, MemletExpr): - self._add_input_data_edge( - input_expr.node, - input_expr.subset, - reduce_node, - ) - else: - self.state.add_nedge( - input_expr.node, - reduce_node, - dace.Memlet.from_array(input_expr.node.data, input_desc), - ) - - temp_name = self.sdfg.temp_data_name() - self.sdfg.add_scalar(temp_name, dtype, transient=True) - temp_node = self.state.add_access(temp_name) - - self.state.add_nedge( - reduce_node, - temp_node, - dace.Memlet(data=temp_name, subset="0"), - ) - assert isinstance(node.type, ts.ScalarType) - return DataExpr(temp_node, node.type) - - def _split_shift_args( - self, args: list[gtir.Expr] - ) -> tuple[tuple[gtir.Expr, gtir.Expr], Optional[list[gtir.Expr]]]: - """ - Splits the arguments to `shift` builtin function as pairs, each pair containing - the offset provider and the offset expression in one dimension. - """ - nargs = len(args) - assert nargs >= 2 and nargs % 2 == 0 - return (args[-2], args[-1]), args[: nargs - 2] if nargs > 2 else None - - def _visit_shift_multidim( - self, iterator: gtir.Expr, shift_args: list[gtir.Expr] - ) -> tuple[gtir.Expr, gtir.Expr, IteratorExpr]: - """Transforms a multi-dimensional shift into recursive shift calls, each in a single dimension.""" - (offset_provider_arg, offset_value_arg), tail = self._split_shift_args(shift_args) - if tail: - node = gtir.FunCall( - fun=gtir.FunCall(fun=gtir.SymRef(id="shift"), args=tail), - args=[iterator], - ) - it = self.visit(node) - else: - it = self.visit(iterator) - - assert isinstance(it, IteratorExpr) - return offset_provider_arg, offset_value_arg, it - - def _make_cartesian_shift( - self, it: IteratorExpr, offset_dim: gtx_common.Dimension, offset_expr: ValueExpr - ) -> IteratorExpr: - """Implements cartesian shift along one dimension.""" - assert offset_dim in it.dimensions - new_index: SymbolExpr | DataExpr - assert offset_dim in it.indices - index_expr = it.indices[offset_dim] - if isinstance(index_expr, SymbolExpr) and isinstance(offset_expr, SymbolExpr): - # purely symbolic expression which can be interpreted at compile time - new_index = SymbolExpr( - dace.symbolic.pystr_to_symbolic(index_expr.value) + offset_expr.value, - index_expr.dtype, - ) - else: - # the offset needs to be calculated by means of a tasklet (i.e. dynamic offset) - new_index_connector = "shifted_index" - if isinstance(index_expr, SymbolExpr): - dynamic_offset_tasklet = self._add_tasklet( - "dynamic_offset", - {"offset"}, - {new_index_connector}, - f"{new_index_connector} = {index_expr.value} + offset", - ) - elif isinstance(offset_expr, SymbolExpr): - dynamic_offset_tasklet = self._add_tasklet( - "dynamic_offset", - {"index"}, - {new_index_connector}, - f"{new_index_connector} = index + {offset_expr}", - ) - else: - dynamic_offset_tasklet = self._add_tasklet( - "dynamic_offset", - {"index", "offset"}, - {new_index_connector}, - f"{new_index_connector} = index + offset", - ) - for input_expr, input_connector in [(index_expr, "index"), (offset_expr, "offset")]: - if isinstance(input_expr, MemletExpr): - self._add_input_data_edge( - input_expr.node, - input_expr.subset, - dynamic_offset_tasklet, - input_connector, - ) - elif isinstance(input_expr, DataExpr): - self._add_edge( - input_expr.node, - None, - dynamic_offset_tasklet, - input_connector, - dace.Memlet(data=input_expr.node.data, subset="0"), - ) - - if isinstance(index_expr, SymbolExpr): - dtype = index_expr.dtype - else: - dtype = index_expr.node.desc(self.sdfg).dtype - - new_index = self._construct_tasklet_result( - dtype, dynamic_offset_tasklet, new_index_connector - ) - - # a new iterator with a shifted index along one dimension - return IteratorExpr( - it.field, - it.dimensions, - {dim: (new_index if dim == offset_dim else index) for dim, index in it.indices.items()}, - ) - - def _make_dynamic_neighbor_offset( - self, - offset_expr: MemletExpr | DataExpr, - offset_table_node: dace.nodes.AccessNode, - origin_index: SymbolExpr, - ) -> DataExpr: - """ - Implements access to neighbor connectivity table by means of a tasklet node. - - It requires a dynamic offset value, either obtained from a field/scalar argument (`MemletExpr`) - or computed by another tasklet (`DataExpr`). - """ - new_index_connector = "neighbor_index" - tasklet_node = self._add_tasklet( - "dynamic_neighbor_offset", - {"table", "offset"}, - {new_index_connector}, - f"{new_index_connector} = table[{origin_index.value}, offset]", - ) - self._add_input_data_edge( - offset_table_node, - sbs.Range.from_array(offset_table_node.desc(self.sdfg)), - tasklet_node, - "table", - ) - if isinstance(offset_expr, MemletExpr): - self._add_input_data_edge( - offset_expr.node, - offset_expr.subset, - tasklet_node, - "offset", - ) - else: - self._add_edge( - offset_expr.node, - None, - tasklet_node, - "offset", - dace.Memlet(data=offset_expr.node.data, subset="0"), - ) - - dtype = offset_table_node.desc(self.sdfg).dtype - return self._construct_tasklet_result(dtype, tasklet_node, new_index_connector) - - def _make_unstructured_shift( - self, - it: IteratorExpr, - connectivity: gtx_common.Connectivity, - offset_table_node: dace.nodes.AccessNode, - offset_expr: ValueExpr, - ) -> IteratorExpr: - """Implements shift in unstructured domain by means of a neighbor table.""" - assert connectivity.neighbor_axis in it.dimensions - neighbor_dim = connectivity.neighbor_axis - assert neighbor_dim not in it.indices - - origin_dim = connectivity.origin_axis - assert origin_dim in it.indices - origin_index = it.indices[origin_dim] - assert isinstance(origin_index, SymbolExpr) - - shifted_indices = {dim: idx for dim, idx in it.indices.items() if dim != origin_dim} - if isinstance(offset_expr, SymbolExpr): - # use memlet to retrieve the neighbor index - shifted_indices[neighbor_dim] = MemletExpr( - offset_table_node, - sbs.Indices([origin_index.value, offset_expr.value]), - ) - else: - # dynamic offset: we cannot use a memlet to retrieve the offset value, use a tasklet node - shifted_indices[neighbor_dim] = self._make_dynamic_neighbor_offset( - offset_expr, offset_table_node, origin_index - ) - - return IteratorExpr(it.field, it.dimensions, shifted_indices) - - def _visit_shift(self, node: gtir.FunCall) -> IteratorExpr: - # convert builtin-index type to dace type - IndexDType: Final = dace_utils.as_dace_type( - ts.ScalarType(kind=getattr(ts.ScalarKind, gtir.INTEGER_INDEX_BUILTIN.upper())) - ) - - assert isinstance(node.fun, gtir.FunCall) - # the iterator to be shifted is the node argument, while the shift arguments - # are provided by the nested function call; the shift arguments consist of - # the offset provider and the offset value in each dimension to be shifted - offset_provider_arg, offset_value_arg, it = self._visit_shift_multidim( - node.args[0], node.fun.args - ) - - # first argument of the shift node is the offset provider - assert isinstance(offset_provider_arg, gtir.OffsetLiteral) - offset = offset_provider_arg.value - assert isinstance(offset, str) - offset_provider = self.subgraph_builder.get_offset_provider(offset) - # second argument should be the offset value, which could be a symbolic expression or a dynamic offset - offset_expr = ( - SymbolExpr(offset_value_arg.value, IndexDType) - if isinstance(offset_value_arg, gtir.OffsetLiteral) - else self.visit(offset_value_arg) - ) - - if isinstance(offset_provider, gtx_common.Dimension): - return self._make_cartesian_shift(it, offset_provider, offset_expr) - else: - # initially, the storage for the connectivity tables is created as transient; - # when the tables are used, the storage is changed to non-transient, - # so the corresponding arrays are supposed to be allocated by the SDFG caller - offset_table = dace_utils.connectivity_identifier(offset) - self.sdfg.arrays[offset_table].transient = False - offset_table_node = self.state.add_access(offset_table) - - return self._make_unstructured_shift( - it, offset_provider, offset_table_node, offset_expr - ) - - def _visit_generic_builtin(self, node: gtir.FunCall) -> DataExpr: - """ - Generic handler called by `visit_FunCall()` when it encounters - a builtin function that does not match any other specific handler. - """ - assert isinstance(node.type, ts.ScalarType) - dtype = dace_utils.as_dace_type(node.type) - - node_internals = [] - node_connections: dict[str, MemletExpr | DataExpr] = {} - for i, arg in enumerate(node.args): - arg_expr = self.visit(arg) - if isinstance(arg_expr, MemletExpr | DataExpr): - # the argument value is the result of a tasklet node or direct field access - connector = f"__inp_{i}" - node_connections[connector] = arg_expr - node_internals.append(connector) - else: - assert isinstance(arg_expr, SymbolExpr) - # use the argument value without adding any connector - node_internals.append(arg_expr.value) - - assert isinstance(node.fun, gtir.SymRef) - builtin_name = str(node.fun.id) - # use tasklet connectors as expression arguments - code = gtir_python_codegen.format_builtin(builtin_name, *node_internals) - - out_connector = "result" - tasklet_node = self._add_tasklet( - builtin_name, - set(node_connections.keys()), - {out_connector}, - "{} = {}".format(out_connector, code), - ) - - for connector, arg_expr in node_connections.items(): - if isinstance(arg_expr, DataExpr): - self._add_edge( - arg_expr.node, - None, - tasklet_node, - connector, - dace.Memlet(data=arg_expr.node.data, subset="0"), - ) - else: - self._add_input_data_edge( - arg_expr.node, - arg_expr.subset, - tasklet_node, - connector, - ) - - return self._construct_tasklet_result(dtype, tasklet_node, "result") - - def visit_FunCall(self, node: gtir.FunCall) -> IteratorExpr | ValueExpr: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) - - elif cpm.is_call_to(node, "neighbors"): - return self._visit_neighbors(node) - - elif cpm.is_applied_reduce(node): - return self._visit_reduce(node) - - elif cpm.is_applied_shift(node): - return self._visit_shift(node) - - elif isinstance(node.fun, gtir.SymRef): - return self._visit_generic_builtin(node) - - else: - raise NotImplementedError(f"Invalid 'FunCall' node: {node}.") - - def visit_Lambda( - self, node: gtir.Lambda, args: list[IteratorExpr | MemletExpr | SymbolExpr] - ) -> tuple[list[DataflowInputEdge], DataflowOutputEdge]: - for p, arg in zip(node.params, args, strict=True): - self.symbol_map[str(p.id)] = arg - output_expr: ValueExpr = self.visit(node.expr) - if isinstance(output_expr, DataExpr): - return self.input_edges, DataflowOutputEdge(self.state, output_expr) - - if isinstance(output_expr, MemletExpr): - # special case where the field operator is simply copying data from source to destination node - output_dtype = output_expr.node.desc(self.sdfg).dtype - tasklet_node = self._add_tasklet("copy", {"__inp"}, {"__out"}, "__out = __inp") - self._add_input_data_edge( - output_expr.node, - output_expr.subset, - tasklet_node, - "__inp", - ) - else: - assert isinstance(output_expr, SymbolExpr) - # even simpler case, where a constant value is written to destination node - output_dtype = output_expr.dtype - tasklet_node = self._add_tasklet("write", {}, {"__out"}, f"__out = {output_expr.value}") - - output_expr = self._construct_tasklet_result(output_dtype, tasklet_node, "__out") - return self.input_edges, DataflowOutputEdge(self.state, output_expr) - - def visit_Literal(self, node: gtir.Literal) -> SymbolExpr: - dtype = dace_utils.as_dace_type(node.type) - return SymbolExpr(node.value, dtype) - - def visit_SymRef(self, node: gtir.SymRef) -> IteratorExpr | MemletExpr | SymbolExpr: - param = str(node.id) - if param in self.symbol_map: - return self.symbol_map[param] - # if not in the lambda symbol map, this must be a symref to a builtin function - assert param in gtir_python_codegen.MATH_BUILTINS_MAPPING - return SymbolExpr(param, dace.string) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py deleted file mode 100644 index 7d878dde99..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ /dev/null @@ -1,666 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -""" -Contains visitors to lower GTIR to DaCe SDFG. - -Note: this module covers the fieldview flavour of GTIR. -""" - -from __future__ import annotations - -import abc -import dataclasses -import itertools -from typing import Any, Dict, Iterable, List, Optional, Protocol, Sequence, Set, Tuple, Union - -import dace - -from gt4py import eve -from gt4py.eve import concepts -from gt4py.next import common as gtx_common, utils as gtx_utils -from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.type_system import inference as gtir_type_inference -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.program_processors.runners.dace_fieldview import ( - gtir_builtin_translators, - gtir_dataflow, - transformations as gtx_transformations, - utility as dace_gtir_utils, -) -from gt4py.next.type_system import type_specifications as ts, type_translation as tt - - -class DataflowBuilder(Protocol): - """Visitor interface to build a dataflow subgraph.""" - - @abc.abstractmethod - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: ... - - @abc.abstractmethod - def unique_map_name(self, name: str) -> str: ... - - @abc.abstractmethod - def unique_tasklet_name(self, name: str) -> str: ... - - def add_map( - self, - name: str, - state: dace.SDFGState, - ndrange: Union[ - Dict[str, Union[str, dace.subsets.Subset]], - List[Tuple[str, Union[str, dace.subsets.Subset]]], - ], - **kwargs: Any, - ) -> Tuple[dace.nodes.MapEntry, dace.nodes.MapExit]: - """Wrapper of `dace.SDFGState.add_map` that assigns unique name.""" - unique_name = self.unique_map_name(name) - return state.add_map(unique_name, ndrange, **kwargs) - - def add_tasklet( - self, - name: str, - state: dace.SDFGState, - inputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], - outputs: Union[Set[str], Dict[str, dace.dtypes.typeclass]], - code: str, - **kwargs: Any, - ) -> dace.nodes.Tasklet: - """Wrapper of `dace.SDFGState.add_tasklet` that assigns unique name.""" - unique_name = self.unique_tasklet_name(name) - return state.add_tasklet(unique_name, inputs, outputs, code, **kwargs) - - -class SDFGBuilder(DataflowBuilder, Protocol): - """Visitor interface available to GTIR-primitive translators.""" - - @abc.abstractmethod - def get_symbol_type(self, symbol_name: str) -> ts.DataType: - """Retrieve the GT4Py type of a symbol used in the program.""" - ... - - @abc.abstractmethod - def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: - """Visit a node of the GT4Py IR.""" - ... - - -@dataclasses.dataclass(frozen=True) -class GTIRToSDFG(eve.NodeVisitor, SDFGBuilder): - """Provides translation capability from a GTIR program to a DaCe SDFG. - - This class is responsible for translation of `ir.Program`, that is the top level representation - of a GT4Py program as a sequence of `ir.Stmt` (aka statement) expressions. - Each statement is translated to a taskgraph inside a separate state. Statement states are chained - one after the other: concurrency between states should be extracted by means of SDFG analysis. - The translator will extend the SDFG while preserving the property of single exit state: - branching is allowed within the context of one statement, but in that case the statement should - terminate with a join state; the join state will represent the head state for next statement, - from where to continue building the SDFG. - """ - - offset_provider: gtx_common.OffsetProvider - global_symbols: dict[str, ts.DataType] = dataclasses.field(default_factory=lambda: {}) - map_uids: eve.utils.UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="map") - ) - tesklet_uids: eve.utils.UIDGenerator = dataclasses.field( - init=False, repr=False, default_factory=lambda: eve.utils.UIDGenerator(prefix="tlet") - ) - - def get_offset_provider(self, offset: str) -> gtx_common.OffsetProviderElem: - return self.offset_provider[offset] - - def get_symbol_type(self, symbol_name: str) -> ts.DataType: - return self.global_symbols[symbol_name] - - def unique_map_name(self, name: str) -> str: - return f"{self.map_uids.sequential_id()}_{name}" - - def unique_tasklet_name(self, name: str) -> str: - return f"{self.tesklet_uids.sequential_id()}_{name}" - - def _make_array_shape_and_strides( - self, name: str, dims: Sequence[gtx_common.Dimension] - ) -> tuple[list[dace.symbol], list[dace.symbol]]: - """ - Parse field dimensions and allocate symbols for array shape and strides. - - For local dimensions, the size is known at compile-time and therefore - the corresponding array shape dimension is set to an integer literal value. - - Returns: - Two lists of symbols, one for the shape and the other for the strides of the array. - """ - dtype = dace.int32 - neighbor_tables = dace_utils.filter_connectivities(self.offset_provider) - shape = [ - ( - neighbor_tables[dim.value].max_neighbors - if dim.kind == gtx_common.DimensionKind.LOCAL - else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) - ) - for i, dim in enumerate(dims) - ] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) - for i in range(len(dims)) - ] - return shape, strides - - def _add_storage( - self, - sdfg: dace.SDFG, - name: str, - symbol_type: ts.DataType, - transient: bool = True, - is_tuple_member: bool = False, - ) -> list[tuple[str, ts.DataType]]: - """ - Add storage for data containers used in the SDFG. For fields, it allocates dace arrays, - while scalars are stored as SDFG symbols. - - The fields used as temporary arrays, when `transient = True`, are allocated and exist - only within the SDFG; when `transient = False`, the fields have to be allocated outside - and have to be passed as array arguments to the SDFG. - - Returns: - List of data containers or symbols allocated as storage. This is a list, not a single value, - because in case of tuples we flat the tuple fields (eventually nested) and allocate storage - for each tuple element. - """ - if isinstance(symbol_type, ts.TupleType): - tuple_fields = [] - for tname, tsymbol_type in dace_gtir_utils.get_tuple_fields( - name, symbol_type, flatten=True - ): - tuple_fields.extend( - self._add_storage(sdfg, tname, tsymbol_type, transient, is_tuple_member=True) - ) - return tuple_fields - - elif isinstance(symbol_type, ts.FieldType): - dtype = dace_utils.as_dace_type(symbol_type.dtype) - # use symbolic shape, which allows to invoke the program with fields of different size; - # and symbolic strides, which enables decoupling the memory layout from generated code. - sym_shape, sym_strides = self._make_array_shape_and_strides(name, symbol_type.dims) - sdfg.add_array(name, sym_shape, dtype, strides=sym_strides, transient=transient) - - return [(name, symbol_type)] - - elif isinstance(symbol_type, ts.ScalarType): - dtype = dace_utils.as_dace_type(symbol_type) - # Scalar arguments passed to the program are represented as symbols in DaCe SDFG; - # the exception are members of tuple arguments, that are represented as scalar containers. - # The field size is sometimes passed as scalar argument to the program, so we have to - # check if the shape symbol was already allocated by `_make_array_shape_and_strides`. - # We assume that the scalar argument for field size always follows the field argument. - if is_tuple_member: - sdfg.add_scalar(name, dtype, transient=transient) - elif name in sdfg.symbols: - assert sdfg.symbols[name].dtype == dtype - else: - sdfg.add_symbol(name, dtype) - - return [(name, symbol_type)] - - raise RuntimeError(f"Data type '{type(symbol_type)}' not supported.") - - def _add_storage_for_temporary(self, temp_decl: gtir.Temporary) -> dict[str, str]: - """ - Add temporary storage (aka transient) for data containers used as GTIR temporaries. - - Assume all temporaries to be fields, therefore represented as dace arrays. - """ - raise NotImplementedError("Temporaries not supported yet by GTIR DaCe backend.") - - def _visit_expression( - self, node: gtir.Expr, sdfg: dace.SDFG, head_state: dace.SDFGState, use_temp: bool = True - ) -> list[gtir_builtin_translators.Field]: - """ - Specialized visit method for fieldview expressions. - - This method represents the entry point to visit `ir.Stmt` expressions. - As such, it must preserve the property of single exit state in the SDFG. - - Returns: - A list of array nodes containing the result fields. - - TODO: Do we need to return the GT4Py `FieldType`/`ScalarType`? It is needed - in case the transient arrays containing the expression result are not guaranteed - to have the same memory layout as the target array. - """ - result = self.visit(node, sdfg=sdfg, head_state=head_state, reduce_identity=None) - - # sanity check: each statement should preserve the property of single exit state (aka head state), - # i.e. eventually only introduce internal branches, and keep the same head state - sink_states = sdfg.sink_nodes() - assert len(sink_states) == 1 - assert sink_states[0] == head_state - - def make_temps(field: gtir_builtin_translators.Field) -> gtir_builtin_translators.Field: - desc = sdfg.arrays[field.data_node.data] - if desc.transient or not use_temp: - return field - else: - temp, _ = sdfg.add_temp_transient_like(desc) - temp_node = head_state.add_access(temp) - head_state.add_nedge( - field.data_node, temp_node, sdfg.make_array_memlet(field.data_node.data) - ) - return gtir_builtin_translators.Field(temp_node, field.data_type) - - temp_result = gtx_utils.tree_map(make_temps)(result) - return list(gtx_utils.flatten_nested_tuple((temp_result,))) - - def _add_sdfg_params(self, sdfg: dace.SDFG, node_params: Sequence[gtir.Sym]) -> list[str]: - """Helper function to add storage for node parameters and connectivity tables.""" - # add non-transient arrays and/or SDFG symbols for the program arguments - sdfg_args = [] - for param in node_params: - pname = str(param.id) - assert isinstance(param.type, (ts.DataType)) - sdfg_args += self._add_storage(sdfg, pname, param.type, transient=False) - self.global_symbols[pname] = param.type - - # add SDFG storage for connectivity tables - for offset, offset_provider in dace_utils.filter_connectivities( - self.offset_provider - ).items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = gtx_common.Dimension(offset, kind=gtx_common.DimensionKind.LOCAL) - type_ = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) - ) - # We store all connectivity tables as transient arrays here; later, while building - # the field operator expressions, we change to non-transient (i.e. allocated externally) - # the tables that are actually used. This way, we avoid adding SDFG arguments for - # the connectivity tables that are not used. The remaining unused transient arrays - # are removed by the dace simplify pass. - self._add_storage(sdfg, dace_utils.connectivity_identifier(offset), type_) - - # the list of all sdfg arguments (aka non-transient arrays) which include tuple-element fields - return [arg_name for arg_name, _ in sdfg_args] - - def visit_Program(self, node: gtir.Program) -> dace.SDFG: - """Translates `ir.Program` to `dace.SDFG`. - - First, it will allocate field and scalar storage for global data. The storage - represents global data, available everywhere in the SDFG, either containing - external data (aka non-transient data) or temporary data (aka transient data). - The temporary data is global, therefore available everywhere in the SDFG - but not outside. Then, all statements are translated, one after the other. - """ - if node.function_definitions: - raise NotImplementedError("Functions expected to be inlined as lambda calls.") - - sdfg = dace.SDFG(node.id) - sdfg.debuginfo = dace_utils.debug_info(node, default=sdfg.debuginfo) - entry_state = sdfg.add_state("program_entry", is_start_block=True) - - # declarations of temporaries result in transient array definitions in the SDFG - if node.declarations: - temp_symbols: dict[str, str] = {} - for decl in node.declarations: - temp_symbols |= self._add_storage_for_temporary(decl) - - # define symbols for shape and offsets of temporary arrays as interstate edge symbols - head_state = sdfg.add_state_after(entry_state, "init_temps", assignments=temp_symbols) - else: - head_state = entry_state - - sdfg_arg_names = self._add_sdfg_params(sdfg, node.params) - - # visit one statement at a time and expand the SDFG from the current head state - for i, stmt in enumerate(node.body): - # include `debuginfo` only for `ir.Program` and `ir.Stmt` nodes: finer granularity would be too messy - head_state = sdfg.add_state_after(head_state, f"stmt_{i}") - head_state._debuginfo = dace_utils.debug_info(stmt, default=sdfg.debuginfo) - head_state = self.visit(stmt, sdfg=sdfg, state=head_state) - - # Create the call signature for the SDFG. - # Only the arguments required by the GT4Py program, i.e. `node.params`, are added - # as positional arguments. The implicit arguments, such as the offset providers or - # the arguments created by the translation process, must be passed as keyword arguments. - sdfg.arg_names = sdfg_arg_names - - sdfg.validate() - return sdfg - - def visit_SetAt( - self, stmt: gtir.SetAt, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dace.SDFGState: - """Visits a `SetAt` statement expression and writes the local result to some external storage. - - Each statement expression results in some sort of dataflow gragh writing to temporary storage. - The translation of `SetAt` ensures that the result is written back to the target external storage. - - Returns: - The SDFG head state, eventually updated if the target write requires a new state. - """ - - temp_fields = self._visit_expression(stmt.expr, sdfg, state) - - # the target expression could be a `SymRef` to an output node or a `make_tuple` expression - # in case the statement returns more than one field - target_fields = self._visit_expression(stmt.target, sdfg, state, use_temp=False) - - # convert domain expression to dictionary to ease access to dimension boundaries - domain = dace_gtir_utils.get_domain_ranges(stmt.domain) - - expr_input_args = { - sym_id - for sym in eve.walk_values(stmt.expr).if_isinstance(gtir.SymRef) - if (sym_id := str(sym.id)) in sdfg.arrays - } - state_input_data = { - node.data - for node in state.data_nodes() - if node.data in expr_input_args and state.degree(node) != 0 - } - - target_state: Optional[dace.SDFGState] = None - for temp, target in zip(temp_fields, target_fields, strict=True): - target_desc = sdfg.arrays[target.data_node.data] - assert not target_desc.transient - - if isinstance(target.data_type, ts.FieldType): - subset = ",".join( - f"{domain[dim][0]}:{domain[dim][1]}" for dim in target.data_type.dims - ) - else: - assert len(domain) == 0 - subset = "0" - - if target.data_node.data in state_input_data: - # if inout argument, write the result in separate next state - # this is needed to avoid undefined behavior for expressions like: X, Y = X + 1, X - if not target_state: - target_state = sdfg.add_state_after(state, f"post_{state.label}") - # create new access nodes in the target state - target_state.add_nedge( - target_state.add_access(temp.data_node.data), - target_state.add_access(target.data_node.data), - dace.Memlet(data=target.data_node.data, subset=subset, other_subset=subset), - ) - # remove isolated access node - state.remove_node(target.data_node) - else: - state.add_nedge( - temp.data_node, - target.data_node, - dace.Memlet(data=target.data_node.data, subset=subset, other_subset=subset), - ) - - return target_state or state - - def visit_FunCall( - self, - node: gtir.FunCall, - sdfg: dace.SDFG, - head_state: dace.SDFGState, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], - ) -> gtir_builtin_translators.FieldopResult: - # use specialized dataflow builder classes for each builtin function - if cpm.is_call_to(node, "if_"): - return gtir_builtin_translators.translate_if( - node, sdfg, head_state, self, reduce_identity - ) - elif cpm.is_call_to(node, "make_tuple"): - return gtir_builtin_translators.translate_make_tuple( - node, sdfg, head_state, self, reduce_identity - ) - elif cpm.is_call_to(node, "tuple_get"): - return gtir_builtin_translators.translate_tuple_get( - node, sdfg, head_state, self, reduce_identity - ) - elif cpm.is_applied_as_fieldop(node): - return gtir_builtin_translators.translate_as_field_op( - node, sdfg, head_state, self, reduce_identity - ) - elif isinstance(node.fun, gtir.Lambda): - lambda_args = [ - self.visit( - arg, - sdfg=sdfg, - head_state=head_state, - reduce_identity=reduce_identity, - ) - for arg in node.args - ] - - return self.visit( - node.fun, - sdfg=sdfg, - head_state=head_state, - reduce_identity=reduce_identity, - args=lambda_args, - ) - elif isinstance(node.type, ts.ScalarType): - return gtir_builtin_translators.translate_scalar_expr( - node, sdfg, head_state, self, reduce_identity - ) - else: - raise NotImplementedError(f"Unexpected 'FunCall' expression ({node}).") - - def visit_Lambda( - self, - node: gtir.Lambda, - sdfg: dace.SDFG, - head_state: dace.SDFGState, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], - args: list[gtir_builtin_translators.FieldopResult], - ) -> gtir_builtin_translators.FieldopResult: - """ - Translates a `Lambda` node to a nested SDFG in the current state. - - All arguments to lambda functions are fields (i.e. `as_fieldop`, field or scalar `gtir.SymRef`, - nested let-lambdas thereof). The reason for creating a nested SDFG is to define local symbols - (the lambda paremeters) that map to parent fields, either program arguments or temporary fields. - - If the lambda has a parameter whose name is already present in `GTIRToSDFG.global_symbols`, - i.e. a lambda parameter with the same name as a symbol in scope, the parameter will shadow - the previous symbol during traversal of the lambda expression. - """ - lambda_args_mapping = [ - (str(param.id), arg) for param, arg in zip(node.params, args, strict=True) - ] - - # inherit symbols from parent scope but eventually override with local symbols - lambda_symbols = self.global_symbols | { - pname: dace_gtir_utils.get_tuple_type(arg) if isinstance(arg, tuple) else arg.data_type - for pname, arg in lambda_args_mapping - } - - # lower let-statement lambda node as a nested SDFG - lambda_translator = GTIRToSDFG(self.offset_provider, lambda_symbols) - nsdfg = dace.SDFG(f"{sdfg.label}_lambda") - nstate = nsdfg.add_state("lambda") - - # add sdfg storage for the symbols that need to be passed as input parameters - lambda_translator._add_sdfg_params( - nsdfg, - node_params=[ - gtir.Sym(id=p_name, type=p_type) for p_name, p_type in lambda_symbols.items() - ], - ) - - lambda_result = lambda_translator.visit( - node.expr, - sdfg=nsdfg, - head_state=nstate, - reduce_identity=reduce_identity, - ) - - def _flatten_tuples( - name: str, - arg: gtir_builtin_translators.FieldopResult, - ) -> list[tuple[str, gtir_builtin_translators.Field]]: - if isinstance(arg, tuple): - tuple_type = dace_gtir_utils.get_tuple_type(arg) - tuple_field_names = [ - arg_name for arg_name, _ in dace_gtir_utils.get_tuple_fields(name, tuple_type) - ] - tuple_args = zip(tuple_field_names, arg, strict=True) - return list( - itertools.chain(*[_flatten_tuples(fname, farg) for fname, farg in tuple_args]) - ) - else: - return [(name, arg)] - - # Process lambda inputs - # - lambda_arg_nodes = dict( - itertools.chain(*[_flatten_tuples(pname, arg) for pname, arg in lambda_args_mapping]) - ) - connectivity_arrays = { - dace_utils.connectivity_identifier(offset) - for offset in dace_utils.filter_connectivities(self.offset_provider) - } - - input_memlets = {} - nsdfg_symbols_mapping: dict[str, dace.symbolic.SymExpr] = {} - for nsdfg_dataname, nsdfg_datadesc in nsdfg.arrays.items(): - if nsdfg_datadesc.transient: - continue - datadesc: Optional[dace.dtypes.Data] = None - if nsdfg_dataname in lambda_arg_nodes: - src_node = lambda_arg_nodes[nsdfg_dataname].data_node - dataname = src_node.data - datadesc = src_node.desc(sdfg) - else: - dataname = nsdfg_dataname - datadesc = sdfg.arrays[nsdfg_dataname] - - # ensure that connectivity tables are non-transient arrays in parent SDFG - if dataname in connectivity_arrays: - datadesc.transient = False - - input_memlets[nsdfg_dataname] = sdfg.make_array_memlet(dataname) - - nsdfg_symbols_mapping |= { - str(nested_symbol): parent_symbol - for nested_symbol, parent_symbol in zip( - [*nsdfg_datadesc.shape, *nsdfg_datadesc.strides], - [*datadesc.shape, *datadesc.strides], - strict=True, - ) - if isinstance(nested_symbol, dace.symbol) - } - - # Process lambda outputs - # - lambda_output_nodes: Iterable[gtir_builtin_translators.Field] = ( - gtx_utils.flatten_nested_tuple(lambda_result) - ) - # sanity check on isolated nodes - assert all( - nstate.degree(x.data_node) == 0 - for x in lambda_output_nodes - if x.data_node.data in input_memlets - ) - # keep only non-isolated output nodes - lambda_outputs = { - x.data_node.data for x in lambda_output_nodes if x.data_node.data not in input_memlets - } - - if lambda_outputs: - nsdfg_node = head_state.add_nested_sdfg( - nsdfg, - parent=sdfg, - inputs=set(input_memlets.keys()), - outputs=lambda_outputs, - symbol_mapping=nsdfg_symbols_mapping, - debuginfo=dace_utils.debug_info(node, default=sdfg.debuginfo), - ) - - for connector, memlet in input_memlets.items(): - if connector in lambda_arg_nodes: - src_node = lambda_arg_nodes[connector].data_node - else: - src_node = head_state.add_access(memlet.data) - - head_state.add_edge(src_node, None, nsdfg_node, connector, memlet) - - def make_temps( - x: gtir_builtin_translators.Field, - ) -> gtir_builtin_translators.Field: - if x.data_node.data in lambda_outputs: - connector = x.data_node.data - desc = x.data_node.desc(nsdfg) - # make lambda result non-transient and map it to external temporary - desc.transient = False - # isolated access node will make validation fail - if nstate.degree(x.data_node) == 0: - nstate.remove_node(x.data_node) - temp, _ = sdfg.add_temp_transient_like(desc) - dst_node = head_state.add_access(temp) - head_state.add_edge( - nsdfg_node, connector, dst_node, None, sdfg.make_array_memlet(temp) - ) - return gtir_builtin_translators.Field(dst_node, x.data_type) - elif x.data_node.data in lambda_arg_nodes: - nstate.remove_node(x.data_node) - return lambda_arg_nodes[x.data_node.data] - else: - nstate.remove_node(x.data_node) - data_node = head_state.add_access(x.data_node.data) - return gtir_builtin_translators.Field(data_node, x.data_type) - - return gtx_utils.tree_map(make_temps)(lambda_result) - - def visit_Literal( - self, - node: gtir.Literal, - sdfg: dace.SDFG, - head_state: dace.SDFGState, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], - ) -> gtir_builtin_translators.FieldopResult: - return gtir_builtin_translators.translate_literal( - node, sdfg, head_state, self, reduce_identity=None - ) - - def visit_SymRef( - self, - node: gtir.SymRef, - sdfg: dace.SDFG, - head_state: dace.SDFGState, - reduce_identity: Optional[gtir_dataflow.SymbolExpr], - ) -> gtir_builtin_translators.FieldopResult: - return gtir_builtin_translators.translate_symbol_ref( - node, sdfg, head_state, self, reduce_identity=None - ) - - -def build_sdfg_from_gtir( - ir: gtir.Program, - offset_provider: gtx_common.OffsetProvider, -) -> dace.SDFG: - """ - Receives a GTIR program and lowers it to a DaCe SDFG. - - The lowering to SDFG requires that the program node is type-annotated, therefore this function - runs type ineference as first step. - As a final step, it runs the `simplify` pass to ensure that the SDFG is in the DaCe canonical form. - - Arguments: - ir: The GTIR program node to be lowered to SDFG - offset_provider: The definitions of offset providers used by the program node - - Returns: - An SDFG in the DaCe canonical form (simplified) - """ - ir = gtir_type_inference.infer(ir, offset_provider=offset_provider) - ir = dace_gtir_utils.patch_gtir(ir) - sdfg_genenerator = GTIRToSDFG(offset_provider) - sdfg = sdfg_genenerator.visit(ir) - assert isinstance(sdfg, dace.SDFG) - - gtx_transformations.gt_simplify(sdfg) - return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py deleted file mode 100644 index 8852dd6d2d..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Transformation and optimization pipeline for the DaCe backend in GT4Py. - -Please also see [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md) -that explains the general structure and requirements on the SDFGs. -""" - -from .auto_opt import ( - GT_SIMPLIFY_DEFAULT_SKIP_SET, - gt_auto_optimize, - gt_inline_nested_sdfg, - gt_set_iteration_order, - gt_simplify, -) -from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize -from .loop_blocking import LoopBlocking -from .map_orderer import MapIterationOrder -from .map_promoter import SerialMapPromoter -from .map_serial_fusion import SerialMapFusion - - -__all__ = [ - "GT_SIMPLIFY_DEFAULT_SKIP_SET", - "GPUSetBlockSize", - "LoopBlocking", - "MapIterationOrder", - "SerialMapFusion", - "SerialMapPromoter", - "SerialMapPromoterGPU", - "gt_auto_optimize", - "gt_gpu_transformation", - "gt_inline_nested_sdfg", - "gt_set_iteration_order", - "gt_set_gpu_blocksize", - "gt_simplify", -] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py deleted file mode 100644 index 16c9600a3a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ /dev/null @@ -1,394 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Functions for turning an SDFG into a GPU SDFG.""" - -from __future__ import annotations - -import copy -from typing import Any, Optional, Sequence, Union - -import dace -from dace import properties as dace_properties, transformation as dace_transformation -from dace.sdfg import nodes as dace_nodes - -from gt4py.next.program_processors.runners.dace_fieldview import ( - transformations as gtx_transformations, -) - - -def gt_gpu_transformation( - sdfg: dace.SDFG, - try_removing_trivial_maps: bool = True, - use_gpu_storage: bool = True, - gpu_block_size: Optional[Sequence[int | str] | str] = None, - gpu_launch_bounds: Optional[int | str] = None, - gpu_launch_factor: Optional[int] = None, - validate: bool = True, - validate_all: bool = False, - **kwargs: Any, -) -> dace.SDFG: - """Transform an SDFG into a GPU SDFG. - - The transformation expects a rather optimized SDFG and turn it into an SDFG - capable of running on the GPU. - The function performs the following steps: - - If requested, modify the storage location of the non transient arrays such - that they reside in GPU memory. - - Call the normal GPU transform function followed by simplify. - - If requested try to remove trivial kernels. - - If specified, set the `gpu_block_size` parameters of the Maps to the given value. - - Args: - sdfg: The SDFG that should be processed. - try_removing_trivial_maps: Try to get rid of trivial maps by incorporating them. - use_gpu_storage: Assume that the non global memory is already on the GPU. This - will avoid the data copy from host to GPU memory. - gpu_block_size: The size of a thread block on the GPU. - gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. - gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` - validate: Perform validation during the steps. - validate_all: Perform extensive validation. - - Notes: - The function might modify the order of the iteration variables of some - maps. - In addition it might fuse Maps together that should not be fused. To prevent - that you should set `try_removing_trivial_maps` to `False`. - - Todo: - - Solve the fusing problem. - - Currently only one block size for all maps is given, add more options. - """ - assert ( - len(kwargs) == 0 - ), f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" - - # Turn all global arrays (which we identify as input) into GPU memory. - # This way the GPU transformation will not create this copying stuff. - if use_gpu_storage: - for desc in sdfg.arrays.values(): - if isinstance(desc, dace.data.Array) and not desc.transient: - desc.storage = dace.dtypes.StorageType.GPU_Global - - # Now turn it into a GPU SDFG - sdfg.apply_gpu_transformations( - validate=validate, - validate_all=validate_all, - simplify=False, - ) - # The documentation recommends to run simplify afterwards - gtx_transformations.gt_simplify(sdfg) - - if try_removing_trivial_maps: - # A Tasklet, outside of a Map, that writes into an array on GPU can not work - # `sdfg.appyl_gpu_transformations()` puts Map around it (if said Tasklet - # would write into a Scalar that then goes into a GPU Map, nothing would - # happen. So we might end up with lot of these trivial Maps, that results - # in a single kernel launch. To prevent this we will try to fuse them. - # NOTE: The current implementation has a bug, because promotion and fusion - # are two different steps. Because of this the function will implicitly - # fuse everything together it can find. - # TODO(phimuell): Fix the issue described above. - sdfg.apply_transformations_once_everywhere( - TrivialGPUMapPromoter(), - validate=False, - validate_all=False, - ) - sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), - validate=validate, - validate_all=validate_all, - ) - - # Set the GPU block size if it is known. - if gpu_block_size is not None: - gt_set_gpu_blocksize( - sdfg=sdfg, - gpu_block_size=gpu_block_size, - gpu_launch_bounds=gpu_launch_bounds, - gpu_launch_factor=gpu_launch_factor, - ) - - return sdfg - - -def gt_set_gpu_blocksize( - sdfg: dace.SDFG, - gpu_block_size: Optional[Sequence[int | str] | str], - gpu_launch_bounds: Optional[int | str] = None, - gpu_launch_factor: Optional[int] = None, -) -> Any: - """Set the block size related properties of _all_ Maps. - - See `GPUSetBlockSize` for more information. - - Args: - sdfg: The SDFG to process. - gpu_block_size: The size of a thread block on the GPU. - launch_bounds: The value for the launch bound that should be used. - launch_factor: If no `launch_bounds` was given use the number of threads - in a block multiplied by this number. - """ - xform = GPUSetBlockSize( - block_size=gpu_block_size, - launch_bounds=gpu_launch_bounds, - launch_factor=gpu_launch_factor, - ) - return sdfg.apply_transformations_once_everywhere([xform]) - - -def _gpu_block_parser( - self: GPUSetBlockSize, - val: Any, -) -> None: - """Used by the setter of `GPUSetBlockSize.block_size`.""" - org_val = val - if isinstance(val, (tuple | list)): - pass - elif isinstance(val, str): - val = tuple(x.strip() for x in val.split(",")) - elif isinstance(val, int): - val = (val,) - else: - raise TypeError( - f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." - ) - if 0 < len(val) <= 3: - val = [*val, *([1] * (3 - len(val)))] - else: - raise ValueError(f"Can not parse block size '{org_val}': wrong length") - try: - val = [int(x) for x in val] - except ValueError: - raise TypeError( - f"Currently only block sizes convertible to int are supported, you passed '{val}'." - ) from None - self._block_size = val - - -def _gpu_block_getter( - self: "GPUSetBlockSize", -) -> tuple[int, int, int]: - """Used as getter in the `GPUSetBlockSize.block_size` property.""" - assert isinstance(self._block_size, (tuple, list)) and len(self._block_size) == 3 - assert all(isinstance(x, int) for x in self._block_size) - return tuple(self._block_size) - - -@dace_properties.make_properties -class GPUSetBlockSize(dace_transformation.SingleStateTransformation): - """Sets the GPU block size on GPU Maps. - - The transformation will apply to all Maps that have a GPU schedule, regardless - of their dimensionality. - - The `gpu_block_size` is either a sequence, of up to three integers or a string - of up to three numbers, separated by comma (`,`). - The first number is the size of the block in `x` direction, the second for the - `y` direction and the third for the `z` direction. Missing values will be filled - with `1`. - - Args: - block_size: The size of a thread block on the GPU. - launch_bounds: The value for the launch bound that should be used. - launch_factor: If no `launch_bounds` was given use the number of threads - in a block multiplied by this number. - - Todo: - Add the possibility to specify other bounds for 1, 2, or 3 dimensional maps. - """ - - block_size = dace_properties.Property( - dtype=None, - allow_none=False, - default=(32, 1, 1), - setter=_gpu_block_parser, - getter=_gpu_block_getter, - desc="Size of the block size a GPU Map should have.", - ) - - launch_bounds = dace_properties.Property( - dtype=str, - allow_none=True, - default=None, - desc="Set the launch bound property of the map.", - ) - - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) - - def __init__( - self, - block_size: Sequence[int | str] | str | None = None, - launch_bounds: int | str | None = None, - launch_factor: int | None = None, - ) -> None: - super().__init__() - if block_size is not None: - self.block_size = block_size - - if launch_factor is not None: - assert launch_bounds is None - self.launch_bounds = str( - int(launch_factor) * self.block_size[0] * self.block_size[1] * self.block_size[2] - ) - elif launch_bounds is None: - self.launch_bounds = None - elif isinstance(launch_bounds, (str, int)): - self.launch_bounds = str(launch_bounds) - else: - raise TypeError( - f"Does not know how to parse '{launch_bounds}' as 'launch_bounds' argument." - ) - - @classmethod - def expressions(cls) -> Any: - return [dace.sdfg.utils.node_path_graph(cls.map_entry)] - - def can_be_applied( - self, - graph: Union[dace.SDFGState, dace.SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Test if the block size can be set. - - The function tests: - - If the block size of the map is already set. - - If the map is at global scope. - - If if the schedule of the map is correct. - """ - - scope = graph.scope_dict() - if scope[self.map_entry] is not None: - return False - if self.map_entry.map.schedule not in dace.dtypes.GPU_SCHEDULES: - return False - if self.map_entry.map.gpu_block_size is not None: - return False - return True - - def apply( - self, - graph: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - ) -> None: - """Modify the map as requested.""" - self.map_entry.map.gpu_block_size = self.block_size - if self.launch_bounds is not None: # Note empty string has a meaning in DaCe - self.map_entry.map.gpu_launch_bounds = self.launch_bounds - - -@dace_properties.make_properties -class TrivialGPUMapPromoter(dace_transformation.SingleStateTransformation): - """Serial Map promoter for empty GPU maps. - - In CPU mode a Tasklet can be outside of a map, however, this is not - possible in GPU mode. For this reason DaCe wraps such Tasklets in a - trivial Map. - This transformation will look for such Maps and promote them, such - that they can be fused with downstream maps. - - Note: - - This transformation should not be run on its own, instead it - is run within the context of `gt_gpu_transformation()`. - - This transformation must be run after the GPU Transformation. - - Currently the transformation does not do the fusion on its own. - Instead map fusion must be run afterwards. - - The transformation assumes that the upper Map is a trivial Tasklet. - Which should be the majority of all cases. - """ - - # Pattern Matching - trivial_map_exit = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - second_map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) - - @classmethod - def expressions(cls) -> Any: - return [ - dace.sdfg.utils.node_path_graph( - cls.trivial_map_exit, cls.access_node, cls.second_map_entry - ) - ] - - def can_be_applied( - self, - graph: Union[dace.SDFGState, dace.SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Tests if the promotion is possible. - - The tests includes: - - Schedule of the maps. - - If the map is trivial. - - If the trivial map was not used to define a symbol. - - Intermediate access node can only have in and out degree of 1. - - The trivial map exit can only have one output. - """ - trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit - trivial_map: dace_nodes.Map = trivial_map_exit.map - trivial_map_entry: dace_nodes.MapEntry = graph.entry_node(trivial_map_exit) - second_map: dace_nodes.Map = self.second_map_entry.map - access_node: dace_nodes.AccessNode = self.access_node - - # The kind of maps we are interested only have one parameter. - if len(trivial_map.params) != 1: - return False - - # Check if it is a GPU map - for map_to_check in [trivial_map, second_map]: - if map_to_check.schedule not in [ - dace.dtypes.ScheduleType.GPU_Device, - dace.dtypes.ScheduleType.GPU_Default, - ]: - return False - - # Check if the map is trivial. - for rng in trivial_map.range.ranges: - if rng[0] != rng[1]: - return False - - # Now we have to ensure that the symbol is not used inside the scope of the - # map, if it is, then the symbol is just there to define a symbol. - scope_view = graph.scope_subgraph( - trivial_map_entry, - include_entry=False, - include_exit=False, - ) - if any(map_param in scope_view.free_symbols for map_param in trivial_map.params): - return False - - # ensuring that the trivial map exit and the intermediate node have degree - # one is a cheap way to ensure that the map can be merged into the - # second map. - if graph.in_degree(access_node) != 1: - return False - if graph.out_degree(access_node) != 1: - return False - if graph.out_degree(trivial_map_exit) != 1: - return False - - return True - - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: - """Performs the Map Promoting. - - The function essentially copies the parameters and the ranges from the - bottom map to the top one. - """ - trivial_map: dace_nodes.Map = self.trivial_map_exit.map - second_map: dace_nodes.Map = self.second_map_entry.map - - trivial_map.params = copy.deepcopy(second_map.params) - trivial_map.range = copy.deepcopy(second_map.range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py deleted file mode 100644 index ec33e7ea63..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ /dev/null @@ -1,566 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements helper functions for the map fusion transformations. - -Note: - After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements - a better map fusion transformation is merged, this file will be deleted. -""" - -import functools -import itertools -from typing import Any, Optional, Sequence, Union - -import dace -from dace import ( - data as dace_data, - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) -from dace.sdfg import graph as dace_graph, nodes as dace_nodes, validation as dace_validation -from dace.transformation import helpers as dace_helpers - -from gt4py.next.program_processors.runners.dace_fieldview.transformations import util - - -@dace_properties.make_properties -class MapFusionHelper(dace_transformation.SingleStateTransformation): - """Contains common part of the fusion for parallel and serial Map fusion. - - The transformation assumes that the SDFG obeys the principals outlined in - [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md). - The main advantage of this structure is, that it is rather easy to determine - if a transient is used anywhere else. This check, performed by - `is_interstate_transient()`. It is further speeded up by cashing some computation, - thus such an object should not be used after interstate optimizations were applied - to the SDFG. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - """ - - only_toplevel_maps = dace_properties.Property( - dtype=bool, - default=False, - allow_none=False, - desc="Only perform fusing if the Maps are in the top level.", - ) - only_inner_maps = dace_properties.Property( - dtype=bool, - default=False, - allow_none=False, - desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", - ) - shared_transients = dace_properties.DictProperty( - key_type=dace.SDFG, - value_type=set[str], - default=None, - allow_none=True, - desc="Maps SDFGs to the set of array transients that can not be removed. " - "The variable acts as a cache, and is managed by 'is_interstate_transient()'.", - ) - - def __init__( - self, - only_inner_maps: Optional[bool] = None, - only_toplevel_maps: Optional[bool] = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - if only_toplevel_maps is not None: - self.only_toplevel_maps = bool(only_toplevel_maps) - if only_inner_maps is not None: - self.only_inner_maps = bool(only_inner_maps) - self.shared_transients = {} - - @classmethod - def expressions(cls) -> bool: - raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") - - def can_be_fused( - self, - map_entry_1: dace_nodes.MapEntry, - map_entry_2: dace_nodes.MapEntry, - graph: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Performs basic checks if the maps can be fused. - - This function only checks constrains that are common between serial and - parallel map fusion process, which includes: - - The scope of the maps. - - The scheduling of the maps. - - The map parameters. - - However, for performance reasons, the function does not check if the node - decomposition exists. - - Args: - map_entry_1: The entry of the first (in serial case the top) map. - map_exit_2: The entry of the second (in serial case the bottom) map. - graph: The SDFGState in which the maps are located. - sdfg: The SDFG itself. - permissive: Currently unused. - """ - if self.only_inner_maps and self.only_toplevel_maps: - raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") - - # Ensure that both have the same schedule - if map_entry_1.map.schedule != map_entry_2.map.schedule: - return False - - # Fusing is only possible if the two entries are in the same scope. - scope = graph.scope_dict() - if scope[map_entry_1] != scope[map_entry_2]: - return False - elif self.only_inner_maps: - if scope[map_entry_1] is None: - return False - elif self.only_toplevel_maps: - if scope[map_entry_1] is not None: - return False - # TODO(phimuell): Figuring out why this is here. - elif util.is_nested_sdfg(sdfg): - return False - - # We will now check if there exists a "remapping" that we can use. - # NOTE: The serial map promoter depends on the fact that this is the - # last check. - if not self.map_parameter_compatible( - map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg - ): - return False - - return True - - @staticmethod - def relocate_nodes( - from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - state: dace.SDFGState, - sdfg: dace.SDFG, - ) -> None: - """Move the connectors and edges from `from_node` to `to_nodes` node. - - This function will only rewire the edges, it does not remove the nodes - themselves. Furthermore, this function should be called twice per Map, - once for the entry and then for the exit. - While it does not remove the node themselves if guarantees that the - `from_node` has degree zero. - - Args: - from_node: Node from which the edges should be removed. - to_node: Node to which the edges should reconnect. - state: The state in which the operation happens. - sdfg: The SDFG that is modified. - """ - - # Now we relocate empty Memlets, from the `from_node` to the `to_node` - for empty_edge in filter(lambda e: e.data.is_empty(), state.out_edges(from_node)): - dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) - for empty_edge in filter(lambda e: e.data.is_empty(), state.in_edges(from_node)): - dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) - - # We now ensure that there is only one empty Memlet from the `to_node` to any other node. - # Although it is allowed, we try to prevent it. - empty_targets: set[dace_nodes.Node] = set() - for empty_edge in filter(lambda e: e.data.is_empty(), state.all_edges(to_node)): - if empty_edge.dst in empty_targets: - state.remove_edge(empty_edge) - empty_targets.add(empty_edge.dst) - - # We now determine which edges we have to migrate, for this we are looking at - # the incoming edges, because this allows us also to detect dynamic map ranges. - for edge_to_move in state.in_edges(from_node): - assert isinstance(edge_to_move.dst_conn, str) - - if not edge_to_move.dst_conn.startswith("IN_"): - # Dynamic Map Range - # The connector name simply defines a variable name that is used, - # inside the Map scope to define a variable. We handle it directly. - dmr_symbol = edge_to_move.dst_conn - - # TODO(phimuell): Check if the symbol is really unused in the target scope. - if dmr_symbol in to_node.in_connectors: - raise NotImplementedError( - f"Tried to move the dynamic map range '{dmr_symbol}' from {from_node}'" - f" to '{to_node}', but the symbol is already known there, but the" - " renaming is not implemented." - ) - if not to_node.add_in_connector(dmr_symbol, force=False): - raise RuntimeError( # Might fail because of out connectors. - f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." - ) - dace_helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) - from_node.remove_in_connector(dmr_symbol) - - # There is no other edge that we have to consider, so we just end here - continue - - # We have a Passthrough connection, i.e. there exists a matching `OUT_`. - old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix - new_conn = to_node.next_connector(old_conn) - - to_node.add_in_connector("IN_" + new_conn) - for e in state.in_edges_by_connector(from_node, "IN_" + old_conn): - dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) - to_node.add_out_connector("OUT_" + new_conn) - for e in state.out_edges_by_connector(from_node, "OUT_" + old_conn): - dace_helpers.redirect_edge( - state, e, new_src=to_node, new_src_conn="OUT_" + new_conn - ) - from_node.remove_in_connector("IN_" + old_conn) - from_node.remove_out_connector("OUT_" + old_conn) - - # Check if we succeeded. - if state.out_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( - f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", - sdfg, - sdfg.node_id(state), - ) - if state.in_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( - f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", - sdfg, - sdfg.node_id(state), - ) - assert len(from_node.in_connectors) == 0 - assert len(from_node.out_connectors) == 0 - - @staticmethod - def map_parameter_compatible( - map_1: dace_nodes.Map, - map_2: dace_nodes.Map, - state: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - ) -> bool: - """Checks if the parameters of `map_1` are compatible with `map_2`. - - The check follows the following rules: - - The names of the map variables must be the same, i.e. no renaming - is performed. - - The ranges must be the same. - """ - range_1: dace_subsets.Range = map_1.range - params_1: Sequence[str] = map_1.params - range_2: dace_subsets.Range = map_2.range - params_2: Sequence[str] = map_2.params - - # The maps are only fuseable if we have an exact match in the parameter names - # this is because we do not do any renaming. This is in accordance with the - # rules. - if set(params_1) != set(params_2): - return False - - # Maps the name of a parameter to the dimension index - param_dim_map_1: dict[str, int] = {pname: i for i, pname in enumerate(params_1)} - param_dim_map_2: dict[str, int] = {pname: i for i, pname in enumerate(params_2)} - - # To fuse the two maps the ranges must have the same ranges - for pname in params_1: - idx_1 = param_dim_map_1[pname] - idx_2 = param_dim_map_2[pname] - # TODO(phimuell): do we need to call simplify? - if range_1[idx_1] != range_2[idx_2]: - return False - - return True - - def is_interstate_transient( - self, - transient: Union[str, dace_nodes.AccessNode], - sdfg: dace.SDFG, - state: dace.SDFGState, - ) -> bool: - """Tests if `transient` is an interstate transient, an can not be removed. - - Essentially this function checks if a transient might be needed in a - different state in the SDFG, because it transmit information from - one state to the other. - If only the name of the data container is passed the function will - first look for an corresponding access node. - - The set of these "interstate transients" is computed once per SDFG. - The result is then cached internally for later reuse. - - Args: - transient: The transient that should be checked. - sdfg: The SDFG containing the array. - state: If given the state the node is located in. - """ - - # The following builds upon the HACK MD document and not on ADR0018. - # Therefore the numbers are slightly different, but both documents - # essentially describes the same SDFG. - # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) - # the set of such transients is partially given by all source access dace_nodes. - # Because of rule 3 we also include all scalars in this set, as an over - # approximation. Furthermore, because simplify might violate rule 3, - # we also include the sink dace_nodes. - - # See if we have already computed the set - if sdfg in self.shared_transients: - shared_sdfg_transients: set[str] = self.shared_transients[sdfg] - else: - # SDFG is not known so we have to compute the set. - shared_sdfg_transients = set() - for state_to_scan in sdfg.all_states(): - # TODO(phimuell): Use `all_nodes_recursive()` once it is available. - shared_sdfg_transients.update( - [ - node.data - for node in itertools.chain( - state_to_scan.source_nodes(), state_to_scan.sink_nodes() - ) - if isinstance(node, dace_nodes.AccessNode) - and sdfg.arrays[node.data].transient - ] - ) - self.shared_transients[sdfg] = shared_sdfg_transients - - if isinstance(transient, str): - name = transient - matching_access_nodes = [node for node in state.data_nodes() if node.data == name] - # Rule 8: There is only one access node per state for data. - assert len(matching_access_nodes) == 1 - transient = matching_access_nodes[0] - else: - assert isinstance(transient, dace_nodes.AccessNode) - name = transient.data - - desc: dace_data.Data = sdfg.arrays[name] - if not desc.transient: - return True - if isinstance(desc, dace_data.Scalar): - return True # Scalars can not be removed by fusion anyway. - - # Rule 8: If degree larger than one then it is used within the state. - if state.out_degree(transient) > 1: - return True - - # Now we check if it is used in a different state. - return name in shared_sdfg_transients - - def partition_first_outputs( - self, - state: dace.SDFGState, - sdfg: dace.SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - ) -> Union[ - tuple[ - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - ], - None, - ]: - """Partition the output edges of `map_exit_1` for serial map fusion. - - The output edges of the first map are partitioned into three distinct sets, - defined as follows: - - - Pure Output Set `\mathbb{P}`: - These edges exits the first map and does not enter the second map. These - outputs will be simply be moved to the output of the second map. - - Exclusive Intermediate Set `\mathbb{E}`: - Edges in this set leaves the first map exit, enters an access node, from - where a Memlet then leads immediately to the second map. The memory - referenced by this access node is not used anywhere else, thus it can - be removed. - - Shared Intermediate Set `\mathbb{S}`: - These edges are very similar to the one in `\mathbb{E}` except that they - are used somewhere else, thus they can not be removed and have to be - recreated as output of the second map. - - Returns: - If such a decomposition exists the function will return the three sets - mentioned above in the same order. - In case the decomposition does not exist, i.e. the maps can not be fused - the function returns `None`. - - Args: - state: The in which the two maps are located. - sdfg: The full SDFG in whcih we operate. - map_exit_1: The exit node of the first map. - map_entry_2: The entry node of the second map. - """ - # The three outputs set. - pure_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - - # Set of intermediate nodes that we have already processed. - processed_inter_nodes: set[dace_nodes.Node] = set() - - # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(map_exit_1): - intermediate_node: dace_nodes.Node = out_edge.dst - - # We already processed the node, this should indicate that we should - # run simplify again, or we should start implementing this case. - if intermediate_node in processed_inter_nodes: - return None - processed_inter_nodes.add(intermediate_node) - - # Now let's look at all nodes that are downstream of the intermediate node. - # This, among other things, will tell us, how we have to handle this node. - downstream_nodes = util.all_nodes_between( - graph=state, - begin=intermediate_node, - end=map_entry_2, - ) - - # If `downstream_nodes` is `None` this means that `map_entry_2` was never - # reached, thus `intermediate_node` does not enter the second map and - # the node is a pure output node. - if downstream_nodes is None: - pure_outputs.add(out_edge) - continue - - # The following tests are _after_ we have determined if we have a pure - # output node, because this allows us to handle more exotic pure node - # cases, as handling them is essentially rerouting an edge, whereas - # handling intermediate nodes is much more complicated. - - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): - return None - - # In case the intermediate has more than one entry, all must come from the - # first map, otherwise we can not fuse them. Currently we restrict this - # even further by saying that it has only one incoming Memlet. - if state.in_degree(intermediate_node) != 1: - return None - - # It can happen that multiple edges converges at the `IN_` connector - # of the first map exit, but there is only one edge leaving the exit. - # It is complicate to handle this, so for now we ignore it. - # TODO(phimuell): Handle this case properly. - inner_collector_edges = list( - state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) - ) - if len(inner_collector_edges) > 1: - return None - - # For us an intermediate node must always be an access node, because - # everything else we do not know how to handle. It is important that - # we do not test for non transient data here, because they can be - # handled has shared intermediates. - if not isinstance(intermediate_node, dace_nodes.AccessNode): - return None - intermediate_desc: dace_data.Data = intermediate_node.desc(sdfg) - if isinstance(intermediate_desc, dace_data.View): - return None - - # There are some restrictions we have on intermediate dace_nodes. The first one - # is that we do not allow WCR, this is because they need special handling - # which is currently not implement (the DaCe transformation has this - # restriction as well). The second one is that we can reduce the - # intermediate node and only feed a part into the second map, consider - # the case `b = a + 1; return b + 2`, where we have arrays. In this - # example only a single element must be available to the second map. - # However, this is hard to check so we will make a simplification. - # First, we will not check it at the producer, but at the consumer point. - # There we assume if the consumer does _not consume the whole_ - # intermediate array, then we can decompose the intermediate, by setting - # the map iteration index to zero and recover the shape, see - # implementation in the actual fusion routine. - # This is an assumption that is in most cases correct, but not always. - # However, doing it correctly is extremely complex. - for _, produce_edge in util.find_upstream_producers(state, out_edge): - if produce_edge.data.wcr is not None: - return None - - if len(downstream_nodes) == 0: - # There is nothing between intermediate node and the entry of the - # second map, thus the edge belongs either in `\mathbb{S}` or - # `\mathbb{E}`. - - # This is a very special situation, i.e. the access node has many - # different connections to the second map entry, this is a special - # case that we do not handle. - # TODO(phimuell): Handle this case. - if state.out_degree(intermediate_node) != 1: - return None - - # Certain nodes need more than one element as input. As explained - # above, in this situation we assume that we can naturally decompose - # them iff the node does not consume that whole intermediate. - # Furthermore, it can not be a dynamic map range or a library node. - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) - for consumer_node, feed_edge in consumers: - # TODO(phimuell): Improve this approximation. - if ( - intermediate_size != 1 - ) and feed_edge.data.num_elements() == intermediate_size: - return None - if consumer_node is map_entry_2: # Dynamic map range. - return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. - return None - - # Note that "remove" has a special meaning here, regardless of the - # output of the check function, from within the second map we remove - # the intermediate, it has more the meaning of "do we need to - # reconstruct it after the second map again?" - if self.is_interstate_transient(intermediate_node, sdfg, state): - shared_outputs.add(out_edge) - else: - exclusive_outputs.add(out_edge) - continue - - else: - # There is not only a single connection from the intermediate node to - # the second map, but the intermediate has more connections, thus - # the node might belong to the shared output. Of the many different - # possibilities, we only consider a single case: - # - The intermediate has a single connection to the second map, that - # fulfills the restriction outlined above. - # - All other connections have no connection to the second map. - found_second_entry = False - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - for edge in state.out_edges(intermediate_node): - if edge.dst is map_entry_2: - if found_second_entry: # The second map was found again. - return None - found_second_entry = True - consumers = util.find_downstream_consumers(state=state, begin=edge) - for consumer_node, feed_edge in consumers: - if feed_edge.data.num_elements() == intermediate_size: - return None - if consumer_node is map_entry_2: # Dynamic map range - return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. - return None - else: - # Ensure that there is no path that leads to the second map. - after_intermdiate_node = util.all_nodes_between( - graph=state, begin=edge.dst, end=map_entry_2 - ) - if after_intermdiate_node is not None: - return None - # If we are here, then we know that the node is a shared output - shared_outputs.add(out_edge) - continue - - assert exclusive_outputs or shared_outputs or pure_outputs - assert len(processed_inter_nodes) == sum( - len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] - ) - return (pure_outputs, exclusive_outputs, shared_outputs) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py deleted file mode 100644 index 4b34dd6adc..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ /dev/null @@ -1,125 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Any, Optional, Sequence, Union - -import dace -from dace import properties as dace_properties, transformation as dace_transformation -from dace.sdfg import nodes as dace_nodes - -from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util - - -@dace_properties.make_properties -class MapIterationOrder(dace_transformation.SingleStateTransformation): - """Modify the order of the iteration variables. - - The iteration order, while irrelevant from an SDFG point of view, is highly - relevant in code, and the fastest varying index ("inner most loop" in CPU or - "x block dimension" in GPU) should be associated with the stride 1 dimension - of the array. - This transformation will reorder the map indexes such that this is the case. - - While the place of the leading dimension is clearly defined, the order of the - other loop indexes, after this transformation is unspecified. - - Args: - leading_dim: A GT4Py dimension object that identifies the dimension that - is supposed to have stride 1. - - Note: - The transformation does follow the rules outlines in - [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md) - especially rule 11, regarding the names. - - Todo: - - Extend that different dimensions can be specified to be leading - dimensions, with some priority mechanism. - - Maybe also process the parameters to bring them in a canonical order. - """ - - leading_dim = dace_properties.Property( - dtype=str, - allow_none=True, - desc="Dimension that should become the leading dimension.", - ) - - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) - - def __init__( - self, - leading_dim: Optional[Union[gtx_common.Dimension, str]] = None, - *args: Any, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - if isinstance(leading_dim, gtx_common.Dimension): - self.leading_dim = gtx_dace_fieldview_util.get_map_variable(leading_dim) - elif leading_dim is not None: - self.leading_dim = leading_dim - - @classmethod - def expressions(cls) -> Any: - return [dace.sdfg.utils.node_path_graph(cls.map_entry)] - - def can_be_applied( - self, - graph: Union[dace.SDFGState, dace.SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Test if the map can be reordered. - - Essentially the function checks if the selected dimension is inside the map, - and if so, if it is on the right place. - """ - - if self.leading_dim is None: - return False - map_entry: dace_nodes.MapEntry = self.map_entry - map_params: Sequence[str] = map_entry.map.params - map_var: str = self.leading_dim - - if map_var not in map_params: - return False - if map_params[-1] == map_var: # Already at the correct location - return False - return True - - def apply( - self, - graph: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - ) -> None: - """Performs the actual parameter reordering. - - The function will make the map variable, that corresponds to - `self.leading_dim` the last map variable (this is given by the structure of - DaCe's code generator). - """ - map_entry: dace_nodes.MapEntry = self.map_entry - map_params: list[str] = map_entry.map.params - map_var: str = self.leading_dim - - # This implementation will just swap the variable that is currently the last - # with the one that should be the last. - dst_idx = -1 - src_idx = map_params.index(map_var) - - for to_process in [ - map_entry.map.params, - map_entry.map.range.ranges, - map_entry.map.range.tile_sizes, - ]: - assert isinstance(to_process, list) - src_val = to_process[src_idx] - dst_val = to_process[dst_idx] - to_process[dst_idx] = src_val - to_process[src_idx] = dst_val diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py deleted file mode 100644 index bca5aa2268..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ /dev/null @@ -1,483 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements the serial map fusing transformation. - -Note: - After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements - a better map fusion transformation is merged, this file will be deleted. -""" - -import copy -from typing import Any, Union - -import dace -from dace import ( - dtypes as dace_dtypes, - properties as dace_properties, - subsets as dace_subsets, - symbolic as dace_symbolic, - transformation as dace_transformation, -) -from dace.sdfg import graph as dace_graph, nodes as dace_nodes - -from gt4py.next.program_processors.runners.dace_fieldview.transformations import map_fusion_helper - - -@dace_properties.make_properties -class SerialMapFusion(map_fusion_helper.MapFusionHelper): - """Specialized replacement for the map fusion transformation that is provided by DaCe. - - As its name is indicating this transformation is only able to handle Maps that - are in sequence. Compared to the native DaCe transformation, this one is able - to handle more complex cases of connection between the maps. In that sense, it - is much more similar to DaCe's `SubgraphFusion` transformation. - - Things that are improved, compared to the native DaCe implementation: - - Nested Maps. - - Temporary arrays and the correct propagation of their Memlets. - - Top Maps that have multiple outputs. - - Conceptually this transformation removes the exit of the first or upper map - and the entry of the lower or second map and then rewrites the connections - appropriately. - - This transformation assumes that an SDFG obeys the structure that is outlined - [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that - reason it is not true replacement of the native DaCe transformation. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - - Notes: - - This transformation modifies more nodes than it matches! - """ - - map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) - - def __init__( - self, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - @classmethod - def expressions(cls) -> Any: - """Get the match expression. - - The transformation matches the exit node of the top Map that is connected to - an access node that again is connected to the entry node of the second Map. - An important note is, that the transformation operates not just on the - matched nodes, but more or less on anything that has an incoming connection - from the first Map or an outgoing connection to the second Map entry. - """ - return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] - - def can_be_applied( - self, - graph: Union[dace.SDFGState, dace.SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Tests if the matched Maps can be merged. - - The two Maps are mergeable iff: - - The `can_be_fused()` of the base succeed, which checks some basic constraints. - - The decomposition exists and at least one of the intermediate sets - is not empty. - """ - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - - # This essentially test the structural properties of the two Maps. - if not self.can_be_fused( - map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg - ): - return False - - # Two maps can be serially fused if the node decomposition exists and - # at least one of the intermediate output sets is not empty. The state - # of the pure outputs is irrelevant for serial map fusion. - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=self.map_exit1, - map_entry_2=self.map_entry2, - ) - if output_partition is None: - return False - _, exclusive_outputs, shared_outputs = output_partition - if not (exclusive_outputs or shared_outputs): - return False - return True - - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: - """Performs the serial Map fusing. - - The function first computes the map decomposition and then handles the - three sets. The pure outputs are handled by `relocate_nodes()` while - the two intermediate sets are handled by `handle_intermediate_set()`. - - By assumption we do not have to rename anything. - - Args: - graph: The SDFG state we are operating on. - sdfg: The SDFG we are operating on. - """ - # NOTE: `self.map_*` actually stores the ID of the node. - # once we start adding and removing nodes it seems that their ID changes. - # Thus we have to save them here, this is a known behaviour in DaCe. - assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) - - map_exit_1: dace_nodes.MapExit = self.map_exit1 - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - map_exit_2: dace_nodes.MapExit = graph.exit_node(self.map_entry2) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - ) - assert output_partition is not None # Make MyPy happy. - pure_outputs, exclusive_outputs, shared_outputs = output_partition - - if len(exclusive_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=exclusive_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=True, - ) - if len(shared_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=shared_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=False, - ) - assert pure_outputs == set(graph.out_edges(map_exit_1)) - if len(pure_outputs) != 0: - self.relocate_nodes( - from_node=map_exit_1, - to_node=map_exit_2, - state=graph, - sdfg=sdfg, - ) - - # Above we have handled the input of the second map and moved them - # to the first map, now we must move the output of the first map - # to the second one, as this one is used. - self.relocate_nodes( - from_node=map_entry_2, - to_node=map_entry_1, - state=graph, - sdfg=sdfg, - ) - - for node_to_remove in [map_exit_1, map_entry_2]: - assert graph.degree(node_to_remove) == 0 - graph.remove_node(node_to_remove) - - # Now turn the second output node into the output node of the first Map. - map_exit_2.map = map_entry_1.map - - @staticmethod - def handle_intermediate_set( - intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], - state: dace.SDFGState, - sdfg: dace.SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - map_exit_2: dace_nodes.MapExit, - is_exclusive_set: bool, - ) -> None: - """This function handles the intermediate sets. - - The function is able to handle both the shared and exclusive intermediate - output set, see `partition_first_outputs()`. The main difference is that - in exclusive mode the intermediate nodes will be fully removed from - the SDFG. While in shared mode the intermediate node will be preserved. - - Args: - intermediate_outputs: The set of outputs, that should be processed. - state: The state in which the map is processed. - sdfg: The SDFG that should be optimized. - map_exit_1: The exit of the first/top map. - map_entry_2: The entry of the second map. - map_exit_2: The exit of the second map. - is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. - - Notes: - Before the transformation the `state` does not have to be valid and - after this function has run the state is (most likely) invalid. - - Todo: - Rewrite using `MemletTree`. - """ - - # Essentially this function removes the AccessNode between the two maps. - # However, we still need some temporary memory that we can use, which is - # just much smaller, i.e. a scalar. But all Memlets inside the second map - # assumes that the intermediate memory has the bigger shape. - # To fix that we will create this replacement dict that will replace all - # occurrences of the iteration variables of the second map with zero. - # Note that this is still not enough as the dimensionality might be different. - memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} - - # Now we will iterate over all intermediate edges and process them. - # If not stated otherwise the comments assume that we run in exclusive mode. - for out_edge in intermediate_outputs: - # This is the intermediate node that, that we want to get rid of. - # In shared mode we want to recreate it after the second map. - inter_node: dace_nodes.AccessNode = out_edge.dst - inter_name = inter_node.data - inter_desc = inter_node.desc(sdfg) - inter_shape = inter_desc.shape - - # Now we will determine the shape of the new intermediate. This size of - # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) - ) - if len(pre_exit_edges) != 1: - raise NotImplementedError() - pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = dace_symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # That are known to cause some troubles, so we will now remove them. - squeezed_dims: list[int] = [] # These are the dimensions we removed. - new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): - # Order of checks is important! - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - - # This is the name of the new "intermediate" node that we will create. - # It will only have the shape `new_inter_shape` which is basically its - # output within one Map iteration. - # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" - - # Now generate the intermediate data container. - if len(new_inter_shape) == 0: - assert pre_exit_edge.data.subset.num_elements() == 1 - is_scalar = True - new_inter_name, new_inter_desc = sdfg.add_scalar( - new_inter_name, - dtype=inter_desc.dtype, - transient=True, - storage=dace_dtypes.StorageType.Register, - find_new_name=True, - ) - - else: - assert (pre_exit_edge.data.subset.num_elements() > 1) or all( - x == 1 for x in new_inter_shape - ) - is_scalar = False - new_inter_name, new_inter_desc = sdfg.add_transient( - new_inter_name, - shape=new_inter_shape, - dtype=inter_desc.dtype, - find_new_name=True, - ) - new_inter_node: dace_nodes.AccessNode = state.add_access(new_inter_name) - - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # we will delete the previous edge later. - pre_exit_memlet: dace.Memlet = pre_exit_edge.data - new_pre_exit_memlet = copy.deepcopy(pre_exit_memlet) - - # We might operate on a different array, but the check below, ensures - # that we do not change the direction of the Memlet. - assert pre_exit_memlet.data == inter_name - new_pre_exit_memlet.data = new_inter_name - - # Now we have to modify the subset of the Memlet. - # Before the subset of the Memlet was dependent on the Map variables, - # however, this is no longer the case, as we removed them. This change - # has to be reflected in the Memlet. - # NOTE: Assert above ensures that the below is correct. - new_pre_exit_memlet.replace(memlet_repl) - if is_scalar: - new_pre_exit_memlet.subset = "0" - new_pre_exit_memlet.other_subset = None - else: - new_pre_exit_memlet.subset.pop(squeezed_dims) - - # Now we create the new edge between the producer and the new output - # (the new intermediate node). We will remove the old edge further down. - new_pre_exit_edge = state.add_edge( - pre_exit_edge.src, - pre_exit_edge.src_conn, - new_inter_node, - None, - new_pre_exit_memlet, - ) - - # We just have handled the last Memlet, but we must actually handle the - # whole producer side, i.e. the scope of the top Map. - for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(): - producer_edge = producer_tree.edge - - # Ensure the correctness of the rerouting below. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - - # Will not change the direction, because of test above! - producer_edge.data.data = new_inter_name - producer_edge.data.replace(memlet_repl) - if is_scalar: - producer_edge.data.dst_subset = "0" - elif producer_edge.data.dst_subset is not None: - producer_edge.data.dst_subset.pop(squeezed_dims) - - # Now after we have handled the input of the new intermediate node, - # we must handle its output. For this we have to "inject" the newly - # created intermediate into the second map. We do this by finding - # the input connectors on the map entry, such that we know where we - # have to reroute inside the Map. - # NOTE: Assumes that map (if connected is the direct neighbour). - conn_names: set[str] = set() - for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: - assert inter_node_out_edge.dst_conn.startswith("IN_") - conn_names.add(inter_node_out_edge.dst_conn) - else: - # If we found another target than the second map entry from the - # intermediate node it means that the node _must_ survive, - # i.e. we are not in exclusive mode. - assert not is_exclusive_set - - # Now we will reroute the connections inside the second map, i.e. - # instead of consuming the old intermediate node, they will now - # consume the new intermediate node. - for in_conn_name in conn_names: - out_conn_name = "OUT_" + in_conn_name[3:] - - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): - assert inner_edge.data.data == inter_name # DIRECTION!! - - # The create the first Memlet to transmit information, within - # the second map, we do this again by copying and modifying - # the original Memlet. - # NOTE: Test above is important to ensure the direction of the - # Memlet and the correctness of the code below. - new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.replace(memlet_repl) - new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. - - # Now remove the old edge, that started the second map entry. - # Also add the new edge that started at the new intermediate. - state.remove_edge(inner_edge) - new_inner_edge = state.add_edge( - new_inter_node, - None, - inner_edge.dst, - inner_edge.dst_conn, - new_inner_memlet, - ) - - # Now we do subset modification to ensure that nothing failed. - if is_scalar: - new_inner_memlet.src_subset = "0" - elif new_inner_memlet.src_subset is not None: - new_inner_memlet.src_subset.pop(squeezed_dims) - - # Now clean the Memlets of that tree to use the new intermediate node. - for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): - consumer_edge = consumer_tree.edge - assert consumer_edge.data.data == inter_name - consumer_edge.data.data = new_inter_name - consumer_edge.data.replace(memlet_repl) - if is_scalar: - consumer_edge.data.src_subset = "0" - elif consumer_edge.data.subset is not None: - consumer_edge.data.subset.pop(squeezed_dims) - - # The edge that leaves the second map entry was already deleted. - # We will now delete the edges that brought the data. - for edge in state.in_edges_by_connector(map_entry_2, in_conn_name): - assert edge.src == inter_node - state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) - - if is_exclusive_set: - # In exclusive mode the old intermediate node is no longer needed. - assert state.degree(inter_node) == 1 - state.remove_edge_and_connectors(out_edge) - state.remove_node(inter_node) - - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) - del sdfg.arrays[inter_name] - - else: - # This is the shared mode, so we have to recreate the intermediate - # node, but this time it is at the exit of the second map. - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - - # This is the Memlet that goes from the map internal intermediate - # temporary node to the Map output. This will essentially restore - # or preserve the output for the intermediate node. It is important - # that we use the data that `preExitEdge` was used. - new_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert new_exit_memlet.data == inter_name - new_exit_memlet.subset = pre_exit_edge.data.dst_subset - new_exit_memlet.other_subset = ( - "0" if is_scalar else dace_subsets.Range.from_array(inter_desc) - ) - - new_pre_exit_conn = map_exit_2.next_connector() - state.add_edge( - new_inter_node, - None, - map_exit_2, - "IN_" + new_pre_exit_conn, - new_exit_memlet, - ) - state.add_edge( - map_exit_2, - "OUT_" + new_pre_exit_conn, - inter_node, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), - ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) - - map_exit_1.remove_out_connector(out_edge.src_conn) - state.remove_edge(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py deleted file mode 100644 index 29bae7bbe0..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ /dev/null @@ -1,160 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Common functionality for the transformations/optimization pipeline.""" - -from typing import Iterable, Union - -import dace -from dace.sdfg import graph as dace_graph, nodes as dace_nodes - - -def is_nested_sdfg( - sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], -) -> bool: - """Tests if `sdfg` is a NestedSDFG.""" - if isinstance(sdfg, dace.SDFGState): - sdfg = sdfg.parent - if isinstance(sdfg, dace_nodes.NestedSDFG): - return True - elif isinstance(sdfg, dace.SDFG): - return sdfg.parent_nsdfg_node is not None - raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") - - -def all_nodes_between( - graph: dace.SDFG | dace.SDFGState, - begin: dace_nodes.Node, - end: dace_nodes.Node, - reverse: bool = False, -) -> set[dace_nodes.Node] | None: - """Find all nodes that are reachable from `begin` but bound by `end`. - - Essentially the function starts a DFS at `begin`. If an edge is found that lead - to `end`, this edge is ignored. It will thus found any node that is reachable - from `begin` by a path that does not involve `end`. The returned set will - never contain `end` nor `begin`. In case `end` is never found the function - will return `None`. - - If `reverse` is set to `True` the function will start exploring at `end` and - follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. - - Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The terminator node of the DFS. - reverse: Perform a backward DFS. - - Notes: - - The returned set will also contain the nodes of path that starts at - `begin` and ends at a node that is not `end`. - """ - - def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: - return ( - (edge.src for edge in graph.in_edges(node)) - if reverse - else (edge.dst for edge in graph.out_edges(node)) - ) - - if reverse: - begin, end = end, begin - - to_visit: list[dace_nodes.Node] = [begin] - seen: set[dace_nodes.Node] = set() - - while len(to_visit) > 0: - node: dace_nodes.Node = to_visit.pop() - if node != end and node not in seen: - to_visit.extend(next_nodes(node)) - seen.add(node) - - # If `end` was not found we have to return `None` to indicate this. - if end not in seen: - return None - - # `begin` and `end` are not included in the output set. - return seen - {begin, end} - - -def find_downstream_consumers( - state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], - only_tasklets: bool = False, - reverse: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: - """Find all downstream connectors of `begin`. - - A consumer, in for this function, is any node that is neither an entry nor - an exit node. The function returns a set of pairs, the first element is the - node that acts as consumer and the second is the edge that leads to it. - By setting `only_tasklets` the nodes the function finds are only Tasklets. - - To find this set the function starts a search at `begin`, however, it is also - possible to pass an edge as `begin`. - If `reverse` is `True` the function essentially finds the producers that are - upstream. - - Args: - state: The state in which to look for the consumers. - begin: The initial node that from which the search starts. - only_tasklets: Return only Tasklets. - reverse: Follow the reverse direction. - """ - if isinstance(begin, dace_graph.MultiConnectorEdge): - to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] - else: - to_visit = state.in_edges(begin) if reverse else state.out_edges(begin) - - seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() - - while len(to_visit) > 0: - curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() - next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst - - if curr_edge in seen: - continue - seen.add(curr_edge) - - if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): - if not reverse: - # In forward mode a Map entry could also mean the definition of a - # dynamic map range. - if isinstance(next_node, dace_nodes.MapEntry) and ( - not curr_edge.dst_conn.startswith("IN_") - ): - if not only_tasklets: - found.add((next_node, curr_edge)) - continue - target_conn = curr_edge.dst_conn[3:] - new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) - else: - target_conn = curr_edge.src_conn[4:] - new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) - to_visit.extend(new_edges) - - elif isinstance(next_node, dace_nodes.Tasklet) or not only_tasklets: - # We have found a consumer. - found.add((next_node, curr_edge)) - - return found - - -def find_upstream_producers( - state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], - only_tasklets: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: - """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" - return find_downstream_consumers( - state=state, - begin=begin, - only_tasklets=only_tasklets, - reverse=True, - ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py deleted file mode 100644 index 2988b01a61..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ /dev/null @@ -1,147 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import itertools -from typing import Any - -import dace - -from gt4py import eve -from gt4py.next import common as gtx_common -from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.program_processors.runners.dace_fieldview import gtir_python_codegen -from gt4py.next.type_system import type_specifications as ts - - -def get_domain( - node: gtir.Expr, -) -> list[tuple[gtx_common.Dimension, dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: - """ - Specialized visit method for domain expressions. - - Returns for each domain dimension the corresponding range. - - TODO: Domain expressions will be recurrent in the GTIR program. An interesting idea - would be to cache the results of lowering here (e.g. using `functools.lru_cache`) - """ - assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain")) - - domain = [] - for named_range in node.args: - assert cpm.is_call_to(named_range, "named_range") - assert len(named_range.args) == 3 - axis = named_range.args[0] - assert isinstance(axis, gtir.AxisLiteral) - bounds = [ - dace.symbolic.pystr_to_symbolic(gtir_python_codegen.get_source(arg)) - for arg in named_range.args[1:3] - ] - dim = gtx_common.Dimension(axis.value, axis.kind) - domain.append((dim, bounds[0], bounds[1])) - - return domain - - -def get_domain_ranges( - node: gtir.Expr, -) -> dict[gtx_common.Dimension, tuple[dace.symbolic.SymbolicType, dace.symbolic.SymbolicType]]: - """ - Returns domain represented in dictionary form. - """ - domain = get_domain(node) - - return {dim: (lb, ub) for dim, lb, ub in domain} - - -def get_map_variable(dim: gtx_common.Dimension) -> str: - """ - Format map variable name based on the naming convention for application-specific SDFG transformations. - """ - suffix = "dim" if dim.kind == gtx_common.DimensionKind.LOCAL else "" - return f"i_{dim.value}_gtx_{dim.kind}{suffix}" - - -def get_tuple_fields( - tuple_name: str, tuple_type: ts.TupleType, flatten: bool = False -) -> list[tuple[str, ts.DataType]]: - """ - Creates a list of names with the corresponding data type for all elements of the given tuple. - - Examples - -------- - >>> sty = ts.ScalarType(kind=ts.ScalarKind.INT32) - >>> fty = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) - >>> t = ts.TupleType(types=[sty, ts.TupleType(types=[fty, sty])]) - >>> assert get_tuple_fields("a", t) == [("a_0", sty), ("a_1", ts.TupleType(types=[fty, sty]))] - >>> assert get_tuple_fields("a", t, flatten=True) == [ - ... ("a_0", sty), - ... ("a_1_0", fty), - ... ("a_1_1", sty), - ... ] - """ - fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] - if flatten: - expanded_fields = [ - get_tuple_fields(field_name, field_type) - if isinstance(field_type, ts.TupleType) - else [(field_name, field_type)] - for field_name, field_type in fields - ] - return list(itertools.chain(*expanded_fields)) - else: - return fields - - -def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: - """ - Compute the `ts.TupleType` corresponding to the structure of a tuple of data nodes. - """ - return ts.TupleType( - types=[get_tuple_type(d) if isinstance(d, tuple) else d.data_type for d in data] - ) - - -def patch_gtir(ir: gtir.Program) -> gtir.Program: - """ - Make the IR compliant with the requirements of lowering to SDFG. - - Applies canonicalization of as_fieldop expressions as well as some temporary workarounds. - This allows to lower the IR to SDFG for some special cases. - """ - - class PatchGTIR(eve.PreserveLocationVisitor, eve.NodeTranslator): - def visit_FunCall(self, node: gtir.FunCall) -> gtir.Node: - if cpm.is_applied_as_fieldop(node): - assert isinstance(node.fun, gtir.FunCall) - assert isinstance(node.type, ts.FieldType) - - # Handle the case of fieldop without domain. This case should never happen, but domain - # inference currently produces this kind of nodes for unreferenced tuple fields. - # TODO(tehrengruber): remove this workaround once domain ineference supports this case - if len(node.fun.args) == 1: - return gtir.Literal(value="0", type=node.type.dtype) - - assert len(node.fun.args) == 2 - stencil = node.fun.args[0] - - # Canonicalize as_fieldop: always expect a lambda expression. - # Here we replace the call to deref with a lambda expression and empty arguments list. - if cpm.is_ref_to(stencil, "deref"): - node.fun.args[0] = gtir.Lambda( - expr=gtir.FunCall(fun=stencil, args=node.args), params=[] - ) - node.args = [] - - node.args = [self.visit(arg) for arg in node.args] - node.fun = self.visit(node.fun) - return node - - return PatchGTIR().visit(ir) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py deleted file mode 100644 index dab8d29fd1..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ /dev/null @@ -1,356 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import dataclasses -import warnings -from collections import OrderedDict -from collections.abc import Callable, Mapping, Sequence -from dataclasses import field -from inspect import currentframe, getframeinfo -from pathlib import Path -from typing import Any, ClassVar, Optional - -import dace -import numpy as np -from dace.sdfg import utils as sdutils -from dace.transformation.auto import auto_optimize as autoopt - -import gt4py.next.iterator.ir as itir -from gt4py.next import common -from gt4py.next.ffront import decorator -from gt4py.next.iterator import transforms as itir_transforms -from gt4py.next.iterator.transforms import program_to_fencil -from gt4py.next.iterator.type_system import inference as itir_type_inference -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .itir_to_sdfg import ItirToSDFG - - -def preprocess_program( - program: itir.FencilDefinition, - offset_provider: Mapping[str, Any], - lift_mode: itir_transforms.LiftMode, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - unroll_reduce: bool = False, -): - node = itir_transforms.apply_common_transforms( - program, - common_subexpression_elimination=False, - force_inline_lambda_args=True, - lift_mode=lift_mode, - offset_provider=offset_provider, - symbolic_domain_sizes=symbolic_domain_sizes, - temporary_extraction_heuristics=temporary_extraction_heuristics, - unroll_reduce=unroll_reduce, - ) - - node = itir_type_inference.infer(node, offset_provider=offset_provider) - - if isinstance(node, itir.Program): - fencil_definition = program_to_fencil.program_to_fencil(node) - tmps = node.declarations - assert all(isinstance(tmp, itir.Temporary) for tmp in tmps) - else: - raise TypeError(f"Expected 'Program', got '{type(node).__name__}'.") - - return fencil_definition, tmps - - -def build_sdfg_from_itir( - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, Any], - auto_optimize: bool = False, - on_gpu: bool = False, - column_axis: Optional[common.Dimension] = None, - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - load_sdfg_from_file: bool = False, - save_sdfg: bool = True, - use_field_canonical_representation: bool = True, -) -> dace.SDFG: - """Translate a Fencil into an SDFG. - - Args: - program: The Fencil that should be translated. - arg_types: Types of the arguments passed to the fencil. - offset_provider: The set of offset providers that should be used. - auto_optimize: Apply DaCe's `auto_optimize` heuristic. - on_gpu: Performs the translation for GPU, defaults to `False`. - column_axis: The column axis to be used, defaults to `None`. - lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. - symbolic_domain_sizes: Used for generation of liskov bindings when temporaries are enabled. - load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only. - save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`. - use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically. - """ - - sdfg_filename = f"_dacegraphs/gt4py/{program.id}.sdfg" - if load_sdfg_from_file and Path(sdfg_filename).exists(): - sdfg: dace.SDFG = dace.SDFG.from_file(sdfg_filename) - sdfg.validate() - return sdfg - - # visit ITIR and generate SDFG - program, tmps = preprocess_program( - program, offset_provider, lift_mode, symbolic_domain_sizes, temporary_extraction_heuristics - ) - sdfg_genenerator = ItirToSDFG( - list(arg_types), offset_provider, tmps, use_field_canonical_representation, column_axis - ) - sdfg = sdfg_genenerator.visit(program) - if sdfg is None: - raise RuntimeError(f"Visit failed for program {program.id}.") - - for nested_sdfg in sdfg.all_sdfgs_recursive(): - if not nested_sdfg.debuginfo: - _, frameinfo = ( - warnings.warn( - f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg.", - stacklevel=2, - ), - getframeinfo(currentframe()), # type: ignore[arg-type] - ) - nested_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=frameinfo.lineno, end_line=frameinfo.lineno, filename=frameinfo.filename - ) - - # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct - sdutils.inline_loop_blocks(sdfg) - - # run DaCe transformations to simplify the SDFG - sdfg.simplify() - - # run DaCe auto-optimization heuristics - if auto_optimize: - # TODO: Investigate performance improvement from SDFG specialization with constant symbols, - # for array shape and strides, although this would imply JIT compilation. - symbols: dict[str, int] = {} - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) - elif on_gpu: - autoopt.apply_gpu_storage(sdfg) - - if on_gpu: - sdfg.apply_gpu_transformations() - - # Store the sdfg such that we can later reuse it. - if save_sdfg: - sdfg.save(sdfg_filename) - - return sdfg - - -@dataclasses.dataclass(frozen=True) -class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): - """Extension of GT4Py Program implementing the SDFGConvertible interface.""" - - sdfg_closure_vars: dict[str, Any] = field(default_factory=dict) - - # Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, - # there is no name mangling of the connectivity tables used across the nested SDFGs - # since they share the same memory address. - connectivity_tables_data_descriptors: ClassVar[ - dict[str, dace.data.Array] - ] = {} # symbolically defined - - def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: - if "dace" not in self.backend.name.lower(): # type: ignore[union-attr] - raise ValueError("The SDFG can be generated only for the DaCe backend.") - - params = {str(p.id): p.type for p in self.itir.params} - fields = {str(p.id): p.type for p in self.itir.params if hasattr(p.type, "dims")} - arg_types = [*params.values()] - - dace_parsed_args = [*args, *kwargs.values()] - gt4py_program_args = [*params.values()] - _crosscheck_dace_parsing(dace_parsed_args, gt4py_program_args) - - if self.connectivities is None: - raise ValueError( - "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." - ) - offset_provider = ( - self.connectivities | self._implicit_offset_provider - ) # tables are None at this point - - sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] - self.itir, - arg_types, - offset_provider=offset_provider, - column_axis=kwargs.get("column_axis", None), - ) - self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ - - # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field - # Add them as dynamic properties to the SDFG - - input_fields = [ - str(in_field.id) - for closure in self.itir.closures - for in_field in closure.inputs - if str(in_field.id) in fields - ] - sdfg.gt4py_program_input_fields = { - in_field: dim - for in_field in input_fields - for dim in fields[in_field].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - output_fields = [] - for closure in self.itir.closures: - output = closure.output - if isinstance(output, itir.SymRef): - if str(output.id) in fields: - output_fields.append(str(output.id)) - else: - for arg in output.args: - if str(arg.id) in fields: # type: ignore[attr-defined] - output_fields.append(str(arg.id)) # type: ignore[attr-defined] - sdfg.gt4py_program_output_fields = { - output: dim - for output in output_fields - for dim in fields[output].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - sdfg.offset_providers_per_input_field = {} - itir_tmp = itir_transforms.apply_common_transforms( - self.itir, offset_provider=offset_provider - ) - itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) - for closure in itir_tmp_fencil.closures: - params_shifts = itir_transforms.trace_shifts.trace_stencil( - closure.stencil, num_args=len(closure.inputs) - ) - for param, shifts in zip(closure.inputs, params_shifts): - if not isinstance(param.id, str): - continue - if param.id not in sdfg.gt4py_program_input_fields: - continue - sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) - - return sdfg - - def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: - """ - Returns the closure arrays of the SDFG represented by this object - as a mapping between array name and the corresponding value. - - The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. - The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that - the offset providers are not part of GT4Py Program's arguments. - Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. - """ - offset_provider = self.connectivities - - # Define DaCe symbols - connectivity_table_size_symbols = { - dace_utils.field_size_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider.items() # type: ignore[union-attr] - for axis in [0, 1] - if hasattr(v, "table") - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - connectivity_table_stride_symbols = { - dace_utils.field_stride_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider.items() # type: ignore[union-attr] - for axis in [0, 1] - if hasattr(v, "table") - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - symbols = {**connectivity_table_size_symbols, **connectivity_table_stride_symbols} - - # Define the storage location (e.g. CPU, GPU) of the connectivity tables - if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider.items(): # type: ignore[union-attr] - if not hasattr(v, "table"): - continue - if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: - Program.connectivity_tables_data_descriptors["storage"] = ( - self.sdfg_closure_vars[ - "sdfg.arrays" - ][dace_utils.connectivity_identifier(k)].storage - ) - break - - # Build the closure dictionary - closure_dict = {} - for k, v in offset_provider.items(): # type: ignore[union-attr] - conn_id = dace_utils.connectivity_identifier(k) - if hasattr(v, "table") and conn_id in self.sdfg_closure_vars["sdfg.arrays"]: - if conn_id not in Program.connectivity_tables_data_descriptors: - Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.index_type == np.int64 else dace.int32, - shape=[ - symbols[dace_utils.field_size_symbol_name(conn_id, 0)], - symbols[dace_utils.field_size_symbol_name(conn_id, 1)], - ], - strides=[ - symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], - symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], - ], - storage=Program.connectivity_tables_data_descriptors["storage"], - ) - closure_dict[conn_id] = Program.connectivity_tables_data_descriptors[conn_id] - - return closure_dict - - def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: - args = [] - for arg in self.past_stage.past_node.params: - args.append(arg.id) - return (args, []) - - -def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> bool: - for dace_parsed_arg, gt4py_program_arg in zip(dace_parsed_args, gt4py_program_args): - if isinstance(dace_parsed_arg, dace.data.Scalar): - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) - elif isinstance( - dace_parsed_arg, (bool, int, float, str, np.bool_, np.integer, np.floating, np.str_) - ): # compile-time constant scalar - assert isinstance(gt4py_program_arg, ts.ScalarType) - if isinstance(dace_parsed_arg, (bool, np.bool_)): - assert gt4py_program_arg.kind == ts.ScalarKind.BOOL - elif isinstance(dace_parsed_arg, (int, np.integer)): - assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] - elif isinstance(dace_parsed_arg, (float, np.floating)): - assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] - elif isinstance(dace_parsed_arg, (str, np.str_)): - assert gt4py_program_arg.kind == ts.ScalarKind.STRING - elif isinstance(dace_parsed_arg, dace.data.Array): - assert isinstance(gt4py_program_arg, ts.FieldType) - assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) - elif isinstance( - dace_parsed_arg, (dace.data.Structure, dict, OrderedDict) - ): # offset_provider - continue - else: - raise ValueError(f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}") - - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py deleted file mode 100644 index d52fbc5857..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ /dev/null @@ -1,798 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import warnings -from typing import Any, Mapping, Optional, Sequence, cast - -import dace -from dace.sdfg.state import LoopRegion - -import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind -from gt4py.next.common import Connectivity -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt - -from .itir_to_tasklet import ( - Context, - GatherOutputSymbolsPass, - PythonTaskletCodegen, - SymbolExpr, - TaskletExpr, - ValueExpr, - closure_to_tasklet_sdfg, - is_scan, -) -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_var_name, -) - - -def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: - """ - Parse stencil expression to extract the scan arguments. - - Returns - ------- - tuple(is_forward, init_carry) - The output tuple fields verify the following semantics: - - is_forward: forward boolean flag - - init_carry: carry initial value - """ - stencil_fobj = cast(FunCall, stencil) - is_forward = stencil_fobj.args[1] - assert isinstance(is_forward, Literal) and type_info.is_logical(is_forward.type) - init_carry = stencil_fobj.args[2] - assert isinstance(init_carry, Literal) - return is_forward.value == "True", init_carry - - -def _get_scan_dim( - column_axis: Dimension, - storage_types: dict[str, ts.TypeSpec], - output: SymRef, - use_field_canonical_representation: bool, -) -> tuple[str, int, ts.ScalarType]: - """ - Extract information about the scan dimension. - - Returns - ------- - tuple(scan_dim_name, scan_dim_index, scan_dim_dtype) - The output tuple fields verify the following semantics: - - scan_dim_name: name of the scan dimension - - scan_dim_index: domain index of the scan dimension - - scan_dim_dtype: data type along the scan dimension - """ - output_type = storage_types[output.id] - assert isinstance(output_type, ts.FieldType) - sorted_dims = [ - dim - for _, dim in ( - dace_utils.get_sorted_dims(output_type.dims) - if use_field_canonical_representation - else enumerate(output_type.dims) - ) - ] - return (column_axis.value, sorted_dims.index(column_axis), output_type.dtype) - - -def _make_array_shape_and_strides( - name: str, dims: Sequence[Dimension], offset_provider: Mapping[str, Any], sort_dims: bool -) -> tuple[list[dace.symbol], list[dace.symbol]]: - """ - Parse field dimensions and allocate symbols for array shape and strides. - - For local dimensions, the size is known at compile-time and therefore - the corresponding array shape dimension is set to an integer literal value. - - Returns - ------- - tuple(shape, strides) - The output tuple fields are arrays of dace symbolic expressions. - """ - dtype = dace.int32 - sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - neighbor_tables = dace_utils.filter_connectivities(offset_provider) - shape = [ - ( - neighbor_tables[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL - # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain - else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) - ) - for i, dim in sorted_dims - ] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i, _ in sorted_dims - ] - return shape, strides - - -def _check_no_lifts(node: itir.StencilClosure): - """ - Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions. - - Returns - ------- - True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise. - """ - neighbors_call_count = 0 - for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"): - if getattr(fun, "id", "") == "neighbors": - neighbors_call_count = 3 - elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1: - return False - neighbors_call_count = max(0, neighbors_call_count - 1) - return True - - -class ItirToSDFG(eve.NodeVisitor): - param_types: list[ts.TypeSpec] - storage_types: dict[str, ts.TypeSpec] - column_axis: Optional[Dimension] - offset_provider: dict[str, Any] - unique_id: int - use_field_canonical_representation: bool - - def __init__( - self, - param_types: list[ts.TypeSpec], - offset_provider: dict[str, Connectivity | Dimension], - tmps: list[itir.Temporary], - use_field_canonical_representation: bool, - column_axis: Optional[Dimension] = None, - ): - self.param_types = param_types - self.column_axis = column_axis - self.offset_provider = offset_provider - self.storage_types = {} - self.tmps = tmps - self.use_field_canonical_representation = use_field_canonical_representation - - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): - if isinstance(type_, ts.FieldType): - shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider, sort_dimensions - ) - dtype = dace_utils.as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - elif isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in sdfg.symbols: - assert sdfg.symbols[name].dtype == dtype - else: - sdfg.add_symbol(name, dtype) - - else: - raise NotImplementedError() - self.storage_types[name] = type_ - - def add_storage_for_temporaries( - self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG - ) -> dict[str, str]: - symbol_map: dict[str, TaskletExpr] = {} - # The shape of temporary arrays might be defined based on scalar values passed as program arguments. - # Here we collect these values in a symbol map. - for sym in node_params: - if isinstance(sym.type, ts.ScalarType): - name_ = str(sym.id) - symbol_map[name_] = SymbolExpr(name_, dace_utils.as_dace_type(sym.type)) - - tmp_symbols: dict[str, str] = {} - for tmp in self.tmps: - tmp_name = str(tmp.id) - - # We visit the domain of the temporary field, passing the set of available symbols. - assert isinstance(tmp.domain, itir.FunCall) - domain_ctx = Context(program_sdfg, defs_state, symbol_map) - tmp_domain = self._visit_domain(tmp.domain, domain_ctx) - - if isinstance(tmp.type, ts.TupleType): - raise NotImplementedError("Temporaries of tuples are not supported.") - assert isinstance(tmp.type, ts.FieldType) and isinstance(tmp.dtype, ts.ScalarType) - - # We store the FieldType for this temporary array. - self.storage_types[tmp_name] = tmp.type - - # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. - # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics - # to assign optimal stride values. - tmp_shape, _ = new_array_symbols(tmp_name, len(tmp.type.dims)) - _, tmp_array = program_sdfg.add_array( - tmp_name, tmp_shape, dace_utils.as_dace_type(tmp.dtype), transient=True - ) - - # Loop through all dimensions to visit the symbolic expressions for array shape and offset. - # These expressions are later mapped to interstate symbols. - for (_, (begin, end)), shape_sym in zip(tmp_domain, tmp_array.shape): - # The temporary field has a dimension range defined by `begin` and `end` values. - # Therefore, the actual size is given by the difference `end.value - begin.value`. - # Instead of allocating the actual size, we allocate space to enable indexing from 0 - # because we want to avoid using dace array offsets (which will be deprecated soon). - # The result should still be valid, but the stencil will be using only a subset - # of the array. - if not (isinstance(begin, SymbolExpr) and begin.value == "0"): - warnings.warn( - f"Domain start offset for temporary {tmp_name} is ignored.", stacklevel=2 - ) - tmp_symbols[str(shape_sym)] = end.value - - return tmp_symbols - - def create_memlet_at(self, field_name: str, index: dict[str, str]): - field_type = self.storage_types[field_name] - assert isinstance(field_type, ts.FieldType) - if self.use_field_canonical_representation: - field_index = [ - index[dim.value] for _, dim in dace_utils.get_sorted_dims(field_type.dims) - ] - else: - field_index = [index[dim.value] for dim in field_type.dims] - subset = ", ".join(field_index) - return dace.Memlet(data=field_name, subset=subset) - - def get_output_nodes( - self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dict[str, dace.nodes.AccessNode]: - # Visit output node, which could be a `make_tuple` expression, to collect the required access nodes - output_symbols_pass = GatherOutputSymbolsPass(sdfg, state) - output_symbols_pass.visit(closure.output) - # Visit output node again to generate the corresponding tasklet - context = Context(sdfg, state, output_symbols_pass.symbol_refs) - translator = PythonTaskletCodegen( - self.offset_provider, context, self.use_field_canonical_representation - ) - output_nodes = flatten_list(translator.visit(closure.output)) - return {node.value.data: node.value for node in output_nodes} - - def visit_FencilDefinition(self, node: itir.FencilDefinition): - program_sdfg = dace.SDFG(name=node.id) - program_sdfg.debuginfo = dace_utils.debug_info(node) - entry_state = program_sdfg.add_state("program_entry", is_start_block=True) - - # Filter neighbor tables from offset providers. - neighbor_tables = get_used_connectivities(node, self.offset_provider) - - # Add program parameters as SDFG storages. - for param, type_ in zip(node.params, self.param_types): - self.add_storage( - program_sdfg, str(param.id), type_, self.use_field_canonical_representation - ) - - if self.tmps: - tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg) - # on the first interstate edge define symbols for shape and offsets of temporary arrays - last_state = program_sdfg.add_state("init_symbols_for_temporaries") - program_sdfg.add_edge( - entry_state, last_state, dace.InterstateEdge(assignments=tmp_symbols) - ) - else: - last_state = entry_state - - # Add connectivities as SDFG storages. - for offset, offset_provider in neighbor_tables.items(): - scalar_kind = tt.get_scalar_kind(offset_provider.index_type) - local_dim = Dimension(offset, kind=DimensionKind.LOCAL) - type_ = ts.FieldType( - [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) - ) - self.add_storage( - program_sdfg, - dace_utils.connectivity_identifier(offset), - type_, - sort_dimensions=False, - ) - - # Create a nested SDFG for all stencil closures. - for closure in node.closures: - # Translate the closure and its stencil's body to an SDFG. - closure_sdfg, input_names, output_names = self.visit( - closure, array_table=program_sdfg.arrays - ) - - # Create a new state for the closure. - last_state = program_sdfg.add_state_after(last_state) - - # Create memlets to transfer the program parameters - input_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in input_names - } - output_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in output_names - } - - symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping) - - # Insert the closure's SDFG as a nested SDFG of the program. - nsdfg_node = last_state.add_nested_sdfg( - sdfg=closure_sdfg, - parent=program_sdfg, - inputs=set(input_names), - outputs=set(output_names), - symbol_mapping=symbol_mapping, - debuginfo=closure_sdfg.debuginfo, - ) - - # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. - for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - - for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) - - # Create the call signature for the SDFG. - # Only the arguments requiered by the Fencil, i.e. `node.params` are added as positional arguments. - # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. - program_sdfg.arg_names = [str(a) for a in node.params] - - program_sdfg.validate() - return program_sdfg - - def visit_StencilClosure( - self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] - ) -> tuple[dace.SDFG, list[str], list[str]]: - assert _check_no_lifts(node) - - # Create the closure's nested SDFG and single state. - closure_sdfg = dace.SDFG(name="closure") - closure_sdfg.debuginfo = dace_utils.debug_info(node) - closure_state = closure_sdfg.add_state("closure_entry") - closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) - - input_names = [str(inp.id) for inp in node.inputs] - neighbor_tables = get_used_connectivities(node, self.offset_provider) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) - output_names = [k for k, _ in output_nodes.items()] - - # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. - input_transients_mapping = {} - for name in [*input_names, *connectivity_names, *output_names]: - if name in closure_sdfg.arrays: - assert name in input_names and name in output_names - # In case of closures with in/out fields, there is risk of race condition - # between read/write access nodes in the (asynchronous) map tasklet. - transient_name = unique_var_name() - closure_sdfg.add_array( - transient_name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - transient=True, - ) - closure_init_state.add_nedge( - closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), - closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), - dace.Memlet.from_array(name, closure_sdfg.arrays[name]), - ) - input_transients_mapping[name] = transient_name - elif isinstance(self.storage_types[name], ts.FieldType): - closure_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - assert isinstance(self.storage_types[name], ts.ScalarType) - - input_field_names = [ - input_name - for input_name in input_names - if isinstance(self.storage_types[input_name], ts.FieldType) - ] - - # Closure outputs should all be fields - assert all( - isinstance(self.storage_types[output_name], ts.FieldType) - for output_name in output_names - ) - - # Update symbol table and get output domain of the closure - program_arg_syms: dict[str, TaskletExpr] = {} - for name, type_ in self.storage_types.items(): - if isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in input_names: - out_name = unique_var_name() - closure_sdfg.add_scalar(out_name, dtype, transient=True) - out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", - {}, - {"__result"}, - f"__result = {name}", - debuginfo=closure_sdfg.debuginfo, - ) - access = closure_init_state.add_access( - out_name, debuginfo=closure_sdfg.debuginfo - ) - value = ValueExpr(access, dtype) - memlet = dace.Memlet(data=out_name, subset="0") - closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) - program_arg_syms[name] = value - else: - program_arg_syms[name] = SymbolExpr(name, dtype) - else: - assert isinstance(type_, ts.FieldType) - # make shape symbols (corresponding to field size) available as arguments to domain visitor - if name in input_names or name in output_names: - field_symbols = [ - val - for val in closure_sdfg.arrays[name].shape - if isinstance(val, dace.symbol) and str(val) not in input_names - ] - for sym in field_symbols: - sym_name = str(sym) - program_arg_syms[sym_name] = SymbolExpr(sym, sym.dtype) - closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - closure_domain = self._visit_domain(node.domain, closure_ctx) - - # Map SDFG tasklet arguments to parameters - input_local_names = [ - ( - input_transients_mapping[input_name] - if input_name in input_transients_mapping - else ( - input_name - if input_name in input_field_names - else cast(ValueExpr, program_arg_syms[input_name]).value.data - ) - ) - for input_name in input_names - ] - input_memlets = [ - dace.Memlet.from_array(name, closure_sdfg.arrays[name]) - for name in [*input_local_names, *connectivity_names] - ] - - # create and write to transient that is then copied back to actual output array to avoid aliasing of - # same memory in nested SDFG with different names - output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} - # scan operator should always be the first function call in a closure - if is_scan(node.stencil): - assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" - transient_name, output_name = next(iter(output_connectors_mapping.items())) - - nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( - node, closure_sdfg.arrays, closure_domain, transient_name - ) - results = [transient_name] - - _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] - output_subset = f"{scan_lb.value}:{scan_ub.value}" - - domain_subset = { - dim: ( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" - ) - for dim, _ in closure_domain - } - output_memlets = [self.create_memlet_at(output_name, domain_subset)] - else: - nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( - node, closure_sdfg.arrays, closure_domain - ) - - output_subset = "0" - - output_memlets = [ - self.create_memlet_at(output_name, {dim: f"i_{dim}" for dim, _ in closure_domain}) - for output_name in output_connectors_mapping.values() - ] - - input_mapping = { - param: arg for param, arg in zip([*input_names, *connectivity_names], input_memlets) - } - output_mapping = {param: memlet for param, memlet in zip(results, output_memlets)} - - symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, input_mapping) - - nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( - closure_state, - sdfg=nsdfg, - map_ranges=map_ranges or {"__dummy": "0"}, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - output_nodes=output_nodes, - debuginfo=nsdfg.debuginfo, - ) - access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} - for edge in closure_state.in_edges(map_exit): - memlet = edge.data - if memlet.data not in output_connectors_mapping: - continue - transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) - closure_state.add_edge( - nsdfg_node, - edge.src_conn, - transient_access, - None, - dace.Memlet(data=memlet.data, subset=output_subset, debuginfo=nsdfg.debuginfo), - ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset - ) - closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) - closure_state.remove_edge(edge) - access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] - - return closure_sdfg, input_field_names + connectivity_names, output_names - - def _visit_scan_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - output_name: str, - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: - # extract scan arguments - is_forward, init_carry_value = _get_scan_args(node.stencil) - # select the scan dimension based on program argument for column axis - assert self.column_axis - assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( - self.column_axis, - self.storage_types, - node.output, - self.use_field_canonical_representation, - ) - - assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - if not dim == scan_dim: - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - else: - scan_lb_str = lb_str - scan_ub_str = ub_str - - # the scan operator is implemented as an SDFG to be nested in the closure SDFG - scan_sdfg = dace.SDFG(name="scan") - scan_sdfg.debuginfo = dace_utils.debug_info(node) - - # the carry value of the scan operator exists only in the scope of the scan sdfg - scan_carry_name = unique_var_name() - scan_sdfg.add_scalar( - scan_carry_name, dtype=dace_utils.as_dace_type(scan_dtype), transient=True - ) - - # create a loop region for lambda call over the scan dimension - scan_loop_var = f"i_{scan_dim}" - if is_forward: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} < {scan_ub_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_lb_str}", - update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", - inverted=False, - ) - else: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} >= {scan_lb_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", - update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", - inverted=False, - ) - scan_sdfg.add_node(scan_loop) - compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) - update_state = scan_loop.add_state("lambda_update") - scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) - - start_state = scan_sdfg.add_state("start", is_start_block=True) - scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) - - # tasklet for initialization of carry - carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", - {}, - {"__result"}, - f"__result = {init_carry_value}", - debuginfo=scan_sdfg.debuginfo, - ) - start_state.add_edge( - carry_init_tasklet, - "__result", - start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), - None, - dace.Memlet(data=scan_carry_name, subset="0"), - ) - - # add storage to scan SDFG for inputs - for name in [*input_names, *connectivity_names]: - assert name not in scan_sdfg.arrays - if isinstance(self.storage_types[name], ts.FieldType): - scan_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - scan_sdfg.add_scalar( - name, - dtype=dace_utils.as_dace_type(cast(ts.ScalarType, self.storage_types[name])), - ) - # add storage to scan SDFG for output - scan_sdfg.add_array( - output_name, - shape=(array_table[node.output.id].shape[scan_dim_index],), - strides=(array_table[node.output.id].strides[scan_dim_index],), - dtype=array_table[node.output.id].dtype, - ) - - # implement the lambda function as a nested SDFG that computes a single item in the scan dimension - lambda_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - input_arrays = [(scan_carry_name, scan_dtype)] + [ - (name, self.storage_types[name]) for name in input_names - ] - connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] - lambda_context, lambda_outputs = closure_to_tasklet_sdfg( - node, - self.offset_provider, - lambda_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - lambda_input_names = [name for name, _ in input_arrays] - lambda_output_names = [connector.value.data for connector in lambda_outputs] - - input_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in lambda_input_names - ] - connectivity_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in connectivity_names - ] - input_mapping = {param: arg for param, arg in zip(lambda_input_names, input_memlets)} - connectivity_mapping = { - param: arg for param, arg in zip(connectivity_names, connectivity_memlets) - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - - scan_inner_node = compute_state.add_nested_sdfg( - lambda_context.body, - parent=scan_sdfg, - inputs=set(lambda_input_names) | set(connectivity_names), - outputs=set(lambda_output_names), - symbol_mapping=symbol_mapping, - debuginfo=lambda_context.body.debuginfo, - ) - - # connect scan SDFG to lambda inputs - for name, memlet in array_mapping.items(): - access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) - compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) - - output_names = [output_name] - assert len(lambda_output_names) == 1 - # connect lambda output to scan SDFG - for name, connector in zip(output_names, lambda_output_names): - compute_state.add_edge( - scan_inner_node, - connector, - compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), - None, - dace.Memlet(data=name, subset=scan_loop_var), - ) - - update_state.add_nedge( - update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), - dace.Memlet(data=output_name, subset=scan_loop_var, other_subset="0"), - ) - - return scan_sdfg, map_ranges, scan_dim_index - - def _visit_parallel_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider) - input_names = [str(inp.id) for inp in node.inputs] - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - - # Create an SDFG for the tasklet that computes a single item of the output domain. - index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - - input_arrays = [(name, self.storage_types[name]) for name in input_names] - connectivity_arrays = [(array_table[name], name) for name in connectivity_names] - - context, results = closure_to_tasklet_sdfg( - node, - self.offset_provider, - index_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - return context.body, map_ranges, [r.value.data for r in results] - - def _visit_domain( - self, node: itir.FunCall, context: Context - ) -> tuple[tuple[str, tuple[SymbolExpr | ValueExpr, SymbolExpr | ValueExpr]], ...]: - assert isinstance(node.fun, itir.SymRef) - assert node.fun.id == "cartesian_domain" or node.fun.id == "unstructured_domain" - - bounds: list[tuple[str, tuple[ValueExpr, ValueExpr]]] = [] - - for named_range in node.args: - assert isinstance(named_range, itir.FunCall) - assert isinstance(named_range.fun, itir.SymRef) - assert len(named_range.args) == 3 - dimension = named_range.args[0] - assert isinstance(dimension, itir.AxisLiteral) - lower_bound = named_range.args[1] - upper_bound = named_range.args[2] - translator = PythonTaskletCodegen( - self.offset_provider, - context, - self.use_field_canonical_representation, - ) - lb = translator.visit(lower_bound)[0] - ub = translator.visit(upper_bound)[0] - bounds.append((dimension.value, (lb, ub))) - - return tuple(bounds) - - @staticmethod - def _check_shift_offsets_are_literals(node: itir.StencilClosure): - fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall) - shifts = [nd for nd in fun_calls if getattr(nd.fun, "id", "") == "shift"] - for shift in shifts: - if not all(isinstance(arg, (itir.Literal, itir.OffsetLiteral)) for arg in shift.args): - return False - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py deleted file mode 100644 index 991053b4a5..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ /dev/null @@ -1,1567 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import copy -import dataclasses -import itertools -from collections.abc import Sequence -from typing import Any, Callable, Optional, TypeAlias, cast - -import dace -import numpy as np - -import gt4py.eve.codegen -from gt4py import eve -from gt4py.next import Dimension -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value, Connectivity -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import FunCall, Lambda -from gt4py.next.iterator.type_system import type_specifications as it_ts -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_name, - unique_var_name, -) - - -_TYPE_MAPPING = { - "float": dace.float64, - "float32": dace.float32, - "float64": dace.float64, - "int": dace.int32 if np.dtype(int).itemsize == 4 else dace.int64, - "int32": dace.int32, - "int64": dace.int64, - "bool": dace.bool_, -} - - -def itir_type_as_dace_type(type_: ts.TypeSpec): - # TODO(tehrengruber): this function just converts the scalar type of whatever it is given, - # let it be a field, iterator, or directly a scalar. The caller should take care of the - # extraction. - dtype: ts.TypeSpec - if isinstance(type_, ts.FieldType): - dtype = type_.dtype - elif isinstance(type_, it_ts.IteratorType): - dtype = type_.element_type - else: - dtype = type_ - assert isinstance(dtype, ts.ScalarType) - return _TYPE_MAPPING[dtype.kind.name.lower()] - - -def get_reduce_identity_value(op_name_: str, type_: Any): - if op_name_ == "plus": - init_value = type_(0) - elif op_name_ == "multiplies": - init_value = type_(1) - elif op_name_ == "minimum": - init_value = type_("inf") - elif op_name_ == "maximum": - init_value = type_("-inf") - else: - raise NotImplementedError() - - return init_value - - -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -# Define type of variables used for field indexing -_INDEX_DTYPE = _TYPE_MAPPING["int64"] - - -@dataclasses.dataclass -class SymbolExpr: - value: dace.symbolic.SymbolicType - dtype: dace.typeclass - - -@dataclasses.dataclass -class ValueExpr: - value: dace.nodes.AccessNode - dtype: dace.typeclass - - -@dataclasses.dataclass -class IteratorExpr: - field: dace.nodes.AccessNode - indices: dict[str, dace.nodes.AccessNode] - dtype: dace.typeclass - dimensions: list[str] - - -# Union of possible expression types -TaskletExpr: TypeAlias = IteratorExpr | SymbolExpr | ValueExpr - - -@dataclasses.dataclass -class Context: - body: dace.SDFG - state: dace.SDFGState - symbol_map: dict[str, TaskletExpr] - # if we encounter a reduction node, the reduction state needs to be pushed to child nodes - reduce_identity: Optional[SymbolExpr] - - def __init__( - self, - body: dace.SDFG, - state: dace.SDFGState, - symbol_map: dict[str, TaskletExpr], - reduce_identity: Optional[SymbolExpr] = None, - ): - self.body = body - self.state = state - self.symbol_map = symbol_map - self.reduce_identity = reduce_identity - - -def _visit_lift_in_neighbors_reduction( - transformer: PythonTaskletCodegen, - node: itir.FunCall, - node_args: Sequence[IteratorExpr | list[ValueExpr]], - offset_provider: Connectivity, - map_entry: dace.nodes.MapEntry, - map_exit: dace.nodes.MapExit, - neighbor_index_node: dace.nodes.AccessNode, - neighbor_value_node: dace.nodes.AccessNode, -) -> list[ValueExpr]: - assert transformer.context.reduce_identity is not None - neighbor_dim = offset_provider.neighbor_axis.value - origin_dim = offset_provider.origin_axis.value - - lifted_args: list[IteratorExpr | ValueExpr] = [] - for arg in node_args: - if isinstance(arg, IteratorExpr): - if origin_dim in arg.indices: - lifted_indices = arg.indices.copy() - lifted_indices.pop(origin_dim) - lifted_indices[neighbor_dim] = neighbor_index_node - lifted_args.append( - IteratorExpr(arg.field, lifted_indices, arg.dtype, arg.dimensions) - ) - else: - lifted_args.append(arg) - else: - lifted_args.append(arg[0]) - - lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args) - assert len(inner_outputs) == 1 - inner_out_connector = inner_outputs[0].value.data - - input_nodes = {} - iterator_index_nodes = {} - lifted_index_connectors = [] - - for x, y in inner_inputs: - if isinstance(y, IteratorExpr): - field_connector, inner_index_table = x - input_nodes[field_connector] = y.field - for dim, connector in inner_index_table.items(): - if dim == neighbor_dim: - lifted_index_connectors.append(connector) - iterator_index_nodes[connector] = y.indices[dim] - else: - assert isinstance(y, ValueExpr) - input_nodes[x] = y.value - - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - parent_sdfg = transformer.context.body - parent_state = transformer.context.state - - input_mapping = { - connector: dace.Memlet.from_array(node.data, node.desc(parent_sdfg)) - for connector, node in input_nodes.items() - } - connectivity_mapping = { - name: dace.Memlet.from_array(name, parent_sdfg.arrays[name]) for name in connectivity_names - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping) - - nested_sdfg_node = parent_state.add_nested_sdfg( - lift_context.body, - parent_sdfg, - inputs={*array_mapping.keys(), *iterator_index_nodes.keys()}, - outputs={inner_out_connector}, - symbol_mapping=symbol_mapping, - debuginfo=lift_context.body.debuginfo, - ) - - for connectivity_connector, memlet in connectivity_mapping.items(): - parent_state.add_memlet_path( - parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo), - map_entry, - nested_sdfg_node, - dst_conn=connectivity_connector, - memlet=memlet, - ) - - for inner_connector, access_node in input_nodes.items(): - parent_state.add_memlet_path( - access_node, - map_entry, - nested_sdfg_node, - dst_conn=inner_connector, - memlet=input_mapping[inner_connector], - ) - - for inner_connector, access_node in iterator_index_nodes.items(): - memlet = dace.Memlet(data=access_node.data, subset="0") - if inner_connector in lifted_index_connectors: - parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) - else: - parent_state.add_memlet_path( - access_node, map_entry, nested_sdfg_node, dst_conn=inner_connector, memlet=memlet - ) - - parent_state.add_memlet_path( - nested_sdfg_node, - map_exit, - neighbor_value_node, - src_conn=inner_out_connector, - memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), - ) - - if offset_provider.has_skip_values: - # check neighbor validity on if/else inter-state edge - # use one branch for connectivity case - start_state = lift_context.body.add_state_before( - lift_context.body.start_state, - "start", - condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}", - ) - # use the other branch for skip value case - skip_neighbor_state = lift_context.body.add_state("skip_neighbor") - skip_neighbor_state.add_edge( - skip_neighbor_state.add_tasklet( - "identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}" - ), - "val", - skip_neighbor_state.add_access(inner_outputs[0].value.data), - None, - dace.Memlet(data=inner_outputs[0].value.data, subset="0"), - ) - lift_context.body.add_edge( - start_state, - skip_neighbor_state, - dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"), - ) - - return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] - - -def builtin_neighbors( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - sdfg: dace.SDFG = transformer.context.body - state: dace.SDFGState = transformer.context.state - - di = dace_utils.debug_info(node, default=sdfg.debuginfo) - offset_literal, data = node_args - assert isinstance(offset_literal, itir.OffsetLiteral) - offset_dim = offset_literal.value - assert isinstance(offset_dim, str) - offset_provider = transformer.offset_provider[offset_dim] - if not isinstance(offset_provider, Connectivity): - raise NotImplementedError( - "Neighbor reduction only implemented for connectivity based on neighbor tables." - ) - - lift_node = None - if isinstance(data, FunCall): - assert isinstance(data.fun, itir.FunCall) - fun_node = data.fun - if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift": - lift_node = fun_node - lift_args = transformer.visit(data.args) - iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None) - if lift_node is None: - iterator = transformer.visit(data) - assert isinstance(iterator, IteratorExpr) - field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[offset_provider.origin_axis.value] - - assert transformer.context.reduce_identity is not None - assert transformer.context.reduce_identity.dtype == iterator.dtype - - # gather the neighbors in a result array dimensioned for `max_neighbors` - neighbor_value_var = unique_var_name() - sdfg.add_array( - neighbor_value_var, - dtype=iterator.dtype, - shape=(offset_provider.max_neighbors,), - transient=True, - ) - neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) - - # allocate scalar to store index for direct addressing of neighbor field - neighbor_index_var = unique_var_name() - sdfg.add_scalar(neighbor_index_var, _INDEX_DTYPE, transient=True) - neighbor_index_node = state.add_access(neighbor_index_var, debuginfo=di) - - # generate unique map index name to avoid conflict with other maps inside same state - neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") - me, mx = state.add_map( - f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{offset_provider.max_neighbors}"}, - debuginfo=di, - ) - - table_name = dace_utils.connectivity_identifier(offset_dim) - shift_tasklet = state.add_tasklet( - "shift", - code=f"__result = __table[__idx, {neighbor_map_index}]", - inputs={"__table", "__idx"}, - outputs={"__result"}, - debuginfo=di, - ) - state.add_memlet_path( - state.add_access(table_name, debuginfo=di), - me, - shift_tasklet, - memlet=dace.Memlet.from_array(table_name, sdfg.arrays[table_name]), - dst_conn="__table", - ) - state.add_memlet_path( - origin_index_node, - me, - shift_tasklet, - memlet=dace.Memlet(data=origin_index_node.data, subset="0"), - dst_conn="__idx", - ) - state.add_edge( - shift_tasklet, - "__result", - neighbor_index_node, - None, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - - if lift_node is not None: - _visit_lift_in_neighbors_reduction( - transformer, - lift_node, - lift_args, - offset_provider, - me, - mx, - neighbor_index_node, - neighbor_value_node, - ) - else: - sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) - data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{offset_provider.neighbor_axis.value}_v" - data_access_tasklet = state.add_tasklet( - "data_access", - code=f"__data = __field[{data_access_index}] " - + ( - f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if offset_provider.has_skip_values - else "" - ), - inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, - outputs={"__data"}, - debuginfo=di, - ) - state.add_memlet_path( - iterator.field, - me, - data_access_tasklet, - memlet=dace.Memlet.from_array(iterator.field.data, field_desc), - dst_conn="__field", - ) - for dim in iterator.dimensions: - connector = f"{dim}_v" - if dim == offset_provider.neighbor_axis.value: - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - connector, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - else: - state.add_memlet_path( - iterator.indices[dim], - me, - data_access_tasklet, - dst_conn=connector, - memlet=dace.Memlet(data=iterator.indices[dim].data, subset="0"), - ) - - state.add_memlet_path( - data_access_tasklet, - mx, - neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index), - src_conn="__data", - ) - - if not offset_provider.has_skip_values: - return [ValueExpr(neighbor_value_node, iterator.dtype)] - else: - """ - In case of neighbor tables with skip values, in addition to the array of neighbor values this function also - returns an array of booleans to indicate if the neighbor value is present or not. This node is only used - for neighbor reductions with lambda functions, a very specific case. For single input neighbor reductions, - the regular case, this node will be removed by the simplify pass. - """ - neighbor_valid_var = unique_var_name() - sdfg.add_array( - neighbor_valid_var, - dtype=dace.dtypes.bool, - shape=(offset_provider.max_neighbors,), - transient=True, - ) - neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) - - neighbor_valid_tasklet = state.add_tasklet( - f"check_valid_neighbor_{offset_dim}", - {"__idx"}, - {"__valid"}, - f"__valid = True if __idx != {neighbor_skip_value} else False", - debuginfo=di, - ) - state.add_edge( - neighbor_index_node, - None, - neighbor_valid_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - state.add_memlet_path( - neighbor_valid_tasklet, - mx, - neighbor_valid_node, - memlet=dace.Memlet(data=neighbor_valid_var, subset=neighbor_map_index), - src_conn="__valid", - ) - return [ - ValueExpr(neighbor_value_node, iterator.dtype), - ValueExpr(neighbor_valid_node, dace.dtypes.bool), - ] - - -def builtin_can_deref( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - # first visit shift, to get set of indices for deref - can_deref_callable = node_args[0] - assert isinstance(can_deref_callable, itir.FunCall) - shift_callable = can_deref_callable.fun - assert isinstance(shift_callable, itir.FunCall) - assert isinstance(shift_callable.fun, itir.SymRef) - assert shift_callable.fun.id == "shift" - iterator = transformer._visit_shift(can_deref_callable) - - # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it - if not isinstance(iterator, IteratorExpr): - assert len(iterator) == 1 and isinstance(iterator[0], ValueExpr) - # We can always deref a value expression, therefore hard-code `can_deref` to True. - # Returning a SymbolExpr would be preferable, but it requires update to type-checking. - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name, debuginfo=di) - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di - ), - "_out", - result_node, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_node, dace.dtypes.bool)] - - # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] - internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) - - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref", dace_debuginfo=di - ) - - -def builtin_if( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - assert len(node_args) == 3 - sdfg = transformer.context.body - current_state = transformer.context.state - is_start_state = sdfg.start_block == current_state - - # build an empty state to join true and false branches - join_state = sdfg.add_state_before(current_state, "join") - - def build_if_state(arg, state): - symbol_map = copy.deepcopy(transformer.context.symbol_map) - node_context = Context(sdfg, state, symbol_map) - node_taskgen = PythonTaskletCodegen( - transformer.offset_provider, - node_context, - transformer.use_field_canonical_representation, - ) - return node_taskgen.visit(arg) - - # represent the if-statement condition as a tasklet inside an `if_statement` state preceding `join` state - stmt_state = sdfg.add_state_before(join_state, "if_statement", is_start_state) - stmt_node = build_if_state(node_args[0], stmt_state)[0] - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - assert sdfg.arrays[stmt_node.value.data].shape == (1,) - - # visit true and false branches (here called `tbr` and `fbr`) as separate states, following `if_statement` state - tbr_state = sdfg.add_state("true_branch") - sdfg.add_edge( - stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True") - ) - sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge()) - tbr_values = flatten_list(build_if_state(node_args[1], tbr_state)) - # - fbr_state = sdfg.add_state("false_branch") - sdfg.add_edge( - stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False") - ) - sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge()) - fbr_values = flatten_list(build_if_state(node_args[2], fbr_state)) - - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - # make the result of the if-statement evaluation available inside current state - ctx_stmt_node = ValueExpr(current_state.add_access(stmt_node.value.data), stmt_node.dtype) - - # we distinguish between select if-statements, where both true and false branches are symbolic expressions, - # and therefore do not require exclusive branch execution, and regular if-statements where at least one branch - # is a value expression, which has to be evaluated at runtime with conditional state transition - result_values = [] - assert len(tbr_values) == len(fbr_values) - for tbr_value, fbr_value in zip(tbr_values, fbr_values): - assert isinstance(tbr_value, (SymbolExpr, ValueExpr)) - assert isinstance(fbr_value, (SymbolExpr, ValueExpr)) - assert tbr_value.dtype == fbr_value.dtype - - if all(isinstance(x, SymbolExpr) for x in (tbr_value, fbr_value)): - # both branches return symbolic expressions, therefore the if-node can be translated - # to a select-tasklet inside current state - # TODO: use select-memlet when it becomes available in dace - code = f"{tbr_value.value} if _cond else {fbr_value.value}" - if_expr = transformer.add_expr_tasklet( - [(ctx_stmt_node, "_cond")], code, tbr_value.dtype, "if_select" - )[0] - result_values.append(if_expr) - else: - # at least one of the two branches contains a value expression, which should be evaluated - # only if the corresponding true/false condition is satisfied - desc = sdfg.arrays[ - tbr_value.value.data if isinstance(tbr_value, ValueExpr) else fbr_value.value.data - ] - var = unique_var_name() - if isinstance(desc, dace.data.Scalar): - sdfg.add_scalar(var, desc.dtype, transient=True) - else: - sdfg.add_array(var, desc.shape, desc.dtype, transient=True) - - # write result to transient data container and access it in the original state - for state, expr in [(tbr_state, tbr_value), (fbr_state, fbr_value)]: - val_node = state.add_access(var) - if isinstance(expr, ValueExpr): - state.add_nedge( - expr.value, val_node, dace.Memlet.from_array(expr.value.data, desc) - ) - else: - assert desc.shape == (1,) - state.add_edge( - state.add_tasklet("write_symbol", {}, {"_out"}, f"_out = {expr.value}"), - "_out", - val_node, - None, - dace.Memlet(var, "0"), - ) - result_values.append(ValueExpr(current_state.add_access(var), desc.dtype)) - - if tbr_state.is_empty() and fbr_state.is_empty(): - # if all branches are symbolic expressions, the true/false and join states can be removed - # as well as the conditional state transition - sdfg.remove_nodes_from([join_state, tbr_state, fbr_state]) - sdfg.add_edge(stmt_state, current_state, dace.InterstateEdge()) - elif tbr_state.is_empty(): - # use direct edge from if-statement to join state for true branch - tbr_condition = sdfg.edges_between(stmt_state, tbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = tbr_condition - sdfg.remove_node(tbr_state) - elif fbr_state.is_empty(): - # use direct edge from if-statement to join state for false branch - fbr_condition = sdfg.edges_between(stmt_state, fbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = fbr_condition - sdfg.remove_node(fbr_state) - else: - # remove direct edge from if-statement to join state - sdfg.remove_edge(sdfg.edges_between(stmt_state, join_state)[0]) - # the if-statement condition is not used in current state - current_state.remove_node(ctx_stmt_node.value) - - return result_values - - -def builtin_list_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = list(itertools.chain(*transformer.visit(node_args))) - assert len(args) == 2 - # index node - if isinstance(args[0], SymbolExpr): - index_value = args[0].value - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) - result_node = transformer.context.state.add_access(result_name) - transformer.context.state.add_nedge( - args[1].value, result_node, dace.Memlet(data=args[1].value.data, subset=index_value) - ) - return [ValueExpr(result_node, args[1].dtype)] - - else: - expr_args = [(arg, f"{arg.value.data}_v") for arg in args] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet( - expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di - ) - - -def builtin_cast( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = transformer.visit(node_args[0]) - internals = [f"{arg.value.data}_v" for arg in args] - target_type = node_args[1] - assert isinstance(target_type, itir.SymRef) - expr = _MATH_BUILTINS_MAPPING[target_type.id].format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr, type_, "cast", dace_debuginfo=di - ) - - -def builtin_make_const_list( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = [transformer.visit(arg)[0] for arg in node_args] - assert all(isinstance(x, (SymbolExpr, ValueExpr)) for x in args) - args_dtype = [x.dtype for x in args] - assert len(set(args_dtype)) == 1 - dtype = args_dtype[0] - - var_name = unique_var_name() - transformer.context.body.add_array(var_name, (len(args),), dtype, transient=True) - var_node = transformer.context.state.add_access(var_name, debuginfo=di) - - for i, arg in enumerate(args): - if isinstance(arg, SymbolExpr): - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - f"get_arg{i}", {}, {"val"}, f"val = {arg.value}" - ), - "val", - var_node, - None, - dace.Memlet(data=var_name, subset=f"{i}"), - ) - else: - assert arg.value.desc(transformer.context.body).shape == (1,) - transformer.context.state.add_nedge( - arg.value, - var_node, - dace.Memlet(data=arg.value.data, subset="0", other_subset=f"{i}"), - ) - - return [ValueExpr(var_node, dtype)] - - -def builtin_make_tuple( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - args = [transformer.visit(arg) for arg in node_args] - return args - - -def builtin_tuple_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - elements = transformer.visit(node_args[1]) - index = node_args[0] - if isinstance(index, itir.Literal): - return [elements[int(index.value)]] - raise ValueError("Tuple can only be subscripted with compile-time constants.") - - -_GENERAL_BUILTIN_MAPPING: dict[ - str, Callable[[PythonTaskletCodegen, itir.Expr, list[itir.Expr]], list[ValueExpr]] -] = { - "can_deref": builtin_can_deref, - "cast_": builtin_cast, - "if_": builtin_if, - "list_get": builtin_list_get, - "make_const_list": builtin_make_const_list, - "make_tuple": builtin_make_tuple, - "neighbors": builtin_neighbors, - "tuple_get": builtin_tuple_get, -} - - -class GatherLambdaSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]] - _parent_symbol_map: dict[str, TaskletExpr] - - def __init__(self, sdfg, state, parent_symbol_map): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - self._parent_symbol_map = parent_symbol_map - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the lambda expression.""" - return self._symbol_map - - def _add_symbol(self, param, arg): - if isinstance(arg, ValueExpr): - # create storage in lambda sdfg - self._sdfg.add_scalar(param, dtype=arg.dtype) - # update table of lambda symbols - self._symbol_map[param] = ValueExpr( - self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype - ) - elif isinstance(arg, IteratorExpr): - # create storage in lambda sdfg - ndims = len(arg.dimensions) - shape, strides = new_array_symbols(param, ndims) - self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) - index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} - for _, index_name in index_names.items(): - self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) - # update table of lambda symbols - field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - indices = { - dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) - for dim, index_arg in index_names.items() - } - self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) - else: - assert isinstance(arg, SymbolExpr) - self._symbol_map[param] = arg - - def _add_tuple(self, param, args): - nodes = [] - # create storage in lambda sdfg for each tuple element - for arg in args: - var = unique_var_name() - self._sdfg.add_scalar(var, dtype=arg.dtype) - arg_node = self._state.add_access(var, debuginfo=self._sdfg.debuginfo) - nodes.append(ValueExpr(arg_node, arg.dtype)) - # update table of lambda symbols - self._symbol_map[param] = tuple(nodes) - - def visit_SymRef(self, node: itir.SymRef): - name = str(node.id) - if name in self._parent_symbol_map and name not in self._symbol_map: - arg = self._parent_symbol_map[name] - self._add_symbol(name, arg) - - def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None): - if args is not None: - if len(node.params) == len(args): - for param, arg in zip(node.params, args): - self._add_symbol(str(param.id), arg) - else: - # implicitly make tuple - assert len(node.params) == 1 - self._add_tuple(str(node.params[0].id), args) - self.visit(node.expr) - - -class GatherOutputSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr] - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the output expression.""" - return self._symbol_map - - def __init__(self, sdfg, state): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - - def visit_SymRef(self, node: itir.SymRef): - param = str(node.id) - if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: - access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - self._symbol_map[param] = ValueExpr( - access_node, - dtype=itir_type_as_dace_type(node.type), # type: ignore[arg-type] # ensure by type inference - ) - - -class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider: dict[str, Any] - context: Context - use_field_canonical_representation: bool - - def __init__( - self, - offset_provider: dict[str, Any], - context: Context, - use_field_canonical_representation: bool, - ): - self.offset_provider = offset_provider - self.context = context - self.use_field_canonical_representation = use_field_canonical_representation - - def get_sorted_field_dimensions(self, dims: Sequence[str]): - return sorted(dims) if self.use_field_canonical_representation else dims - - def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): - raise NotImplementedError() - - def visit_Lambda( - self, node: itir.Lambda, args: Sequence[TaskletExpr], use_neighbor_tables: bool = True - ) -> tuple[ - Context, - list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], - list[ValueExpr], - ]: - func_name = f"lambda_{abs(hash(node)):x}" - neighbor_tables = ( - get_used_connectivities(node, self.offset_provider) if use_neighbor_tables else {} - ) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # Create the SDFG for the lambda's body - lambda_sdfg = dace.SDFG(func_name) - lambda_sdfg.debuginfo = dace_utils.debug_info(node, default=self.context.body.debuginfo) - lambda_state = lambda_sdfg.add_state(f"{func_name}_body", is_start_block=True) - - lambda_symbols_pass = GatherLambdaSymbolsPass( - lambda_sdfg, lambda_state, self.context.symbol_map - ) - lambda_symbols_pass.visit(node, args=args) - - # Add for input nodes for lambda symbols - inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = [] - for sym, input_node in lambda_symbols_pass.symbol_refs.items(): - params = [str(p.id) for p in node.params] - try: - param_index = params.index(sym) - except ValueError: - param_index = -1 - if param_index >= 0: - outer_node = args[param_index] - else: - # the symbol is not found among lambda arguments, then it is inherited from parent scope - outer_node = self.context.symbol_map[sym] - if isinstance(input_node, IteratorExpr): - assert isinstance(outer_node, IteratorExpr) - index_params = { - dim: index_node.data for dim, index_node in input_node.indices.items() - } - inputs.append(((sym, index_params), outer_node)) - elif isinstance(input_node, ValueExpr): - assert isinstance(outer_node, ValueExpr) - inputs.append((sym, outer_node)) - elif isinstance(input_node, tuple): - assert param_index >= 0 - for i, input_node_i in enumerate(input_node): - arg_i = args[param_index + i] - assert isinstance(arg_i, ValueExpr) - assert isinstance(input_node_i, ValueExpr) - inputs.append((input_node_i.value.data, arg_i)) - - # Add connectivities as arrays - for name in connectivity_names: - shape, strides = new_array_symbols(name, ndim=2) - dtype = self.context.body.arrays[name].dtype - lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - # Translate the lambda's body in its own context - lambda_context = Context( - lambda_sdfg, - lambda_state, - lambda_symbols_pass.symbol_refs, - reduce_identity=self.context.reduce_identity, - ) - lambda_taskgen = PythonTaskletCodegen( - self.offset_provider, - lambda_context, - self.use_field_canonical_representation, - ) - - results: list[ValueExpr] = [] - # We are flattening the returned list of value expressions because the multiple outputs of a lambda - # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. - node.expr.location = node.location - for expr in flatten_list(lambda_taskgen.visit(node.expr)): - if isinstance(expr, ValueExpr): - result_name = unique_var_name() - lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access( - result_name, debuginfo=lambda_sdfg.debuginfo - ) - lambda_state.add_nedge( - expr.value, result_access, dace.Memlet(data=result_access.data, subset="0") - ) - result = ValueExpr(value=result_access, dtype=expr.dtype) - else: - # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet( - [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo - )[0] - lambda_sdfg.arrays[result.value.data].transient = False - results.append(result) - - # remove isolated access nodes for connectivity arrays not consumed by lambda - for sub_node in lambda_state.nodes(): - if isinstance(sub_node, dace.nodes.AccessNode): - if lambda_state.out_degree(sub_node) == 0 and lambda_state.in_degree(sub_node) == 0: - lambda_state.remove_node(sub_node) - - return lambda_context, inputs, results - - def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: - param = str(node.id) - value = self.context.symbol_map[param] - if isinstance(value, (ValueExpr, SymbolExpr)): - return [value] - return value - - def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: - return [SymbolExpr(node.value, itir_type_as_dace_type(node.type))] - - def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: - node.fun.location = node.location - if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": - return self._visit_deref(node) - if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): - if node.fun.fun.id == "shift": - return self._visit_shift(node) - elif node.fun.fun.id == "reduce": - return self._visit_reduce(node) - - if isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - elif builtin_name in _GENERAL_BUILTIN_MAPPING: - return self._visit_general_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - return self._visit_call(node) - - def _visit_call(self, node: itir.FunCall): - args = self.visit(node.args) - args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] - args = list(itertools.chain(*args)) - node.fun.location = node.location - func_context, func_inputs, results = self.visit(node.fun, args=args) - - nsdfg_inputs = {} - for name, value in func_inputs: - if isinstance(value, ValueExpr): - nsdfg_inputs[name] = dace.Memlet.from_array( - value.value.data, self.context.body.arrays[value.value.data] - ) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - nsdfg_inputs[field] = dace.Memlet.from_array( - value.field.data, self.context.body.arrays[value.field.data] - ) - for dim, var in indices.items(): - store = value.indices[dim].data - nsdfg_inputs[var] = dace.Memlet.from_array( - store, self.context.body.arrays[store] - ) - - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) - - symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) - - nsdfg_node = self.context.state.add_nested_sdfg( - func_context.body, - None, - inputs=set(nsdfg_inputs.keys()), - outputs=set(r.value.data for r in results), - symbol_mapping=symbol_mapping, - debuginfo=dace_utils.debug_info(node, default=func_context.body.debuginfo), - ) - - for name, value in func_inputs: - if isinstance(value, ValueExpr): - value_memlet = nsdfg_inputs[name] - self.context.state.add_edge(value.value, None, nsdfg_node, name, value_memlet) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - field_memlet = nsdfg_inputs[field] - self.context.state.add_edge(value.field, None, nsdfg_node, field, field_memlet) - for dim, var in indices.items(): - store = value.indices[dim] - idx_memlet = nsdfg_inputs[var] - self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) - self.context.state.add_edge(access, None, nsdfg_node, var, memlet) - - result_exprs = [] - for result in results: - name = unique_var_name() - self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) - result_exprs.append(ValueExpr(result_access, result.dtype)) - memlet = dace.Memlet.from_array(name, self.context.body.arrays[name]) - self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) - - return result_exprs - - def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # already a list of ValueExpr - return iterator - - sorted_dims = self.get_sorted_field_dimensions(iterator.dimensions) - if all([dim in iterator.indices for dim in iterator.dimensions]): - # The deref iterator has index values on all dimensions: the result will be a scalar - args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet( - list(zip(args, internals)), expr, iterator.dtype, "deref", dace_debuginfo=di - ) - - else: - dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] - assert len(dims_not_indexed) == 1 - offset = dims_not_indexed[0] - offset_provider = self.offset_provider[offset] - neighbor_dim = offset_provider.neighbor_axis.value - - result_name = unique_var_name() - self.context.body.add_array( - result_name, (offset_provider.max_neighbors,), iterator.dtype, transient=True - ) - result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name, debuginfo=di) - - deref_connectors = ["_inp"] + [ - f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices - ] - deref_nodes = [iterator.field] + [ - iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices - ] - deref_memlets = [ - dace.Memlet.from_array(iterator.field.data, iterator.field.desc(self.context.body)) - ] + [dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]] - - # we create a mapped tasklet for array slicing - index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider.max_neighbors}"} - src_subset = ",".join( - [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] - ) - self.context.state.add_mapped_tasklet( - "deref", - map_ranges, - inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, - outputs={"_out": dace.Memlet.from_array(result_name, result_array)}, - code=f"_out[{index_name}] = _inp[{src_subset}]", - external_edges=True, - input_nodes={node.data: node for node in deref_nodes}, - output_nodes={result_name: result_node}, - debuginfo=di, - ) - return [ValueExpr(result_node, iterator.dtype)] - - def _split_shift_args( - self, args: list[itir.Expr] - ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: - pairs = [args[i : i + 2] for i in range(0, len(args), 2)] - assert len(pairs) >= 1 - assert all(len(pair) == 2 for pair in pairs) - return pairs[-1], list(itertools.chain(*pairs[0:-1])) if len(pairs) > 1 else None - - def _make_shift_for_rest(self, rest, iterator): - return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), - args=[iterator], - location=iterator.location, - ) - - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - shift = node.fun - assert isinstance(shift, itir.FunCall) - tail, rest = self._split_shift_args(shift.args) - if rest: - iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) - else: - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # shift cannot be applied because the argument is not iterable - # TODO: remove this special case when ITIR pass is able to catch it - assert isinstance(iterator, list) and len(iterator) == 1 - assert isinstance(iterator[0], ValueExpr) - return iterator - - assert isinstance(tail[0], itir.OffsetLiteral) - offset_dim = tail[0].value - assert isinstance(offset_dim, str) - offset_node = self.visit(tail[1])[0] - assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - - if isinstance(self.offset_provider[offset_dim], Connectivity): - offset_provider = self.offset_provider[offset_dim] - connectivity = self.context.state.add_access( - dace_utils.connectivity_identifier(offset_dim), debuginfo=di - ) - - shifted_dim = offset_provider.origin_axis.value - target_dim = offset_provider.neighbor_axis.value - args = [ - ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), - offset_node, - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" - else: - assert isinstance(self.offset_provider[offset_dim], Dimension) - - shifted_dim = self.offset_provider[offset_dim].value - target_dim = shifted_dim - args = [ValueExpr(iterator.indices[shifted_dim], offset_node.dtype), offset_node] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]} + {internals[1]}" - - shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, offset_node.dtype, "shift", dace_debuginfo=di - )[0].value - - shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim] - shifted_index[target_dim] = shifted_value - - return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - - def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - offset = node.value - assert isinstance(offset, int) - offset_var = unique_var_name() - self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) - offset_node = self.context.state.add_access(offset_var, debuginfo=di) - tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di - ) - self.context.state.add_edge( - tasklet_node, "__out", offset_node, None, dace.Memlet(data=offset_var, subset="0") - ) - return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] - - def _visit_reduce(self, node: itir.FunCall): - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - reduce_dtype = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - - if len(node.args) == 1: - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - reduce_identity = node.fun.args[1] - assert isinstance(reduce_identity, itir.Literal) - - # set reduction state - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args[0]) - - assert 1 <= len(args) <= 2 - reduce_input_node = args[0].value - - else: - assert isinstance(node.fun, itir.FunCall) - assert isinstance(node.fun.args[0], itir.Lambda) - fun_node = node.fun.args[0] - assert isinstance(fun_node.expr, itir.FunCall) - - op_name = fun_node.expr.fun - assert isinstance(op_name, itir.SymRef) - reduce_identity = get_reduce_identity_value(op_name.id, reduce_dtype) - - # set reduction state in visit context - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args) - - # clear context - self.context.reduce_identity = None - - # check that all neighbor expressions have the same shape - args_shape = [ - arg[0].value.desc(self.context.body).shape - for arg in args - if arg[0].value.desc(self.context.body).shape != (1,) - ] - assert len(set(args_shape)) == 1 - nreduce_shape = args_shape[0] - - input_args = [arg[0] for arg in args] - input_valid_args = [arg[1] for arg in args if len(arg) == 2] - - assert len(nreduce_shape) == 1 - nreduce_index = unique_name("_i") - nreduce_domain = {nreduce_index: f"0:{nreduce_shape[0]}"} - - reduce_input_name = unique_var_name() - self.context.body.add_array( - reduce_input_name, nreduce_shape, reduce_dtype, transient=True - ) - - lambda_node = itir.Lambda( - expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location - ) - lambda_context, inner_inputs, inner_outputs = self.visit( - lambda_node, args=input_args, use_neighbor_tables=False - ) - - input_mapping = { - param: ( - dace.Memlet(data=arg.value.data, subset="0") - if arg.value.desc(self.context.body).shape == (1,) - else dace.Memlet(data=arg.value.data, subset=nreduce_index) - ) - for (param, _), arg in zip(inner_inputs, input_args) - } - output_mapping = { - inner_outputs[0].value.data: dace.Memlet( - data=reduce_input_name, subset=nreduce_index - ) - } - symbol_mapping = map_nested_sdfg_symbols( - self.context.body, lambda_context.body, input_mapping - ) - - if input_valid_args: - """ - The neighbors builtin returns an array of booleans in case the connectivity table contains skip values. - These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select - the result of field access or the identity value, respectively. - If the neighbor table has full connectivity (no skip values by type definition), the input_valid node - is not built, and the construction of the select tasklet below is also skipped. - """ - input_args.append(input_valid_args[0]) - input_valid_node = input_valid_args[0].value - lambda_output_node = inner_outputs[0].value - # add input connector to nested sdfg - lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool) - input_mapping["_valid_neighbor"] = dace.Memlet( - data=input_valid_node.data, subset=nreduce_index - ) - # add select tasklet before writing to output node - # TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API - output_edge = lambda_context.state.in_edges(lambda_output_node)[0] - assert isinstance( - lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar - ) - select_tasklet = lambda_context.state.add_tasklet( - "neighbor_select", - {"_inp", "_valid"}, - {"_out"}, - f"_out = _inp if _valid else {reduce_identity}", - ) - lambda_context.state.add_edge( - output_edge.src, - None, - select_tasklet, - "_inp", - dace.Memlet(data=output_edge.src.data, subset="0"), - ) - lambda_context.state.add_edge( - lambda_context.state.add_access("_valid_neighbor"), - None, - select_tasklet, - "_valid", - dace.Memlet(data="_valid_neighbor", subset="0"), - ) - lambda_context.state.add_edge( - select_tasklet, - "_out", - lambda_output_node, - None, - dace.Memlet(data=lambda_output_node.data, subset="0"), - ) - lambda_context.state.remove_edge(output_edge) - - reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) - - nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( - self.context.state, - sdfg=lambda_context.body, - map_ranges=nreduce_domain, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - input_nodes={arg.value.data: arg.value for arg in input_args}, - output_nodes={reduce_input_name: reduce_input_node}, - debuginfo=di, - ) - - reduce_input_desc = reduce_input_node.desc(self.context.body) - - result_name = unique_var_name() - # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node - self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") - reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) - self.context.state.add_nedge( - reduce_input_node, - reduce_node, - dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc), - ) - self.context.state.add_nedge( - reduce_node, result_access, dace.Memlet(data=result_name, subset="0") - ) - - return [ValueExpr(result_access, reduce_dtype)] - - def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = flatten_list(self.visit(node.args)) - expr_args = [ - (arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr) - ] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args - ] - expr = fmt.format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return self.add_expr_tasklet( - expr_args, - expr, - type_, - "numeric", - dace_debuginfo=dace_utils.debug_info(node, default=self.context.body.debuginfo), - ) - - def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - expr_func = _GENERAL_BUILTIN_MAPPING[str(node.fun.id)] - return expr_func(self, node, node.args) - - def add_expr_tasklet( - self, - args: list[tuple[ValueExpr, str]], - expr: str, - result_type: Any, - name: str, - dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, - ) -> list[ValueExpr]: - di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo - result_name = unique_var_name() - self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - expr_tasklet = self.context.state.add_tasklet( - name=name, - inputs={internal for _, internal in args}, - outputs={"__result"}, - code=f"__result = {expr}", - debuginfo=di, - ) - - for arg, internal in args: - edges = self.context.state.in_edges(expr_tasklet) - used = False - for edge in edges: - if edge.dst_conn == internal: - used = True - break - if used: - continue - elif not isinstance(arg, SymbolExpr): - memlet = dace.Memlet.from_array( - arg.value.data, self.context.body.arrays[arg.value.data] - ) - self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - - memlet = dace.Memlet(data=result_access.data, subset="0") - self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) - - return [ValueExpr(result_access, result_type)] - - -def is_scan(node: itir.Node) -> bool: - return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") - - -def closure_to_tasklet_sdfg( - node: itir.StencilClosure, - offset_provider: dict[str, Any], - domain: dict[str, str], - inputs: Sequence[tuple[str, ts.TypeSpec]], - connectivities: Sequence[tuple[dace.ndarray, str]], - use_field_canonical_representation: bool, -) -> tuple[Context, Sequence[ValueExpr]]: - body = dace.SDFG("tasklet_toplevel") - body.debuginfo = dace_utils.debug_info(node) - state = body.add_state("tasklet_toplevel_entry", True) - symbol_map: dict[str, TaskletExpr] = {} - - idx_accesses = {} - for dim, idx in domain.items(): - name = f"{idx}_value" - body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet( - f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo - ) - access = state.add_access(name, debuginfo=body.debuginfo) - idx_accesses[dim] = access - state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0")) - for name, ty in inputs: - if isinstance(ty, ts.FieldType): - ndim = len(ty.dims) - shape, strides = new_array_symbols(name, ndim) - dims = [dim.value for dim in ty.dims] - dtype = dace_utils.as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name, debuginfo=body.debuginfo) - indices = {dim: idx_accesses[dim] for dim in domain.keys()} - symbol_map[name] = IteratorExpr(field, indices, dtype, dims) - else: - assert isinstance(ty, ts.ScalarType) - dtype = dace_utils.as_dace_type(ty) - body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) - for arr, name in connectivities: - shape, strides = new_array_symbols(name, ndim=2) - body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) - - context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen(offset_provider, context, use_field_canonical_representation) - - args = [itir.SymRef(id=name) for name, _ in inputs] - if is_scan(node.stencil): - stencil = cast(FunCall, node.stencil) - assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda( - expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location - ) - fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) - else: - fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) - - results = translator.visit(fun_node) - for r in results: - context.body.arrays[r.value.data].transient = False - - return context, results diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py deleted file mode 100644 index d808fbfbe1..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ /dev/null @@ -1,146 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import itertools -from typing import Any, Mapping - -import dace - -import gt4py.next.iterator.ir as itir -from gt4py import eve -from gt4py.next.common import Connectivity -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils - - -def get_used_connectivities( - node: itir.Node, offset_provider: Mapping[str, Any] -) -> dict[str, Connectivity]: - connectivities = dace_utils.filter_connectivities(offset_provider) - offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) - return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} - - -def map_nested_sdfg_symbols( - parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] -) -> dict[str, str]: - symbol_mapping: dict[str, str] = {} - for param, arg in array_mapping.items(): - arg_array = parent_sdfg.arrays[arg.data] - param_array = nested_sdfg.arrays[param] - if not isinstance(param_array, dace.data.Scalar): - assert len(arg.subset.size()) == len(param_array.shape) - for arg_shape, param_shape in zip(arg.subset.size(), param_array.shape): - if isinstance(param_shape, dace.symbol): - symbol_mapping[str(param_shape)] = str(arg_shape) - assert len(arg_array.strides) == len(param_array.strides) - for arg_stride, param_stride in zip(arg_array.strides, param_array.strides): - if isinstance(param_stride, dace.symbol): - symbol_mapping[str(param_stride)] = str(arg_stride) - else: - assert arg.subset.num_elements() == 1 - for sym in nested_sdfg.free_symbols: - if str(sym) not in symbol_mapping: - symbol_mapping[str(sym)] = str(sym) - return symbol_mapping - - -def add_mapped_nested_sdfg( - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, -) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit - - -def unique_name(prefix): - unique_id = getattr(unique_name, "_unique_id", 0) # static variable - setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] - - return f"{prefix}_{unique_id}" - - -def unique_var_name(): - return unique_name("_var") - - -def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: - dtype = dace.int64 - shape = [dace.symbol(unique_name(f"{name}_shape{i}"), dtype) for i in range(ndim)] - strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i in range(ndim)] - return shape, strides - - -def flatten_list(node_list: list[Any]) -> list[Any]: - return list( - itertools.chain.from_iterable( - [flatten_list(e) if isinstance(e, list) else [e] for e in node_list] - ) - ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py deleted file mode 100644 index 7a442e3819..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ /dev/null @@ -1,150 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -import functools -from typing import Callable, Optional, Sequence - -import dace -import factory - -from gt4py._core import definitions as core_defs -from gt4py.next import common, config -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.otf import languages, recipes, stages, step_types, workflow -from gt4py.next.otf.binding import interface -from gt4py.next.otf.languages import LanguageSettings -from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.type_system import type_specifications as ts - -from . import build_sdfg_from_itir - - -@dataclasses.dataclass(frozen=True) -class DaCeTranslator( - workflow.ChainableWorkflowMixin[ - stages.CompilableProgram, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] - ], - step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], -): - auto_optimize: bool = False - lift_mode: LiftMode = LiftMode.FORCE_INLINE - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None - use_field_canonical_representation: bool = False - - def _language_settings(self) -> languages.LanguageSettings: - return languages.LanguageSettings( - formatter_key="", formatter_style="", file_extension="sdfg" - ) - - def generate_sdfg( - self, - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider: dict[str, common.Dimension | common.Connectivity], - column_axis: Optional[common.Dimension], - ) -> dace.SDFG: - on_gpu = ( - True - if self.device_type in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] - else False - ) - - return build_sdfg_from_itir( - program, - arg_types, - offset_provider=offset_provider, - auto_optimize=self.auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - lift_mode=self.lift_mode, - symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, - load_sdfg_from_file=False, - save_sdfg=False, - use_field_canonical_representation=self.use_field_canonical_representation, - ) - - def __call__( - self, inp: stages.CompilableProgram - ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: - """Generate DaCe SDFG file from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data - assert isinstance(program, itir.FencilDefinition) - - sdfg = self.generate_sdfg( - program, - inp.args.args, - inp.args.offset_provider, - inp.args.column_axis, - ) - - param_types = tuple( - interface.Parameter(param, arg) for param, arg in zip(sdfg.arg_names, inp.args.args) - ) - - module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( - stages.ProgramSource( - entry_point=interface.Function(program.id, param_types), - source_code=sdfg.to_json(), - library_deps=tuple(), - language=languages.SDFG, - language_settings=self._language_settings(), - implicit_domain=inp.data.implicit_domain, - ) - ) - return module - - -class DaCeTranslationStepFactory(factory.Factory): - class Meta: - model = DaCeTranslator - - -def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: - return stages.CompilableSource(program_source=inp, binding_source=None) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( - lambda: config.CMAKE_BUILD_TYPE - ) - use_field_canonical_representation: bool = False - - translation = factory.SubFactory( - DaCeTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), - ) - bindings = _no_bindings - compilation = factory.SubFactory( - dace_workflow.DaCeCompilationStepFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - dace_workflow.convert_args, - device=o.device_type, - use_field_canonical_representation=o.use_field_canonical_representation, - ) - ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 35db4cb7f2..12f5f34a7e 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,18 +7,18 @@ # SPDX-License-Identifier: BSD-3-Clause import functools +import pathlib +import tempfile import warnings -from typing import Any +from typing import Any, Optional +import diskcache import factory -import numpy.typing as npt +import filelock import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators -from gt4py.eve.utils import content_hash from gt4py.next import backend, common, config -from gt4py.next.iterator import transforms -from gt4py.next.iterator.transforms import global_tmps from gt4py.next.otf import arguments, recipes, stages, workflow from gt4py.next.otf.binding import nanobind from gt4py.next.otf.compilation import compiler @@ -61,8 +61,8 @@ def decorated_program( def _ensure_is_on_device( - connectivity_arg: npt.NDArray, device: core_defs.DeviceType -) -> npt.NDArray: + connectivity_arg: core_defs.NDArrayObject, device: core_defs.DeviceType +) -> core_defs.NDArrayObject: if device in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]: import cupy as cp @@ -77,17 +77,17 @@ def _ensure_is_on_device( def extract_connectivity_args( offset_provider: dict[str, common.Connectivity | common.Dimension], device: core_defs.DeviceType -) -> list[tuple[npt.NDArray, tuple[int, ...]]]: +) -> list[tuple[core_defs.NDArrayObject, tuple[int, ...]]]: # note: the order here needs to agree with the order of the generated bindings - args: list[tuple[npt.NDArray, tuple[int, ...]]] = [] + args: list[tuple[core_defs.NDArrayObject, tuple[int, ...]]] = [] for name, conn in offset_provider.items(): if isinstance(conn, common.Connectivity): - if not isinstance(conn, common.NeighborTable): + if not common.is_neighbor_table(conn): raise NotImplementedError( "Only 'NeighborTable' connectivities implemented at this point." ) # copying to device here is a fallback for easy testing and might be removed later - conn_arg = _ensure_is_on_device(conn.table, device) + conn_arg = _ensure_is_on_device(conn.ndarray, device) args.append((conn_arg, tuple([0] * 2))) elif isinstance(conn, common.Dimension): pass @@ -99,22 +99,36 @@ def extract_connectivity_args( return args -def compilation_hash(otf_closure: stages.CompilableProgram) -> int: - """Given closure compute a hash uniquely determining if we need to recompile.""" - offset_provider = otf_closure.args.offset_provider - return hash( - ( - otf_closure.data, - # As the frontend types contain lists they are not hashable. As a workaround we just - # use content_hash here. - content_hash(tuple(arg for arg in otf_closure.args.args)), - # Directly using the `id` of the offset provider is not possible as the decorator adds - # the implicitly defined ones (i.e. to allow the `TDim + 1` syntax) resulting in a - # different `id` every time. Instead use the `id` of each individual offset provider. - tuple((k, id(v)) for (k, v) in offset_provider.items()) if offset_provider else None, - otf_closure.args.column_axis, - ) - ) +class FileCache(diskcache.Cache): + """ + This class extends `diskcache.Cache` to ensure the cache is properly + - opened when accessed by multiple processes using a file lock. This guards the creating of the + cache object, which has been reported to cause `sqlite3.OperationalError: database is locked` + errors and slow startup times when multiple processes access the cache concurrently. While this + issue occurred frequently and was observed to be fixed on distributed file systems, the lock + does not guarantee correct behavior in particular for accesses to the cache (beyond opening) + since the underlying SQLite database is unreliable when stored on an NFS based file system. + It does however ensure correctness of concurrent cache accesses on a local file system. See + #1745 for more details. + - closed upon deletion, i.e. it ensures that any resources associated with the cache are + properly released when the instance is garbage collected. + """ + + def __init__(self, directory: Optional[str | pathlib.Path] = None, **settings: Any) -> None: + if directory: + lock_dir = pathlib.Path(directory).parent + else: + lock_dir = pathlib.Path(tempfile.gettempdir()) + + lock = filelock.FileLock(lock_dir / "file_cache.lock") + with lock: + super().__init__(directory=directory, **settings) + + self._init_complete = True + + def __del__(self) -> None: + if getattr(self, "_init_complete", False): # skip if `__init__` didn't finished + self.close() class GTFNCompileWorkflowFactory(factory.Factory): @@ -130,10 +144,23 @@ class Params: lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) - translation = factory.SubFactory( - gtfn_module.GTFNTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - ) + cached_translation = factory.Trait( + translation=factory.LazyAttribute( + lambda o: workflow.CachedStep( + o.bare_translation, + hash_function=stages.fingerprint_compilable_program, + cache=FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), + ) + ), + ) + + bare_translation = factory.SubFactory( + gtfn_module.GTFNTranslationStepFactory, + device_type=factory.SelfAttribute("..device_type"), + ) + + translation = factory.LazyAttribute(lambda o: o.bare_translation) + bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source ) @@ -158,7 +185,7 @@ class Params: name_postfix = "" gpu = factory.Trait( allocator=next_allocators.StandardGPUFieldBufferAllocator(), - device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA, + device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, name_device="gpu", ) cached = factory.Trait( @@ -167,13 +194,8 @@ class Params: ), name_cached="_cached", ) - use_temporaries = factory.Trait( - otf_workflow__translation__lift_mode=transforms.LiftMode.USE_TEMPORARIES, - otf_workflow__translation__temporary_extraction_heuristics=global_tmps.SimpleTemporaryExtractionHeuristics, - name_temps="_with_temporaries", - ) device_type = core_defs.DeviceType.CPU - hash_function = compilation_hash + hash_function = stages.compilation_hash otf_workflow = factory.SubFactory( GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) @@ -193,10 +215,12 @@ class Params: name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True ) -run_gtfn_cached = GTFNBackendFactory(cached=True) - -run_gtfn_with_temporaries = GTFNBackendFactory(use_temporaries=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) run_gtfn_gpu = GTFNBackendFactory(gpu=True) run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True) + +run_gtfn_no_transforms = GTFNBackendFactory( + otf_workflow__bare_translation__enable_itir_transforms=False +) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 93e6d09c5b..32c3f7a360 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -9,6 +9,7 @@ from __future__ import annotations import dataclasses +import functools import importlib.util import pathlib import tempfile @@ -20,7 +21,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako from gt4py.next import allocators as next_allocators, backend as next_backend, common, config -from gt4py.next.ffront import foast_to_gtir, past_to_itir +from gt4py.next.ffront import foast_to_gtir, foast_to_past, past_to_itir from gt4py.next.iterator import ir as itir, transforms as itir_transforms from gt4py.next.otf import stages, workflow from gt4py.next.type_system import type_specifications as ts @@ -45,7 +46,6 @@ class EmbeddedDSL(codegen.TemplatedGenerator): AxisLiteral = as_fmt("{value}") FunCall = as_fmt("{fun}({','.join(args)})") Lambda = as_mako("(lambda ${','.join(params)}: ${expr})") - StencilClosure = as_mako("closure(${domain}, ${stencil}, ${output}, [${','.join(inputs)}])") FunctionDefinition = as_mako( """ @fundef @@ -90,11 +90,11 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: def fencil_generator( - ir: itir.Node, + ir: itir.Program, debug: bool, - lift_mode: itir_transforms.LiftMode, use_embedded: bool, - offset_provider: dict[str, common.Connectivity | common.Dimension], + offset_provider: common.OffsetProvider, + transforms: itir_transforms.GTIRTransform, ) -> stages.CompiledProgram: """ Generate a directly executable fencil from an ITIR node. @@ -102,7 +102,7 @@ def fencil_generator( Arguments: ir: The iterator IR (ITIR) node. debug: Keep module source containing fencil implementation. - lift_mode: Change the way lifted function calls are evaluated. + extract_temporaries: Extract intermediate field values into temporaries. use_embedded: Directly use builtins from embedded backend instead of generic dispatcher. Gives faster performance and is easier to debug. @@ -110,15 +110,21 @@ def fencil_generator( """ # TODO(tehrengruber): just a temporary solution until we have a proper generic # caching mechanism - cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items()))) + cache_key = hash( + ( + ir, + transforms, + debug, + use_embedded, + tuple(common.offset_provider_to_type(offset_provider).items()), + ) + ) if cache_key in _FENCIL_CACHE: if debug: print(f"Using cached fencil for key {cache_key}") return typing.cast(stages.CompiledProgram, _FENCIL_CACHE[cache_key]) - ir = itir_transforms.apply_common_transforms( - ir, lift_mode=lift_mode, offset_provider=offset_provider - ) + ir = transforms(ir, offset_provider=offset_provider) program = EmbeddedDSL.apply(ir) @@ -152,7 +158,9 @@ def fencil_generator( """ ) - with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as source_file: + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", encoding="utf-8", delete=False + ) as source_file: source_file_name = source_file.name if debug: print(source_file_name) @@ -187,9 +195,9 @@ def fencil_generator( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgram]): debug: Optional[bool] = None - lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None + transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug @@ -198,8 +206,8 @@ def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: inp.data, offset_provider=inp.args.offset_provider, debug=debug, - lift_mode=self.lift_mode, use_embedded=self.use_embedded, + transforms=self.transforms, ) def decorated_fencil( @@ -211,7 +219,7 @@ def decorated_fencil( ) -> None: if out is not None: args = (*args, out) - if not column_axis: + if not column_axis: # TODO(tehrengruber): This variable is never used. Bug? column_axis = inp.args.column_axis fencil( *args, @@ -224,28 +232,47 @@ def decorated_fencil( return decorated_fencil -executor = Roundtrip() -executor_with_temporaries = Roundtrip(lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES) - +# TODO(tehrengruber): introduce factory default = next_backend.Backend( name="roundtrip", - executor=executor, + executor=Roundtrip( + transforms=functools.partial( + itir_transforms.apply_common_transforms, + extract_temporaries=False, + ) + ), allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) with_temporaries = next_backend.Backend( name="roundtrip_with_temporaries", - executor=executor_with_temporaries, + executor=Roundtrip( + transforms=functools.partial( + itir_transforms.apply_common_transforms, + extract_temporaries=True, + ) + ), + allocator=next_allocators.StandardCPUFieldBufferAllocator(), + transforms=next_backend.DEFAULT_TRANSFORMS, +) +no_transforms = next_backend.Backend( + name="roundtrip", + executor=Roundtrip(transforms=lambda o, *, offset_provider: o), allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.DEFAULT_TRANSFORMS, ) + gtir = next_backend.Backend( name="roundtrip_gtir", - executor=executor, + executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # don't understand why mypy complains allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), + past_to_itir=past_to_itir.past_to_gtir_factory(), foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), + field_view_op_to_prog=foast_to_past.operator_to_program_factory( + foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() + ), ), ) +foast_to_gtir_step = foast_to_gtir.adapted_foast_to_gtir_factory(cached=True) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 5bda9a6f2e..bbaaa82728 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -26,6 +26,7 @@ from gt4py.eve.utils import XIterable, xiter from gt4py.next import common +from gt4py.next.iterator.type_system import type_specifications as it_ts from gt4py.next.type_system import type_specifications as ts @@ -78,15 +79,15 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: >>> type_class(ts.TupleType(types=[])).__name__ 'TupleType' """ - match symbol_type: - case ts.DeferredType(constraint): - if constraint is None: - raise ValueError(f"No type information available for '{symbol_type}'.") - elif isinstance(constraint, tuple): - raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") - return constraint - case ts.TypeSpec() as concrete_type: - return concrete_type.__class__ + if isinstance(symbol_type, ts.DeferredType): + constraint = symbol_type.constraint + if constraint is None: + raise ValueError(f"No type information available for '{symbol_type}'.") + elif isinstance(constraint, tuple): + raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") + return constraint + if isinstance(symbol_type, ts.TypeSpec): + return symbol_type.__class__ raise ValueError( f"Invalid type for TypeInfo: requires '{ts.TypeSpec}', got '{type(symbol_type)}'." ) @@ -173,7 +174,7 @@ def apply_to_primitive_constituents( ... with_path_arg=True, ... tuple_constructor=lambda *elements: dict(elements), ... ) - {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} + {(0,): ScalarType(kind=, shape=None), (1,): ScalarType(kind=, shape=None)} """ if isinstance(symbol_types[0], ts.TupleType): assert all(isinstance(symbol_type, ts.TupleType) for symbol_type in symbol_types) @@ -197,7 +198,7 @@ def apply_to_primitive_constituents( return fun(*symbol_types) -def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: +def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: """ Extract the data type from ``symbol_type`` if it is either `FieldType` or `ScalarType`. @@ -234,7 +235,10 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ - return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] + return isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind in [ + ts.ScalarKind.FLOAT32, + ts.ScalarKind.FLOAT64, + ] def is_integer(symbol_type: ts.TypeSpec) -> bool: @@ -251,7 +255,12 @@ def is_integer(symbol_type: ts.TypeSpec) -> bool: False """ return isinstance(symbol_type, ts.ScalarType) and symbol_type.kind in { + ts.ScalarKind.INT8, + ts.ScalarKind.UINT8, + ts.ScalarKind.INT16, + ts.ScalarKind.UINT16, ts.ScalarKind.INT32, + ts.ScalarKind.UINT32, ts.ScalarKind.INT64, } @@ -295,7 +304,10 @@ def is_number(symbol_type: ts.TypeSpec) -> bool: def is_logical(symbol_type: ts.TypeSpec) -> bool: - return extract_dtype(symbol_type).kind is ts.ScalarKind.BOOL + return ( + isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) + and dtype.kind is ts.ScalarKind.BOOL + ) def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: @@ -321,8 +333,14 @@ def arithmetic_bounds(arithmetic_type: ts.ScalarType) -> tuple[np.number, np.num return { # type: ignore[return-value] # why resolved to `tuple[object, object]`? ts.ScalarKind.FLOAT32: (np.finfo(np.float32).min, np.finfo(np.float32).max), ts.ScalarKind.FLOAT64: (np.finfo(np.float64).min, np.finfo(np.float64).max), + ts.ScalarKind.INT8: (np.iinfo(np.int8).min, np.iinfo(np.int8).max), + ts.ScalarKind.UINT8: (np.iinfo(np.uint8).min, np.iinfo(np.uint8).max), + ts.ScalarKind.INT16: (np.iinfo(np.int16).min, np.iinfo(np.int16).max), + ts.ScalarKind.UINT16: (np.iinfo(np.uint16).min, np.iinfo(np.uint16).max), ts.ScalarKind.INT32: (np.iinfo(np.int32).min, np.iinfo(np.int32).max), + ts.ScalarKind.UINT32: (np.iinfo(np.uint32).min, np.iinfo(np.uint32).max), ts.ScalarKind.INT64: (np.iinfo(np.int64).min, np.iinfo(np.int64).max), + ts.ScalarKind.UINT64: (np.iinfo(np.uint64).min, np.iinfo(np.uint64).max), }[arithmetic_type.kind] @@ -385,11 +403,10 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: >>> extract_dims(ts.FieldType(dims=[I, J], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64))) [Dimension(value='I', kind=), Dimension(value='J', kind=)] """ - match symbol_type: - case ts.ScalarType(): - return [] - case ts.FieldType(dims): - return dims + if isinstance(symbol_type, ts.ScalarType): + return [] + if isinstance(symbol_type, ts.FieldType): + return symbol_type.dims raise ValueError(f"Can not extract dimensions from '{symbol_type}'.") @@ -416,6 +433,69 @@ def contains_local_field(type_: ts.TypeSpec) -> bool: ) +# TODO(tehrengruber): This function has specializations on Iterator types, which are not part of +# the general / shared type system. This functionality should be moved to the iterator-only +# type system, but we need some sort of multiple dispatch for that. +# TODO(tehrengruber): Should this have a direction like is_concretizable? +def is_compatible_type(type_a: ts.TypeSpec, type_b: ts.TypeSpec) -> bool: + """ + Predicate to determine if two types are compatible. + + This function gracefully handles: + - iterators with unknown positions which are considered compatible to any other positions + of another iterator. + - iterators which are defined everywhere, i.e. empty defined dimensions + Beside that this function simply checks for equality of types. + + >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) + >>> IDim = common.Dimension(value="IDim") + >>> type_on_i_of_i_it = it_ts.IteratorType( + ... position_dims=[IDim], defined_dims=[IDim], element_type=bool_type + ... ) + >>> type_on_undefined_of_i_it = it_ts.IteratorType( + ... position_dims="unknown", defined_dims=[IDim], element_type=bool_type + ... ) + >>> is_compatible_type(type_on_i_of_i_it, type_on_undefined_of_i_it) + True + + >>> JDim = common.Dimension(value="JDim") + >>> type_on_j_of_j_it = it_ts.IteratorType( + ... position_dims=[JDim], defined_dims=[JDim], element_type=bool_type + ... ) + >>> is_compatible_type(type_on_i_of_i_it, type_on_j_of_j_it) + False + """ + is_compatible = True + + if isinstance(type_a, it_ts.IteratorType) and isinstance(type_b, it_ts.IteratorType): + if not any(el_type.position_dims == "unknown" for el_type in [type_a, type_b]): + is_compatible &= type_a.position_dims == type_b.position_dims + if type_a.defined_dims and type_b.defined_dims: + is_compatible &= type_a.defined_dims == type_b.defined_dims + is_compatible &= type_a.element_type == type_b.element_type + elif isinstance(type_a, ts.TupleType) and isinstance(type_b, ts.TupleType): + if len(type_a.types) != len(type_b.types): + return False + for el_type_a, el_type_b in zip(type_a.types, type_b.types, strict=True): + is_compatible &= is_compatible_type(el_type_a, el_type_b) + elif isinstance(type_a, ts.FunctionType) and isinstance(type_b, ts.FunctionType): + for arg_a, arg_b in zip(type_a.pos_only_args, type_b.pos_only_args, strict=True): + is_compatible &= is_compatible_type(arg_a, arg_b) + for arg_a, arg_b in zip( + type_a.pos_or_kw_args.values(), type_b.pos_or_kw_args.values(), strict=True + ): + is_compatible &= is_compatible_type(arg_a, arg_b) + for arg_a, arg_b in zip( + type_a.kw_only_args.values(), type_b.kw_only_args.values(), strict=True + ): + is_compatible &= is_compatible_type(arg_a, arg_b) + is_compatible &= is_compatible_type(type_a.returns, type_b.returns) + else: + is_compatible &= is_concretizable(type_a, type_b) + + return is_compatible + + def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: """ Check if ``symbol_type`` can be concretized to ``to_type``. @@ -459,7 +539,9 @@ def is_concretizable(symbol_type: ts.TypeSpec, to_type: ts.TypeSpec) -> bool: """ if isinstance(symbol_type, ts.DeferredType) and ( - symbol_type.constraint is None or issubclass(type_class(to_type), symbol_type.constraint) + symbol_type.constraint is None + or (isinstance(to_type, ts.DeferredType) and to_type.constraint is None) + or issubclass(type_class(to_type), symbol_type.constraint) ): return True elif is_concrete(symbol_type): @@ -485,12 +567,11 @@ def promote( >>> promoted.dims == [I, J, K] and promoted.dtype == dtype True - >>> promote( + >>> promoted: ts.FieldType = promote( ... ts.FieldType(dims=[I, J], dtype=dtype), ts.FieldType(dims=[K], dtype=dtype) - ... ) # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - ValueError: Dimensions can not be promoted. Could not determine order of the following dimensions: J, K. + ... ) + >>> promoted.dims == [I, J, K] and promoted.dtype == dtype + True """ if not always_field and all(isinstance(type_, ts.ScalarType) for type_ in types): if not all(type_ == types[0] for type_ in types): @@ -500,7 +581,9 @@ def promote( return types[0] elif all(isinstance(type_, (ts.ScalarType, ts.FieldType)) for type_ in types): dims = common.promote_dims(*(extract_dims(type_) for type_ in types)) - dtype = cast(ts.ScalarType, promote(*(extract_dtype(type_) for type_ in types))) + extracted_dtypes = [extract_dtype(type_) for type_ in types] + assert all(isinstance(dtype, ts.ScalarType) for dtype in extracted_dtypes) + dtype = cast(ts.ScalarType, promote(*extracted_dtypes)) # type: ignore[arg-type] # checked is `ScalarType` return ts.FieldType(dims=dims, dtype=dtype) raise TypeError("Expected a 'FieldType' or 'ScalarType'.") @@ -558,6 +641,7 @@ def return_type_field( new_dims.append(d) else: new_dims.extend(target_dims) + new_dims = common._ordered_dims(new_dims) # e.g. `Vertex, V2E, K` -> `Vertex, K, V2E` return ts.FieldType(dims=new_dims, dtype=field_type.dtype) @@ -705,11 +789,7 @@ def function_signature_incompatibilities_func( for i, (a_arg, b_arg) in enumerate( zip(list(func_type.pos_only_args) + list(func_type.pos_or_kw_args.values()), args) ): - if ( - b_arg is not UNDEFINED_ARG - and a_arg != b_arg - and not is_concretizable(a_arg, to_type=b_arg) - ): + if b_arg is not UNDEFINED_ARG and a_arg != b_arg and not is_compatible_type(a_arg, b_arg): if i < len(func_type.pos_only_args): arg_repr = f"{_number_to_ordinal_number(i + 1)} argument" else: @@ -719,7 +799,7 @@ def function_signature_incompatibilities_func( for kwarg in set(func_type.kw_only_args.keys()) & set(kwargs.keys()): if (a_kwarg := func_type.kw_only_args[kwarg]) != ( b_kwarg := kwargs[kwarg] - ) and not is_concretizable(a_kwarg, to_type=b_kwarg): + ) and not is_compatible_type(a_kwarg, b_kwarg): yield f"Expected keyword argument '{kwarg}' to be of type '{func_type.kw_only_args[kwarg]}', got '{kwargs[kwarg]}'." diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 0827d99cdc..5b46f9dd0d 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,21 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass from typing import Iterator, Optional, Sequence, Union -from gt4py.eve.type_definitions import IntEnum -from gt4py.eve.utils import content_hash -from gt4py.next import common as func_common +from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types +from gt4py.next import common -@dataclass(frozen=True) -class TypeSpec: - def __hash__(self) -> int: - return hash(content_hash(self)) - - def __init_subclass__(cls) -> None: - cls.__hash__ = TypeSpec.__hash__ # type: ignore[method-assign] +class TypeSpec(eve_datamodels.DataModel, kw_only=False, frozen=True): ... # type: ignore[call-arg] class DataType(TypeSpec): @@ -40,14 +32,12 @@ class CallableType: """ -@dataclass(frozen=True) class DeferredType(TypeSpec): """Dummy used to represent a type not yet inferred.""" constraint: Optional[type[TypeSpec] | tuple[type[TypeSpec], ...]] -@dataclass(frozen=True) class VoidType(TypeSpec): """ Return type of a function without return values. @@ -56,30 +46,34 @@ class VoidType(TypeSpec): """ -@dataclass(frozen=True) class DimensionType(TypeSpec): - dim: func_common.Dimension + dim: common.Dimension -@dataclass(frozen=True) class OffsetType(TypeSpec): - source: func_common.Dimension - target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] + # TODO(havogt): replace by ConnectivityType + source: common.Dimension + target: tuple[common.Dimension] | tuple[common.Dimension, common.Dimension] def __str__(self) -> str: return f"Offset[{self.source}, {self.target}]" -class ScalarKind(IntEnum): +class ScalarKind(eve_types.IntEnum): BOOL = 1 - INT32 = 32 - INT64 = 64 - FLOAT32 = 1032 - FLOAT64 = 1064 - STRING = 3001 + INT8 = 2 + UINT8 = 3 + INT16 = 4 + UINT16 = 5 + INT32 = 6 + UINT32 = 7 + INT64 = 8 + UINT64 = 9 + FLOAT32 = 10 + FLOAT64 = 11 + STRING = 12 -@dataclass(frozen=True) class ScalarType(DataType): kind: ScalarKind shape: Optional[list[int]] = None @@ -91,31 +85,49 @@ def __str__(self) -> str: return f"{kind_str}{self.shape}" -@dataclass(frozen=True) -class TupleType(DataType): - types: list[DataType] - - def __str__(self) -> str: - return f"tuple[{', '.join(map(str, self.types))}]" +class ListType(DataType): + """Represents a neighbor list in the ITIR representation. - def __iter__(self) -> Iterator[DataType]: - yield from self.types + Note: not used in the frontend. The concept is represented as Field with local Dimension. + """ - def __len__(self) -> int: - return len(self.types) + element_type: DataType + # TODO(havogt): the `offset_type` is not yet used in type_inference, + # it is meant to describe the neighborhood (via the local dimension) + offset_type: Optional[common.Dimension] = None -@dataclass(frozen=True) class FieldType(DataType, CallableType): - dims: list[func_common.Dimension] - dtype: ScalarType + dims: list[common.Dimension] + dtype: ScalarType | ListType def __str__(self) -> str: dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" return f"Field[{dims}, {self.dtype}]" + @eve_datamodels.validator("dims") + def _dims_validator( + self, attribute: eve_datamodels.Attribute, dims: list[common.Dimension] + ) -> None: + common.check_dims(dims) + + +class TupleType(DataType): + # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously + # introduced before we checked the annotations at runtime. All attributes of + # a type that are types themselves must be concrete. + types: list[DataType | DimensionType | DeferredType] + + def __str__(self) -> str: + return f"tuple[{', '.join(map(str, self.types))}]" + + def __iter__(self) -> Iterator[DataType | DimensionType | DeferredType]: + yield from self.types + + def __len__(self) -> int: + return len(self.types) + -@dataclass(frozen=True) class FunctionType(TypeSpec, CallableType): pos_only_args: Sequence[TypeSpec] pos_or_kw_args: dict[str, TypeSpec] diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 62a6781316..10b82f7861 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -10,7 +10,6 @@ import builtins import collections.abc -import dataclasses import functools import types import typing @@ -42,16 +41,10 @@ def get_scalar_kind(dtype: npt.DTypeLike) -> ts.ScalarKind: match dt: case np.bool_: return ts.ScalarKind.BOOL - case np.int32: - return ts.ScalarKind.INT32 - case np.int64: - return ts.ScalarKind.INT64 - case np.float32: - return ts.ScalarKind.FLOAT32 - case np.float64: - return ts.ScalarKind.FLOAT64 case np.str_: return ts.ScalarKind.STRING + case np.dtype(): + return getattr(ts.ScalarKind, dt.name.upper()) case _: raise ValueError(f"Impossible to map '{dtype}' value to a 'ScalarKind'.") else: @@ -105,7 +98,7 @@ def from_type_hint( raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") tuple_types = [recursive_make_symbol(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) - return ts.TupleType(types=tuple_types) # type: ignore[arg-type] # checked in assert + return ts.TupleType(types=tuple_types) case common.Field: if (n_args := len(args)) != 2: @@ -168,7 +161,6 @@ def from_type_hint( raise ValueError(f"'{type_hint}' type is not supported.") -@dataclasses.dataclass(frozen=True) class UnknownPythonObject(ts.TypeSpec): _object: Any @@ -217,9 +209,9 @@ def from_value(value: Any) -> ts.TypeSpec: # not needed anymore. elems = [from_value(el) for el in value] assert all(isinstance(elem, ts.DataType) for elem in elems) - return ts.TupleType(types=elems) # type: ignore[arg-type] # checked in assert + return ts.TupleType(types=elems) elif isinstance(value, types.ModuleType): - return UnknownPythonObject(_object=value) + return UnknownPythonObject(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) symbol_type = from_type_hint(type_) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 44fa929e56..f1a82c6bd9 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -68,7 +68,12 @@ def flatten_nested_tuple( @overload -def tree_map(fun: Callable[_P, _R], /) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... +def tree_map( + fun: Callable[_P, _R], + *, + collection_type: type | tuple[type, ...] = tuple, + result_collection_constructor: Optional[type | Callable] = None, +) -> Callable[..., _R | tuple[_R | tuple, ...]]: ... @overload @@ -82,7 +87,8 @@ def tree_map( def tree_map( - *args: Callable[_P, _R], + fun: Optional[Callable[_P, _R]] = None, + *, collection_type: type | tuple[type, ...] = tuple, result_collection_constructor: Optional[type | Callable] = None, ) -> Callable[..., _R | tuple[_R | tuple, ...]] | Callable[[Callable[_P, _R]], Callable[..., Any]]: @@ -108,6 +114,12 @@ def tree_map( ... [[1, 2], 3] ... ) ((2, 3), 4) + + >>> @tree_map + ... def impl(x): + ... return x + 1 + >>> impl(((1, 2), 3)) + ((2, 3), 4) """ if result_collection_constructor is None: @@ -117,8 +129,7 @@ def tree_map( ) result_collection_constructor = collection_type - if len(args) == 1: - fun = args[0] + if fun: @functools.wraps(fun) def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: @@ -129,17 +140,14 @@ def impl(*args: Any | tuple[Any | tuple, ...]) -> _R | tuple[_R | tuple, ...]: assert result_collection_constructor is not None return result_collection_constructor(impl(*arg) for arg in zip(*args)) - return fun( + return fun( # type: ignore[call-arg, misc] # mypy not smart enough *cast(_P.args, args) ) # mypy doesn't understand that `args` at this point is of type `_P.args` return impl - if len(args) == 0: + else: return functools.partial( tree_map, collection_type=collection_type, result_collection_constructor=result_collection_constructor, ) - raise TypeError( - "tree_map() can be used as decorator with optional kwarg `collection_type` and `result_collection_constructor`." - ) diff --git a/src/gt4py/storage/__init__.py b/src/gt4py/storage/__init__.py index 4866cd480c..5986baa65e 100644 --- a/src/gt4py/storage/__init__.py +++ b/src/gt4py/storage/__init__.py @@ -16,12 +16,12 @@ __all__ = [ "cartesian", - "layout", "empty", "from_array", + "from_name", "full", + "layout", "ones", - "zeros", - "from_name", "register", + "zeros", ] diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 298b9c2e5a..e2311e3e60 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -211,9 +211,10 @@ def allocate( # Compute the padding required in the contiguous dimension to get aligned blocks dims_layout = [layout_map.index(i) for i in range(len(shape))] - padded_shape_lst = list(shape) + # Convert shape size to same data type (note that `np.int16` can overflow) + padded_shape_lst = [np.int32(x) for x in shape] if ndim > 0: - padded_shape_lst[dims_layout[-1]] = ( + padded_shape_lst[dims_layout[-1]] = ( # type: ignore[call-overload] math.ceil(shape[dims_layout[-1]] / items_per_aligned_block) * items_per_aligned_block ) diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 50500e536b..2275c1cd57 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -12,35 +12,23 @@ import functools import math import numbers -from typing import Any, Final, Literal, Optional, Sequence, Tuple, Union, cast +from typing import Literal, Optional, Sequence, Tuple, Union, cast import numpy as np import numpy.typing as npt +from numpy.typing import DTypeLike from gt4py._core import definitions as core_defs -from gt4py.cartesian import config as gt_config from gt4py.eve.extended_typing import ArrayInterface, CUDAArrayInterface from gt4py.storage import allocators -if np.lib.NumpyVersion(np.__version__) >= "1.20.0": - from numpy.typing import DTypeLike -else: - DTypeLike = Any # type: ignore[misc] # assign multiple types in both branches - try: import cupy as cp except ImportError: cp = None -CUPY_DEVICE: Final[Literal[None, core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM]] = ( - None - if not cp - else (core_defs.DeviceType.ROCM if cp.cuda.get_hipcc_path() else core_defs.DeviceType.CUDA) -) - - FieldLike = Union["cp.ndarray", np.ndarray, ArrayInterface, CUDAArrayInterface] _CPUBufferAllocator = allocators.NDArrayBufferAllocator( @@ -51,12 +39,12 @@ if cp: assert isinstance(allocators.cupy_array_utils, allocators.ArrayUtils) - if CUPY_DEVICE == core_defs.DeviceType.CUDA: + if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.CUDA: _GPUBufferAllocator = allocators.NDArrayBufferAllocator( device_type=core_defs.DeviceType.CUDA, array_utils=allocators.cupy_array_utils, ) - elif CUPY_DEVICE == core_defs.DeviceType.ROCM: + elif core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: _GPUBufferAllocator = allocators.NDArrayBufferAllocator( device_type=core_defs.DeviceType.ROCM, array_utils=allocators.cupy_array_utils, @@ -203,7 +191,7 @@ def asarray( elif not device: if hasattr(array, "__dlpack_device__"): kind, _ = array.__dlpack_device__() - if kind in [core_defs.DeviceType.CPU, core_defs.DeviceType.CPU_PINNED]: + if kind in [core_defs.DeviceType.CPU]: xp = np elif kind in [ core_defs.DeviceType.CUDA, @@ -270,9 +258,10 @@ def _allocate_gpu( ) -> Tuple["cp.ndarray", "cp.ndarray"]: assert cp is not None assert _GPUBufferAllocator is not None, "GPU allocation library or device not found" + if core_defs.CUPY_DEVICE_TYPE is None: + raise ValueError("CUPY_DEVICE_TYPE detection failed.") device = core_defs.Device( # type: ignore[type-var] - (core_defs.DeviceType.ROCM if gt_config.GT4PY_USE_HIP else core_defs.DeviceType.CUDA), - 0, + core_defs.CUPY_DEVICE_TYPE, 0 ) buffer = _GPUBufferAllocator.allocate( shape, @@ -290,7 +279,7 @@ def _allocate_gpu( allocate_gpu = _allocate_gpu -if CUPY_DEVICE == core_defs.DeviceType.ROCM: +if core_defs.CUPY_DEVICE_TYPE == core_defs.DeviceType.ROCM: class CUDAArrayInterfaceNDArray(cp.ndarray): def __new__(cls, input_array: "cp.ndarray") -> CUDAArrayInterfaceNDArray: diff --git a/tach.toml b/tach.toml index 7861ed1fe6..78541c5dff 100644 --- a/tach.toml +++ b/tach.toml @@ -3,7 +3,9 @@ source_roots = [ "src", ] exact = true -forbid_circular_dependencies = true +# forbid_circular_dependencies = true +# TODO(egparedes): try to solve the circular dependencies between +# gt4py.cartesian and gt4py.storage [[modules]] path = "gt4py._core" @@ -14,6 +16,7 @@ depends_on = [ [[modules]] path = "gt4py.cartesian" depends_on = [ + { path = "gt4py._core" }, { path = "gt4py.eve" }, { path = "gt4py.storage" }, ] @@ -34,6 +37,5 @@ depends_on = [ path = "gt4py.storage" depends_on = [ { path = "gt4py._core" }, - { path = "gt4py.cartesian" }, # for backward-compatibility the cartesian allocators are in `gt4py.storage` { path = "gt4py.eve" }, ] diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 7499ad4a95..38cb6caca8 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -14,7 +14,6 @@ cp = None import datetime - import numpy as np import pytest @@ -22,7 +21,7 @@ from gt4py.cartesian import utils as gt_utils -def _backend_name_as_param(name): +def _backend_name_as_param(name: str): marks = [] if gt4pyc.backend.from_name(name).storage_info["device"] == "gpu": marks.append(pytest.mark.requires_gpu) @@ -48,8 +47,9 @@ def _get_backends_with_storage_info(storage_info_kind: str): GPU_BACKENDS = _get_backends_with_storage_info("gpu") ALL_BACKENDS = CPU_BACKENDS + GPU_BACKENDS -_PERFORMANCE_BACKEND_NAMES = [name for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda")] -PERFORMANCE_BACKENDS = [_backend_name_as_param(name) for name in _PERFORMANCE_BACKEND_NAMES] +PERFORMANCE_BACKENDS = [ + _backend_name_as_param(name) for name in _ALL_BACKEND_NAMES if name not in ("numpy", "cuda") +] @pytest.fixture() diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py index c1b4e58f97..c3bf40e456 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py @@ -26,8 +26,8 @@ def test_numpy_allocators(backend, order): xp = get_array_library(backend) shape = (20, 10, 5) - inp = xp.array(xp.random.randn(*shape), order=order, dtype=xp.float_) - outp = xp.zeros(shape=shape, order=order, dtype=xp.float_) + inp = xp.array(xp.random.randn(*shape), order=order, dtype=xp.float64) + outp = xp.zeros(shape=shape, order=order, dtype=xp.float64) stencil = gtscript.stencil(definition=copy_stencil, backend=backend) stencil(field_a=inp, field_b=outp) @@ -43,8 +43,8 @@ def test_bad_layout_warns(backend): shape = (10, 10, 10) - inp = xp.array(xp.random.randn(*shape), dtype=xp.float_) - outp = gt_storage.zeros(backend=backend, shape=shape, dtype=xp.float_, aligned_index=(0, 0, 0)) + inp = xp.array(xp.random.randn(*shape), dtype=xp.float64) + outp = gt_storage.zeros(backend=backend, shape=shape, dtype=xp.float64, aligned_index=(0, 0, 0)) # set up non-optimal storage layout: if backend_cls.storage_info["is_optimal_layout"](inp, "IJK"): diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 1a8cfef695..217c0ee488 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -13,6 +13,8 @@ __INLINED, BACKWARD, FORWARD, + I, + J, PARALLEL, acos, acosh, @@ -28,6 +30,7 @@ exp, floor, gamma, + horizontal, interval, isfinite, isinf, @@ -35,6 +38,7 @@ log, log10, mod, + region, sin, sinh, sqrt, @@ -57,7 +61,7 @@ def _register_decorator(actual_func): return _register_decorator(func) if func else _register_decorator -Field3D = gtscript.Field[np.float_] +Field3D = gtscript.Field[np.float64] Field3DBool = gtscript.Field[np.bool_] @@ -291,7 +295,8 @@ def large_k_interval(in_field: Field3D, out_field: Field3D): with computation(PARALLEL): with interval(0, 6): out_field = in_field - with interval(6, -10): # this stage will only run if field has more than 16 elements + # this stenicl is only legal to call with fields that have more than 16 elements + with interval(6, -10): out_field = in_field + 1 with interval(-10, None): out_field = in_field @@ -402,3 +407,23 @@ def two_optional_fields( out_a = out_a + dt * phys_tend_a if __INLINED(PHYS_TEND_B): out_b = out_b + dt * phys_tend_b + + +@register +def horizontal_regions(field_in: Field3D, field_out: Field3D): + with computation(PARALLEL), interval(...): + with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]): + field_out = field_in + 1.0 + + with horizontal(region[I[0] : I[2], J[-3] : J[-1]], region[I[-3] : I[-1], J[0] : J[2]]): + field_out = field_in - 1.0 + + +@register +def horizontal_region_with_conditional(field_in: Field3D, field_out: Field3D): + with computation(PARALLEL), interval(...): + with horizontal(region[I[0] : I[2], J[0] : J[2]], region[I[-3] : I[-1], J[-3] : J[-1]]): + if field_in > 0: + field_out = field_in + 1.0 + else: + field_out = 0 diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index 398e312af3..c2b82e4bac 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -34,6 +34,9 @@ ) +rng = np.random.default_rng(1337) + + @pytest.mark.parametrize("name", stencil_definitions) @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_generation(name, backend): @@ -52,22 +55,22 @@ def test_generation(name, backend): ) else: args[k] = v(1.5) - # vertical domain size >= 16 required for test_large_k_interval - stencil(**args, origin=(10, 10, 5), domain=(3, 3, 16)) + # vertical domain size > 16 required for test_large_k_interval + stencil(**args, origin=(10, 10, 5), domain=(3, 3, 17)) @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_lazy_stencil(backend): @gtscript.lazy_stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.float_]): + def definition(field_a: Field[np.float64], field_b: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): - field_a = field_b + field_a[0, 0, 0] = field_b @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_temporary_field_declared_in_if(backend): @gtscript.stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float_]): + def definition(field_a: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): if field_a < 0: field_b = -field_a @@ -79,19 +82,19 @@ def definition(field_a: gtscript.Field[np.float_]): @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_stage_without_effect(backend): @gtscript.stencil(backend=backend) - def definition(field_a: gtscript.Field[np.float_]): + def definition(field_a: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): - field_c = 0.0 + field_c = 0.0 # noqa: F841 def test_ignore_np_errstate(): def setup_and_run(backend, **kwargs): field_a = gt_storage.zeros( - dtype=np.float_, backend=backend, shape=(3, 3, 1), aligned_index=(0, 0, 0) + dtype=np.float64, backend=backend, shape=(3, 3, 1), aligned_index=(0, 0, 0) ) @gtscript.stencil(backend=backend, **kwargs) - def divide_by_zero(field_a: gtscript.Field[np.float_]): + def divide_by_zero(field_a: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): field_a = 1.0 / field_a @@ -106,22 +109,22 @@ def divide_by_zero(field_a: gtscript.Field[np.float_]): @pytest.mark.parametrize("backend", CPU_BACKENDS) def test_stencil_without_effect(backend): - def definition1(field_in: gtscript.Field[np.float_]): + def definition1(field_in: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): - tmp = 0.0 + tmp = 0.0 # noqa: F841 - def definition2(f_in: gtscript.Field[np.float_]): - from __externals__ import flag + def definition2(f_in: Field[np.float64]): # type: ignore + from __externals__ import flag # type: ignore with computation(PARALLEL), interval(...): if __INLINED(flag): - B = f_in + B = f_in # noqa: F841 stencil1 = gtscript.stencil(backend, definition1) stencil2 = gtscript.stencil(backend, definition2, externals={"flag": False}) field_in = gt_storage.ones( - dtype=np.float_, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) + dtype=np.float64, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) ) # test with explicit domain specified @@ -135,14 +138,14 @@ def definition2(f_in: gtscript.Field[np.float_]): @pytest.mark.parametrize("backend", CPU_BACKENDS) def test_stage_merger_induced_interval_block_reordering(backend): field_in = gt_storage.ones( - dtype=np.float_, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) + dtype=np.float64, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) ) field_out = gt_storage.zeros( - dtype=np.float_, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) + dtype=np.float64, backend=backend, shape=(23, 23, 23), aligned_index=(0, 0, 0) ) @gtscript.stencil(backend=backend) - def stencil(field_in: gtscript.Field[np.float_], field_out: gtscript.Field[np.float_]): + def stencil(field_in: Field[np.float64], field_out: Field[np.float64]): # type: ignore with computation(BACKWARD): with interval(-2, -1): # block 1 field_out = field_in @@ -152,7 +155,7 @@ def stencil(field_in: gtscript.Field[np.float_], field_out: gtscript.Field[np.fl with interval(-1, None): # block 3 field_out = 2 * field_in with interval(0, -1): # block 4 - field_out = 3 * field_in + field_out[0, 0, 0] = 3 * field_in stencil(field_in, field_out) @@ -164,9 +167,9 @@ def stencil(field_in: gtscript.Field[np.float_], field_out: gtscript.Field[np.fl def test_lower_dimensional_inputs(backend): @gtscript.stencil(backend=backend) def stencil( - field_3d: gtscript.Field[gtscript.IJK, np.float_], - field_2d: gtscript.Field[gtscript.IJ, np.float_], - field_1d: gtscript.Field[gtscript.K, np.float_], + field_3d: Field[gtscript.IJK, np.float64], # type: ignore + field_2d: Field[gtscript.IJ, np.float64], # type: ignore + field_1d: Field[gtscript.K, np.float64], # type: ignore ): with computation(PARALLEL): with interval(0, -1): @@ -178,7 +181,7 @@ def stencil( with interval(0, 1): field_3d = tmp[1, 0, 0] + field_1d[1] with interval(1, None): - field_3d = tmp[-1, 0, 0] + field_3d[0, 0, 0] = tmp[-1, 0, 0] full_shape = (6, 6, 6) aligned_index = (1, 1, 0) @@ -219,17 +222,17 @@ def stencil( def test_lower_dimensional_masked(backend): @gtscript.stencil(backend=backend) def copy_2to3( - cond: gtscript.Field[gtscript.IJK, np.float_], - inp: gtscript.Field[gtscript.IJ, np.float_], - outp: gtscript.Field[gtscript.IJK, np.float_], + cond: Field[gtscript.IJK, np.float64], # type: ignore + inp: Field[gtscript.IJ, np.float64], # type: ignore + outp: Field[gtscript.IJK, np.float64], # type: ignore ): with computation(PARALLEL), interval(...): if cond > 0.0: - outp = inp + outp[0, 0, 0] = inp - inp = np.random.randn(10, 10) - outp = np.random.randn(10, 10, 10) - cond = np.random.randn(10, 10, 10) + inp = rng.standard_normal((10, 10)) + outp = rng.standard_normal((10, 10, 10)) + cond = rng.standard_normal((10, 10, 10)) inp_f = gt_storage.from_array(inp, aligned_index=(0, 0), backend=backend) outp_f = gt_storage.from_array(outp, aligned_index=(0, 0, 0), backend=backend) @@ -250,17 +253,17 @@ def copy_2to3( def test_lower_dimensional_masked_2dcond(backend): @gtscript.stencil(backend=backend) def copy_2to3( - cond: gtscript.Field[gtscript.IJK, np.float_], - inp: gtscript.Field[gtscript.IJ, np.float_], - outp: gtscript.Field[gtscript.IJK, np.float_], + cond: Field[gtscript.IJK, np.float64], # type: ignore + inp: Field[gtscript.IJ, np.float64], # type: ignore + outp: Field[gtscript.IJK, np.float64], # type: ignore ): with computation(FORWARD), interval(...): if cond > 0.0: - outp = inp + outp[0, 0, 0] = inp - inp = np.random.randn(10, 10) - outp = np.random.randn(10, 10, 10) - cond = np.random.randn(10, 10, 10) + inp = rng.standard_normal((10, 10)) + outp = rng.standard_normal((10, 10, 10)) + cond = rng.standard_normal((10, 10, 10)) inp_f = gt_storage.from_array(inp, aligned_index=(0, 0), backend=backend) outp_f = gt_storage.from_array(outp, aligned_index=(0, 0, 0), backend=backend) @@ -281,15 +284,17 @@ def copy_2to3( def test_lower_dimensional_inputs_2d_to_3d_forward(backend): @gtscript.stencil(backend=backend) def copy_2to3( - inp: gtscript.Field[gtscript.IJ, np.float_], - outp: gtscript.Field[gtscript.IJK, np.float_], + inp: Field[gtscript.IJ, np.float64], # type: ignore + outp: Field[gtscript.IJK, np.float64], # type: ignore ): with computation(FORWARD), interval(...): outp[0, 0, 0] = inp - inp_f = gt_storage.from_array(np.random.randn(10, 10), aligned_index=(0, 0), backend=backend) + inp_f = gt_storage.from_array( + rng.standard_normal((10, 10)), aligned_index=(0, 0), backend=backend + ) outp_f = gt_storage.from_array( - np.random.randn(10, 10, 10), aligned_index=(0, 0, 0), backend=backend + rng.standard_normal((10, 10, 10)), aligned_index=(0, 0, 0), backend=backend ) copy_2to3(inp_f, outp_f) inp_f = storage_utils.cpu_copy(inp_f) @@ -304,12 +309,12 @@ def test_higher_dimensional_fields(backend): @gtscript.stencil(backend=backend) def stencil( - field: gtscript.Field[np.float64], - vec_field: gtscript.Field[FLOAT64_VEC2], - mat_field: gtscript.Field[FLOAT64_MAT22], + field: Field[np.float64], # type: ignore + vec_field: Field[FLOAT64_VEC2], # type: ignore + mat_field: Field[FLOAT64_MAT22], # type: ignore ): with computation(PARALLEL), interval(...): - tmp = vec_field[0, 0, 0][0] + vec_field[0, 0, 0][1] + tmp = vec_field[0, 0, 0][0] + vec_field[0, 0, 0][1] # noqa: F841 with computation(FORWARD): with interval(0, 1): @@ -356,51 +361,45 @@ def stencil( def test_input_order(backend): @gtscript.stencil(backend=backend) def stencil( - in_field: gtscript.Field[np.float64], + in_field: Field[np.float64], # type: ignore parameter: np.float64, - out_field: gtscript.Field[np.float64], + out_field: Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): - out_field = in_field * parameter + out_field[0, 0, 0] = in_field * parameter @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets(backend): - if backend == "dace:cpu": - pytest.skip("Internal compiler error in GitHub action container") - @gtscript.stencil(backend=backend) def stencil_ij( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], - index_field: gtscript.Field[gtscript.IJ, int], + in_field: Field[np.float64], # type: ignore + out_field: Field[np.float64], # type: ignore + index_field: Field[gtscript.IJ, int], # type: ignore ): with computation(FORWARD), interval(...): - out_field = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] + out_field[0, 0, 0] = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] index_field = index_field + 1 @gtscript.stencil(backend=backend) def stencil_ijk( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], - index_field: gtscript.Field[int], + in_field: Field[np.float64], # type: ignore + out_field: Field[np.float64], # type: ignore + index_field: Field[int], # type: ignore ): with computation(PARALLEL), interval(...): - out_field = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] + out_field[0, 0, 0] = in_field[0, 0, 1] + in_field[0, 0, index_field + 1] @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_variable_offsets_and_while_loop(backend): - if backend == "dace:cpu": - pytest.skip("Internal compiler error in GitHub action container") - @gtscript.stencil(backend=backend) def stencil( - pe1: gtscript.Field[np.float_], - pe2: gtscript.Field[np.float_], - qin: gtscript.Field[np.float_], - qout: gtscript.Field[np.float_], - lev: gtscript.Field[gtscript.IJ, np.int_], + pe1: Field[np.float64], # type: ignore + pe2: Field[np.float64], # type: ignore + qin: Field[np.float64], # type: ignore + qout: Field[np.float64], # type: ignore + lev: Field[gtscript.IJ, np.int_], # type: ignore ): with computation(FORWARD), interval(0, -1): if pe2[0, 0, 1] <= pe1[0, 0, lev]: @@ -410,13 +409,13 @@ def stencil( while pe1[0, 0, lev + 1] < pe2[0, 0, 1]: qsum += qin[0, 0, lev] / (pe2[0, 0, 1] - pe1[0, 0, lev]) lev = lev + 1 - qout = qsum / (pe2[0, 0, 1] - pe2) + qout[0, 0, 0] = qsum / (pe2[0, 0, 1] - pe2) @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_nested_while_loop(backend): @gtscript.stencil(backend=backend) - def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_]): + def stencil(field_a: Field[np.float64], field_b: Field[np.int_]): # type: ignore with computation(PARALLEL), interval(...): while field_a < 1: add = 0 @@ -427,14 +426,14 @@ def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_] @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_mask_with_offset_written_in_conditional(backend): - @gtscript.stencil(backend, externals={"mord": 5}) - def stencil(outp: gtscript.Field[np.float_]): + @gtscript.stencil(backend) + def stencil(outp: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): cond = True if cond[0, -1, 0] or cond[0, 0, 0]: outp = 1.0 else: - outp = 0.0 + outp[0, 0, 0] = 0.0 outp = gt_storage.zeros( shape=(10, 10, 10), backend=backend, aligned_index=(0, 0, 0), dtype=float @@ -451,8 +450,8 @@ def test_write_data_dim_indirect_addressing(backend): INT32_VEC2 = (np.int32, (2,)) def stencil( - input_field: gtscript.Field[gtscript.IJK, np.int32], - output_field: gtscript.Field[gtscript.IJK, INT32_VEC2], + input_field: Field[gtscript.IJK, np.int32], # type: ignore + output_field: Field[gtscript.IJK, INT32_VEC2], # type: ignore index: int, ): with computation(PARALLEL), interval(...): @@ -476,12 +475,12 @@ def test_read_data_dim_indirect_addressing(backend): INT32_VEC2 = (np.int32, (2,)) def stencil( - input_field: gtscript.Field[gtscript.IJK, INT32_VEC2], - output_field: gtscript.Field[gtscript.IJK, np.int32], + input_field: Field[gtscript.IJK, INT32_VEC2], # type: ignore + output_field: Field[gtscript.IJK, np.int32], # type: ignore index: int, ): with computation(PARALLEL), interval(...): - output_field = input_field[0, 0, 0][index] + output_field[0, 0, 0] = input_field[0, 0, 0][index] aligned_index = (0, 0, 0) full_shape = (1, 1, 2) @@ -501,11 +500,11 @@ class TestNegativeOrigin: def test_negative_origin_i(self, backend): @gtscript.stencil(backend=backend) def stencil_i( - input_field: gtscript.Field[gtscript.IJK, np.int32], - output_field: gtscript.Field[gtscript.IJK, np.int32], + input_field: Field[gtscript.IJK, np.int32], # type: ignore + output_field: Field[gtscript.IJK, np.int32], # type: ignore ): with computation(PARALLEL), interval(...): - output_field = input_field[1, 0, 0] + output_field[0, 0, 0] = input_field[1, 0, 0] input_field = gt_storage.ones( backend=backend, aligned_index=(0, 0, 0), shape=(1, 1, 1), dtype=np.int32 @@ -520,11 +519,11 @@ def stencil_i( def test_negative_origin_k(self, backend): @gtscript.stencil(backend=backend) def stencil_k( - input_field: gtscript.Field[gtscript.IJK, np.int32], - output_field: gtscript.Field[gtscript.IJK, np.int32], + input_field: Field[gtscript.IJK, np.int32], # type: ignore + output_field: Field[gtscript.IJK, np.int32], # type: ignore ): with computation(PARALLEL), interval(...): - output_field = input_field[0, 0, 1] + output_field[0, 0, 0] = input_field[0, 0, 1] input_field = gt_storage.ones( backend=backend, aligned_index=(0, 0, 0), shape=(1, 1, 1), dtype=np.int32 @@ -540,9 +539,9 @@ def stencil_k( @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_origin_k_fields(backend): @gtscript.stencil(backend=backend, rebuild=True) - def k_to_ijk(outp: Field[np.float64], inp: Field[gtscript.K, np.float64]): + def k_to_ijk(outp: Field[np.float64], inp: Field[gtscript.K, np.float64]): # type: ignore with computation(PARALLEL), interval(...): - outp = inp + outp[0, 0, 0] = inp origin = {"outp": (0, 0, 1), "inp": (2,)} domain = (2, 2, 8) @@ -568,11 +567,11 @@ def k_to_ijk(outp: Field[np.float64], inp: Field[gtscript.K, np.float64]): @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_pruned_args_match(backend): @gtscript.stencil(backend=backend) - def test(out: Field[np.float64], inp: Field[np.float64]): + def test(out: Field[np.float64], inp: Field[np.float64]): # type: ignore with computation(PARALLEL), interval(...): out = 0.0 with horizontal(region[I[0] - 1, J[0] - 1]): - out = inp + out[0, 0, 0] = inp inp = gt_storage.zeros( backend=backend, aligned_index=(0, 0, 0), shape=(2, 2, 2), dtype=np.float64 @@ -588,13 +587,11 @@ def test_K_offset_write(backend): # Cuda generates bad code for the K offset if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend in ["gt:gpu", "dace:gpu"]: - import cupy as cp - if cp.cuda.runtime.runtimeGetVersion() < 12000: - pytest.skip( - f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" - ) + if backend in ["gt:gpu", "dace:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1684" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -604,7 +601,7 @@ def test_K_offset_write(backend): # A is untouched # B is written in K+1 and should have K_values, except for the first element (FORWARD) @gtscript.stencil(backend=backend) - def simple(A: Field[np.float64], B: Field[np.float64]): + def simple(A: Field[np.float64], B: Field[np.float64]): # type: ignore with computation(FORWARD), interval(...): B[0, 0, 1] = A @@ -623,7 +620,7 @@ def simple(A: Field[np.float64], B: Field[np.float64]): # means while A is update B will have non-updated values of A # Because of the interval, value of B[0] is 0 @gtscript.stencil(backend=backend) - def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # type: ignore with computation(FORWARD), interval(1, None): A[0, 0, -1] = scalar B[0, 0, 0] = A @@ -644,7 +641,7 @@ def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # Order of operations: BACKWARD with negative offset # means A is update B will get the updated values of A @gtscript.stencil(backend=backend) - def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # type: ignore with computation(BACKWARD), interval(1, None): A[0, 0, -1] = scalar B[0, 0, 0] = A @@ -666,21 +663,14 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): def test_K_offset_write_conditional(backend): if backend == "cuda": pytest.skip("Cuda backend is not capable of K offset write") - if backend in ["gt:gpu", "dace:gpu"]: - import cupy as cp - - if cp.cuda.runtime.runtimeGetVersion() < 12000: - pytest.skip( - f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" - ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) K_values = arraylib.arange(start=40, stop=44) @gtscript.stencil(backend=backend) - def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): - with computation(BACKWARD), interval(1, None): + def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): # type: ignore + with computation(BACKWARD), interval(1, -1): if A > 0 and B > 0: A[0, 0, -1] = scalar B[0, 0, 1] = A @@ -698,6 +688,42 @@ def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scala backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 ) column_physics_conditional(A, B, 2.0) + # Manual unroll of the above + # Starts with + # - A[...] = [40, 41, 42, 43] + # - B[...] = [1, 1, 1, 1] + # Now in-stencil + # ITERATION k = 2 of [2:1] + # if condition + # - A[2] == 42 && B[2] == 1 => True + # - A[1] = 2.0 + # - B[3] = A[2] = 42 + # while + # - lev = 1 + # - A[2] == 42 && B[2] == 1 => True + # - A[3] = -1 + # - B[2] = -1 + # - lev = 2 + # - A[2] == 42 && B[2] == -1 => False + # End of iteration state + # - A[...] = A[40, 2.0, 2.0, -1] + # - B[...] = A[1, 1, -1, 42] + # ITERATION k = 1 of [2:1] + # if condition + # - A[1] == 2.0 && B[1] == 1 => True + # - A[0] = 2.0 + # - B[2] = A[1] = 2.0 + # while + # - lev = 1 + # - A[1] == 2.0 && B[1] == 1 => True + # - A[2] = -1 + # - B[1] = -1 + # - lev = 2 + # - A[1] == 2.0 && B[2] == -1 => False + # End of stencil state + # - A[...] = A[2.0, 2.0, -1, -1] + # - B[...] = A[1, -1, 2.0, 42] + assert (A[0, 0, :] == arraylib.array([2, 2, -1, -1])).all() assert (B[0, 0, :] == arraylib.array([1, -1, 2, 42])).all() @@ -707,9 +733,9 @@ def test_direct_datadims_index(backend): F64_VEC4 = (np.float64, (2, 2, 2, 2)) @gtscript.stencil(backend=backend) - def test(out: Field[np.float64], inp: GlobalTable[F64_VEC4]): + def test(out: Field[np.float64], inp: GlobalTable[F64_VEC4]): # type: ignore with computation(PARALLEL), interval(...): - out = inp.A[1, 0, 1, 0] + out[0, 0, 0] = inp.A[1, 0, 1, 0] inp = gt_storage.ones(backend=backend, shape=(2, 2, 2, 2), dtype=np.float64) inp[1, 0, 1, 0] = 42 @@ -726,8 +752,8 @@ def add_42(v): @gtscript.stencil(backend=backend) def test( - in_field: Field[np.float64], - out_field: Field[np.float64], + in_field: Field[np.float64], # type: ignore + out_field: Field[np.float64], # type: ignore ): with computation(PARALLEL), interval(...): count = 1 @@ -741,3 +767,94 @@ def test( out_arr = gt_storage.ones(backend=backend, shape=domain, dtype=np.float64) test(in_arr, out_arr) assert (out_arr[:, :, :] == 388.0).all() + + +def _xfail_dace_backends(param): + if param.values[0].startswith("dace:"): + marks = [ + *param.marks, + pytest.mark.xfail( + raises=ValueError, + reason="Missing support in DaCe backends, see https://github.com/GridTools/gt4py/issues/1881.", + ), + ] + # make a copy because otherwise we are operating in-place + return pytest.param(*param.values, marks=marks) + return param + + +@pytest.mark.parametrize("backend", map(_xfail_dace_backends, ALL_BACKENDS)) +def test_cast_in_index(backend): + @gtscript.stencil(backend) + def cast_in_index( + in_field: Field[np.float64], # type: ignore + i32: np.int32, + i64: np.int64, + out_field: Field[np.float64], # type: ignore + ): + """Simple copy stencil with forced cast in index calculation.""" + with computation(PARALLEL), interval(...): + out_field[0, 0, 0] = in_field[0, 0, i32 - i64] + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_read_after_write_stencil(backend): + """Stencil with multiple read after write access patterns.""" + + @gtscript.stencil(backend=backend) + def lagrangian_contributions( + q: Field[np.float64], # type: ignore + pe1: Field[np.float64], # type: ignore + pe2: Field[np.float64], # type: ignore + q4_1: Field[np.float64], # type: ignore + q4_2: Field[np.float64], # type: ignore + q4_3: Field[np.float64], # type: ignore + q4_4: Field[np.float64], # type: ignore + dp1: Field[np.float64], # type: ignore + lev: Field[gtscript.IJ, np.int64], # type: ignore + ): + """ + Args: + q (out): + pe1 (in): + pe2 (in): + q4_1 (in): + q4_2 (in): + q4_3 (in): + q4_4 (in): + dp1 (in): + lev (inout): + """ + with computation(FORWARD), interval(...): + pl = (pe2 - pe1[0, 0, lev]) / dp1[0, 0, lev] + if pe2[0, 0, 1] <= pe1[0, 0, lev + 1]: + pr = (pe2[0, 0, 1] - pe1[0, 0, lev]) / dp1[0, 0, lev] + q[0, 0, 0] = ( + q4_2[0, 0, lev] + + 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (pr + pl) + - q4_4[0, 0, lev] * 1.0 / 3.0 * (pr * (pr + pl) + pl * pl) + ) + else: + qsum = (pe1[0, 0, lev + 1] - pe2) * ( + q4_2[0, 0, lev] + + 0.5 * (q4_4[0, 0, lev] + q4_3[0, 0, lev] - q4_2[0, 0, lev]) * (1.0 + pl) + - q4_4[0, 0, lev] * 1.0 / 3.0 * (1.0 + pl * (1.0 + pl)) + ) + lev = lev + 1 + while pe1[0, 0, lev + 1] < pe2[0, 0, 1]: + qsum += dp1[0, 0, lev] * q4_1[0, 0, lev] + lev = lev + 1 + dp = pe2[0, 0, 1] - pe1[0, 0, lev] + esl = dp / dp1[0, 0, lev] + qsum += dp * ( + q4_2[0, 0, lev] + + 0.5 + * esl + * ( + q4_3[0, 0, lev] + - q4_2[0, 0, lev] + + q4_4[0, 0, lev] * (1.0 - (2.0 / 3.0) * esl) + ) + ) + q = qsum / (pe2[0, 0, 1] - pe2) + lev = lev - 1 diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py index 9fafc27c85..faeca7b8dc 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py @@ -6,30 +6,32 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import pathlib -import re -import typing +import pytest +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import dace +else: + dace = pytest.importorskip("dace") import hypothesis.strategies as hyp_st import numpy as np -import pytest +import pathlib +import re +import typing from gt4py import cartesian as gt4pyc, storage as gt_storage from gt4py.cartesian import gtscript from gt4py.cartesian.gtscript import PARALLEL, computation, interval from gt4py.cartesian.stencil_builder import StencilBuilder from gt4py.storage.cartesian import utils as storage_utils +from gt4py.cartesian.backend.dace_lazy_stencil import DaCeLazyStencil from cartesian_tests.utils import OriginWrapper - -dace = pytest.importorskip("dace") -from gt4py.cartesian.backend.dace_lazy_stencil import ( # noqa: E402 [import-shadowed-by-loop-var] 'importorskip' is needed - DaCeLazyStencil, -) - - -pytestmark = pytest.mark.usefixtures("dace_env") +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = [pytest.mark.requires_dace, pytest.mark.usefixtures("dace_env")] @pytest.fixture(scope="module") diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py index 44112f3899..032dc3bb5e 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_suites.py @@ -7,7 +7,6 @@ # SPDX-License-Identifier: BSD-3-Clause import numpy as np -import pytest from gt4py.cartesian import gtscript, testing as gt_testing from gt4py.cartesian.gtscript import ( @@ -25,7 +24,6 @@ from .stencil_definitions import optional_field, two_optional_fields -# ---- Identity stencil ---- class TestIdentity(gt_testing.StencilTestSuite): """Identity stencil.""" @@ -43,11 +41,10 @@ def validation(field_a, domain=None, origin=None): pass -# ---- Copy stencil ---- class TestCopy(gt_testing.StencilTestSuite): """Copy stencil.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 25), (1, 25), (1, 25)] backends = ALL_BACKENDS symbols = dict( @@ -66,7 +63,7 @@ def validation(field_a, field_b, domain=None, origin=None): class TestAugAssign(gt_testing.StencilTestSuite): """Increment by one stencil.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 25), (1, 25), (1, 25)] backends = ALL_BACKENDS symbols = dict( @@ -86,11 +83,10 @@ def validation(field_a, field_b, domain=None, origin=None): field_b[...] = (field_b[...] - 1.0) / 2.0 -# ---- Scale stencil ---- class TestGlobalScale(gt_testing.StencilTestSuite): """Scale stencil using a global global_name.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -108,11 +104,10 @@ def validation(field_a, domain, origin, **kwargs): field_a[...] = SCALE_FACTOR * field_a # noqa: F821 [undefined-name] -# ---- Parametric scale stencil ----- class TestParametricScale(gt_testing.StencilTestSuite): """Scale stencil using a parameter.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -128,7 +123,6 @@ def validation(field_a, *, scale, domain, origin, **kwargs): field_a[...] = scale * field_a -# --- Parametric-mix stencil ---- class TestParametricMix(gt_testing.StencilTestSuite): """Linear combination of input fields using several parameters.""" @@ -136,7 +130,7 @@ class TestParametricMix(gt_testing.StencilTestSuite): ("USE_ALPHA",): np.int_, ("field_a", "field_b", "field_c"): np.float64, ("field_out",): np.float32, - ("weight", "alpha_factor"): np.float_, + ("weight", "alpha_factor"): np.float64, } domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS @@ -177,7 +171,7 @@ def validation( class TestHeatEquation_FTCS_3D(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -206,7 +200,7 @@ def validation(u, v, u_new, v_new, *, ru, rv, domain, origin, **kwargs): class TestHorizontalDiffusion(gt_testing.StencilTestSuite): """Diffusion in a horizontal 2D plane .""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -270,7 +264,7 @@ def fwd_diff_op_y(field): class TestHorizontalDiffusionSubroutines(gt_testing.StencilTestSuite): """Diffusion in a horizontal 2D plane .""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -305,7 +299,7 @@ def validation(u, diffusion, *, weight, domain, origin, **kwargs): class TestHorizontalDiffusionSubroutines2(gt_testing.StencilTestSuite): """Diffusion in a horizontal 2D plane .""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -346,7 +340,7 @@ def validation(u, diffusion, *, weight, domain, origin, **kwargs): class TestRuntimeIfFlat(gt_testing.StencilTestSuite): """Tests runtime ifs.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict(outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)])) @@ -365,7 +359,7 @@ def validation(outfield, *, domain, origin, **kwargs): class TestRuntimeIfNested(gt_testing.StencilTestSuite): """Tests nested runtime ifs.""" - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (1, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict(outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)])) @@ -391,7 +385,7 @@ def add_one(field_in): class Test3FoldNestedIf(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(3, 3), (3, 3), (3, 3)] backends = ALL_BACKENDS symbols = dict(field_a=gt_testing.field(in_range=(-1, 1), boundary=[(0, 0), (0, 0), (0, 0)])) @@ -411,7 +405,7 @@ def validation(field_a, domain, origin): class TestRuntimeIfNestedDataDependent(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(3, 3), (3, 3), (3, 3)] backends = ALL_BACKENDS symbols = dict( @@ -444,8 +438,38 @@ def validation(field_a, field_b, field_c, *, factor, domain, origin, **kwargs): field_a += 1 +class TestRuntimeIfNestedWhile(gt_testing.StencilTestSuite): + """Test conditional while statements.""" + + dtypes = (np.float64,) + domain_range = [(1, 15), (1, 15), (1, 15)] + backends = ALL_BACKENDS + symbols = dict( + infield=gt_testing.field(in_range=(-1, 1), boundary=[(0, 0), (0, 0), (0, 0)]), + outfield=gt_testing.field(in_range=(-10, 10), boundary=[(0, 0), (0, 0), (0, 0)]), + ) + + def definition(infield, outfield): + with computation(PARALLEL), interval(...): + if infield < 10: + outfield = 1 + done = False + while not done: + outfield = 2 + done = True + else: + condition = True + while condition: + outfield = 4 + condition = False + outfield = 3 + + def validation(infield, outfield, *, domain, origin, **kwargs): + outfield[...] = 2 + + class TestTernaryOp(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (2, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -466,7 +490,7 @@ def validation(infield, outfield, *, domain, origin, **kwargs): class TestThreeWayAnd(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (2, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -488,7 +512,7 @@ def validation(outfield, *, a, b, c, domain, origin, **kwargs): class TestThreeWayOr(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 15), (2, 15), (1, 15)] backends = ALL_BACKENDS symbols = dict( @@ -510,7 +534,7 @@ def validation(outfield, *, a, b, c, domain, origin, **kwargs): class TestOptionalField(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 32), (1, 32), (1, 32)] backends = ALL_BACKENDS symbols = dict( @@ -538,7 +562,7 @@ class TestNotSpecifiedOptionalField(TestOptionalField): class TestTwoOptionalFields(gt_testing.StencilTestSuite): - dtypes = (np.float_,) + dtypes = (np.float64,) domain_range = [(1, 32), (1, 32), (1, 32)] backends = ALL_BACKENDS symbols = dict( @@ -738,29 +762,10 @@ def validation(field_in, field_out, *, domain, origin): field_out[:, :, 0] = field_in[:, :, 0] -def _skip_dace_cpu_gcc_error(backends): - paramtype = type(pytest.param()) - res = [] - for b in backends: - if isinstance(b, paramtype) and b.values[0] == "dace:cpu": - res.append( - pytest.param( - *b.values, - marks=[ - *b.marks, - pytest.mark.skip("Internal compiler error in GitHub action container"), - ], - ) - ) - else: - res.append(b) - return res - - class TestVariableKRead(gt_testing.StencilTestSuite): dtypes = {"field_in": np.float32, "field_out": np.float32, "index": np.int32} domain_range = [(2, 2), (2, 2), (2, 8)] - backends = _skip_dace_cpu_gcc_error(ALL_BACKENDS) + backends = ALL_BACKENDS symbols = { "field_in": gt_testing.field( in_range=(-10, 10), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] @@ -782,7 +787,7 @@ def validation(field_in, field_out, index, *, domain, origin): class TestVariableKAndReadOutside(gt_testing.StencilTestSuite): dtypes = {"field_in": np.float64, "field_out": np.float64, "index": np.int32} domain_range = [(2, 2), (2, 2), (2, 8)] - backends = _skip_dace_cpu_gcc_error(ALL_BACKENDS) + backends = ALL_BACKENDS symbols = { "field_in": gt_testing.field( in_range=(0.1, 10), axes="IJK", boundary=[(0, 0), (0, 0), (1, 0)] @@ -865,7 +870,13 @@ def validation(field_in, field_out, *, domain, origin): field_out[:, -1, :] = field_in[:, -1, :] - 1.0 -class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite): +class TestHorizontalRegionsPartialWrites(gt_testing.StencilTestSuite): + """Use horizontal regions to only write to certain parts of the field. + + This test is different from the corner case below because the corner + case follows a different code path (we have specific optimizations for + them).""" + dtypes = {"field_in": np.float32, "field_out": np.float32} domain_range = [(4, 4), (4, 4), (2, 2)] backends = ALL_BACKENDS @@ -874,8 +885,40 @@ class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite): in_range=(-10, 10), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] ), "field_out": gt_testing.field( + in_range=(42, 42), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] + ), + } + + def definition(field_in, field_out): + with computation(PARALLEL), interval(...): + with horizontal(region[I[0], :], region[I[-1], :]): + field_out = ( # noqa: F841 [unused-variable] + field_in + 1.0 + ) + with horizontal(region[:, J[0]], region[:, J[-1]]): + field_out = ( # noqa: F841 [unused-variable] + field_in - 1.0 + ) + + def validation(field_in, field_out, *, domain, origin): + field_out[:, :, :] = 42 + field_out[0, :, :] = field_in[0, :, :] + 1.0 + field_out[-1, :, :] = field_in[-1, :, :] + 1.0 + field_out[:, 0, :] = field_in[:, 0, :] - 1.0 + field_out[:, -1, :] = field_in[:, -1, :] - 1.0 + + +class TestHorizontalRegionsCorners(gt_testing.StencilTestSuite): + dtypes = {"field_in": np.float32, "field_out": np.float32} + domain_range = [(4, 4), (4, 4), (2, 2)] + backends = ALL_BACKENDS + symbols = { + "field_in": gt_testing.field( in_range=(-10, 10), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] ), + "field_out": gt_testing.field( + in_range=(42, 42), axes="IJK", boundary=[(0, 0), (0, 0), (0, 0)] + ), } def definition(field_in, field_out): @@ -890,6 +933,7 @@ def definition(field_in, field_out): ) def validation(field_in, field_out, *, domain, origin): + field_out[:, :, :] = 42 field_out[0:2, 0:2, :] = field_in[0:2, 0:2, :] + 1.0 field_out[-3:-1, -3:-1, :] = field_in[-3:-1, -3:-1, :] + 1.0 field_out[0:2, -3:-1, :] = field_in[0:2, -3:-1, :] - 1.0 diff --git a/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py b/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py index c47ad10e94..3fbf586b35 100644 --- a/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py +++ b/tests/cartesian_tests/unit_tests/backend_tests/test_backend_api.py @@ -79,7 +79,7 @@ def test_generate_bindings(backend, tmp_path): ) else: # assumption: only gt backends support python bindings for other languages than python - result = builder.backend.generate_bindings("python", stencil_ir=builder.gtir) + result = builder.backend.generate_bindings("python") assert "init_1_src" in result - srcs = result["init_1_src"] - assert "bindings.cpp" in srcs or "bindings.cu" in srcs + sources = result["init_1_src"] + assert "bindings.cpp" in sources or "bindings.cu" in sources diff --git a/tests/cartesian_tests/unit_tests/backend_tests/test_module_generator.py b/tests/cartesian_tests/unit_tests/backend_tests/test_module_generator.py index 963b824122..8efc414458 100644 --- a/tests/cartesian_tests/unit_tests/backend_tests/test_module_generator.py +++ b/tests/cartesian_tests/unit_tests/backend_tests/test_module_generator.py @@ -36,7 +36,7 @@ def sample_builder(): @pytest.fixture def sample_args_data(): - dtype = np.dtype(np.float_) + dtype = np.dtype(np.float64) yield ModuleData( field_info={ "in_field": FieldInfo( diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index e62f878746..1f7a779835 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -720,7 +720,7 @@ def definition_func(field: gtscript.Field[float]): class TestRegions: def test_one_interval_only(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...), horizontal(region[I[0:3], :]): in_f = 1.0 @@ -732,7 +732,7 @@ def stencil(in_f: gtscript.Field[np.float_]): assert isinstance(def_ir.computations[0].body.stmts[0], nodes.HorizontalIf) def test_one_interval_only_single(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...), horizontal(region[I[0], :]): in_f = 1.0 @@ -744,7 +744,7 @@ def stencil(in_f: gtscript.Field[np.float_]): assert def_ir.computations[0].body.stmts[0].intervals["I"].is_single_index def test_from_external(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): from gt4py.cartesian.__externals__ import i1 with computation(PARALLEL), interval(...), horizontal(region[i1, :]): @@ -766,7 +766,7 @@ def stencil(in_f: gtscript.Field[np.float_]): assert def_ir.computations[0].body.stmts[0].intervals["I"].is_single_index def test_multiple_inline(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_f = in_f + 1.0 with horizontal(region[I[0], :], region[:, J[-1]]): @@ -789,7 +789,7 @@ def region_func(): return field - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_f = region_func() @@ -801,7 +801,7 @@ def stencil(in_f: gtscript.Field[np.float_]): ) def test_error_undefined(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): from gt4py.cartesian.__externals__ import i0 # forget to add 'ia' with computation(PARALLEL), interval(...): @@ -813,7 +813,7 @@ def stencil(in_f: gtscript.Field[np.float_]): parse_definition(stencil, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_error_nested(self): - def stencil(in_f: gtscript.Field[np.float_]): + def stencil(in_f: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_f = in_f + 1.0 with horizontal(region[I[0], :]): @@ -1054,9 +1054,9 @@ def definition(inout_field: gtscript.Field[float]): class TestReducedDimensions: def test_syntax(self): def definition_func( - field_3d: gtscript.Field[gtscript.IJK, np.float_], - field_2d: gtscript.Field[gtscript.IJ, np.float_], - field_1d: gtscript.Field[gtscript.K, np.float_], + field_3d: gtscript.Field[gtscript.IJK, np.float64], + field_2d: gtscript.Field[gtscript.IJ, np.float64], + field_1d: gtscript.Field[gtscript.K, np.float64], ): with computation(FORWARD), interval(...): field_2d = field_1d[1] @@ -1085,8 +1085,8 @@ def definition_func( def test_error_syntax(self): def definition( - field_in: gtscript.Field[gtscript.K, np.float_], - field_out: gtscript.Field[gtscript.IJK, np.float_], + field_in: gtscript.Field[gtscript.K, np.float64], + field_out: gtscript.Field[gtscript.IJK, np.float64], ): with computation(PARALLEL), interval(...): field_out = field_in[0, 0, 1] @@ -1099,8 +1099,8 @@ def definition( def test_error_write_1d(self): def definition( - field_in: gtscript.Field[gtscript.IJK, np.float_], - field_out: gtscript.Field[gtscript.K, np.float_], + field_in: gtscript.Field[gtscript.IJK, np.float64], + field_out: gtscript.Field[gtscript.K, np.float64], ): with computation(PARALLEL), interval(...): field_out = field_in[0, 0, 0] @@ -1113,10 +1113,10 @@ def definition( def test_higher_dim_temp(self): def definition( - field_in: gtscript.Field[gtscript.IJK, np.float_], - field_out: gtscript.Field[gtscript.IJK, np.float_], + field_in: gtscript.Field[gtscript.IJK, np.float64], + field_out: gtscript.Field[gtscript.IJK, np.float64], ): - tmp: Field[IJK, (np.float_, (2,))] = 0.0 + tmp: Field[IJK, (np.float64, (2,))] = 0.0 with computation(PARALLEL), interval(...): tmp[0, 0, 0][0] = field_in field_out = tmp[0, 0, 0][0] @@ -1125,10 +1125,10 @@ def definition( def test_typed_temp_missing(self): def definition( - field_in: gtscript.Field[gtscript.IJK, np.float_], - field_out: gtscript.Field[gtscript.IJK, np.float_], + field_in: gtscript.Field[gtscript.IJK, np.float64], + field_out: gtscript.Field[gtscript.IJK, np.float64], ): - tmp: Field[IJ, np.float_] = 0.0 + tmp: Field[IJ, np.float64] = 0.0 with computation(FORWARD), interval(1, None): tmp = field_in[0, 0, -1] field_out = tmp @@ -1143,9 +1143,9 @@ def definition( class TestDataDimensions: def test_syntax(self): def definition( - field_in: gtscript.Field[np.float_], - another_field: gtscript.Field[(np.float_, 3)], - field_out: gtscript.Field[gtscript.IJK, (np.float_, (3,))], + field_in: gtscript.Field[np.float64], + another_field: gtscript.Field[(np.float64, 3)], + field_out: gtscript.Field[gtscript.IJK, (np.float64, (3,))], ): with computation(PARALLEL), interval(...): field_out[0, 0, 0][0] = field_in @@ -1156,8 +1156,8 @@ def definition( def test_syntax_no_datadim(self): def definition( - field_in: gtscript.Field[np.float_], - field_out: gtscript.Field[gtscript.IJK, (np.float_, (3,))], + field_in: gtscript.Field[np.float64], + field_out: gtscript.Field[gtscript.IJK, (np.float64, (3,))], ): with computation(PARALLEL), interval(...): field_out[0, 0, 0][0] = field_in @@ -1169,8 +1169,8 @@ def definition( def test_syntax_out_bounds(self): def definition( - field_in: gtscript.Field[np.float_], - field_out: gtscript.Field[gtscript.IJK, (np.float_, (3,))], + field_in: gtscript.Field[np.float64], + field_out: gtscript.Field[gtscript.IJK, (np.float64, (3,))], ): with computation(PARALLEL), interval(...): field_out[0, 0, 0][3] = field_in[0, 0, 0] @@ -1180,8 +1180,8 @@ def definition( def test_indirect_access_read(self): def definition( - field_3d: gtscript.Field[np.float_], - field_4d: gtscript.Field[gtscript.IJK, (np.float_, (2,))], + field_3d: gtscript.Field[np.float64], + field_4d: gtscript.Field[gtscript.IJK, (np.float64, (2,))], variable: float, ): with computation(PARALLEL), interval(...): @@ -1194,8 +1194,8 @@ def definition( def test_indirect_access_write(self): def definition( - field_3d: gtscript.Field[np.float_], - field_4d: gtscript.Field[gtscript.IJK, (np.float_, (2,))], + field_3d: gtscript.Field[np.float64], + field_4d: gtscript.Field[gtscript.IJK, (np.float64, (2,))], variable: float, ): with computation(PARALLEL), interval(...): @@ -1381,14 +1381,14 @@ def test_literal_floating_parametrization(self, the_float): class TestAssignmentSyntax: def test_ellipsis(self): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): out_field[...] = in_field parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_offset(self): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): out_field[0, 0, 0] = in_field @@ -1397,15 +1397,15 @@ def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float with pytest.raises(gt_frontend.GTScriptSyntaxError): def func( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], ): with computation(PARALLEL), interval(...): out_field[0, 0, 1] = in_field parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64]): from gt4py.cartesian.__externals__ import offset with computation(PARALLEL), interval(...): @@ -1471,8 +1471,8 @@ def test_slice(self): with pytest.raises(gt_frontend.GTScriptSyntaxError): def func( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], ): with computation(PARALLEL), interval(...): out_field[:, :, :] = in_field @@ -1483,8 +1483,8 @@ def test_string(self): with pytest.raises(gt_frontend.GTScriptSyntaxError): def func( - in_field: gtscript.Field[np.float_], - out_field: gtscript.Field[np.float_], + in_field: gtscript.Field[np.float64], + out_field: gtscript.Field[np.float64], ): with computation(PARALLEL), interval(...): out_field["a_key"] = in_field @@ -1492,7 +1492,7 @@ def func( parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_augmented(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += 2.0 in_field -= 0.5 @@ -1589,7 +1589,7 @@ def data_dims_with_at( class TestNestedWithSyntax: def test_nested_with(self): - def definition(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def definition(in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64]): with computation(PARALLEL): with interval(...): in_field = out_field @@ -1598,7 +1598,7 @@ def definition(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np def test_nested_with_ordering(self): def definition_fw( - in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_] + in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64] ): from gt4py.cartesian.__gtscript__ import FORWARD, computation, interval @@ -1609,7 +1609,7 @@ def definition_fw( in_field = out_field + 2 def definition_bw( - in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_] + in_field: gtscript.Field[np.float64], out_field: gtscript.Field[np.float64] ): from gt4py.cartesian.__gtscript__ import FORWARD, computation, interval @@ -1633,35 +1633,35 @@ def definition_bw( class TestNativeFunctions: def test_simple_call(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sin(in_field) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_offset_arg(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sin(in_field[1, 0, 0]) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_nested_calls(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sin(abs(in_field)) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_nested_external_call(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sin(add_external_const(in_field)) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_multi_nested_calls(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += min(abs(sin(add_external_const(in_field))), -0.5) @@ -1672,28 +1672,28 @@ def test_native_in_function(self): def sinus(field_in): return sin(field_in) - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field += sinus(in_field) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_native_function_unary(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field = not isfinite(in_field) parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_native_function_binary(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field = asin(in_field) + 1 parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) def test_native_function_ternary(self): - def func(in_field: gtscript.Field[np.float_]): + def func(in_field: gtscript.Field[np.float64]): with computation(PARALLEL), interval(...): in_field = asin(in_field) + 1 if 1 < in_field else sin(in_field) @@ -1702,7 +1702,7 @@ def func(in_field: gtscript.Field[np.float_]): class TestWarnInlined: def test_inlined_emits_warning(self): - def func(field: gtscript.Field[np.float_]): + def func(field: gtscript.Field[np.float64]): from gt4py.cartesian.__externals__ import SET_TO_ONE with computation(PARALLEL), interval(...): diff --git a/src/gt4py/next/program_processors/runners/dace_common/__init__.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/__init__.py similarity index 58% rename from src/gt4py/next/program_processors/runners/dace_common/__init__.py rename to tests/cartesian_tests/unit_tests/test_gtc/dace/__init__.py index abf4c3e24c..c1f188446b 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/__init__.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/__init__.py @@ -6,3 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import pytest + +# Skip this entire folder when we collecting tests and "dace" is not installed as a dependency. +pytest.importorskip("dace") diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py new file mode 100644 index 0000000000..af23d7056a --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_daceir_builder.py @@ -0,0 +1,109 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py.cartesian.gtc.dace import daceir as dcir + +from cartesian_tests.unit_tests.test_gtc.dace import utils +from cartesian_tests.unit_tests.test_gtc.oir_utils import ( + AssignStmtFactory, + BinaryOpFactory, + HorizontalExecutionFactory, + LiteralFactory, + LocalScalarFactory, + MaskStmtFactory, + ScalarAccessFactory, + StencilFactory, + WhileFactory, +) + + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +def test_dcir_code_structure_condition() -> None: + """Tests the following code structure: + + ComputationState + Condition + true_states: [ComputationState] + false_states: [] + ComputationState + """ + stencil = StencilFactory( + vertical_loops__0__sections__0__horizontal_executions=[ + HorizontalExecutionFactory( + body=[ + AssignStmtFactory( + left=ScalarAccessFactory(name="tmp"), + right=BinaryOpFactory( + left=LiteralFactory(value="0"), right=LiteralFactory(value="2") + ), + ), + MaskStmtFactory(), + AssignStmtFactory( + left=ScalarAccessFactory(name="other"), + right=ScalarAccessFactory(name="tmp"), + ), + ], + declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], + ), + ] + ) + expansions = utils.library_node_expansions(stencil) + assert len(expansions) == 1, "expect one vertical loop to be expanded" + + nested_SDFG = utils.nested_SDFG_inside_triple_loop(expansions[0]) + assert isinstance(nested_SDFG.states[0], dcir.ComputationState) + assert isinstance(nested_SDFG.states[1], dcir.Condition) + assert nested_SDFG.states[1].true_states + assert isinstance(nested_SDFG.states[1].true_states[0], dcir.ComputationState) + assert not nested_SDFG.states[1].false_states + assert isinstance(nested_SDFG.states[2], dcir.ComputationState) + + +def test_dcir_code_structure_while() -> None: + """Tests the following code structure + + ComputationState + WhileLoop + body: [ComputationState] + ComputationState + """ + stencil = StencilFactory( + vertical_loops__0__sections__0__horizontal_executions=[ + HorizontalExecutionFactory( + body=[ + AssignStmtFactory( + left=ScalarAccessFactory(name="tmp"), + right=BinaryOpFactory( + left=LiteralFactory(value="0"), right=LiteralFactory(value="2") + ), + ), + WhileFactory(), + AssignStmtFactory( + left=ScalarAccessFactory(name="other"), + right=ScalarAccessFactory(name="tmp"), + ), + ], + declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], + ), + ] + ) + expansions = utils.library_node_expansions(stencil) + assert len(expansions) == 1, "expect one vertical loop to be expanded" + + nested_SDFG = utils.nested_SDFG_inside_triple_loop(expansions[0]) + assert isinstance(nested_SDFG.states[0], dcir.ComputationState) + assert isinstance(nested_SDFG.states[1], dcir.WhileLoop) + assert nested_SDFG.states[1].body + assert isinstance(nested_SDFG.states[1].body[0], dcir.ComputationState) + assert isinstance(nested_SDFG.states[2], dcir.ComputationState) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py new file mode 100644 index 0000000000..561e994b27 --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_sdfg_builder.py @@ -0,0 +1,144 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dace +import pytest + +from gt4py.cartesian.gtc.common import BuiltInLiteral, DataType +from gt4py.cartesian.gtc.dace.expansion.sdfg_builder import StencilComputationSDFGBuilder + +from cartesian_tests.unit_tests.test_gtc.dace import utils +from cartesian_tests.unit_tests.test_gtc.oir_utils import ( + AssignStmtFactory, + BinaryOpFactory, + HorizontalExecutionFactory, + LiteralFactory, + LocalScalarFactory, + MaskStmtFactory, + ScalarAccessFactory, + StencilFactory, +) + + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +def test_scalar_access_multiple_tasklets() -> None: + """Test scalar access if an oir.CodeBlock is split over multiple Tasklets. + + We are breaking up vertical loops inside stencils in multiple Tasklets. It might thus happen that + we write a "local" scalar in one Tasklet and read it in another Tasklet (downstream). + We thus create output connectors for all writes to scalar variables inside Tasklets. And input + connectors for all scalar reads unless previously written in the same Tasklet. DaCe's simplify + pipeline will get rid of any dead dataflow introduced with this general approach. + """ + stencil = StencilFactory( + vertical_loops__0__sections__0__horizontal_executions=[ + HorizontalExecutionFactory( + body=[ + AssignStmtFactory( + left=ScalarAccessFactory(name="tmp"), + right=BinaryOpFactory( + left=LiteralFactory(value="0"), right=LiteralFactory(value="2") + ), + ), + MaskStmtFactory( + mask=LiteralFactory(value=BuiltInLiteral.TRUE, dtype=DataType.BOOL), body=[] + ), + AssignStmtFactory( + left=ScalarAccessFactory(name="other"), + right=ScalarAccessFactory(name="tmp"), + ), + ], + declarations=[LocalScalarFactory(name="tmp"), LocalScalarFactory(name="other")], + ), + ] + ) + expansions = utils.library_node_expansions(stencil) + nsdfg = StencilComputationSDFGBuilder().visit(expansions[0]) + assert isinstance(nsdfg.sdfg, dace.SDFG) + + for node in nsdfg.sdfg.nodes()[1].nodes(): + if not isinstance(node, dace.nodes.NestedSDFG): + continue + + nested = node.sdfg + for state in nested.states(): + if state.name == "block_0": + nodes = state.nodes() + assert ( + len(list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))) == 1 + ) + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) + and node.data == "tmp", + nodes, + ) + ) + ) + == 1 + ), "one AccessNode of tmp" + + edges = state.edges() + tasklet = list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))[0] + write_access = list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "tmp", + nodes, + ) + )[0] + assert len(edges) == 1, "one edge expected" + assert ( + edges[0].src == tasklet and edges[0].dst == write_access + ), "write access of 'tmp'" + + if state.name == "block_1": + nodes = state.nodes() + assert ( + len(list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))) == 1 + ) + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) + and node.data == "tmp", + nodes, + ) + ) + ) + == 1 + ), "one AccessNode of tmp" + + edges = state.edges() + tasklet = list(filter(lambda node: isinstance(node, dace.nodes.Tasklet), nodes))[0] + read_access = list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "tmp", + nodes, + ) + )[0] + write_access = list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) + and node.data == "other", + nodes, + ) + )[0] + assert len(edges) == 2, "two edges expected" + assert ( + edges[0].src == tasklet and edges[0].dst == write_access + ), "write access of 'other'" + assert ( + edges[1].src == read_access and edges[1].dst == tasklet + ), "read access of 'tmp'" diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py new file mode 100644 index 0000000000..ab501d722e --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/test_utils.py @@ -0,0 +1,44 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from typing import Optional + +from gt4py.cartesian.gtc.common import DataType, CartesianOffset +from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace import prefix +from gt4py.cartesian.gtc.dace import utils + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +@pytest.mark.parametrize( + "name,is_target,offset,expected", + [ + ("A", False, None, f"{prefix.TASKLET_IN}A"), + ("A", True, None, f"{prefix.TASKLET_OUT}A"), + ("A", True, CartesianOffset(i=0, j=0, k=-1), f"{prefix.TASKLET_OUT}Akm1"), + ("A", False, CartesianOffset(i=1, j=-2, k=3), f"{prefix.TASKLET_IN}Aip1_jm2_kp3"), + ( + "A", + True, + dcir.VariableKOffset(k=dcir.Literal(value="3", dtype=DataType.INT32)), + f"{prefix.TASKLET_OUT}A", + ), + ], +) +def test_get_tasklet_symbol( + name: str, + is_target: bool, + offset: Optional[CartesianOffset | dcir.VariableKOffset], + expected: str, +) -> None: + assert utils.get_tasklet_symbol(name, is_target=is_target, offset=offset) == expected diff --git a/tests/cartesian_tests/unit_tests/test_gtc/dace/utils.py b/tests/cartesian_tests/unit_tests/test_gtc/dace/utils.py new file mode 100644 index 0000000000..b976631017 --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/dace/utils.py @@ -0,0 +1,54 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dace + +from gt4py.cartesian.gtc.dace import daceir as dcir +from gt4py.cartesian.gtc.dace.expansion.daceir_builder import DaCeIRBuilder +from gt4py.cartesian.gtc.dace.nodes import StencilComputation +from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder +from gt4py.cartesian.gtc.dace.expansion.expansion import StencilComputationExpansion + +from cartesian_tests.unit_tests.test_gtc.oir_utils import StencilFactory + + +def library_node_expansions(stencil: StencilFactory) -> list[dcir.NestedSDFG]: + """Return all expanded library nodes in a given stencil.""" + sdfg = OirSDFGBuilder().visit(stencil) + assert isinstance(sdfg, dace.SDFG) + + expansions = [] + for state in sdfg.nodes(): + for node in state.nodes(): + if not isinstance(node, StencilComputation): + continue + + arrays = StencilComputationExpansion._get_parent_arrays(node, state, sdfg) + nested_SDFG = DaCeIRBuilder().visit( + node.oir_node, + global_ctx=DaCeIRBuilder.GlobalContext(library_node=node, arrays=arrays), + ) + expansions.append(nested_SDFG) + + return expansions + + +def nested_SDFG_inside_triple_loop(nSDFG: dcir.NestedSDFG) -> dcir.NestedSDFG: + """Pick the inner nested SDFG out of the triple loop.""" + assert isinstance(nSDFG, dcir.NestedSDFG) + assert isinstance(nSDFG.states[0], dcir.ComputationState) + assert isinstance(nSDFG.states[0].computations[0], dcir.DomainMap) + assert isinstance(nSDFG.states[0].computations[0].computations[0], dcir.DomainMap) + assert isinstance( + nSDFG.states[0].computations[0].computations[0].computations[0], dcir.DomainMap + ) + assert isinstance( + nSDFG.states[0].computations[0].computations[0].computations[0].computations[0], + dcir.NestedSDFG, + ) + return nSDFG.states[0].computations[0].computations[0].computations[0].computations[0] diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index 68006c113b..4e799d2090 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -41,6 +41,24 @@ # - For testing non-leave nodes, introduce builders with defaults (for leave nodes as well) +def test_data_type_methods(): + for type in DataType: + if type == DataType.BOOL: + assert type.isbool() + else: + assert not type.isbool() + + if type in (DataType.INT8, DataType.INT16, DataType.INT32, DataType.INT64): + assert type.isinteger() + else: + assert not type.isinteger() + + if type in (DataType.FLOAT32, DataType.FLOAT64): + assert type.isfloat() + else: + assert not type.isfloat() + + class DummyExpr(Expr): """Fake expression for cases where a concrete expression is not needed.""" diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py new file mode 100644 index 0000000000..9b8c127156 --- /dev/null +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_dace.py @@ -0,0 +1,159 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import dace +else: + dace = pytest.importorskip("dace") + +from gt4py.cartesian.gtc import oir +from gt4py.cartesian.gtc.common import DataType +from gt4py.cartesian.gtc.dace.nodes import StencilComputation +from gt4py.cartesian.gtc.dace.oir_to_dace import OirSDFGBuilder + +from cartesian_tests.unit_tests.test_gtc.oir_utils import ( + AssignStmtFactory, + FieldAccessFactory, + FieldDeclFactory, + ScalarAccessFactory, + StencilFactory, +) + +# Because "dace tests" filter by `requires_dace`, we still need to add the marker. +# This global variable add the marker to all test functions in this module. +pytestmark = pytest.mark.requires_dace + + +def test_oir_sdfg_builder_copy_stencil() -> None: + stencil_name = "copy" + stencil = StencilFactory( + name=stencil_name, + params=[ + FieldDeclFactory(name="A", dtype=DataType.FLOAT32), + FieldDeclFactory(name="B", dtype=DataType.FLOAT32), + ], + vertical_loops__0__sections__0__horizontal_executions__0__body=[ + AssignStmtFactory(left=FieldAccessFactory(name="B"), right=FieldAccessFactory(name="A")) + ], + ) + sdfg = OirSDFGBuilder().visit(stencil) + + assert isinstance(sdfg, dace.SDFG), "DaCe SDFG expected" + assert sdfg.name == stencil_name, "Stencil name is preserved" + assert len(sdfg.arrays) == 2, "two arrays expected (A and B)" + + a_array = sdfg.arrays.get("A") + assert a_array is not None, "Array A expected to be defined" + assert a_array.ctype == "float", "A is of type `float`" + assert a_array.offset == (0, 0, 0), "CartesianOffset.zero() expected" + + b_array = sdfg.arrays.get("B") + assert b_array is not None, "Array B expected to be defined" + assert b_array.ctype == "float", "B is of type `float`" + assert b_array.offset == (0, 0, 0), "CartesianOffset.zero() expected" + + states = sdfg.nodes() + assert len(states) >= 1, "at least one state expected" + + # expect StencilComputation, AccessNode(A), and AccessNode(B) in the last block + last_block = states[len(states) - 1] + nodes = last_block.nodes() + assert ( + len(list(filter(lambda node: isinstance(node, StencilComputation), nodes))) == 1 + ), "one StencilComputation library node" + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes + ) + ) + ) + == 1 + ), "one AccessNode of A" + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "B", nodes + ) + ) + ) + == 1 + ), "one AccessNode of B" + + edges = last_block.edges() + assert len(edges) == 2, "read and write memlet path expected" + + library_node = list(filter(lambda node: isinstance(node, StencilComputation), nodes))[0] + read_access = list( + filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes) + )[0] + write_access = list( + filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "B", nodes) + )[0] + + assert edges[0].src == read_access and edges[0].dst == library_node, "read access expected" + assert edges[1].src == library_node and edges[1].dst == write_access, "write access expected" + + +def test_oir_sdfg_builder_assign_scalar_param() -> None: + stencil_name = "scalar_assign" + stencil = StencilFactory( + name=stencil_name, + params=[ + FieldDeclFactory(name="A", dtype=DataType.FLOAT64), + oir.ScalarDecl(name="b", dtype=DataType.INT32), + ], + vertical_loops__0__sections__0__horizontal_executions__0__body=[ + AssignStmtFactory( + left=FieldAccessFactory(name="A"), right=ScalarAccessFactory(name="b") + ) + ], + ) + sdfg = OirSDFGBuilder().visit(stencil) + + assert isinstance(sdfg, dace.SDFG), "DaCe SDFG expected" + assert sdfg.name == stencil_name, "Stencil name is preserved" + assert len(sdfg.arrays) == 1, "one array expected (A)" + + a_array = sdfg.arrays.get("A") + assert a_array is not None, "Array A expected to be defined" + assert a_array.ctype == "double", "Array A is of type `double`" + assert a_array.offset == (0, 0, 0), "CartesianOffset.zeros() expected" + assert "b" in sdfg.symbols.keys(), "expected `b` as scalar parameter" + + states = sdfg.nodes() + assert len(states) >= 1, "at least one state expected" + + last_block = states[len(states) - 1] + nodes = last_block.nodes() + assert ( + len(list(filter(lambda node: isinstance(node, StencilComputation), nodes))) == 1 + ), "one StencilComputation library node" + assert ( + len( + list( + filter( + lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes + ) + ) + ) + == 1 + ), "one AccessNode of A" + + edges = last_block.edges() + library_node = list(filter(lambda node: isinstance(node, StencilComputation), nodes))[0] + write_access = list( + filter(lambda node: isinstance(node, dace.nodes.AccessNode) and node.data == "A", nodes) + )[0] + assert len(edges) == 1, "write memlet path expected" + assert edges[0].src == library_node and edges[0].dst == write_access, "write access expected" diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py index 4de7f9f5d6..4877a39503 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_oir_to_npir.py @@ -28,6 +28,7 @@ StencilFactory, VerticalLoopFactory, VerticalLoopSectionFactory, + WhileFactory, ) @@ -78,6 +79,18 @@ def test_mask_stmt_to_assigns() -> None: assert len(assign_stmts) == 1 +def test_mask_stmt_to_while() -> None: + mask_oir = MaskStmtFactory(body=[WhileFactory()]) + statements = OirToNpir().visit(mask_oir, extent=Extent.zeros(ndims=2)) + assert len(statements) == 1 + assert isinstance(statements[0], npir.While) + condition = statements[0].cond + assert isinstance(condition, npir.VectorLogic) + assert condition.op == common.LogicalOperator.AND + mask_npir = OirToNpir().visit(mask_oir.mask) + assert condition.left == mask_npir or condition.right == mask_npir + + def test_mask_propagation() -> None: mask_stmt = MaskStmtFactory() assign_stmts = OirToNpir().visit(mask_stmt, extent=Extent.zeros(ndims=2)) diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py index 078adcc8da..6bb4ec63f6 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_min_k_interval.py @@ -16,7 +16,10 @@ from gt4py import cartesian as gt4pyc from gt4py.cartesian import gtscript as gs from gt4py.cartesian.backend import from_name -from gt4py.cartesian.gtc.passes.gtir_k_boundary import compute_k_boundary, compute_min_k_size +from gt4py.cartesian.gtc.passes.gtir_k_boundary import ( + compute_k_boundary, + compute_min_k_size, +) from gt4py.cartesian.gtc.passes.gtir_pipeline import prune_unused_parameters from gt4py.cartesian.gtscript import PARALLEL, computation, interval, stencil from gt4py.cartesian.stencil_builder import StencilBuilder @@ -48,21 +51,21 @@ def stencil_no_extent_0(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(0, -2), 0), min_k_size=2) +@register_test_case(k_bounds=(0, 0), min_k_size=2) @typing.no_type_check def stencil_no_extent_1(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(-1, -2), 0), min_k_size=2) +@register_test_case(k_bounds=(-1, 0), min_k_size=2) @typing.no_type_check def stencil_no_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(1, 2): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(max(0, -2), max(-2, -2)), 0), min_k_size=3) +@register_test_case(k_bounds=(0, 0), min_k_size=4) @typing.no_type_check def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): @@ -73,14 +76,14 @@ def stencil_no_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(0, max(-1, 0)), min_k_size=1) +@register_test_case(k_bounds=(0, 0), min_k_size=1) @typing.no_type_check def stencil_no_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(-1, None): field_a = field_b[0, 0, 0] -@register_test_case(k_bounds=(max(0, -1), max(-2, 0)), min_k_size=3) +@register_test_case(k_bounds=(0, 0), min_k_size=3) @typing.no_type_check def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 1): @@ -89,6 +92,13 @@ def stencil_no_extent_5(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 0] +@register_test_case(k_bounds=(-1, -2), min_k_size=4) +@typing.no_type_check +def stencil_no_extent_6(field_a: gs.Field[float], field_b: gs.Field[float]): + with computation(PARALLEL), interval(1, -2): + field_a[0, 0, 0] = field_b[0, 0, 0] + + # stencils with extent @register_test_case(k_bounds=(5, -5), min_k_size=0) @typing.no_type_check @@ -111,7 +121,7 @@ def stencil_with_extent_2(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, 5] -@register_test_case(k_bounds=(3, -3), min_k_size=3) +@register_test_case(k_bounds=(3, -3), min_k_size=4) @typing.no_type_check def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, 2): @@ -122,7 +132,7 @@ def stencil_with_extent_3(field_a: gs.Field[float], field_b: gs.Field[float]): field_a = field_b[0, 0, -3] -@register_test_case(k_bounds=(-5, 5), min_k_size=1) +@register_test_case(k_bounds=(-5, 5), min_k_size=2) @typing.no_type_check def stencil_with_extent_4(field_a: gs.Field[float], field_b: gs.Field[float]): with computation(PARALLEL), interval(0, -1): @@ -171,7 +181,10 @@ def test_min_k_size(definition, expected_min_k_size): @pytest.mark.parametrize("definition,expected", test_data) def test_k_bounds_exec(definition, expected): - expected_k_bounds, expected_min_k_size = expected["k_bounds"], expected["min_k_size"] + expected_k_bounds, expected_min_k_size = ( + expected["k_bounds"], + expected["min_k_size"], + ) required_field_size = expected_min_k_size + expected_k_bounds[0] + expected_k_bounds[1] @@ -234,7 +247,10 @@ def stencil_with_invalid_temporary_access_end(field_a: gs.Field[float], field_b: @pytest.mark.parametrize( "definition", - [stencil_with_invalid_temporary_access_start, stencil_with_invalid_temporary_access_end], + [ + stencil_with_invalid_temporary_access_start, + stencil_with_invalid_temporary_access_end, + ], ) def test_invalid_temporary_access(definition): builder = StencilBuilder(definition, backend=from_name("numpy")) diff --git a/tests/conftest.py b/tests/conftest.py index 285ccda2b0..1bf73651a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,5 +8,55 @@ """Global configuration of pytest for collecting and running tests.""" +import collections.abc +import functools +import sys +import types +from typing import Final + +import pytest + + # Ignore hidden folders and disabled tests collect_ignore_glob = [".*", "_disabled*"] + +# Custom module attribute to store package-level marks +_PKG_MARKS_ATTR_NAME: Final = "package_pytestmarks" + + +@functools.cache +def _get_pkg_marks(module_name: str) -> list[pytest.Mark | str]: + """Collect markers in the `package_pytestmarks` module attribute (and recursively from its parents).""" + module = sys.modules[module_name] + pkg_markers = getattr(module, _PKG_MARKS_ATTR_NAME, []) + assert isinstance( + pkg_markers, collections.abc.Sequence + ), f"'{_PKG_MARKS_ATTR_NAME}' content must be a sequence of marks" + + if (parent := module_name.rsplit(".", 1)[0]) != module_name: + pkg_markers += _get_pkg_marks(parent) + + return pkg_markers + + +def pytest_collection_modifyitems( + session: pytest.Session, config: pytest.Config, items: list[pytest.Item] +) -> None: + """Pytest hook to modify the collected test items. + + See: https://docs.pytest.org/en/stable/reference/reference.html#pytest.hookspec.pytest_collection_modifyitems + """ + for item in items: + # Visit the chain of parents of the current test item in reverse order, + # until we get to the module object where the test function (or class) + # has been defined. At that point, process the custom package-level marks + # attribute if present, and move to the next collected item in the list. + for node in item.listchain()[-2::-1]: + if not (obj := getattr(node, "obj", None)): + break + if not isinstance(obj, types.ModuleType): + continue + + module_name = obj.__name__ + for marker in _get_pkg_marks(module_name): + item.add_marker(marker) diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 05be5f3db0..75b07fd8a0 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -10,9 +10,9 @@ import enum import numbers +import sys import types import typing -from typing import Set # noqa: F401 [unused-import] used in exec() context from typing import ( Any, Callable, @@ -26,6 +26,7 @@ MutableSequence, Optional, Sequence, + Set, # noqa: F401 [unused-import] used in exec() context Tuple, Type, TypeVar, @@ -555,6 +556,18 @@ class WrongModel: ("typing.MutableSequence[int]", ([1, 2, 3], []), ((1, 2, 3), tuple(), 1, [1.0], {1})), ("typing.Set[int]", ({1, 2, 3}, set()), (1, [1], (1,), {1: None})), ("typing.Union[int, float, str]", [1, 3.0, "one"], [[1], [], 1j]), + pytest.param( + "int | float | str", + [1, 3.0, "one"], + [[1], [], 1j], + marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="| union syntax not supported"), + ), + pytest.param( + "typing.List[int|float]", + [[1, 2.0], []], + [1, 2.0, [1, "2.0"]], + marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="| union syntax not supported"), + ), ("typing.Optional[int]", [1, None], [[1], [], 1j]), ( "typing.Dict[Union[int, float, str], Union[Tuple[int, Optional[float]], Set[int]]]", diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 123384a098..1f81076abf 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -11,11 +11,11 @@ import dataclasses import enum import importlib -from typing import Final, Optional, Protocol +from typing import Final import pytest -from gt4py.next import allocators as next_allocators, backend as next_backend +from gt4py.next import allocators as next_allocators # Skip definitions @@ -43,11 +43,10 @@ def short_id(self, num_components: int = 2) -> str: class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): GTFN_CPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn" GTFN_CPU_IMPERATIVE = "gt4py.next.program_processors.runners.gtfn.run_gtfn_imperative" - GTFN_CPU_WITH_TEMPORARIES = ( - "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" - ) + GTFN_CPU_NO_TRANSFORMS = "gt4py.next.program_processors.runners.gtfn.run_gtfn_no_transforms" GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.default" + ROUNDTRIP_NO_TRANSFORMS = "gt4py.next.program_processors.runners.roundtrip.no_transforms" GTIR_EMBEDDED = "gt4py.next.program_processors.runners.roundtrip.gtir" ROUNDTRIP_WITH_TEMPORARIES = "gt4py.next.program_processors.runners.roundtrip.with_temporaries" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -55,11 +54,17 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): @dataclasses.dataclass(frozen=True) class EmbeddedDummyBackend: + name: str allocator: next_allocators.FieldBufferAllocatorProtocol + executor: Final = None -numpy_execution = EmbeddedDummyBackend(next_allocators.StandardCPUFieldBufferAllocator()) -cupy_execution = EmbeddedDummyBackend(next_allocators.StandardGPUFieldBufferAllocator()) +numpy_execution = EmbeddedDummyBackend( + "EmbeddedNumPy", next_allocators.StandardCPUFieldBufferAllocator() +) +cupy_execution = EmbeddedDummyBackend( + "EmbeddedCuPy", next_allocators.StandardGPUFieldBufferAllocator() +) class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): @@ -68,9 +73,10 @@ class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): - DACE_CPU = "gt4py.next.program_processors.runners.dace.itir_cpu" - DACE_GPU = "gt4py.next.program_processors.runners.dace.itir_gpu" - GTIR_DACE_CPU = "gt4py.next.program_processors.runners.dace.gtir_cpu" + DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu" + DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu" + DACE_CPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_noopt" + DACE_GPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_gpu_noopt" class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): @@ -86,21 +92,22 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # to avoid needing to mark all tests. ALL = "all" REQUIRES_ATLAS = "requires_atlas" -# TODO(havogt): Remove, skipped during refactoring to GTIR -STARTS_FROM_GTIR_PROGRAM = "starts_from_gtir_program" USES_APPLIED_SHIFTS = "uses_applied_shifts" +USES_CAN_DEREF = "uses_can_deref" +USES_COMPOSITE_SHIFTS = "uses_composite_shifts" USES_CONSTANT_FIELDS = "uses_constant_fields" USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" USES_FLOORDIV = "uses_floordiv" USES_IF_STMTS = "uses_if_stmts" USES_IR_IF_STMTS = "uses_ir_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" -USES_LIFT_EXPRESSIONS = "uses_lift_expressions" +USES_LIFT = "uses_lift" USES_NEGATIVE_MODULO = "uses_negative_modulo" USES_ORIGIN = "uses_origin" -USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" +USES_REDUCE_WITH_LAMBDA = "uses_reduce_with_lambda" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" +USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" USES_SCAN_WITHOUT_FIELD_ARGS = "uses_scan_without_field_args" USES_SCAN_NESTED = "uses_scan_nested" USES_SCAN_REQUIRING_PROJECTOR = "uses_scan_requiring_projector" @@ -109,6 +116,10 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS = "uses_reduction_with_only_sparse_fields" USES_STRIDED_NEIGHBOR_OFFSET = "uses_strided_neighbor_offset" USES_TUPLE_ARGS = "uses_tuple_args" +USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS = ( + "uses_tuple_args_with_different_but_promotable_dims" +) +USES_TUPLE_ITERATOR = "uses_tuple_iterator" USES_TUPLE_RETURNS = "uses_tuple_returns" USES_ZERO_DIMENSIONAL_FIELDS = "uses_zero_dimensional_fields" USES_CARTESIAN_SHIFT = "uses_cartesian_shift" @@ -127,28 +138,29 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # Common list of feature markers to skip COMMON_SKIP_TEST_LIST = [ (REQUIRES_ATLAS, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), (USES_APPLIED_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), -] -GTIR_DACE_SKIP_TEST_LIST = [ - (ALL, SKIP, UNSUPPORTED_MESSAGE), +# Markers to skip because of missing features in the domain inference +DOMAIN_INFERENCE_SKIP_LIST = [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] +DACE_SKIP_TEST_LIST = ( + COMMON_SKIP_TEST_LIST + + DOMAIN_INFERENCE_SKIP_LIST + + [ + (USES_CAN_DEREF, XFAIL, UNSUPPORTED_MESSAGE), + (USES_COMPOSITE_SHIFTS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_LIFT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_REDUCE_WITH_LAMBDA, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_SPARSE_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLE_ITERATOR, XFAIL, UNSUPPORTED_MESSAGE), + ] +) EMBEDDED_SKIP_LIST = [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (CHECKS_SPECIFIC_ERROR, XFAIL, UNSUPPORTED_MESSAGE), @@ -158,14 +170,23 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): UNSUPPORTED_MESSAGE, ), # we can't extract the field type from scan args ] -GTFN_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 - (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 - (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), +ROUNDTRIP_SKIP_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_TUPLES_ARGS_WITH_DIFFERENT_BUT_PROMOTABLE_DIMS, XFAIL, UNSUPPORTED_MESSAGE), ] +GTFN_SKIP_TEST_LIST = ( + COMMON_SKIP_TEST_LIST + + DOMAIN_INFERENCE_SKIP_LIST + + [ + # floordiv not yet supported, see https://github.com/GridTools/gt4py/issues/1136 + (USES_FLOORDIV, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + # max_over broken, see https://github.com/GridTools/gt4py/issues/1289 + (USES_MAX_OVER, XFAIL, UNSUPPORTED_MESSAGE), + (USES_SCAN_REQUIRING_PROJECTOR, XFAIL, UNSUPPORTED_MESSAGE), + ] +) #: Skip matrix, contains for each backend processor a list of tuples with following fields: #: (, ) @@ -174,23 +195,26 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_GPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST - + [(ALL, XFAIL, UNSUPPORTED_MESSAGE), (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramFormatterId.GTFN_CPP_FORMATTER: [ - (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) + ProgramFormatterId.GTFN_CPP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST + + [ + (USES_SCAN_IN_STENCIL, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), ], - ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], - ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ + ProgramFormatterId.LISP_FORMATTER: DOMAIN_INFERENCE_SKIP_LIST, + ProgramBackendId.ROUNDTRIP: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.DOUBLE_ROUNDTRIP: ROUNDTRIP_SKIP_LIST, + ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: ROUNDTRIP_SKIP_LIST + + [ (ALL, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ], + ProgramBackendId.GTIR_EMBEDDED: ROUNDTRIP_SKIP_LIST, } diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index d85cd5b3df..6e8ff1b3f6 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -28,6 +28,7 @@ common, constructors, field_utils, + utils as gt_utils, ) from gt4py.next.ffront import decorator from gt4py.next.type_system import type_specifications as ts, type_translation @@ -55,11 +56,11 @@ mesh_descriptor, ) -from gt4py.next import utils as gt_utils # mypy does not accept [IDim, ...] as a type IField: TypeAlias = gtx.Field[[IDim], np.int32] # type: ignore [valid-type] +JField: TypeAlias = gtx.Field[[JDim], np.int32] # type: ignore [valid-type] IFloatField: TypeAlias = gtx.Field[[IDim], np.float64] # type: ignore [valid-type] IBoolField: TypeAlias = gtx.Field[[IDim], bool] # type: ignore [valid-type] KField: TypeAlias = gtx.Field[[KDim], np.int32] # type: ignore [valid-type] @@ -69,6 +70,7 @@ IJKField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.int32] # type: ignore [valid-type] IJKFloatField: TypeAlias = gtx.Field[[IDim, JDim, KDim], np.float64] # type: ignore [valid-type] VField: TypeAlias = gtx.Field[[Vertex], np.int32] # type: ignore [valid-type] +VBoolField: TypeAlias = gtx.Field[[Vertex], bool] # type: ignore [valid-type] EField: TypeAlias = gtx.Field[[Edge], np.int32] # type: ignore [valid-type] CField: TypeAlias = gtx.Field[[Cell], np.int32] # type: ignore [valid-type] EmptyField: TypeAlias = gtx.Field[[], np.int32] # type: ignore [valid-type] @@ -191,13 +193,13 @@ class UniqueInitializer(DataInitializer): data containers. """ - start: int = 0 + start: int = 1 @property def scalar_value(self) -> ScalarValue: start = self.start self.start += 1 - return np.int64(start) + return start def field( self, @@ -380,6 +382,7 @@ def verify( fieldview_prog: decorator.FieldOperator | decorator.Program, *args: FieldViewArg, ref: ReferenceValue, + domain: Optional[dict[common.Dimension, tuple[int, int]]] = None, out: Optional[FieldViewInout] = None, inout: Optional[FieldViewInout] = None, offset_provider: Optional[OffsetProvider] = None, @@ -404,6 +407,8 @@ def verify( or tuple of fields here and they will be compared to ``ref`` under the assumption that the fieldview code stores its results in them. + domain: If given will be passed to the fieldview code as ``domain=`` + keyword argument. offset_provider: An override for the test case's offset_provider. Use with care! comparison: A comparison function, which will be called as @@ -413,10 +418,13 @@ def verify( used as an argument to the fieldview program and compared against ``ref``. Else, ``inout`` will not be passed and compared to ``ref``. """ + kwargs = {} if out: - run(case, fieldview_prog, *args, out=out, offset_provider=offset_provider) - else: - run(case, fieldview_prog, *args, offset_provider=offset_provider) + kwargs["out"] = out + if domain: + kwargs["domain"] = domain + + run(case, fieldview_prog, *args, **kwargs, offset_provider=offset_provider) out_comp = out or inout assert out_comp is not None @@ -498,13 +506,21 @@ def unstructured_case( Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: 10, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, ) +@pytest.fixture +def unstructured_case_3d(unstructured_case): + return dataclasses.replace( + unstructured_case, + default_sizes={**unstructured_case.default_sizes, KDim: 10}, + offset_provider={**unstructured_case.offset_provider, "KOff": KDim}, + ) + + def _allocate_from_type( case: Case, arg_type: ts.TypeSpec, diff --git a/tests/next_tests/integration_tests/feature_tests/dace/__init__.py b/tests/next_tests/integration_tests/feature_tests/dace/__init__.py index abf4c3e24c..7a9cb1ece5 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/__init__.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/__init__.py @@ -6,3 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import pytest + +#: Attribute defining package-level marks used by a custom pytest hook. +package_pytestmarks = [pytest.mark.requires_dace] diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 306f0034b5..2fb780c1bd 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -7,74 +7,61 @@ # SPDX-License-Identifier: BSD-3-Clause import numpy as np -from typing import Optional -from types import ModuleType import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend -from gt4py.next.otf import arguments +from gt4py.next import allocators as gtx_allocators, common as gtx_common +from gt4py._core import definitions as core_defs from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import cartesian_case, unstructured_case +from next_tests.integration_tests.cases import cartesian_case, unstructured_case # noqa: F401 from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - exec_alloc_descriptor, - mesh_descriptor, - Vertex, - Edge, E2V, + E2VDim, + Edge, + Vertex, + exec_alloc_descriptor, # noqa: F401 + mesh_descriptor, # noqa: F401 ) from next_tests.integration_tests.multi_feature_tests.ffront_tests.test_laplacian import ( lap_program, - laplap_program, lap_ref, + laplap_program, ) -try: - import dace - from gt4py.next.program_processors.runners.dace import run_dace_cpu, run_dace_gpu -except ImportError: - dace: Optional[ModuleType] = None # type:ignore[no-redef] - run_dace_cpu: Optional[next_backend.Backend] = None - run_dace_gpu: Optional[next_backend.Backend] = None -pytestmark = pytest.mark.requires_dace +dace = pytest.importorskip("dace") -def test_sdfgConvertible_laplap(cartesian_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if cartesian_case.backend not in [run_dace_cpu, run_dace_gpu]: +def test_sdfgConvertible_laplap(cartesian_case): # noqa: F811 + if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - if cartesian_case.backend == run_dace_gpu: - import cupy as xp - else: - import numpy as xp + backend = cartesian_case.backend in_field = cases.allocate(cartesian_case, laplap_program, "in_field")() out_field = cases.allocate(cartesian_case, laplap_program, "out_field")() - connectivities = {} # Dict of NeighborOffsetProviders, where self.table = None - for k, v in cartesian_case.offset_provider.items(): - if hasattr(v, "table"): - connectivities[k] = arguments.CompileTimeConnectivity( - v.max_neighbors, v.has_skip_values, v.origin_axis, v.neighbor_axis, v.table.dtype - ) - else: - connectivities[k] = v + xp = in_field.array_ns # Test DaCe closure support @dace.program def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(connectivities)(in_field, tmp_field) + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( + in_field, tmp_field + ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(connectivities)(tmp_field, out_field) + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( + tmp_field, out_field + ) - sdfg() + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): + sdfg() assert np.allclose( gtx.field_utils.asnumpy(out_field)[2:-2, 2:-2], @@ -93,14 +80,13 @@ def testee(a: gtx.Field[gtx.Dims[Vertex], gtx.float64], b: gtx.Field[gtx.Dims[Ed @pytest.mark.uses_unstructured_shift -def test_sdfgConvertible_connectivities(unstructured_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if unstructured_case.backend not in [run_dace_cpu, run_dace_gpu]: +def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 + if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") allocator, backend = unstructured_case.allocator, unstructured_case.backend - if backend == run_dace_gpu: + if gtx_allocators.is_field_allocator_for(allocator, core_defs.CUPY_DEVICE_TYPE): import cupy as xp dace_storage_type = dace.StorageType.GPU_Global @@ -116,6 +102,15 @@ def test_sdfgConvertible_connectivities(unstructured_case): name="OffsetProvider", ) + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[0, 1], [1, 2], [2, 0]]), + allocator=allocator, + ) + + testee2 = testee.with_backend(backend).with_connectivities({"E2V": e2v}) + @dace.program def sdfg( a: dace.data.Array(dtype=dace.float64, shape=(rows,), storage=dace_storage_type), @@ -123,17 +118,10 @@ def sdfg( offset_provider: OffsetProvider_t, connectivities: dace.compiletime, ): - testee.with_backend(backend).with_connectivities(connectivities)( - a, out, offset_provider=offset_provider - ) + testee2.with_connectivities(connectivities)(a, out, offset_provider=offset_provider) + return out - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[0, 1], [1, 2], [2, 0]]), Edge, Vertex, 2, False - ) - connectivities = {} - connectivities["E2V"] = arguments.CompileTimeConnectivity( - e2v.max_neighbors, e2v.has_skip_values, e2v.origin_axis, e2v.neighbor_axis, e2v.table.dtype - ) + connectivities = {"E2V": e2v} # replace 'e2v' with 'e2v.__gt_type__()' when GTIR is AOT offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) SDFG = sdfg.to_sdfg(connectivities=connectivities) @@ -141,50 +129,50 @@ def sdfg( a = gtx.as_field([Vertex], xp.asarray([0.0, 1.0, 2.0]), allocator=allocator) out = gtx.zeros({Edge: 3}, allocator=allocator) - # This is a low level interface to call the compiled SDFG. - # It is not supposed to be used in user code. - # The high level interface should be provided by a DaCe Orchestrator, - # i.e. decorator that hides the low level operations. - # This test checks only that the SDFGConvertible interface works correctly. - cSDFG( - a, - out, - offset_provider, - rows=3, - cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), - ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + def get_stride_from_numpy_to_dace(arg: core_defs.NDArrayObject, axis: int) -> int: + # NumPy strides: number of bytes to jump + # DaCe strides: number of elements to jump + return arg.strides[axis] // arg.itemsize + + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): + cSDFG( + a, + out, + offset_provider, + rows=3, + cols=2, + connectivity_E2V=e2v, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), + ) + + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) - e2v = gtx.NeighborTableOffsetProvider( - xp.asarray([[1, 0], [2, 1], [0, 2]]), Edge, Vertex, 2, False + e2v = gtx.as_connectivity( + [Edge, E2VDim], + codomain=Vertex, + data=xp.asarray([[1, 0], [2, 1], [0, 2]]), + allocator=allocator, ) offset_provider = OffsetProvider_t.dtype._typeclass.as_ctypes()(E2V=e2v.data_ptr()) - cSDFG( - a, - out, - offset_provider, - rows=3, - cols=2, - connectivity_E2V=e2v.table, - __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 0 - ), - __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace( - xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table, 1 - ), - ) + # use unique cache name based on process id to avoid clashes between parallel pytest workers + with dace.config.set_temporary("cache", value="unique"): + cSDFG( + a, + out, + offset_provider, + rows=3, + cols=2, + connectivity_E2V=e2v, + __connectivity_E2V_stride_0=get_stride_from_numpy_to_dace(e2v.ndarray, 0), + __connectivity_E2V_stride_1=get_stride_from_numpy_to_dace(e2v.ndarray, 1), + ) - e2v_xp = xp.asnumpy(e2v.table) if backend == run_dace_gpu else e2v.table - assert np.allclose(gtx.field_utils.asnumpy(out), gtx.field_utils.asnumpy(a)[e2v_xp[:, 0]]) + e2v_np = e2v.asnumpy() + assert np.allclose(out.asnumpy(), a.asnumpy()[e2v_np[:, 0]]) def get_stride_from_numpy_to_dace(numpy_array: np.ndarray, axis: int) -> int: diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_program.py b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py new file mode 100644 index 0000000000..e5e2f18608 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_program.py @@ -0,0 +1,124 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py import next as gtx +from gt4py.next import common + +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + Cell, + Edge, + IDim, + JDim, + KDim, + Vertex, + mesh_descriptor, # noqa: F401 +) + + +dace = pytest.importorskip("dace") + +from gt4py.next.program_processors.runners import dace as dace_backends + + +@pytest.fixture( + params=[ + pytest.param(dace_backends.run_dace_cpu, marks=pytest.mark.requires_dace), + pytest.param( + dace_backends.run_dace_gpu, marks=(pytest.mark.requires_gpu, pytest.mark.requires_dace) + ), + ] +) +def gtir_dace_backend(request): + yield request.param + + +@pytest.fixture +def cartesian(request, gtir_dace_backend): + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider={ + "Ioff": IDim, + "Joff": JDim, + "Koff": KDim, + }, + default_sizes={IDim: 10, JDim: 10, KDim: 10}, + grid_type=common.GridType.CARTESIAN, + allocator=gtir_dace_backend.allocator, + ) + + +@pytest.fixture +def unstructured(request, gtir_dace_backend, mesh_descriptor): # noqa: F811 + if gtir_dace_backend is None: + yield None + + yield cases.Case( + backend=gtir_dace_backend, + offset_provider=mesh_descriptor.offset_provider, + default_sizes={ + Vertex: mesh_descriptor.num_vertices, + Edge: mesh_descriptor.num_edges, + Cell: mesh_descriptor.num_cells, + KDim: 10, + }, + grid_type=common.GridType.UNSTRUCTURED, + allocator=gtir_dace_backend.allocator, + ) + + +def test_halo_exchange_helper_attrs(unstructured): + local_int = gtx.int + + @gtx.field_operator(backend=unstructured.backend) + def testee_op( + a: gtx.Field[[Vertex, KDim], gtx.int], + ) -> gtx.Field[[Vertex, KDim], gtx.int]: + return a + local_int(10) + + @gtx.program(backend=unstructured.backend) + def testee_prog( + a: gtx.Field[[Vertex, KDim], gtx.int], + b: gtx.Field[[Vertex, KDim], gtx.int], + c: gtx.Field[[Vertex, KDim], gtx.int], + ): + testee_op(b, out=c) + testee_op(a, out=b) + + dace_storage_type = ( + dace.StorageType.GPU_Global + if unstructured.backend == dace_backends.run_dace_gpu + else dace.StorageType.Default + ) + + rows = dace.symbol("rows") + cols = dace.symbol("cols") + + @dace.program + def testee_dace( + a: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + b: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + c: dace.data.Array(dtype=dace.int64, shape=(rows, cols), storage=dace_storage_type), + ): + testee_prog(a, b, c) + + # if simplify=True, DaCe might inline the nested SDFG coming from Program.__sdfg__, + # effectively erasing the attributes we want to test for here + sdfg = testee_dace.to_sdfg(simplify=False) + + testee = next( + subgraph for subgraph in sdfg.all_sdfgs_recursive() if subgraph.name == "testee_prog" + ) + + assert testee.gt4py_program_input_fields == {"a": Vertex, "b": Vertex} + assert testee.gt4py_program_output_fields == {"b": Vertex, "c": Vertex} diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index a0e72ede8d..1147f4bc3e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -46,10 +46,9 @@ def __gt_allocator__( @pytest.fixture( params=[ next_tests.definitions.ProgramBackendId.ROUNDTRIP, - # next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, # FIXME[#1582](havogt): enable once all ingredients for GTIR are available + next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, next_tests.definitions.ProgramBackendId.GTFN_CPU, next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, - next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, pytest.param( next_tests.definitions.ProgramBackendId.GTFN_GPU, marks=pytest.mark.requires_gpu ), @@ -63,11 +62,15 @@ def __gt_allocator__( marks=pytest.mark.requires_dace, ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_CPU, + next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), + ), + pytest.param( + next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, marks=pytest.mark.requires_dace, ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.DACE_GPU, + next_tests.definitions.OptionalProgramBackendId.DACE_GPU_NO_OPT, marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), ], @@ -149,7 +152,10 @@ def num_edges(self) -> int: ... def num_levels(self) -> int: ... @property - def offset_provider(self) -> dict[str, common.Connectivity]: ... + def offset_provider(self) -> common.OffsetProvider: ... + + @property + def offset_provider_type(self) -> common.OffsetProviderType: ... def simple_mesh() -> MeshDescriptor: @@ -208,25 +214,40 @@ def simple_mesh() -> MeshDescriptor: assert all(len(row) == 2 for row in e2v_arr) e2v_arr = np.asarray(e2v_arr, dtype=gtx.IndexType) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 4}, + skip_value=None, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 4}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 4}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="simple_mesh", num_vertices=num_vertices, num_edges=np.int32(num_edges), num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 4, has_skip_values=False - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 4, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 4, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) @@ -284,25 +305,40 @@ def skip_value_mesh() -> MeshDescriptor: dtype=gtx.IndexType, ) + offset_provider = { + V2E.value: common._connectivity( + v2e_arr, + codomain=Edge, + domain={Vertex: v2e_arr.shape[0], V2EDim: 5}, + skip_value=common._DEFAULT_SKIP_VALUE, + ), + E2V.value: common._connectivity( + e2v_arr, + codomain=Vertex, + domain={Edge: e2v_arr.shape[0], E2VDim: 2}, + skip_value=None, + ), + C2V.value: common._connectivity( + c2v_arr, + codomain=Vertex, + domain={Cell: c2v_arr.shape[0], C2VDim: 3}, + skip_value=None, + ), + C2E.value: common._connectivity( + c2e_arr, + codomain=Edge, + domain={Cell: c2e_arr.shape[0], C2EDim: 3}, + skip_value=None, + ), + } + return types.SimpleNamespace( name="skip_value_mesh", num_vertices=num_vertices, num_edges=num_edges, num_cells=num_cells, - offset_provider={ - V2E.value: gtx.NeighborTableOffsetProvider( - v2e_arr, Vertex, Edge, 5, has_skip_values=True - ), - E2V.value: gtx.NeighborTableOffsetProvider( - e2v_arr, Edge, Vertex, 2, has_skip_values=False - ), - C2V.value: gtx.NeighborTableOffsetProvider( - c2v_arr, Cell, Vertex, 3, has_skip_values=False - ), - C2E.value: gtx.NeighborTableOffsetProvider( - c2e_arr, Cell, Edge, 3, has_skip_values=False - ), - }, + offset_provider=offset_provider, + offset_provider_type=common.offset_provider_to_type(offset_provider), ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py index cb535f9596..8f67c1d198 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_arg_call_interface.py @@ -13,7 +13,7 @@ import numpy as np import pytest -from gt4py.next import errors +from gt4py.next import errors, common, constructors from gt4py.next.ffront.decorator import field_operator, program, scan_operator from gt4py.next.ffront.fbuiltins import broadcast, int32 @@ -296,3 +296,21 @@ def test_call_bound_program_with_already_bound_arg(cartesian_case, bound_args_te ) is not None ) + + +@pytest.mark.uses_origin +def test_direct_fo_call_with_domain_arg(cartesian_case): + @field_operator + def testee(inp: IField) -> IField: + return inp + + size = cartesian_case.default_sizes[IDim] + inp = cases.allocate(cartesian_case, testee, "inp").unique()() + out = cases.allocate( + cartesian_case, testee, cases.RETURN, strategy=cases.ConstInitializer(42) + )() + ref = inp.array_ns.zeros(size) + ref[0] = ref[-1] = 42 + ref[1:-1] = inp.ndarray[1:-1] + + cases.verify(cartesian_case, testee, inp, out=out, domain={IDim: (1, size - 1)}, ref=ref) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index e3e919e52e..9e80dba53b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -21,7 +21,7 @@ ) -def test_program_itir_regression(cartesian_case): +def test_program_gtir_regression(cartesian_case): @gtx.field_operator(backend=None) def testee_op(a: cases.IField) -> cases.IField: return a @@ -30,8 +30,8 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, itir.FencilDefinition) - assert isinstance(testee.with_backend(cartesian_case.backend).itir, itir.FencilDefinition) + assert isinstance(testee.gtir, itir.Program) + assert isinstance(testee.with_backend(cartesian_case.backend).gtir, itir.Program) def test_frozen(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 36d6debf9d..a042c60709 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -7,18 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause from functools import reduce - import numpy as np import pytest - import gt4py.next as gtx from gt4py.next import ( astype, broadcast, common, - constructors, errors, - field_utils, float32, float64, int32, @@ -27,8 +23,6 @@ neighbor_sum, ) from gt4py.next.ffront.experimental import as_offset -from gt4py.next.program_processors.runners import gtfn -from gt4py.next.type_system import type_specifications as ts from gt4py.next import utils as gt_utils from next_tests.integration_tests import cases @@ -47,6 +41,7 @@ Edge, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -95,7 +90,21 @@ def testee(a: cases.VField) -> cases.EField: cases.verify_with_default_data( unstructured_case, testee, - ref=lambda a: a[unstructured_case.offset_provider["E2V"].table[:, 0]], + ref=lambda a: a[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], + ) + + +def test_horizontal_only_with_3d_mesh(unstructured_case_3d): + # test field operator operating only on horizontal fields while using an offset provider + # including a vertical dimension. + @gtx.field_operator + def testee(a: cases.VField) -> cases.VField: + return a + + cases.verify_with_default_data( + unstructured_case_3d, + testee, + ref=lambda a: a, ) @@ -121,16 +130,16 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_flat, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) cases.verify_with_default_data( unstructured_case, composed_shift_unstructured_intermediate_result, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], comparison=lambda inp, tmp: np.all(inp == tmp), ) @@ -138,8 +147,8 @@ def composed_shift_unstructured(inp: cases.VField) -> cases.CField: cases.verify_with_default_data( unstructured_case, composed_shift_unstructured, - ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].table[:, 0]][ - unstructured_case.offset_provider["C2E"].table[:, 0] + ref=lambda inp: inp[unstructured_case.offset_provider["E2V"].ndarray[:, 0]][ + unstructured_case.offset_provider["C2E"].ndarray[:, 0] ], ) @@ -219,6 +228,7 @@ def testee(a: tuple[int32, tuple[int32, int32]]) -> cases.VField: @pytest.mark.uses_tuple_args +@pytest.mark.uses_zero_dimensional_fields def test_zero_dim_tuple_arg(unstructured_case): @gtx.field_operator def testee( @@ -252,9 +262,7 @@ def testee(a: tuple[int32, tuple[int32, cases.IField, int32]]) -> cases.IField: @pytest.mark.uses_tuple_args -@pytest.mark.xfail( - reason="Not implemented in frontend (implicit size arg handling needs to be adopted) and GTIR embedded backend." -) +@pytest.mark.uses_tuple_args_with_different_but_promotable_dims def test_tuple_arg_with_different_but_promotable_dims(cartesian_case): @gtx.field_operator def testee(a: tuple[cases.IField, cases.IJField]) -> cases.IJField: @@ -281,7 +289,6 @@ def testee(a: tuple[cases.VField, cases.EField]) -> cases.VField: ) -@pytest.mark.uses_index_fields @pytest.mark.uses_cartesian_shift def test_scalar_arg_with_field(cartesian_case): @gtx.field_operator @@ -297,6 +304,23 @@ def testee(a: cases.IJKField, b: int32) -> cases.IJKField: cases.verify(cartesian_case, testee, a, b, out=out, ref=ref) +@pytest.mark.uses_tuple_args +def test_double_use_scalar(cartesian_case): + # TODO(tehrengruber): This should be a regression test on ITIR level, but tracing doesn't + # work for this case. + @gtx.field_operator + def testee(a: int32, b: int32, c: cases.IField) -> cases.IField: + tmp = a * b + tmp2 = tmp * tmp + # important part here is that we use the intermediate twice so that it is + # not inlined + return tmp2 * tmp2 * c + + cases.verify_with_default_data( + cartesian_case, testee, ref=lambda a, b, c: a * b * a * b * a * b * a * b * c + ) + + @pytest.mark.uses_scalar_in_domain_and_fo def test_scalar_in_domain_spec_and_fo_call(cartesian_case): @gtx.field_operator @@ -336,6 +360,7 @@ def testee(qc: cases.IKFloatField, scalar: float): @pytest.mark.uses_scan @pytest.mark.uses_scan_in_field_operator +@pytest.mark.uses_tuple_iterator def test_tuple_scalar_scan(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def testee_scan( @@ -413,6 +438,22 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +def test_astype_int_local_field(unstructured_case): + @gtx.field_operator + def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]: + tmp = astype(a(E2V), int64) + return neighbor_sum(tmp, axis=E2VDim) + + e2v_table = unstructured_case.offset_provider["E2V"].ndarray + + cases.verify_with_default_data( + unstructured_case, + testee, + ref=lambda a: np.sum(a.astype(int64)[e2v_table], axis=1, initial=0), + comparison=lambda a, b: np.all(a == b), + ) + + @pytest.mark.uses_tuple_returns def test_astype_on_tuples(cartesian_case): @gtx.field_operator @@ -561,7 +602,6 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: @pytest.mark.uses_unstructured_shift -@pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator def testee(a: cases.VField) -> cases.VField: @@ -573,11 +613,11 @@ def testee(a: cases.VField) -> cases.VField: unstructured_case, testee, ref=lambda a: np.sum( - np.sum(a[unstructured_case.offset_provider["E2V"].table], axis=1, initial=0)[ - unstructured_case.offset_provider["V2E"].table + np.sum(a[unstructured_case.offset_provider["E2V"].ndarray], axis=1, initial=0)[ + unstructured_case.offset_provider["V2E"].ndarray ], axis=1, - where=unstructured_case.offset_provider["V2E"].table != common._DEFAULT_SKIP_VALUE, + where=unstructured_case.offset_provider["V2E"].ndarray != common._DEFAULT_SKIP_VALUE, ), comparison=lambda a, tmp_2: np.all(a == tmp_2), ) @@ -596,8 +636,8 @@ def testee(inp: cases.EField) -> cases.EField: unstructured_case, testee, ref=lambda inp: np.sum( - np.sum(inp[unstructured_case.offset_provider["V2E"].table], axis=1)[ - unstructured_case.offset_provider["E2V"].table + np.sum(inp[unstructured_case.offset_provider["V2E"].ndarray], axis=1)[ + unstructured_case.offset_provider["E2V"].ndarray ], axis=1, ), @@ -617,8 +657,8 @@ def testee(a: cases.EField, b: cases.EField) -> tuple[cases.VField, cases.VField unstructured_case, testee, ref=lambda a, b: [ - np.sum(a[unstructured_case.offset_provider["V2E"].table], axis=1), - np.sum(b[unstructured_case.offset_provider["V2E"].table], axis=1), + np.sum(a[unstructured_case.offset_provider["V2E"].ndarray], axis=1), + np.sum(b[unstructured_case.offset_provider["V2E"].ndarray], axis=1), ], comparison=lambda a, tmp: (np.all(a[0] == tmp[0]), np.all(a[1] == tmp[1])), ) @@ -639,11 +679,11 @@ def reduce_tuple_element(e: cases.EField, v: cases.VField) -> cases.EField: unstructured_case, reduce_tuple_element, ref=lambda e, v: np.sum( - e[v2e.table] + np.tile(v, (v2e.max_neighbors, 1)).T, + e[v2e.ndarray] + np.tile(v, (v2e.shape[1], 1)).T, axis=1, initial=0, - where=v2e.table != common._DEFAULT_SKIP_VALUE, - )[unstructured_case.offset_provider["E2V"].table[:, 0]], + where=v2e.ndarray != common._DEFAULT_SKIP_VALUE, + )[unstructured_case.offset_provider["E2V"].ndarray[:, 0]], ) @@ -681,12 +721,8 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_scan -@pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): - if cartesian_case.backend == gtfn.run_gtfn_with_temporaries: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) def tridiag_forward( state: tuple[float, float], a: float, b: float, c: float, d: float @@ -714,7 +750,16 @@ def expected(a, b, c, d): matrices[:, :, i[1:], i[:-1]] = a[:, :, 1:] matrices[:, :, i, i] = b matrices[:, :, i[:-1], i[1:]] = c[:, :, :-1] - return np.linalg.solve(matrices, d) + # Changed in NumPY version 2.0: In a linear matrix equation ax = b, the b array + # is only treated as a shape (M,) column vector if it is exactly 1-dimensional. + # In all other instances it is treated as a stack of (M, K) matrices. Therefore + # below we add an extra dimension (K) of size 1. Previously b would be treated + # as a stack of (M,) vectors if b.ndim was equal to a.ndim - 1. + # Refer to https://numpy.org/doc/2.0/reference/generated/numpy.linalg.solve.html + d_ext = np.empty(shape=(*shape, 1)) + d_ext[:, :, :, 0] = d + x = np.linalg.solve(matrices, d_ext) + return x[:, :, :, 0] cases.verify_with_default_data(cartesian_case, solve_tridiag, ref=expected) @@ -766,14 +811,13 @@ def testee( @pytest.mark.uses_constant_fields @pytest.mark.uses_unstructured_shift -@pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> cases.VField: tmp = neighbor_sum(b(V2E) if 2 < 3 else a(V2E), axis=V2EDim) return tmp - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, @@ -785,9 +829,6 @@ def testee(a: cases.EField, b: cases.EField) -> cases.VField: @pytest.mark.uses_scan def test_ternary_scan(cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - @gtx.scan_operator(axis=KDim, forward=True, init=0.0) def simple_scan_operator(carry: float, a: float) -> float: return carry if carry > a else carry + 1.0 @@ -810,9 +851,6 @@ def simple_scan_operator(carry: float, a: float) -> float: @pytest.mark.uses_scan_without_field_args @pytest.mark.uses_tuple_returns def test_scan_nested_tuple_output(forward, cartesian_case): - if cartesian_case.backend in [gtfn.run_gtfn_with_temporaries]: - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - init = (1, (2, 3)) k_size = cartesian_case.default_sizes[KDim] expected = np.arange(1, 1 + k_size, 1, dtype=int32) @@ -839,8 +877,9 @@ def testee(out: tuple[cases.KField, tuple[cases.KField, cases.KField]]): ) -@pytest.mark.uses_tuple_args @pytest.mark.uses_scan +@pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_scan_nested_tuple_input(cartesian_case): init = 1.0 k_size = cartesian_case.default_sizes[KDim] @@ -869,6 +908,7 @@ def simple_scan_operator(carry: float, a: tuple[float, float]) -> float: @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_scan_different_domain_in_tuple(cartesian_case): init = 1.0 i_size = cartesian_case.default_sizes[IDim] @@ -908,6 +948,7 @@ def foo( @pytest.mark.uses_scan +@pytest.mark.uses_tuple_iterator def test_scan_tuple_field_scalar_mixed(cartesian_case): init = 1.0 i_size = cartesian_case.default_sizes[IDim] @@ -966,7 +1007,7 @@ def program_domain(a: cases.IField, out: cases.IField): a = cases.allocate(cartesian_case, program_domain, "a")() out = cases.allocate(cartesian_case, program_domain, "out")() - ref = out.asnumpy().copy() # ensure we are not overwriting out outside of the domain + ref = out.asnumpy().copy() # ensure we are not writing to out outside the domain ref[1:9] = a.asnumpy()[1:9] * 2 cases.verify(cartesian_case, program_domain, a, out, inout=out, ref=ref) @@ -1093,7 +1134,7 @@ def implicit_broadcast_scalar(inp: cases.EmptyField): inp = cases.allocate(cartesian_case, implicit_broadcast_scalar, "inp")() out = cases.allocate(cartesian_case, implicit_broadcast_scalar, "inp")() - cases.verify(cartesian_case, implicit_broadcast_scalar, inp, out=out, ref=np.array(0)) + cases.verify(cartesian_case, implicit_broadcast_scalar, inp, out=out, ref=np.array(1)) def test_implicit_broadcast_mixed_dim(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 37f4ee2cd1..33832fb5f0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -33,11 +33,11 @@ def testee( ) # multiplication with shifted `ones` because reduction of only non-shifted field with local dimension is not supported inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) ones = cases.allocate(unstructured_case, testee, "ones").strategy(cases.ConstInitializer(1))() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify( unstructured_case, testee, @@ -57,7 +57,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 return neighbor_sum(inp, axis=V2EDim) inp = unstructured_case.as_field( - [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].table + [Vertex, V2EDim], unstructured_case.offset_provider["V2E"].ndarray ) cases.verify( @@ -65,7 +65,7 @@ def testee(inp: gtx.Field[[Vertex, V2EDim], int32]) -> gtx.Field[[Vertex], int32 testee, inp, out=cases.allocate(unstructured_case, testee, cases.RETURN)(), - ref=np.sum(unstructured_case.offset_provider["V2E"].table, axis=1), + ref=np.sum(unstructured_case.offset_provider["V2E"].ndarray, axis=1), ) @@ -76,7 +76,7 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: return inp(V2E) out = unstructured_case.as_field( - [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].table) + [Vertex, V2EDim], np.zeros_like(unstructured_case.offset_provider["V2E"].ndarray) ) inp = cases.allocate(unstructured_case, testee, "inp")() cases.verify( @@ -84,5 +84,5 @@ def testee(inp: gtx.Field[[Edge], int32]) -> gtx.Field[[Vertex, V2EDim], int32]: testee, inp, out=out, - ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].table], + ref=inp.asnumpy()[unstructured_case.offset_provider["V2E"].ndarray], ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 3777de7843..d7fe252cb4 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -29,6 +29,7 @@ Vertex, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -52,7 +53,7 @@ def testee(edge_f: cases.EField) -> cases.VField: inp = cases.allocate(unstructured_case, testee, "edge_f", strategy=strategy)() out = cases.allocate(unstructured_case, testee, cases.RETURN)() - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray ref = np.max( inp.asnumpy()[v2e_table], axis=1, @@ -69,7 +70,7 @@ def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) return out - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, minover, @@ -94,21 +95,14 @@ def reduction_ek_field( return neighbor_sum(edge_f(V2E), axis=V2EDim) -@gtx.field_operator -def reduction_ke_field( - edge_f: common.Field[[KDim, Edge], np.int32], -) -> common.Field[[KDim, Vertex], np.int32]: - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - @pytest.mark.uses_unstructured_shift @pytest.mark.parametrize( - "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ + "fop", [reduction_e_field, reduction_ek_field], ids=lambda fop: fop.__name__ ) -def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].table +def test_neighbor_sum(unstructured_case_3d, fop): + v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray - edge_f = cases.allocate(unstructured_case, fop, "edge_f")() + edge_f = cases.allocate(unstructured_case_3d, fop, "edge_f")() local_dim_idx = edge_f.domain.dims.index(Edge) + 1 adv_indexing = tuple( @@ -131,10 +125,10 @@ def test_neighbor_sum(unstructured_case, fop): where=broadcasted_table != common._DEFAULT_SKIP_VALUE, ) cases.verify( - unstructured_case, + unstructured_case_3d, fop, edge_f, - out=cases.allocate(unstructured_case, fop, cases.RETURN)(), + out=cases.allocate(unstructured_case_3d, fop, cases.RETURN)(), ref=ref, ) @@ -157,7 +151,7 @@ def fencil_op(edge_f: EKField) -> VKField: def fencil(edge_f: EKField, out: VKField): fencil_op(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray field = cases.allocate(unstructured_case, fencil, "edge_f", sizes={KDim: 2})() out = cases.allocate(unstructured_case, fencil_op, cases.RETURN, sizes={KDim: 1})() @@ -190,7 +184,7 @@ def reduce_expr(edge_f: cases.EField) -> cases.VField: def fencil(edge_f: cases.EField, out: cases.VField): reduce_expr(edge_f, out=out) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, fencil, @@ -210,7 +204,7 @@ def test_reduction_with_common_expression(unstructured_case): def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) - v2e_table = unstructured_case.offset_provider["V2E"].table + v2e_table = unstructured_case.offset_provider["V2E"].ndarray cases.verify_with_default_data( unstructured_case, testee, @@ -220,6 +214,94 @@ def testee(flux: cases.EField) -> cases.VField: ) +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(where(mask, inp(V2E), inp(V2E)), axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].ndarray + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table], + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where_and_tuples(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(where(mask, (inp(V2E), inp(V2E)), (inp(V2E), inp(V2E)))[1], axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].ndarray + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table], + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + +@pytest.mark.uses_unstructured_shift +def test_reduction_expression_with_where_and_scalar(unstructured_case): + @gtx.field_operator + def testee(mask: cases.VBoolField, inp: cases.EField) -> cases.VField: + return neighbor_sum(inp(V2E) + where(mask, inp(V2E), 1), axis=V2EDim) + + v2e_table = unstructured_case.offset_provider["V2E"].ndarray + + mask = unstructured_case.as_field( + [Vertex], np.random.choice(a=[False, True], size=unstructured_case.default_sizes[Vertex]) + ) + inp = cases.allocate(unstructured_case, testee, "inp")() + out = cases.allocate(unstructured_case, testee, cases.RETURN)() + + cases.verify( + unstructured_case, + testee, + mask, + inp, + out=out, + ref=np.sum( + inp.asnumpy()[v2e_table] + + np.where(np.expand_dims(mask.asnumpy(), 1), inp.asnumpy()[v2e_table], 1), + axis=1, + initial=0, + where=v2e_table != common._DEFAULT_SKIP_VALUE, + ), + ) + + @pytest.mark.uses_tuple_returns def test_conditional_nested_tuple(cartesian_case): @gtx.field_operator @@ -375,11 +457,13 @@ def conditional_program( ) -def test_promotion(unstructured_case): +def test_promotion(unstructured_case_3d): @gtx.field_operator def promotion( inp1: gtx.Field[[Edge, KDim], float64], inp2: gtx.Field[[KDim], float64] ) -> gtx.Field[[Edge, KDim], float64]: return inp1 / inp2 - cases.verify_with_default_data(unstructured_case, promotion, ref=lambda inp1, inp2: inp1 / inp2) + cases.verify_with_default_data( + unstructured_case_3d, promotion, ref=lambda inp1, inp2: inp1 / inp2 + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py index 89c341e9a6..1707adada8 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_math_unary_builtins.py @@ -128,6 +128,14 @@ def uneg(inp: cases.IField) -> cases.IField: cases.verify_with_default_data(cartesian_case, uneg, ref=lambda inp1: -inp1) +def test_unary_pos(cartesian_case): + @gtx.field_operator + def upos(inp: cases.IField) -> cases.IField: + return +inp + + cases.verify_with_default_data(cartesian_case, upos, ref=lambda inp1: inp1) + + def test_unary_neg_float_conversion(cartesian_case): @gtx.field_operator def uneg_float() -> cases.IFloatField: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index f1cb8ffb17..27c4252e14 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -14,7 +14,7 @@ import pytest import gt4py.next as gtx -from gt4py.next import errors +from gt4py.next import errors, constructors, common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -251,3 +251,42 @@ def empty_domain_program(a: cases.IJField, out_field: cases.IJField): ValueError, match=(r"Dimensions in out field and field domain are not equivalent") ): cases.run(cartesian_case, empty_domain_program, a, out_field, offset_provider={}) + + +@pytest.mark.uses_origin +def test_out_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def): + copy_program = gtx.program(copy_program_def, backend=cartesian_case.backend) + + size = cartesian_case.default_sizes[IDim] + + inp = cases.allocate(cartesian_case, copy_program, "in_field").unique()() + out = constructors.empty( + common.domain({IDim: (1, size - 2)}), + allocator=cartesian_case.allocator, + ) + ref = inp.ndarray[1:-2] + + cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref) + + +@pytest.mark.uses_origin +def test_in_field_arg_with_non_zero_domain_start(cartesian_case, copy_program_def): + @gtx.field_operator + def identity(a: cases.IField) -> cases.IField: + return a + + @gtx.program + def copy_program(a: cases.IField, out: cases.IField): + identity(a, out=out, domain={IDim: (1, 9)}) + + inp = constructors.empty( + common.domain({IDim: (1, 9)}), + dtype=np.int32, + allocator=cartesian_case.allocator, + ) + inp.ndarray[...] = 42 + out = cases.allocate(cartesian_case, copy_program, "out", sizes={IDim: 10})() + ref = out.asnumpy().copy() # ensure we are not writing to `out` outside the domain + ref[1:9] = inp.asnumpy() + + cases.verify(cartesian_case, copy_program, inp, out=out, ref=ref) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py index 0efb599f9e..7ff7edf226 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_scalar_if.py @@ -56,6 +56,7 @@ def simple_if(a: cases.IField, b: cases.IField, condition: bool) -> cases.IField cases.verify(cartesian_case, simple_if, a, b, condition, out=out, ref=a if condition else b) +# TODO(tehrengruber): test with fields on different domains @pytest.mark.parametrize("condition1, condition2", [[True, False], [True, False]]) @pytest.mark.uses_if_stmts def test_simple_if_conditional(condition1, condition2, cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 0305a5841a..7d2eec772c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -5,14 +5,15 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import platform import pytest from numpy import int32, int64 from gt4py import next as gtx from gt4py.next import backend, common -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries +from gt4py.next.iterator.transforms import apply_common_transforms +from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests import cases from next_tests.integration_tests.cases import ( @@ -34,8 +35,8 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): return backend.Backend( name="run_gtfn_with_temporaries_and_sizes", transforms=backend.DEFAULT_TRANSFORMS, - executor=run_gtfn_with_temporaries.executor.replace( - translation=run_gtfn_with_temporaries.executor.translation.replace( + executor=run_gtfn.executor.replace( + translation=run_gtfn.executor.translation.replace( symbolic_domain_sizes={ "Cell": "num_cells", "Edge": "num_edges", @@ -43,7 +44,7 @@ def run_gtfn_with_temporaries_and_symbolic_sizes(): } ) ), - allocator=run_gtfn_with_temporaries.allocator, + allocator=run_gtfn.allocator, ) @@ -64,8 +65,14 @@ def prog( def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh_descriptor): - # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented - pytest.xfail("Temporary pass not implemented.") + if platform.machine() == "x86_64": + pytest.xfail( + reason="The C++ code generated in this test contains unicode characters " + "(coming from the ssa pass) which is not supported by gcc 9 used" + "in the CI. Bumping the container version sadly did not work for" + "unrelated and unclear reasons. Since the issue is not present" + "on Alps we just skip the test for now before investing more time." + ) unstructured_case = Case( run_gtfn_with_temporaries_and_symbolic_sizes, @@ -83,7 +90,7 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh a = cases.allocate(unstructured_case, testee, "a")() out = cases.allocate(unstructured_case, testee, "out")() - first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].table[:, i] for i in [0, 1]) + first_nbs, second_nbs = (mesh_descriptor.offset_provider["E2V"].ndarray[:, i] for i in [0, 1]) ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] cases.verify( @@ -100,15 +107,12 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): - # FIXME[#1582](tehrengruber): enable when temporary pass has been implemented - pytest.xfail("Temporary pass not implemented.") - - itir_with_tmp = apply_common_transforms( - testee.itir, - lift_mode=LiftMode.USE_TEMPORARIES, + gtir_with_tmp = apply_common_transforms( + testee.gtir, + extract_temporaries=True, offset_provider=mesh_descriptor.offset_provider, ) params = ["num_vertices", "num_edges", "num_cells"] for param in params: - assert any([param == str(p) for p in itir_with_tmp.params]) + assert any([param == str(p) for p in gtir_with_tmp.params]) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index c2f72e4ca7..01637e56e0 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -18,6 +18,7 @@ from gt4py.next.iterator import builtins as it_builtins from gt4py.next.iterator.builtins import ( and_, + as_fieldop, bool, can_deref, cartesian_domain, @@ -45,8 +46,10 @@ plus, shift, xor_, + neg, + abs, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -87,7 +90,9 @@ def dispatch(arg0): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0), domain, out) elif len(inps) == 2: @@ -102,7 +107,9 @@ def dispatch(arg0, arg1): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, arg1, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0, arg1), domain, out) elif len(inps) == 3: @@ -117,7 +124,9 @@ def dispatch(arg0, arg1, arg2): @fendef(offset_provider={}, column_axis=column_axis) def fenimpl(size, arg0, arg1, arg2, out): - closure(cartesian_domain(named_range(IDim, 0, size)), dispatch, out, [arg0, arg1, arg2]) + domain = cartesian_domain(named_range(IDim, 0, size)) + + set_at(as_fieldop(dispatch, domain)(arg0, arg1, arg2), domain, out) else: raise AssertionError("Add overload.") @@ -128,6 +137,8 @@ def fenimpl(size, arg0, arg1, arg2, out): def arithmetic_and_logical_test_data(): return [ # (builtin, inputs, expected) + (abs, [[-1.0, 1.0]], [1.0, 1.0]), + (neg, [[-1.0, 1.0, -1, 1]], [1.0, -1.0, 1, -1]), (plus, [2.0, 3.0], 5.0), (minus, [2.0, 3.0], -1.0), (multiplies, [2.0, 3.0], 6.0), @@ -173,8 +184,8 @@ def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, exp @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): - if builtin == if_: - pytest.skip("If cannot be used unapplied") + if builtin == if_ or builtin == abs: + pytest.skip("If and abs cannot be used unapplied.") inps = field_maker(*array_maker(*inputs)) out = field_maker((np.zeros_like(*array_maker(expected))))[0] @@ -237,15 +248,19 @@ def foo(a): @pytest.mark.parametrize("stencil", [_can_deref, _can_deref_lifted]) +@pytest.mark.uses_can_deref def test_can_deref(program_processor, stencil): program_processor, validate = program_processor Node = gtx.Dimension("Node") + NeighDim = gtx.Dimension("Neighbor", kind=gtx.DimensionKind.LOCAL) inp = gtx.as_field([Node], np.ones((1,), dtype=np.int32)) out = gtx.as_field([Node], np.asarray([0], dtype=inp.dtype)) - no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[-1]]), Node, Node, 1) + no_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[-1]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -257,7 +272,9 @@ def test_can_deref(program_processor, stencil): if validate: assert np.allclose(out.asnumpy(), -1.0) - a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) + a_neighbor_tbl = gtx.as_connectivity( + domain={Node: 1, NeighDim: 1}, codomain=Node, data=np.array([[0]]), skip_value=-1 + ) run_processor( stencil[{Node: range(1)}], program_processor, @@ -270,37 +287,6 @@ def test_can_deref(program_processor, stencil): assert np.allclose(out.asnumpy(), 1.0) -# def test_can_deref_lifted(program_processor): -# program_processor, validate = program_processor - -# Neighbor = offset("Neighbor") -# Node = gtx.Dimension("Node") - -# @fundef -# def _can_deref(inp): -# shifted = shift(Neighbor, 0)(inp) -# return if_(can_deref(shifted), 1, -1) - -# inp = gtx.as_field([Node], np.zeros((1,))) -# out = gtx.as_field([Node], np.asarray([0])) - -# no_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[None]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": no_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), -1.0) - -# a_neighbor_tbl = gtx.NeighborTableOffsetProvider(np.array([[0]]), Node, Node, 1) -# _can_deref[{Node: range(1)}]( -# inp, out=out, offset_provider={"Neighbor": a_neighbor_tbl}, program_processor=program_processor -# ) - -# if validate: -# assert np.allclose(np.asarray(out), 1.0) - - @pytest.mark.parametrize( "input_value, dtype, np_dtype", [ diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 2ebcd0c033..fedfd83fd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -10,7 +10,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import double_roundtrip, roundtrip @@ -27,16 +27,14 @@ def foo(inp): @fendef(offset_provider={"I": I_loc, "J": J_loc}) def fencil(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) @fendef(offset_provider={"I": J_loc, "J": I_loc}) def fencil_swapped(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) def test_cartesian_offset_provider(): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 551c567e61..eae66d425b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py new file mode 100644 index 0000000000..2356e9c781 --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_extractors.py @@ -0,0 +1,62 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import typing + +import pytest + +from gt4py import next as gtx +from gt4py.next.iterator.transforms import extractors + +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + IDim, + JDim, + KDim, +) + + +def test_input_names_extractor_cartesian(): + @gtx.field_operator + def testee_op( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: + return a + + @gtx.program + def testee( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + b: gtx.Field[[IDim, JDim, KDim], gtx.int], + c: gtx.Field[[IDim, JDim, KDim], gtx.int], + ): + testee_op(b, out=c) + testee_op(a, out=b) + + input_field_names = extractors.InputNamesExtractor.only_fields(testee.gtir) + assert input_field_names == {"a", "b"} + + +def test_output_names_extractor(): + @gtx.field_operator + def testee_op( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + ) -> gtx.Field[[IDim, JDim, KDim], gtx.int]: + return a + + @gtx.program + def testee( + a: gtx.Field[[IDim, JDim, KDim], gtx.int], + b: gtx.Field[[IDim, JDim, KDim], gtx.int], + c: gtx.Field[[IDim, JDim, KDim], gtx.int], + ): + testee_op(a, out=b) + testee_op(a, out=c) + + output_field_names = extractors.OutputNamesExtractor.only_fields(testee.gtir) + assert output_field_names == {"b", "c"} diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index 4eab7502e7..09dc04acb1 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -10,13 +10,23 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.builtins import ( + as_fieldop, + cartesian_domain, + deref, + index, + named_range, + shift, + INTEGER_INDEX_BUILTIN, +) from gt4py.next.iterator.runtime import fendef, fundef, set_at from next_tests.unit_tests.conftest import program_processor, run_processor I = gtx.Dimension("I") +Ioff = gtx.FieldOffset("Ioff", source=I, target=(I,)) @fundef @@ -33,7 +43,6 @@ def copy_program(inp, out, size): ) -@pytest.mark.starts_from_gtir_program def test_prog(program_processor): program_processor, validate = program_processor @@ -44,3 +53,47 @@ def test_prog(program_processor): run_processor(copy_program, program_processor, inp, out, isize, offset_provider={}) if validate: assert np.allclose(inp.asnumpy(), out.asnumpy()) + + +@fendef +def index_program_simple(out, size): + set_at( + as_fieldop(lambda i: deref(i), cartesian_domain(named_range(I, 0, size)))(index(I)), + cartesian_domain(named_range(I, 0, size)), + out, + ) + + +@pytest.mark.uses_index_fields +def test_index_builtin(program_processor): + program_processor, validate = program_processor + + isize = 10 + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, INTEGER_INDEX_BUILTIN)) + + run_processor(index_program_simple, program_processor, out, isize, offset_provider={}) + if validate: + assert np.allclose(np.arange(10), out.asnumpy()) + + +@fendef +def index_program_shift(out, size): + set_at( + as_fieldop( + lambda i: deref(i) + deref(shift(Ioff, 1)(i)), cartesian_domain(named_range(I, 0, size)) + )(index(I)), + cartesian_domain(named_range(I, 0, size)), + out, + ) + + +@pytest.mark.uses_index_fields +def test_index_builtin_shift(program_processor): + program_processor, validate = program_processor + + isize = 10 + out = gtx.as_field([I], np.zeros(shape=(isize,)), dtype=getattr(np, INTEGER_INDEX_BUILTIN)) + + run_processor(index_program_shift, program_processor, out, isize, offset_provider={"Ioff": I}) + if validate: + assert np.allclose(np.arange(10) + np.arange(1, 11), out.asnumpy()) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index a86959d075..e462aa07eb 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -18,7 +18,9 @@ @pytest.mark.uses_index_fields +@pytest.mark.uses_scan_in_stencil def test_scan_in_stencil(program_processor): + # FIXME[#1582](tehrengruber): Remove test after scan is reworked. program_processor, validate = program_processor isize = 1 diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 69786b323b..68e5f9d532 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -10,10 +10,11 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain, as_fieldop +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor +from gt4py.next.iterator.embedded import StridedConnectivityField LocA = gtx.Dimension("LocA") @@ -21,8 +22,10 @@ LocB = gtx.Dimension("LocB") # unused LocA2LocAB = offset("O") -LocA2LocAB_offset_provider = gtx.StridedNeighborOffsetProvider( - origin_axis=LocA, neighbor_axis=LocAB, max_neighbors=2, has_skip_values=False +LocA2LocAB_offset_provider = StridedConnectivityField( + domain_dims=(LocA, gtx.Dimension("Dummy", kind=gtx.DimensionKind.LOCAL)), + codomain_dim=LocAB, + max_neighbors=2, ) @@ -33,7 +36,8 @@ def foo(inp): @fendef(offset_provider={"O": LocA2LocAB_offset_provider}) def fencil(size, out, inp): - closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp]) + domain = unstructured_domain(named_range(LocA, 0, size)) + set_at(as_fieldop(foo, domain)(inp), domain, out) @pytest.mark.uses_strided_neighbor_offset @@ -41,7 +45,7 @@ def test_strided_offset_provider(program_processor): program_processor, validate = program_processor LocA_size = 2 - max_neighbors = LocA2LocAB_offset_provider.max_neighbors + max_neighbors = LocA2LocAB_offset_provider.__gt_type__().max_neighbors LocAB_size = LocA_size * max_neighbors rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 5f1c70a6b3..7836b1b110 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -38,6 +38,7 @@ def baz(baz_inp): return deref(lift(bar)(baz_inp)) +@pytest.mark.uses_lift def test_trivial(program_processor): program_processor, validate = program_processor @@ -66,6 +67,7 @@ def stencil_shifted_arg_to_lift(inp): return deref(lift(deref)(shift(I, -1)(inp))) +@pytest.mark.uses_lift def test_shifted_arg_to_lift(program_processor): program_processor, validate = program_processor @@ -94,12 +96,8 @@ def test_shifted_arg_to_lift(program_processor): @fendef def fen_direct_deref(i_size, j_size, out, inp): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)), - deref, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)) + set_at(as_fieldop(deref, domain)(inp), domain, out) def test_direct_deref(program_processor): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 2d84439c93..ea89bb23ba 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor @@ -114,16 +114,10 @@ def test_tuple_of_field_output_constructed_inside(program_processor, stencil): @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, - make_tuple(out1, out2), - [inp1, inp2], + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) ) + set_at(as_fieldop(stencil, domain)(inp1, inp2), domain, make_tuple(out1, out2)) shape = [5, 7, 9] rng = np.random.default_rng() @@ -159,15 +153,13 @@ def stencil(inp1, inp2, inp3): @fendef def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) + ) + set_at( + as_fieldop(stencil, domain)(inp1, inp2, inp3), + domain, make_tuple(make_tuple(out1, out2), out3), - [inp1, inp2, inp3], ) shape = [5, 7, 9] @@ -227,6 +219,7 @@ def tuple_input(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_tuple_field_input(program_processor): program_processor, validate = program_processor @@ -280,6 +273,7 @@ def tuple_tuple_input(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_tuple_of_tuple_of_field_input(program_processor): program_processor, validate = program_processor @@ -327,6 +321,7 @@ def test_field_of_2_extra_dim_input(program_processor): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_scalar_tuple_args(program_processor): @fundef def stencil(inp): @@ -356,6 +351,7 @@ def stencil(inp): @pytest.mark.uses_tuple_args +@pytest.mark.uses_tuple_iterator def test_mixed_field_scalar_tuple_arg(program_processor): @fundef def stencil(inp): diff --git a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py index eaeb76b404..3e3df069bf 100644 --- a/tests/next_tests/integration_tests/feature_tests/test_util_cases.py +++ b/tests/next_tests/integration_tests/feature_tests/test_util_cases.py @@ -35,8 +35,8 @@ def mixed_args( def test_allocate_default_unique(cartesian_case): a = cases.allocate(cartesian_case, mixed_args, "a")() - assert np.min(a.asnumpy()) == 0 - assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) - 1 + assert np.min(a.asnumpy()) == 1 + assert np.max(a.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) b = cases.allocate(cartesian_case, mixed_args, "b")() @@ -45,7 +45,7 @@ def test_allocate_default_unique(cartesian_case): c = cases.allocate(cartesian_case, mixed_args, "c")() assert np.min(c.asnumpy()) == b + 1 - assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + assert np.max(c.asnumpy()) == np.prod(tuple(cartesian_case.default_sizes.values())) * 2 + 1 def test_allocate_return_default_zeros(cartesian_case): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index eb59c77201..da354be7ea 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -11,9 +11,9 @@ import numpy as np import pytest - pytest.importorskip("atlas4py") +import gt4py._core.definitions as core_defs from gt4py import next as gtx from gt4py.next import allocators, neighbor_sum from gt4py.next.iterator import atlas_utils @@ -22,20 +22,17 @@ exec_alloc_descriptor, ) from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( + E2V, + V2E, + E2VDim, + Edge, + V2EDim, + Vertex, assert_close, nabla_setup, ) -Vertex = gtx.Dimension("Vertex") -Edge = gtx.Dimension("Edge") -V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) -E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) - -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) - - @gtx.field_operator def compute_zavgS( pp: gtx.Field[[Vertex], float], S_M: gtx.Field[[Edge], float] @@ -66,50 +63,50 @@ def pnabla( return compute_pnabla(pp, S_M[0], sign, vol), compute_pnabla(pp, S_M[1], sign, vol) +@pytest.mark.requires_atlas def test_ffront_compute_zavgS(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator + # TODO(havogt): fix nabla setup to work with GPU + if exec_alloc_descriptor.allocator.device_type != core_defs.DeviceType.CPU: + pytest.skip("This test is only supported on CPU devices yet") - setup = nabla_setup() + setup = nabla_setup(allocator=exec_alloc_descriptor.allocator) - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) - - zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=allocator) - - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) + zavgS = gtx.zeros({Edge: setup.edges_size}, allocator=exec_alloc_descriptor.allocator) - compute_zavgS.with_backend(exec_alloc_descriptor)( - pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + compute_zavgS.with_backend( + None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor + )( + setup.input_field, + setup.S_fields[0], + out=zavgS, + offset_provider={"E2V": setup.edges2node_connectivity}, ) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) assert_close(388241977.58389181, np.max(zavgS.asnumpy())) +@pytest.mark.requires_atlas def test_ffront_nabla(exec_alloc_descriptor): - executor, allocator = exec_alloc_descriptor.executor, exec_alloc_descriptor.allocator - - setup = nabla_setup() - - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field, allocator=allocator) - pp = gtx.as_field([Vertex], setup.input_field, allocator=allocator) - S_M = tuple(map(gtx.as_field.partial([Edge], allocator=allocator), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field, allocator=allocator) - - pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) - pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=allocator) - - e2v = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False - ) - v2e = gtx.NeighborTableOffsetProvider( - atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 - ) - - pnabla.with_backend(exec_alloc_descriptor)( - pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} + # TODO(havogt): fix nabla setup to work with GPU + if exec_alloc_descriptor.allocator.device_type != core_defs.DeviceType.CPU: + pytest.skip("This test is only supported on CPU devices yet") + + setup = nabla_setup(allocator=exec_alloc_descriptor.allocator) + + pnabla_MXX = gtx.zeros({Vertex: setup.nodes_size}, allocator=exec_alloc_descriptor.allocator) + pnabla_MYY = gtx.zeros({Vertex: setup.nodes_size}, allocator=exec_alloc_descriptor.allocator) + + pnabla.with_backend(None if exec_alloc_descriptor.executor is None else exec_alloc_descriptor)( + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + out=(pnabla_MXX, pnabla_MYY), + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) # TODO this check is not sensitive enough, need to implement a proper numpy reference! diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 505879a506..19664f2dab 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -227,14 +227,6 @@ def test_solve_nonhydro_stencil_52_like_z_q(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail( - "Needs implementation of scan projector. Breaks in type inference as executed" - "again after CollapseTuple." - ) if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") @@ -254,12 +246,6 @@ def test_solve_nonhydro_stencil_52_like_z_q_tup(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") - cases.run( test_setup.case, solve_nonhydro_stencil_52_like, @@ -276,11 +262,6 @@ def test_solve_nonhydro_stencil_52_like(test_setup): @pytest.mark.uses_tuple_returns def test_solve_nonhydro_stencil_52_like_with_gtfn_tuple_merge(test_setup): - if ( - test_setup.case.backend - == test_definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES.load() - ): - pytest.xfail("Temporary extraction does not work correctly in combination with scans.") if test_setup.case.backend == test_definitions.ProgramBackendId.ROUNDTRIP.load(): pytest.xfail("Needs proper handling of tuple[Column] <-> Column[tuple].") diff --git a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py index 8d7324f438..6a5865134d 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/fvm_nabla_setup.py @@ -20,6 +20,18 @@ functionspace, ) +from gt4py import next as gtx +from gt4py.next.iterator import atlas_utils + + +Vertex = gtx.Dimension("Vertex") +Edge = gtx.Dimension("Edge") +V2EDim = gtx.Dimension("V2E", kind=gtx.DimensionKind.LOCAL) +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) + +V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) +E2V = gtx.FieldOffset("E2V", source=Vertex, target=(Edge, E2VDim)) + def assert_close(expected, actual): assert math.isclose(expected, actual), "expected={}, actual={}".format(expected, actual) @@ -33,9 +45,10 @@ def _default_config(): config["angle"] = 20.0 return config - def __init__(self, *, grid=StructuredGrid("O32"), config=None): + def __init__(self, *, allocator, grid=StructuredGrid("O32"), config=None): if config is None: config = self._default_config() + self.allocator = allocator mesh = StructuredMeshGenerator(config).generate(grid) fs_edges = functionspace.EdgeColumns(mesh, halo=1) @@ -55,12 +68,22 @@ def __init__(self, *, grid=StructuredGrid("O32"), config=None): self.edges_per_node = edges_per_node @property - def edges2node_connectivity(self): - return self.mesh.edges.node_connectivity + def edges2node_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Edge: self.edges_size, E2VDim: 2}, + codomain=Vertex, + data=atlas_utils.AtlasTable(self.mesh.edges.node_connectivity).asnumpy(), + allocator=self.allocator, + ) @property - def nodes2edge_connectivity(self): - return self.mesh.nodes.edge_connectivity + def nodes2edge_connectivity(self) -> gtx.Connectivity: + return gtx.as_connectivity( + domain={Vertex: self.nodes_size, V2EDim: self.edges_per_node}, + codomain=Edge, + data=atlas_utils.AtlasTable(self.mesh.nodes.edge_connectivity).asnumpy(), + allocator=self.allocator, + ) @property def nodes_size(self): @@ -75,16 +98,16 @@ def _is_pole_edge(e, edge_flags): return Topology.check(edge_flags[e], Topology.POLE) @property - def is_pole_edge_field(self): + def is_pole_edge_field(self) -> gtx.Field: edge_flags = np.array(self.mesh.edges.flags()) pole_edge_field = np.zeros((self.edges_size,), dtype=bool) for e in range(self.edges_size): pole_edge_field[e] = self._is_pole_edge(e, edge_flags) - return pole_edge_field + return gtx.as_field([Edge], pole_edge_field, allocator=self.allocator) @property - def sign_field(self): + def sign_field(self) -> gtx.Field: node2edge_sign = np.zeros((self.nodes_size, self.edges_per_node)) edge_flags = np.array(self.mesh.edges.flags()) @@ -100,10 +123,10 @@ def sign_field(self): node2edge_sign[jnode, jedge] = -1.0 if self._is_pole_edge(iedge, edge_flags): node2edge_sign[jnode, jedge] = 1.0 - return node2edge_sign + return gtx.as_field([Vertex, V2EDim], node2edge_sign, allocator=self.allocator) @property - def S_fields(self): + def S_fields(self) -> tuple[gtx.Field, gtx.Field]: S = np.array(self.mesh.edges.field("dual_normals"), copy=False) S_MXX = np.zeros((self.edges_size)) S_MYY = np.zeros((self.edges_size)) @@ -124,10 +147,12 @@ def S_fields(self): assert math.isclose(min(S_MYY), -2001577.7946404363) assert math.isclose(max(S_MYY), 2001577.7946404363) - return S_MXX, S_MYY + return gtx.as_field([Edge], S_MXX, allocator=self.allocator), gtx.as_field( + [Edge], S_MYY, allocator=self.allocator + ) @property - def vol_field(self): + def vol_field(self) -> gtx.Field: rpi = 2.0 * math.asin(1.0) radius = 6371.22e03 deg2rad = 2.0 * rpi / 360.0 @@ -142,10 +167,10 @@ def vol_field(self): # VOL(min/max): 57510668192.214096 851856184496.32886 assert_close(57510668192.214096, min(vol)) assert_close(851856184496.32886, max(vol)) - return vol + return gtx.as_field([Vertex], vol, allocator=self.allocator) @property - def input_field(self): + def input_field(self) -> gtx.Field: klevel = 0 MXX = 0 MYY = 1 @@ -200,4 +225,5 @@ def input_field(self): assert_close(0.0000000000000000, min(rzs)) assert_close(1965.4980340735883, max(rzs)) - return rzs[:, klevel] + + return gtx.as_field([Vertex], rzs[:, klevel], allocator=self.allocator) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 14271efb27..d0a1601816 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -10,8 +10,15 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import ( + cartesian_domain, + deref, + lift, + named_range, + shift, + as_fieldop, +) +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.unit_tests.conftest import program_processor, run_processor @@ -78,11 +85,6 @@ def naive_lap(inp): def test_anton_toy(stencil, program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn_with_temporaries.executor, - ]: - pytest.xfail("TODO: issue with temporaries that crashes the application") - if stencil is lap: pytest.xfail( "Type inference does not support calling lambdas with offset arguments of changing type." @@ -90,14 +92,10 @@ def test_anton_toy(stencil, program_processor): @fendef(offset_provider={"i": IDim, "j": JDim}) def fencil(x, y, z, out, inp): - closure( - cartesian_domain( - named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) - ), - stencil, - out, - [inp], + domain = cartesian_domain( + named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) ) + set_at(as_fieldop(stencil, domain)(inp), domain, out) shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 2b858f3025..3b4fc0a70c 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next import field_utils from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -122,19 +122,19 @@ def k_level_condition_upper_tuple(k_idx, k_level): @pytest.mark.parametrize( "fun, k_level, inp_function, ref_function", [ - ( + pytest.param( k_level_condition_lower, lambda inp: 0, lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([[0], inp[:-1]]), ), - ( + pytest.param( k_level_condition_upper, lambda inp: inp.shape[0] - 1, lambda k_size: gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), lambda inp: np.concatenate([inp[1:], [0]]), ), - ( + pytest.param( k_level_condition_upper_tuple, lambda inp: inp[0].shape[0] - 1, lambda k_size: ( @@ -142,6 +142,7 @@ def k_level_condition_upper_tuple(k_idx, k_level): gtx.as_field([KDim], np.arange(k_size, dtype=np.int32)), ), lambda inp: np.concatenate([(inp[0][1:] + inp[1][1:]), [0]]), + marks=pytest.mark.uses_tuple_iterator, ), ], ) @@ -170,29 +171,21 @@ def test_k_level_condition(program_processor, fun, k_level, inp_function, ref_fu @fundef -def sum_scanpass(state, inp): +def ksum(state, inp): return state + deref(inp) -@fundef -def ksum(inp): - return scan(sum_scanpass, True, 0.0)(inp) - - @fendef(column_axis=KDim) def ksum_fencil(i_size, k_start, k_end, inp, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)), - ksum, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)) + set_at(as_fieldop(scan(ksum, True, 0.0), domain)(inp), domain, out) @pytest.mark.parametrize( "kstart, reference", [(0, np.asarray([[0, 1, 3, 6, 10, 15, 21]])), (2, np.asarray([[0, 0, 2, 5, 9, 14, 20]]))], ) +@pytest.mark.uses_scan def test_ksum_scan(program_processor, kstart, reference): program_processor, validate = program_processor shape = [1, 7] @@ -214,21 +207,13 @@ def test_ksum_scan(program_processor, kstart, reference): assert np.allclose(reference, out.asnumpy()) -@fundef -def ksum_back(inp): - return scan(sum_scanpass, False, 0.0)(inp) - - @fendef(column_axis=KDim) def ksum_back_fencil(i_size, k_size, inp, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)), - ksum_back, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, 0, k_size)) + set_at(as_fieldop(scan(ksum, False, 0.0), domain)(inp), domain, out) +@pytest.mark.uses_scan def test_ksum_back_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] @@ -252,23 +237,14 @@ def test_ksum_back_scan(program_processor): @fundef -def doublesum_scanpass(state, inp0, inp1): +def kdoublesum(state, inp0, inp1): return make_tuple(tuple_get(0, state) + deref(inp0), tuple_get(1, state) + deref(inp1)) -@fundef -def kdoublesum(inp0, inp1): - return scan(doublesum_scanpass, True, make_tuple(0.0, 0))(inp0, inp1) - - @fendef(column_axis=KDim) def kdoublesum_fencil(i_size, k_start, k_end, inp0, inp1, out): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)), - kdoublesum, - out, - [inp0, inp1], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(KDim, k_start, k_end)) + set_at(as_fieldop(scan(kdoublesum, True, make_tuple(0.0, 0)), domain)(inp0, inp1), domain, out) @pytest.mark.parametrize( @@ -325,7 +301,8 @@ def sum_shifted(inp0, inp1): @fendef(column_axis=KDim) def sum_shifted_fencil(out, inp0, inp1, k_size): - closure(cartesian_domain(named_range(KDim, 1, k_size)), sum_shifted, out, [inp0, inp1]) + domain = cartesian_domain(named_range(KDim, 1, k_size)) + set_at(as_fieldop(sum_shifted, domain)(inp0, inp1), domain, out) def test_different_vertical_sizes(program_processor): @@ -352,7 +329,8 @@ def sum(inp0, inp1): @fendef(column_axis=KDim) def sum_fencil(out, inp0, inp1, k_size): - closure(cartesian_domain(named_range(KDim, 0, k_size)), sum, out, [inp0, inp1]) + domain = cartesian_domain(named_range(KDim, 0, k_size)) + set_at(as_fieldop(sum, domain)(inp0, inp1), domain, out) @pytest.mark.uses_origin diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 156bc1c37f..22b4d8b3c5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -28,9 +28,9 @@ reduce, tuple_get, unstructured_domain, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset -from gt4py.next.iterator.transforms.pass_manager import LiftMode +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -56,7 +56,8 @@ def compute_zavgS(pp, S_M): @fendef def compute_zavgS_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS, domain)(pp, S_M), domain, out) @fundef @@ -101,36 +102,25 @@ def compute_pnabla2(pp, S_M, sign, vol): @fendef def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - pnabla, - out, - [pp, S_MXX, S_MYY, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(pnabla, domain)(pp, S_MXX, S_MYY, sign, vol), domain, out) @pytest.mark.requires_atlas def test_compute_zavgS(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) + setup = nabla_setup(allocator=None) zavgS = gtx.as_field([Edge], np.zeros((setup.edges_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS_fencil, program_processor, setup.edges_size, zavgS, - pp, - S_MXX, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[0], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -142,9 +132,9 @@ def test_compute_zavgS(program_processor): program_processor, setup.edges_size, zavgS, - pp, - S_MYY, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields[1], + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) @@ -153,35 +143,28 @@ def test_compute_zavgS(program_processor): @fendef def compute_zavgS2_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS2, domain)(pp, S_M), domain, out) @pytest.mark.requires_atlas def test_compute_zavgS2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - pp = gtx.as_field([Vertex], setup.input_field) - - S = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) + setup = nabla_setup(allocator=None) zavgS = ( gtx.as_field([Edge], np.zeros((setup.edges_size))), gtx.as_field([Edge], np.zeros((setup.edges_size))), ) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - run_processor( compute_zavgS2_fencil, program_processor, setup.edges_size, zavgS, - pp, - S, - offset_provider={"E2V": e2v}, + setup.input_field, + setup.S_fields, + offset_provider={"E2V": setup.edges2node_connectivity}, ) if validate: @@ -196,34 +179,27 @@ def test_compute_zavgS2(program_processor): def test_nabla(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, + setup.input_field, S_MXX, S_MYY, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -235,44 +211,31 @@ def test_nabla(program_processor): @fendef def nabla2(n_nodes, out, pp, S, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla2, - out, - [pp, S, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(compute_pnabla2, domain)(pp, S, sign, vol), domain, out) @pytest.mark.requires_atlas def test_nabla2(program_processor): program_processor, validate = program_processor - setup = nabla_setup() - - sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_M = tuple(gtx.as_field([Edge], s) for s in setup.S_fields) - vol = gtx.as_field([Vertex], setup.vol_field) + setup = nabla_setup(allocator=None) pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla2, program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), - pp, - S_M, - sign, - vol, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.input_field, + setup.S_fields, + setup.sign_field, + setup.vol_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: @@ -308,17 +271,16 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): @fendef def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): # TODO replace by single stencil which returns tuple - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge), + domain, out_MXX, - [pp, S_MXX, vol, node_index, is_pole_edge], ) - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MYY, vol, node_index, is_pole_edge), + domain, out_MYY, - [pp, S_MYY, vol, node_index, is_pole_edge], ) @@ -326,36 +288,29 @@ def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_ def test_nabla_sign(program_processor): program_processor, validate = program_processor - setup = nabla_setup() + setup = nabla_setup(allocator=None) - is_pole_edge = gtx.as_field([Edge], setup.is_pole_edge_field) - pp = gtx.as_field([Vertex], setup.input_field) - S_MXX, S_MYY = tuple(map(gtx.as_field.partial([Edge]), setup.S_fields)) - vol = gtx.as_field([Vertex], setup.vol_field) + S_MXX, S_MYY = setup.S_fields pnabla_MXX = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) pnabla_MYY = gtx.as_field([Vertex], np.zeros((setup.nodes_size))) - e2v = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.edges2node_connectivity), Edge, Vertex, 2 - ) - v2e = gtx.NeighborTableOffsetProvider( - AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 - ) - run_processor( nabla_sign, program_processor, setup.nodes_size, pnabla_MXX, pnabla_MYY, - pp, + setup.input_field, S_MXX, S_MYY, - vol, + setup.vol_field, gtx.index_field(Vertex), - is_pole_edge, - offset_provider={"E2V": e2v, "V2E": v2e}, + setup.is_pole_edge_field, + offset_provider={ + "E2V": setup.edges2node_connectivity, + "V2E": setup.nodes2edge_connectivity, + }, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 45793b1d3e..1726956332 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim @@ -57,14 +57,11 @@ def hdiff_sten(inp, coeff): @fendef(offset_provider={"I": IDim, "J": JDim}) def hdiff(inp, coeff, out, x, y): - closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), - hdiff_sten, - out, - [inp, coeff], - ) + domain = cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)) + set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) +@pytest.mark.uses_lift @pytest.mark.uses_origin def test_hdiff(hdiff_reference, program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py index 2dde7d7653..3c2ac6e7d7 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_if_stmt.py @@ -6,29 +6,16 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later import numpy as np import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, as_fieldop, named_range -from gt4py.next.iterator.runtime import set_at, if_stmt, fendef, fundef, offset -from gt4py.next.program_processors.runners import gtfn +from gt4py.next.iterator.builtins import as_fieldop, cartesian_domain, deref, named_range +from gt4py.next.iterator.runtime import fendef, fundef, if_stmt, offset, set_at + +from next_tests.unit_tests.conftest import program_processor_no_transforms, run_processor -from next_tests.unit_tests.conftest import program_processor, run_processor i = offset("i") @@ -43,8 +30,8 @@ def multiply(alpha, inp): @pytest.mark.uses_ir_if_stmts @pytest.mark.parametrize("cond", [True, False]) -def test_if_stmt(program_processor, cond): - program_processor, validate = program_processor +def test_if_stmt(program_processor_no_transforms, cond): + program_processor, validate = program_processor_no_transforms size = 10 @fendef(offset_provider={"i": IDim}) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index a89f250571..e98e820f14 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -10,9 +10,9 @@ import pytest import gt4py.next as gtx +from gt4py.next import backend from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.runtime import set_at, fendef, fundef from gt4py.next.program_processors.formatters import gtfn as gtfn_formatters from gt4py.next.program_processors.runners import gtfn @@ -42,22 +42,17 @@ def tridiag_backward2(x_kp1, cp, dp): @fundef -def solve_tridiag(a, b, c, d): - cpdp = lift(scan(tridiag_forward, True, make_tuple(0.0, 0.0)))(a, b, c, d) - return scan(tridiag_backward, False, 0.0)(cpdp) - - -def tuple_get_it(i, x): - def stencil(x): - return tuple_get(i, deref(x)) - - return lift(stencil)(x) +def solve_tridiag(domain, a, b, c, d): + cpdp = as_fieldop(scan(tridiag_forward, True, make_tuple(0.0, 0.0)), domain)(a, b, c, d) + return as_fieldop(scan(tridiag_backward, False, 0.0), domain)(cpdp) @fundef -def solve_tridiag2(a, b, c, d): - cpdp = lift(scan(tridiag_forward, True, make_tuple(0.0, 0.0)))(a, b, c, d) - return scan(tridiag_backward2, False, 0.0)(tuple_get_it(0, cpdp), tuple_get_it(1, cpdp)) +def solve_tridiag2(domain, a, b, c, d): + cpdp = as_fieldop(scan(tridiag_forward, True, make_tuple(0.0, 0.0)), domain)(a, b, c, d) + return as_fieldop(scan(tridiag_backward2, False, 0.0), domain)( + tuple_get(0, cpdp), tuple_get(1, cpdp) + ) @pytest.fixture @@ -67,7 +62,13 @@ def tridiag_reference(): a = rng.normal(size=shape) b = rng.normal(size=shape) * 2 c = rng.normal(size=shape) - d = rng.normal(size=shape) + # Changed in NumPY version 2.0: In a linear matrix equation ax = b, the b array + # is only treated as a shape (M,) column vector if it is exactly 1-dimensional. + # In all other instances it is treated as a stack of (M, K) matrices. Therefore + # below we add an extra dimension (K) of size 1. Previously b would be treated + # as a stack of (M,) vectors if b.ndim was equal to a.ndim - 1. + # Refer to https://numpy.org/doc/2.0/reference/generated/numpy.linalg.solve.html + d = rng.normal(size=(*shape, 1)) matrices = np.zeros(shape + shape[-1:]) i = np.arange(shape[2]) @@ -75,45 +76,32 @@ def tridiag_reference(): matrices[:, :, i, i] = b matrices[:, :, i[:-1], i[1:]] = c[:, :, :-1] x = np.linalg.solve(matrices, d) - return a, b, c, d, x + return a, b, c, d[:, :, :, 0], x[:, :, :, 0] @fendef def fen_solve_tridiag(i_size, j_size, k_size, a, b, c, d, x): - closure( - cartesian_domain( - named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) - ), - solve_tridiag, - x, - [a, b, c, d], + domain = cartesian_domain( + named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) ) + set_at(solve_tridiag(domain, a, b, c, d), domain, x) @fendef def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): - closure( - cartesian_domain( - named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) - ), - solve_tridiag2, - x, - [a, b, c, d], + domain = cartesian_domain( + named_range(IDim, 0, i_size), named_range(JDim, 0, j_size), named_range(KDim, 0, k_size) ) + set_at(solve_tridiag2(domain, a, b, c, d), domain, x) @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) -@pytest.mark.uses_lift_expressions def test_tridiag(fencil, tridiag_reference, program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn, - gtfn.run_gtfn_imperative, - gtfn_formatters.format_cpp, - ]: - pytest.skip("gtfn does only support lifted scans when using temporaries") - if program_processor == gtfn.run_gtfn_with_temporaries: - pytest.xfail("tuple_get on columns not supported.") + + if isinstance(program_processor, backend.Backend) and "dace" in program_processor.name: + pytest.xfail("Dace ITIR backend doesn't support the IR format used in this test.") + a, b, c, d, x = tridiag_reference shape = a.shape as_3d_field = gtx.as_field.partial([IDim, JDim, KDim]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 6fb1d4c152..ff87de7348 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -38,9 +38,13 @@ V2VDim, Vertex, c2e_arr, + c2e_conn, e2v_arr, + e2v_conn, v2e_arr, + v2e_conn, v2v_arr, + v2v_conn, ) from next_tests.unit_tests.conftest import program_processor, run_processor @@ -89,7 +93,7 @@ def test_sum_edges_to_vertices(program_processor, stencil): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -111,7 +115,7 @@ def test_map_neighbors(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -134,7 +138,7 @@ def test_map_make_const_list(program_processor): program_processor, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -145,6 +149,7 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) +@pytest.mark.uses_composite_shifts def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor): program_processor, validate = program_processor inp = vertex_index_field() @@ -157,8 +162,8 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo inp, out=out, offset_provider={ - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), + "E2V": e2v_conn, + "C2E": c2e_conn, }, ) if validate: @@ -170,6 +175,7 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) +@pytest.mark.uses_reduce_with_lambda def test_sparse_input_field(program_processor): program_processor, validate = program_processor @@ -185,13 +191,14 @@ def test_sparse_input_field(program_processor): non_sparse, inp, out=out, - offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, + offset_provider={"V2E": v2e_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) +@pytest.mark.uses_reduce_with_lambda def test_sparse_input_field_v2v(program_processor): program_processor, validate = program_processor @@ -208,8 +215,8 @@ def test_sparse_input_field_v2v(program_processor): inp, out=out, offset_provider={ - "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "V2V": v2v_conn, + "V2E": v2e_conn, }, ) @@ -235,7 +242,7 @@ def test_slice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -259,7 +266,7 @@ def test_slice_twice_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -284,7 +291,7 @@ def test_shift_sliced_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -309,7 +316,7 @@ def test_slice_shifted_sparse(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -326,6 +333,7 @@ def lift_stencil(inp): return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) +@pytest.mark.uses_lift def test_lift(program_processor): program_processor, validate = program_processor inp = vertex_index_field() @@ -337,7 +345,7 @@ def test_lift(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -360,7 +368,7 @@ def test_shift_sparse_input_field(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: @@ -383,7 +391,6 @@ def test_shift_sparse_input_field2(program_processor): if program_processor in [ gtfn.run_gtfn, gtfn.run_gtfn_imperative, - gtfn.run_gtfn_with_temporaries, ]: pytest.xfail( "Bug in bindings/compilation/caching: only the first program seems to be compiled." @@ -394,8 +401,8 @@ def test_shift_sparse_input_field2(program_processor): out2 = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), - "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), + "E2V": e2v_conn, + "V2E": v2e_conn, } domain = {Vertex: range(0, 9)} @@ -449,7 +456,7 @@ def test_sparse_shifted_stencil_reduce(program_processor): program_processor, inp, out=out, - offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, + offset_provider={"V2V": v2v_conn}, ) if validate: diff --git a/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py new file mode 100644 index 0000000000..b69950928d --- /dev/null +++ b/tests/next_tests/regression_tests/embedded_tests/test_domain_pickle.py @@ -0,0 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pickle + +from gt4py.next import common + +I = common.Dimension("I") +J = common.Dimension("J") + + +def test_domain_pickle_after_slice(): + domain = common.domain(((I, (2, 4)), (J, (3, 5)))) + # use slice_at to populate cached property + domain.slice_at[2:5, 5:7] + + pickle.dumps(domain) diff --git a/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py b/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py new file mode 100644 index 0000000000..f95ed4c3a7 --- /dev/null +++ b/tests/next_tests/regression_tests/ffront_tests/test_offset_dimensions_names.py @@ -0,0 +1,63 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from gt4py import next as gtx +from gt4py.next import Dims, Field, common + +from next_tests import definitions as test_defs +from next_tests.integration_tests import cases +from next_tests.integration_tests.feature_tests.ffront_tests import ffront_test_utils + + +V = gtx.Dimension("V") +E = gtx.Dimension("E") +Neigh = gtx.Dimension("Neigh", kind=common.DimensionKind.LOCAL) +Off = gtx.FieldOffset("Off", source=E, target=(V, Neigh)) + + +@pytest.fixture +def case(): + mesh = ffront_test_utils.simple_mesh() + exec_alloc_descriptor = test_defs.ProgramBackendId.GTFN_CPU.load() + v2e_arr = mesh.offset_provider["V2E"].ndarray + return cases.Case( + exec_alloc_descriptor, + offset_provider={ + "Off": common._connectivity( + v2e_arr, + codomain=E, + domain={V: v2e_arr.shape[0], Neigh: 4}, + skip_value=None, + ), + }, + default_sizes={ + V: mesh.num_vertices, + E: mesh.num_edges, + }, + grid_type=common.GridType.UNSTRUCTURED, + allocator=exec_alloc_descriptor.allocator, + ) + + +def test_offset_dimension_name_differ(case): + """ + Ensure that gtfn works with offset name that differs from the name of the local dimension. + + If the value of the `NeighborConnectivityType.neighbor_dim` did not match the `FieldOffset` value, + gtfn would silently ignore the neighbor index, see https://github.com/GridTools/gridtools/pull/1814. + """ + + @gtx.field_operator + def foo(a: Field[Dims[E], float]) -> Field[Dims[V], float]: + return a(Off[1]) + + cases.verify_with_default_data( + case, foo, lambda a: a[case.offset_provider["Off"].ndarray[:, 1]] + ) diff --git a/tests/next_tests/toy_connectivity.py b/tests/next_tests/toy_connectivity.py index 82c91a5e74..154b666c5d 100644 --- a/tests/next_tests/toy_connectivity.py +++ b/tests/next_tests/toy_connectivity.py @@ -9,7 +9,7 @@ import numpy as np import gt4py.next as gtx -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir Vertex = gtx.Dimension("Vertex") @@ -46,9 +46,11 @@ [7, 17, 1, 16], [8, 15, 2, 17], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) +c2e_conn = gtx.as_connectivity(domain={Cell: 9, C2EDim: 4}, codomain=Edge, data=c2e_arr) + v2v_arr = np.array( [ [1, 3, 2, 6], @@ -61,9 +63,11 @@ [8, 1, 6, 4], [6, 2, 7, 5], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) +v2v_conn = gtx.as_connectivity(domain={Vertex: 9, V2VDim: 4}, codomain=Vertex, data=v2v_arr) + e2v_arr = np.array( [ [0, 1], @@ -85,9 +89,10 @@ [7, 1], [8, 2], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) +e2v_conn = gtx.as_connectivity(domain={Edge: 18, E2VDim: 2}, codomain=Vertex, data=e2v_arr) # order east, north, west, south (counter-clock wise) v2e_arr = np.array( @@ -102,5 +107,7 @@ [7, 13, 6, 16], [8, 14, 7, 17], ], - dtype=np.dtype(itir.INTEGER_INDEX_BUILTIN), + dtype=np.dtype(builtins.INTEGER_INDEX_BUILTIN), ) + +v2e_conn = gtx.as_connectivity(domain={Vertex: 9, V2EDim: 4}, codomain=Edge, data=v2e_arr) diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index 8a4aa50730..0bd8653a03 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -14,43 +14,18 @@ import pytest import gt4py.next as gtx -from gt4py.next import backend +from gt4py.next import backend, common +from gt4py.next.embedded import nd_array_field from gt4py.next.iterator import runtime from gt4py.next.program_processors import program_formatter - import next_tests ProgramProcessor: TypeAlias = backend.Backend | program_formatter.ProgramFormatter -@pytest.fixture( - params=[ - (None, True), - (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), - (next_tests.definitions.ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES, True), - (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), - (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), - (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), - (next_tests.definitions.ProgramBackendId.GTFN_CPU_WITH_TEMPORARIES, True), - # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation - (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), - (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), - (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), - pytest.param( - (next_tests.definitions.OptionalProgramBackendId.DACE_CPU, True), - marks=pytest.mark.requires_dace, - ), - # TODO(havogt): update tests to use proper allocation - # pytest.param( - # (next_tests.definitions.OptionalProgramBackendId.DACE_GPU, True), - # marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), - # ), - ], - ids=lambda p: p[0].short_id() if p[0] is not None else "None", -) -def program_processor(request) -> tuple[ProgramProcessor, bool]: +def _program_processor(request) -> tuple[ProgramProcessor, bool]: """ Fixture creating program processors on-demand for tests. @@ -72,6 +47,38 @@ def program_processor(request) -> tuple[ProgramProcessor, bool]: return processor, is_backend +program_processor = pytest.fixture( + _program_processor, + params=[ + (None, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES, True), + (next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, True), + (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), + # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation + (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), + (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), + pytest.param( + (next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, True), + marks=pytest.mark.requires_dace, + ), + ], + ids=lambda p: p[0].short_id() if p[0] is not None else "None", +) + +program_processor_no_transforms = pytest.fixture( + _program_processor, + params=[ + (None, True), + (next_tests.definitions.ProgramBackendId.GTFN_CPU_NO_TRANSFORMS, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP_NO_TRANSFORMS, True), + ], + ids=lambda p: p[0].short_id() if p[0] is not None else "None", +) + + def run_processor( program: runtime.FendefDispatcher, processor: ProgramProcessor, @@ -85,12 +92,21 @@ def run_processor( @dataclasses.dataclass -class DummyConnectivity: +class DummyConnectivity(common.Connectivity): max_neighbors: int has_skip_values: int - origin_axis: gtx.Dimension = gtx.Dimension("dummy_origin") - neighbor_axis: gtx.Dimension = gtx.Dimension("dummy_neighbor") - index_type: type[int] = int + source_dim: gtx.Dimension = gtx.Dimension("dummy_origin") + codomain: gtx.Dimension = gtx.Dimension("dummy_neighbor") + + +def nd_array_implementation_params(): + for xp in nd_array_field._nd_array_implementations: + if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: + yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) + else: + yield pytest.param(xp, id=xp.__name__) + - def mapped_index(_, __) -> int: - return 0 +@pytest.fixture(params=nd_array_implementation_params()) +def nd_array_implementation(request): + yield request.param diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 9fba633cba..9bdc6ab5c1 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -15,7 +15,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import Dimension, Domain, UnitRange, NamedRange, NamedIndex +from gt4py.next.common import Dimension, Domain, NamedIndex, NamedRange, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins @@ -28,19 +28,6 @@ D2 = Dimension("D2") -def nd_array_implementation_params(): - for xp in nd_array_field._nd_array_implementations: - if hasattr(nd_array_field, "cp") and xp == nd_array_field.cp: - yield pytest.param(xp, id=xp.__name__, marks=pytest.mark.requires_gpu) - else: - yield pytest.param(xp, id=xp.__name__) - - -@pytest.fixture(params=nd_array_implementation_params()) -def nd_array_implementation(request): - yield request.param - - @pytest.fixture( params=[ operator.add, @@ -264,6 +251,16 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte assert np.allclose(op_result.ndarray, expected_result) +def test_as_scalar(nd_array_implementation): + testee = common._field( + nd_array_implementation.asarray(42.0, dtype=np.float32), domain=common.Domain() + ) + + result = testee.as_scalar() + assert result == 42.0 + assert isinstance(result, np.float32) + + def product_nd_array_implementation_params(): for xp1 in nd_array_field._nd_array_implementations: for xp2 in nd_array_field._nd_array_implementations: @@ -367,10 +364,11 @@ def test_reshuffling_premap(): ij_field = common._field( np.asarray([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]]), - domain=common.Domain(dims=(I, J), ranges=(UnitRange(0, 3), UnitRange(0, 3))), + domain=common.Domain(dims=(I, J), ranges=(UnitRange(1, 4), UnitRange(2, 5))), ) + max_ij_conn = common._connectivity( - np.fromfunction(lambda i, j: np.maximum(i, j), (3, 3), dtype=int), + np.asarray([[1, 2, 3], [2, 2, 3], [3, 3, 3]], dtype=int), domain=common.Domain( dims=ij_field.domain.dims, ranges=ij_field.domain.ranges, @@ -381,7 +379,7 @@ def test_reshuffling_premap(): result = ij_field.premap(max_ij_conn) expected = common._field( np.asarray([[0.0, 4.0, 8.0], [3.0, 4.0, 8.0], [6.0, 7.0, 8.0]]), - domain=common.Domain(dims=(I, J), ranges=(UnitRange(0, 3), UnitRange(0, 3))), + domain=common.Domain(dims=(I, J), ranges=(UnitRange(1, 4), UnitRange(2, 5))), ) assert result.domain == expected.domain diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 3951c410dc..776cd4e1a9 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -91,7 +91,7 @@ def foo(bar: int64, alpha: int64) -> int64: parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.call("multiplies")("alpha", "bar") + reference = im.multiplies_("alpha", "bar") assert lowered.expr == reference @@ -283,10 +283,21 @@ def foo(a: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) + + reference = im.cast_as_fieldop("int32")("a") + + assert lowered_inlined.expr == reference - reference = im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - "a" - ) + +def test_astype_local_field(): + def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): + return astype(a, int32) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.cast_("val", "int32"))))("a") assert lowered.expr == reference @@ -297,10 +308,11 @@ def foo(a: float64): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) - reference = im.call("cast_")("a", "int32") + reference = im.cast_("a", "int32") - assert lowered.expr == reference + assert lowered_inlined.expr == reference def test_astype_tuple(): @@ -312,12 +324,8 @@ def foo(a: tuple[gtx.Field[[TDim], float64], gtx.Field[[TDim], float64]]): lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, "a") - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, "a") - ), + im.cast_as_fieldop("int32")(im.tuple_get(0, "a")), + im.cast_as_fieldop("int32")(im.tuple_get(1, "a")), ) assert lowered_inlined.expr == reference @@ -332,10 +340,8 @@ def foo(a: tuple[gtx.Field[[TDim], float64], float64]): lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, "a") - ), - im.call("cast_")(im.tuple_get(1, "a"), "int32"), + im.cast_as_fieldop("int32")(im.tuple_get(0, "a")), + im.cast_(im.tuple_get(1, "a"), "int32"), ) assert lowered_inlined.expr == reference @@ -356,16 +362,10 @@ def foo( reference = im.make_tuple( im.make_tuple( - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(0, im.tuple_get(0, "a")) - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, im.tuple_get(0, "a")) - ), - ), - im.as_fieldop(im.lambda_("__val")(im.call("cast_")(im.deref("__val"), "int32")))( - im.tuple_get(1, "a") + im.cast_as_fieldop("int32")(im.tuple_get(0, im.tuple_get(0, "a"))), + im.cast_as_fieldop("int32")(im.tuple_get(1, im.tuple_get(0, "a"))), ), + im.cast_as_fieldop("int32")(im.tuple_get(1, "a")), ) assert lowered_inlined.expr == reference @@ -378,7 +378,7 @@ def foo(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop("minus")(im.literal("0", "float64"), "inp") + reference = im.op_as_fieldop("neg")("inp") assert lowered.expr == reference @@ -390,7 +390,7 @@ def foo(inp: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.op_as_fieldop("plus")(im.literal("0", "float64"), "inp") + reference = im.ref("inp") assert lowered.expr == reference @@ -551,7 +551,7 @@ def foo(a: gtx.Field[[TDim], "int32"]) -> gtx.Field[[TDim], "int32"]: reference = im.let( ssa.unique_name("tmp", 0), - im.call("plus")( + im.plus( im.literal("1", "int32"), im.literal("1", "int32"), ), @@ -656,7 +656,7 @@ def foo() -> bool: parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) - reference = im.call("greater")( + reference = im.greater( im.literal("3", "int32"), im.literal("4", "int32"), ) @@ -761,11 +761,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "plus", - im.literal(value="0", typename="float64"), - ) + im.reduce( + "plus", + im.literal(value="0", typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -780,11 +778,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "maximum", - im.literal(value=str(np.finfo(np.float64).min), typename="float64"), - ) + im.reduce( + "maximum", + im.literal(value=str(np.finfo(np.float64).min), typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -799,11 +795,9 @@ def foo(edge_f: gtx.Field[[Edge], float64]): lowered = FieldOperatorLowering.apply(parsed) reference = im.op_as_fieldop( - im.call( - im.call("reduce")( - "minimum", - im.literal(value=str(np.finfo(np.float64).max), typename="float64"), - ) + im.reduce( + "minimum", + im.literal(value=str(np.finfo(np.float64).max), typename="float64"), ) )(im.as_fieldop_neighbors("V2E", "edge_f")) @@ -828,11 +822,9 @@ def foo(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64] im.as_fieldop_neighbors("V2E", "e1"), )( im.op_as_fieldop( - im.call( - im.call("reduce")( - "plus", - im.literal(value="0", typename="float64"), - ) + im.reduce( + "plus", + im.literal(value="0", typename="float64"), ) )(mapped) ) @@ -909,10 +901,21 @@ def foo() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: def test_broadcast(): def foo(inp: gtx.Field[[TDim], float64]): - return broadcast(inp, (UDim, TDim)) + return broadcast(inp, (TDim, UDim)) parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) assert lowered.id == "foo" - assert lowered.expr == im.ref("inp") + assert lowered.expr == im.as_fieldop("deref")(im.ref("inp")) + + +def test_scalar_broadcast(): + def foo(): + return broadcast(1, (TDim, UDim)) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + assert lowered.id == "foo" + assert lowered.expr == im.as_fieldop("deref")(1) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py deleted file mode 100644 index c102df9d57..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ /dev/null @@ -1,598 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# TODO(tehrengruber): The style of the tests in this file is not optimal as a single change in the -# lowering can (and often does) make all of them fail. Once we have embedded field view we want to -# switch to executing the different cases here; once with a regular backend (i.e. including -# parsing) and then with embedded field view (i.e. no parsing). If the results match the lowering -# should be correct. - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest - -import gt4py.next as gtx -from gt4py.next import float32, float64, int32, int64, neighbor_sum -from gt4py.next.ffront import type_specifications as ts_ffront -from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering -from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts, type_translation -from gt4py.next.iterator.type_system import type_specifications as it_ts - - -IDim = gtx.Dimension("IDim") -Edge = gtx.Dimension("Edge") -Vertex = gtx.Dimension("Vertex") -Cell = gtx.Dimension("Cell") -V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. - - -def debug_itir(tree): - """Compare tree snippets while debugging.""" - from devtools import debug - - from gt4py.eve.codegen import format_python_source - from gt4py.next.program_processors import EmbeddedDSL - - debug(format_python_source(EmbeddedDSL.apply(tree))) - - -def test_copy(): - def copy_field(inp: gtx.Field[[TDim], float64]): - return inp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - assert lowered.id == "copy_field" - assert lowered.expr == im.ref("inp") - - -def test_scalar_arg(): - def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: - return alpha * bar - - parsed = FieldOperatorParser.apply_to_function(scalar_arg) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")( - "alpha", "bar" - ) # no difference to non-scalar arg - - assert lowered.expr == reference - - -def test_multicopy(): - def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1, inp2 - - parsed = FieldOperatorParser.apply_to_function(multicopy) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple("inp1", "inp2") - - assert lowered.expr == reference - - -def test_arithmetic(): - def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1 + inp2 - - parsed = FieldOperatorParser.apply_to_function(arithmetic) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") - - assert lowered.expr == reference - - -def test_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_negative_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[-1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_temp_assignment(): - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp = inp - inp = tmp - tmp2 = inp - return tmp2 - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")( - im.let( - ssa.unique_name("inp", 0), - ssa.unique_name("tmp", 0), - )( - im.let( - ssa.unique_name("tmp2", 0), - ssa.unique_name("inp", 0), - )(ssa.unique_name("tmp2", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_unary_ops(): - def unary(inp: gtx.Field[[TDim], float64]): - tmp = +inp - tmp = -tmp - return tmp - - parsed = FieldOperatorParser.apply_to_function(unary) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("0", "float64")), "inp" - ), - )( - im.let( - ssa.unique_name("tmp", 1), - im.promote_to_lifted_stencil("minus")( - im.promote_to_const_iterator(im.literal("0", "float64")), ssa.unique_name("tmp", 0) - ), - )(ssa.unique_name("tmp", 1)) - ) - - assert lowered.expr == reference - - -@pytest.mark.parametrize("var, var_type", [("-1.0", "float64"), ("True", "bool")]) -def test_unary_op_type_conversion(var, var_type): - def unary_float(): - return float(-1) - - def unary_bool(): - return bool(-1) - - fun = unary_bool if var_type == "bool" else unary_float - parsed = FieldOperatorParser.apply_to_function(fun) - lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_const_iterator(im.literal(var, var_type)) - - assert lowered.expr == reference - - -def test_unpacking(): - """Unpacking assigns should get separated.""" - - def unpacking( - inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] - ) -> gtx.Field[[TDim], float64]: - tmp1, tmp2 = inp1, inp2 # noqa - return tmp1 - - parsed = FieldOperatorParser.apply_to_function(unpacking) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("inp1", "inp2") - tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") - tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") - - reference = im.let("__tuple_tmp_0", tuple_expr)( - im.let( - ssa.unique_name("tmp1", 0), - tuple_access_0, - )( - im.let( - ssa.unique_name("tmp2", 0), - tuple_access_1, - )(ssa.unique_name("tmp1", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_annotated_assignment(): - pytest.xfail("Annotated assignments are not properly supported at the moment.") - - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp: gtx.Field[[TDim], float64] = inp - return tmp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_call(): - # create something that appears to the lowering like a field operator. - # we could also create an actual field operator, but we want to avoid - # using such heavy constructs for testing the lowering. - field_type = type_translation.from_type_hint(gtx.Field[[TDim], float64]) - identity = SimpleNamespace( - __gt_type__=lambda: ts_ffront.FieldOperatorType( - definition=ts.FunctionType( - pos_only_args=[field_type], pos_or_kw_args={}, kw_only_args={}, returns=field_type - ) - ) - ) - - def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: - return identity(inp) - - parsed = FieldOperatorParser.apply_to_function(call) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.call("identity")("inp") - - assert lowered.expr == reference - - -def test_temp_tuple(): - """Returning a temp tuple should work.""" - - def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): - tmp = a, b - return tmp - - parsed = FieldOperatorParser.apply_to_function(temp_tuple) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("a", "b") - reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_unary_not(): - def unary_not(cond: gtx.Field[[TDim], "bool"]): - return not cond - - parsed = FieldOperatorParser.apply_to_function(unary_not) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("not_")("cond") - - assert lowered.expr == reference - - -def test_binary_plus(): - def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a + b - - parsed = FieldOperatorParser.apply_to_function(plus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("a", "b") - - assert lowered.expr == reference - - -def test_add_scalar_literal_to_field(): - def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: - return 2.0 + a - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" - ) - - assert lowered.expr == reference - - -def test_add_scalar_literals(): - def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: - tmp = int32(1) + int32("1") - return a + tmp - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - ), - )(im.promote_to_lifted_stencil("plus")("a", ssa.unique_name("tmp", 0))) - - assert lowered.expr == reference - - -def test_binary_mult(): - def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a * b - - parsed = FieldOperatorParser.apply_to_function(mult) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")("a", "b") - - assert lowered.expr == reference - - -def test_binary_minus(): - def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a - b - - parsed = FieldOperatorParser.apply_to_function(minus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("minus")("a", "b") - - assert lowered.expr == reference - - -def test_binary_div(): - def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a / b - - parsed = FieldOperatorParser.apply_to_function(division) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("divides")("a", "b") - - assert lowered.expr == reference - - -def test_binary_and(): - def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a & b - - parsed = FieldOperatorParser.apply_to_function(bit_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")("a", "b") - - assert lowered.expr == reference - - -def test_scalar_and(): - def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: - return a & False - - parsed = FieldOperatorParser.apply_to_function(scalar_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - "a", im.promote_to_const_iterator(im.literal("False", "bool")) - ) - - assert lowered.expr == reference - - -def test_binary_or(): - def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a | b - - parsed = FieldOperatorParser.apply_to_function(bit_or) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("or_")("a", "b") - - assert lowered.expr == reference - - -def test_compare_scalars(): - def comp_scalars() -> bool: - return 3 > 4 - - parsed = FieldOperatorParser.apply_to_function(comp_scalars) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")( - im.promote_to_const_iterator(im.literal("3", "int32")), - im.promote_to_const_iterator(im.literal("4", "int32")), - ) - - assert lowered.expr == reference - - -def test_compare_gt(): - def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a > b - - parsed = FieldOperatorParser.apply_to_function(comp_gt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")("a", "b") - - assert lowered.expr == reference - - -def test_compare_lt(): - def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a < b - - parsed = FieldOperatorParser.apply_to_function(comp_lt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("less")("a", "b") - - assert lowered.expr == reference - - -def test_compare_eq(): - def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): - return a == b - - parsed = FieldOperatorParser.apply_to_function(comp_eq) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("eq")("a", "b") - - assert lowered.expr == reference - - -def test_compare_chain(): - def compare_chain( - a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] - ) -> gtx.Field[[IDim], bool]: - return a > b > c - - parsed = FieldOperatorParser.apply_to_function(compare_chain) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - im.promote_to_lifted_stencil("greater")("a", "b"), - im.promote_to_lifted_stencil("greater")("b", "c"), - ) - - assert lowered.expr == reference - - -def test_reduction_lowering_simple(): - def reduction(edge_f: gtx.Field[[Edge], float64]): - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref(im.promote_to_const_iterator(im.literal(value="0", typename="float64"))), - ) - ) - )(im.lifted_neighbors("V2E", "edge_f")) - - assert lowered.expr == reference - - -def test_reduction_lowering_expr(): - def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): - e1_nbh = e1(V2E) - return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( - im.promote_to_lifted_stencil("make_const_list")( - im.promote_to_const_iterator(im.literal("1.1", "float64")) - ), - im.promote_to_lifted_stencil(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), - ) - - reference = im.let( - ssa.unique_name("e1_nbh", 0), - im.lifted_neighbors("V2E", "e1"), - )( - im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref( - im.promote_to_const_iterator(im.literal(value="0", typename="float64")) - ), - ) - ) - )(mapped) - ) - - assert lowered.expr == reference - - -def test_builtin_int_constructors(): - def int_constrs() -> tuple[int32, int32, int64, int32, int64]: - return 1, int32(1), int64(1), int32("1"), int64("1") - - parsed = FieldOperatorParser.apply_to_function(int_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - ) - - assert lowered.expr == reference - - -def test_builtin_float_constructors(): - def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: - return ( - 0.1, - float(0.1), - float32(0.1), - float64(0.1), - float(".1"), - float32(".1"), - float64(".1"), - ) - - parsed = FieldOperatorParser.apply_to_function(float_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - ) - - assert lowered.expr == reference - - -def test_builtin_bool_constructors(): - def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: - return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") - - parsed = FieldOperatorParser.apply_to_function(bool_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(0)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(5)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("True")), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("False")), "bool")), - ) - - assert lowered.expr == reference diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index a6231c22a7..cbaa84454d 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -17,7 +17,7 @@ from gt4py.next import errors from gt4py.next.ffront.func_to_past import ProgramParser from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.type_system import type_specifications as ts @@ -46,7 +46,6 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -59,8 +58,30 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), args=[ P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), - P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value="0", + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), + ], + ), + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value="1", + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), + ], + ), ], ) ], @@ -78,8 +99,8 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): params=[ P(itir.Sym, id=eve.SymbolName("in_field")), P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), + P(itir.Sym, id=eve.SymbolName("__in_field_0_range")), + P(itir.Sym, id=eve.SymbolName("__out_0_range")), ], body=[set_at_pattern], ) @@ -93,7 +114,6 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -107,18 +127,58 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) args=[ P(itir.AxisLiteral, value="IDim"), P( - itir.Literal, - value="1", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("plus")), + args=[ + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value="0", + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), + ], + ), + P( + itir.Literal, + value="1", + type=ts.ScalarType( + kind=getattr( + ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper() + ) + ), + ), + ], ), P( - itir.Literal, - value="2", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("plus")), + args=[ + P( + itir.FunCall, + fun=P(itir.SymRef, id=eve.SymbolRef("tuple_get")), + args=[ + P( + itir.Literal, + value="0", + type=ts.ScalarType(kind=ts.ScalarKind.INT32), + ), + P(itir.SymRef, id=eve.SymbolRef("__out_0_range")), + ], + ), + P( + itir.Literal, + value="2", + type=ts.ScalarType( + kind=getattr( + ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper() + ) + ), + ), + ], ), ], ) @@ -131,8 +191,8 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) params=[ P(itir.Sym, id=eve.SymbolName("in_field")), P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), + P(itir.Sym, id=eve.SymbolName("__in_field_0_range")), + P(itir.Sym, id=eve.SymbolName("__out_0_range")), ], body=[set_at_pattern], ) @@ -149,9 +209,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2[1:])) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail( @@ -166,9 +224,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2)) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail @@ -194,7 +250,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): ProgramParser.apply_to_function(invalid_call_sig_program_def), function_definitions=[], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) assert exc_info.match("Invalid call to 'identity'") diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py deleted file mode 100644 index fefd3c653b..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ /dev/null @@ -1,214 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import re - -import pytest - -import gt4py.eve as eve -import gt4py.next as gtx -from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import errors -from gt4py.next.ffront.func_to_past import ProgramParser -from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.iterator import ir as itir -from gt4py.next.type_system import type_specifications as ts - -from next_tests.past_common_fixtures import ( - IDim, - copy_program_def, - copy_restrict_program_def, - float64, - identity_def, - invalid_call_sig_program_def, -) - - -@pytest.fixture -def itir_identity_fundef(): - return itir.FunctionDefinition( - id="identity", - params=[itir.Sym(id="x")], - expr=itir.FunCall(fun=itir.SymRef(id="deref"), args=[itir.SymRef(id="x")]), - ) - - -def test_copy_lowering(copy_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), - P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), - ], - ) - ], - ), - stencil=P( - itir.Lambda, - params=[P(itir.Sym, id=eve.SymbolName("__stencil_arg0"))], - expr=P( - itir.FunCall, - fun=P( - itir.Lambda, - params=[P(itir.Sym)], - expr=P(itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("deref"))), - ), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("identity")), - args=[P(itir.SymRef, id=eve.SymbolRef("__stencil_arg0"))], - ) - ], - ), - ), - inputs=[P(itir.SymRef, id=eve.SymbolRef("in_field"))], - output=P(itir.SymRef, id=eve.SymbolRef("out")), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_copy_restrict_lowering(copy_restrict_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_restrict_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P( - itir.Literal, - value="1", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - P( - itir.Literal, - value="2", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - ], - ) - ], - ), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_restrict_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2[1:])) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail( - reason="slicing is only allowed if all fields are sliced in the same way." -) # see ADR 10 -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2)) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail -def test_inout_prohibited(identity_def): - identity = gtx.field_operator(identity_def) - - def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): - identity(inout_field, out=inout_field) - - with pytest.raises( - ValueError, match=(r"Call to function with field as input and output not allowed.") - ): - ProgramLowering.apply( - ProgramParser.apply_to_function(inout_field_program), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - -def test_invalid_call_sig_program(invalid_call_sig_program_def): - with pytest.raises(errors.DSLError) as exc_info: - ProgramLowering.apply( - ProgramParser.apply_to_function(invalid_call_sig_program_def), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - assert exc_info.match("Invalid call to 'identity'") - # TODO(tehrengruber): re-enable again when call signature check doesn't return - # immediately after missing `out` argument - # assert ( - # re.search( - # "Function takes 1 arguments, but 2 were given.", exc_info.value.__cause__.args[0] - # ) - # is not None - # ) - assert ( - re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) - is not None - ) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py new file mode 100644 index 0000000000..f9393bd99c --- /dev/null +++ b/tests/next_tests/unit_tests/ffront_tests/test_type_deduction.py @@ -0,0 +1,412 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import re +from typing import Optional, Pattern + +import pytest + +import gt4py.next.ffront.type_specifications +from gt4py.next import ( + Dimension, + DimensionKind, + Field, + FieldOffset, + astype, + broadcast, + common, + errors, + float32, + float64, + int32, + int64, + neighbor_sum, + where, +) +from gt4py.next.ffront.ast_passes import single_static_assign as ssa +from gt4py.next.ffront.experimental import as_offset +from gt4py.next.ffront.func_to_foast import FieldOperatorParser +from gt4py.next.type_system import type_info, type_specifications as ts + +TDim = Dimension("TDim") # Meaningless dimension, used for tests. + + +def test_unpack_assign(): + def unpack_explicit_tuple( + a: Field[[TDim], float64], b: Field[[TDim], float64] + ) -> tuple[Field[[TDim], float64], Field[[TDim], float64]]: + tmp_a, tmp_b = (a, b) + return tmp_a, tmp_b + + parsed = FieldOperatorParser.apply_to_function(unpack_explicit_tuple) + + assert parsed.body.annex.symtable[ssa.unique_name("tmp_a", 0)].type == ts.FieldType( + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) + ) + assert parsed.body.annex.symtable[ssa.unique_name("tmp_b", 0)].type == ts.FieldType( + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) + ) + + +def test_assign_tuple(): + def temp_tuple(a: Field[[TDim], float64], b: Field[[TDim], int64]): + tmp = a, b + return tmp + + parsed = FieldOperatorParser.apply_to_function(temp_tuple) + + assert parsed.body.annex.symtable[ssa.unique_name("tmp", 0)].type == ts.TupleType( + types=[ + ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None)), + ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64, shape=None)), + ] + ) + + +def test_adding_bool(): + """Expect an error when using arithmetic on bools.""" + + def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): + return a + b + + with pytest.raises( + errors.DSLError, match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'.") + ): + _ = FieldOperatorParser.apply_to_function(add_bools) + + +def test_binop_nonmatching_dims(): + """Dimension promotion is applied before Binary operations, i.e., they can also work on two fields that don't have the same dimensions.""" + X = Dimension("X") + Y = Dimension("Y") + + def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): + return a + b + + parsed = FieldOperatorParser.apply_to_function(nonmatching) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[X, Y], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + +def test_bitopping_float(): + def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): + return a & b + + with pytest.raises( + errors.DSLError, + match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), + ): + _ = FieldOperatorParser.apply_to_function(float_bitop) + + +def test_signing_bool(): + def sign_bool(a: Field[[TDim], bool]): + return -a + + with pytest.raises( + errors.DSLError, + match=r"Incompatible type for unary operator '\-': 'Field\[\[TDim\], bool\]'.", + ): + _ = FieldOperatorParser.apply_to_function(sign_bool) + + +def test_notting_int(): + def not_int(a: Field[[TDim], int64]): + return not a + + with pytest.raises( + errors.DSLError, + match=r"Incompatible type for unary operator 'not': 'Field\[\[TDim\], int64\]'.", + ): + _ = FieldOperatorParser.apply_to_function(not_int) + + +@pytest.fixture +def premap_setup(): + X = Dimension("X") + Y = Dimension("Y") + Y2XDim = Dimension("Y2X", kind=DimensionKind.LOCAL) + Y2X = FieldOffset("Y2X", source=X, target=(Y, Y2XDim)) + return X, Y, Y2XDim, Y2X + + +def test_premap(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup + + def premap_fo(bar: Field[[X], int64]) -> Field[[Y], int64]: + return bar(Y2X[0]) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) + ) + + +def test_premap_nbfield(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup + + def premap_fo(bar: Field[[X], int64]) -> Field[[Y, Y2XDim], int64]: + return bar(Y2X) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[Y, Y2XDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) + ) + + +def test_premap_reduce(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup + + def premap_fo(bar: Field[[X], int32]) -> Field[[Y], int32]: + return 2 * neighbor_sum(bar(Y2X), axis=Y2XDim) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) + ) + + +def test_premap_reduce_sparse(premap_setup): + X, Y, Y2XDim, Y2X = premap_setup + + def premap_fo(bar: Field[[Y, Y2XDim], int32]) -> Field[[Y], int32]: + return 5 * neighbor_sum(bar, axis=Y2XDim) + + parsed = FieldOperatorParser.apply_to_function(premap_fo) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) + ) + + +def test_mismatched_literals(): + def mismatched_lit() -> Field[[TDim], "float32"]: + return float32("1.0") + float64("1.0") + + with pytest.raises( + errors.DSLError, + match=(r"Could not promote 'float32' and 'float64' to common type in call to +."), + ): + _ = FieldOperatorParser.apply_to_function(mismatched_lit) + + +def test_broadcast_multi_dim(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + CDim = Dimension("CDim") + + def simple_broadcast(a: Field[[ADim], float64]): + return broadcast(a, (ADim, BDim, CDim)) + + parsed = FieldOperatorParser.apply_to_function(simple_broadcast) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[ADim, BDim, CDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + +def test_broadcast_disjoint(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + CDim = Dimension("CDim") + + def disjoint_broadcast(a: Field[[ADim], float64]): + return broadcast(a, (BDim, CDim)) + + with pytest.raises(errors.DSLError, match=r"expected broadcast dimension\(s\) \'.*\' missing"): + _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) + + +def test_broadcast_badtype(): + ADim = Dimension("ADim") + BDim = "BDim" + CDim = Dimension("CDim") + + def badtype_broadcast(a: Field[[ADim], float64]): + return broadcast(a, (BDim, CDim)) + + with pytest.raises( + errors.DSLError, match=r"expected all broadcast dimensions to be of type 'Dimension'." + ): + _ = FieldOperatorParser.apply_to_function(badtype_broadcast) + + +def test_where_dim(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + + def simple_where(a: Field[[ADim], bool], b: Field[[ADim, BDim], float64]): + return where(a, b, 9.0) + + parsed = FieldOperatorParser.apply_to_function(simple_where) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[ADim, BDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + +def test_where_broadcast_dim(): + ADim = Dimension("ADim") + + def simple_where(a: Field[[ADim], bool]): + return where(a, 5.0, 9.0) + + parsed = FieldOperatorParser.apply_to_function(simple_where) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ) + + +def test_where_tuple_dim(): + ADim = Dimension("ADim") + + def tuple_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): + return where(a, ((5.0, 9.0), (b, 6.0)), ((8.0, b), (5.0, 9.0))) + + parsed = FieldOperatorParser.apply_to_function(tuple_where) + + assert parsed.body.stmts[0].value.type == ts.TupleType( + types=[ + ts.TupleType( + types=[ + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ] + ), + ts.TupleType( + types=[ + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ] + ), + ] + ) + + +def test_where_bad_dim(): + ADim = Dimension("ADim") + + def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): + return where(a, ((5.0, 9.0), (b, 6.0)), b) + + with pytest.raises(errors.DSLError, match=r"Return arguments need to be of same type"): + _ = FieldOperatorParser.apply_to_function(bad_dim_where) + + +def test_where_mixed_dims(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + + def tuple_where_mix_dims( + a: Field[[ADim], bool], b: Field[[ADim], float64], c: Field[[ADim, BDim], float64] + ): + return where(a, ((c, 9.0), (b, 6.0)), ((8.0, b), (5.0, 9.0))) + + parsed = FieldOperatorParser.apply_to_function(tuple_where_mix_dims) + + assert parsed.body.stmts[0].value.type == ts.TupleType( + types=[ + ts.TupleType( + types=[ + ts.FieldType( + dims=[ADim, BDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ), + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ] + ), + ts.TupleType( + types=[ + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), + ] + ), + ] + ) + + +def test_astype_dtype(): + def simple_astype(a: Field[[TDim], float64]): + return astype(a, bool) + + parsed = FieldOperatorParser.apply_to_function(simple_astype) + + assert parsed.body.stmts[0].value.type == ts.FieldType( + dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) + ) + + +def test_astype_wrong_dtype(): + def simple_astype(a: Field[[TDim], float64]): + # we just use broadcast here, but anything with type function is fine + return astype(a, broadcast) + + with pytest.raises( + errors.DSLError, + match=r"Invalid call to 'astype': second argument must be a scalar type, got.", + ): + _ = FieldOperatorParser.apply_to_function(simple_astype) + + +def test_astype_wrong_value_type(): + def simple_astype(a: Field[[TDim], float64]): + # we just use broadcast here but anything that is not a field, scalar or tuple thereof works + return astype(broadcast, bool) + + with pytest.raises(errors.DSLError) as exc_info: + _ = FieldOperatorParser.apply_to_function(simple_astype) + + assert ( + re.search("Expected 1st argument to be of type", exc_info.value.__cause__.args[0]) + is not None + ) + + +def test_mod_floats(): + def modulo_floats(inp: Field[[TDim], float]): + return inp % 3.0 + + with pytest.raises(errors.DSLError, match=r"Type 'float64' can not be used in operator '%'"): + _ = FieldOperatorParser.apply_to_function(modulo_floats) + + +def test_undefined_symbols(): + def return_undefined(): + return undefined_symbol + + with pytest.raises(errors.DSLError, match="Undeclared symbol"): + _ = FieldOperatorParser.apply_to_function(return_undefined) + + +def test_as_offset_dim(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + Boff = FieldOffset("Boff", source=BDim, target=(BDim,)) + + def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): + return a(as_offset(Boff, b)) + + with pytest.raises(errors.DSLError, match=f"not in list of offset field dimensions"): + _ = FieldOperatorParser.apply_to_function(as_offset_dim) + + +def test_as_offset_dtype(): + ADim = Dimension("ADim") + BDim = Dimension("BDim") + Boff = FieldOffset("Boff", source=BDim, target=(BDim,)) + + def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): + return a(as_offset(Boff, b)) + + with pytest.raises(errors.DSLError, match=f"expected integer for offset field dtype"): + _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py new file mode 100644 index 0000000000..a91dbeb608 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_embedded_field_with_list.py @@ -0,0 +1,150 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + +import gt4py.next as gtx +from gt4py.next.embedded import context as embedded_context +from gt4py.next.iterator import embedded, runtime +from gt4py.next.iterator.builtins import ( + as_fieldop, + deref, + if_, + make_const_list, + map_, + neighbors, + plus, +) + + +E = gtx.Dimension("E") +V = gtx.Dimension("V") +E2VDim = gtx.Dimension("E2V", kind=gtx.DimensionKind.LOCAL) +E2V = gtx.FieldOffset("E2V", source=V, target=(E, E2VDim)) + + +# 0 --0-- 1 --1-- 2 +e2v_arr = np.array([[0, 1], [1, 2]]) +e2v_conn = gtx.as_connectivity( + domain={E: 2, E2VDim: 2}, + codomain=V, + data=e2v_arr, +) + + +def test_write_neighbors(): + def testee(inp): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda it: neighbors(E2V, it), domain)(inp) + + inp = gtx.as_field([V], np.arange(3)) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp) + + ref = e2v_arr + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_const_list(): + def testee(): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda: make_const_list(42.0), domain)() + + with embedded_context.new_context(offset_provider={}) as ctx: + result = ctx.run(testee) + + ref = np.asarray([[42.0], [42.0]]) + + assert result.domain.dims[0] == E + assert result.domain.dims[1] == embedded._CONST_DIM # this is implementation detail + assert result.shape[1] == 1 # this is implementation detail + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_neighbors_and_const_list(): + def testee(inp): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), + as_fieldop(lambda: make_const_list(42.0), domain)(), + ) + + inp = gtx.as_field([V], np.arange(3)) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp) + + ref = e2v_arr + 42.0 + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_conditional_neighbors_and_const_list(): + def testee(inp, mask): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda m, x, y: map_(if_)(deref(m), deref(x), deref(y)), domain)( + as_fieldop(lambda it: make_const_list(deref(it)), domain)(mask), + as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), + as_fieldop(lambda it: make_const_list(deref(it)), domain)(42.0), + ) + + inp = gtx.as_field([V], np.arange(3)) + mask_field = gtx.as_field([E], np.array([True, False])) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp, mask_field) + + ref = np.empty_like(e2v_arr, dtype=float) + ref[0, :] = e2v_arr[0, :] + ref[1, :] = 42.0 + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_non_mapped_conditional_neighbors_and_const_list(): + """ + This test-case demonstrates a non-supported pattern: + Current ITIR requires the `if_` to be `map_`ed, see `test_write_map_conditional_neighbors_and_const_list`. + We keep it here for documenting corner cases of the `itir.List` implementation for future discussions. + """ + + pytest.skip("Unsupported.") + + def testee(inp, mask): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda m, x, y: if_(deref(m), deref(x), deref(y)), domain)( + mask, + as_fieldop(lambda it: make_const_list(deref(it)), domain)(42.0), + as_fieldop(lambda it: neighbors(E2V, it), domain)(inp), + ) + + inp = gtx.as_field([V], np.arange(3)) + mask_field = gtx.as_field([E], np.array([True, False])) + with embedded_context.new_context(offset_provider={"E2V": e2v_conn}) as ctx: + result = ctx.run(testee, inp, mask_field) + + ref = np.empty_like(e2v_arr, dtype=float) + ref[0, :] = e2v_arr[0, :] + ref[1, :] = 42.0 + np.testing.assert_array_equal(result.asnumpy(), ref) + + +def test_write_map_const_list_and_const_list(): + def testee(): + domain = runtime.UnstructuredDomain({E: range(2)}) + return as_fieldop(lambda x, y: map_(plus)(deref(x), deref(y)), domain)( + as_fieldop(lambda: make_const_list(1.0), domain)(), + as_fieldop(lambda: make_const_list(42.0), domain)(), + ) + + with embedded_context.new_context(offset_provider={}) as ctx: + result = ctx.run(testee) + + ref = np.asarray([[43.0], [43.0]]) + + assert result.domain.dims[0] == E + assert result.domain.dims[1] == embedded._CONST_DIM # this is implementation detail + assert result.shape[1] == 1 # this is implementation detail + np.testing.assert_array_equal(result.asnumpy(), ref) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py new file mode 100644 index 0000000000..ff7a761c5a --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -0,0 +1,48 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_dynamic_shifts +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def test_inline_dynamic_shift_as_fieldop_arg(): + testee = im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + im.as_fieldop("deref")("inp"), "offset_field" + ) + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected + + +def test_inline_dynamic_shift_let_var(): + testee = im.let("tmp", im.as_fieldop("deref")("inp"))( + im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + "tmp", "offset_field" + ) + ) + + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index da4bea8874..f825c3823b 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -6,9 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator import ir -from gt4py.next.iterator.pretty_parser import pparse +from gt4py.next.iterator import ir, builtins from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.type_system import type_specifications as ts @@ -111,7 +111,7 @@ def test_tuple_get(): testee = "x[42]" expected = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", builtins.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) actual = pparse(testee) assert actual == expected @@ -127,7 +127,7 @@ def test_make_tuple(): def test_named_range_horizontal(): - testee = "IDimₕ: [x, y)" + testee = "IDimₕ: [x, y[" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], @@ -137,7 +137,7 @@ def test_named_range_horizontal(): def test_named_range_vertical(): - testee = "IDimᵥ: [x, y)" + testee = "IDimᵥ: [x, y[" expected = ir.FunCall( fun=ir.SymRef(id="named_range"), args=[ @@ -208,18 +208,6 @@ def test_temporary(): assert actual == expected -def test_stencil_closure(): - testee = "y ← (deref)(x) @ cartesian_domain();" - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - actual = pparse(testee) - assert actual == expected - - def test_set_at(): testee = "y @ cartesian_domain() ← x;" expected = ir.SetAt( @@ -262,28 +250,6 @@ def test_if_stmt(): assert actual == expected -# TODO(havogt): remove after refactoring to GTIR -def test_fencil_definition(): - testee = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" - expected = ir.FencilDefinition( - id="f", - function_definitions=[ - ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - ], - params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - closures=[ - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - ], - ) - actual = pparse(testee) - assert actual == expected - - def test_program(): testee = "f(d, x, y) {\n g = λ(x) → x;\n tmp = temporary(domain=cartesian_domain(), dtype=float64);\n y @ cartesian_domain() ← x;\n}" expected = ir.Program( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 69a45cf128..b0f7021bc0 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -6,9 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from gt4py.next.iterator import ir -from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat +from gt4py.next.iterator import ir, builtins from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.type_system import type_specifications as ts @@ -200,7 +200,7 @@ def test_shift(): def test_tuple_get(): testee = ir.FunCall( fun=ir.SymRef(id="tuple_get"), - args=[im.literal("42", ir.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], + args=[im.literal("42", builtins.INTEGER_INDEX_BUILTIN), ir.SymRef(id="x")], ) expected = "x[42]" actual = pformat(testee) @@ -233,7 +233,7 @@ def test_named_range_horizontal(): fun=ir.SymRef(id="named_range"), args=[ir.AxisLiteral(value="IDim"), ir.SymRef(id="x"), ir.SymRef(id="y")], ) - expected = "IDimₕ: [x, y)" + expected = "IDimₕ: [x, y[" actual = pformat(testee) assert actual == expected @@ -313,18 +313,6 @@ def test_temporary(): assert actual == expected -def test_stencil_closure(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - expected = "y ← (deref)(x) @ cartesian_domain();" - actual = pformat(testee) - assert actual == expected - - def test_set_at(): testee = ir.SetAt( expr=ir.SymRef(id="x"), @@ -336,28 +324,6 @@ def test_set_at(): assert actual == expected -# TODO(havogt): remove after refactoring. -def test_fencil_definition(): - testee = ir.FencilDefinition( - id="f", - function_definitions=[ - ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - ], - params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - closures=[ - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - ], - ) - actual = pformat(testee) - expected = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" - assert actual == expected - - def test_program(): testee = ir.Program( id="f", diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 1f08362f4f..bf2df06bf2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -10,18 +10,24 @@ import pytest import gt4py.next as gtx +from gt4py.next import common from gt4py.next.iterator.builtins import deref from gt4py.next.iterator.runtime import CartesianDomain, UnstructuredDomain, _deduce_domain, fundef -from next_tests.unit_tests.conftest import DummyConnectivity - @fundef def foo(inp): return deref(inp) -connectivity = DummyConnectivity(max_neighbors=0, has_skip_values=True) +connectivity = common.ConnectivityType( + domain=[gtx.Dimension("dummy_origin"), gtx.Dimension("dummy_neighbor")], + codomain=gtx.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE, + dtype=None, +) + +I = gtx.Dimension("I") def test_deduce_domain(): @@ -29,15 +35,12 @@ def test_deduce_domain(): assert isinstance(_deduce_domain(UnstructuredDomain(), {}), UnstructuredDomain) assert isinstance(_deduce_domain({}, {"foo": connectivity}), UnstructuredDomain) assert isinstance( - _deduce_domain(CartesianDomain([("I", range(1))]), {"foo": connectivity}), CartesianDomain + _deduce_domain(CartesianDomain([(I, range(1))]), {"foo": connectivity}), CartesianDomain ) -I = gtx.Dimension("I") - - def test_embedded_error_on_wrong_domain(): - dom = CartesianDomain([("I", range(1))]) + dom = CartesianDomain([(I, range(1))]) out = gtx.as_field([I], np.zeros(1)) with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 05cd6b6854..6e2f941095 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -5,6 +5,7 @@ # # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy # TODO: test failure when something is not typed after inference is run # TODO: test lift with no args @@ -15,6 +16,7 @@ import pytest +from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.type_system import ( @@ -23,13 +25,12 @@ ) from gt4py.next.type_system import type_specifications as ts -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh - from next_tests.integration_tests.cases import ( C2E, E2V, V2E, E2VDim, + Edge, IDim, Ioff, JDim, @@ -37,21 +38,25 @@ Koff, V2EDim, Vertex, - Edge, - mesh_descriptor, exec_alloc_descriptor, + mesh_descriptor, unstructured_case, ) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh +from next_tests.integration_tests.cases import IField, JField bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -float64_list_type = it_ts.ListType(element_type=float64_type) -int_list_type = it_ts.ListType(element_type=int_type) +float64_list_type = ts.ListType(element_type=float64_type) +int_list_type = ts.ListType(element_type=int_type) float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) +float_j_field = ts.FieldType(dims=[JDim], dtype=float64_type) +float_ij_field = ts.FieldType(dims=[IDim, JDim], dtype=float64_type) float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) float_edge_k_field = ts.FieldType(dims=[Edge, KDim], dtype=float64_type) +float_edge_field = ts.FieldType(dims=[Edge], dtype=float64_type) float_vertex_v2e_field = ts.FieldType(dims=[Vertex, V2EDim], dtype=float64_type) it_on_v_of_e_type = it_ts.IteratorType( @@ -75,12 +80,14 @@ def expression_test_cases(): (im.plus(1, 2), int_type), (im.eq(1, 2), bool_type), (im.deref(im.ref("it", it_on_e_of_e_type)), it_on_e_of_e_type.element_type), - (im.call("can_deref")(im.ref("it", it_on_e_of_e_type)), bool_type), + (im.can_deref(im.ref("it", it_on_e_of_e_type)), bool_type), (im.if_(True, 1, 2), int_type), - (im.call("make_const_list")(True), it_ts.ListType(element_type=bool_type)), - (im.call("list_get")(0, im.ref("l", it_ts.ListType(element_type=bool_type))), bool_type), + (im.call("make_const_list")(True), ts.ListType(element_type=bool_type)), + (im.list_get(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), ( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), it_ts.NamedRangeType(dim=Vertex), ), ( @@ -91,7 +98,9 @@ def expression_test_cases(): ), ( im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1) + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ) ), it_ts.DomainType(dims=[Vertex]), ), @@ -103,13 +112,17 @@ def expression_test_cases(): # tuple_get (im.tuple_get(0, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), int_type), (im.tuple_get(1, im.make_tuple(im.ref("a", int_type), im.ref("b", bool_type))), bool_type), + ( + im.tuple_get(0, im.ref("t", ts.DeferredType(constraint=None))), + ts.DeferredType(constraint=None), + ), # neighbors ( im.neighbors("E2V", im.ref("a", it_on_e_of_e_type)), - it_ts.ListType(element_type=it_on_e_of_e_type.element_type), + ts.ListType(element_type=it_on_e_of_e_type.element_type), ), # cast - (im.call("cast_")(1, "int32"), int_type), + (im.cast_(1, int_type), int_type), # TODO: lift # TODO: scan # map @@ -118,18 +131,16 @@ def expression_test_cases(): int_list_type, ), # reduce - (im.call(im.call("reduce")("plus", 0))(im.ref("l", int_list_type)), int_type), + (im.reduce("plus", 0)(im.ref("l", int_list_type)), int_type), ( - im.call( - im.call("reduce")( - im.lambda_("acc", "a", "b")( - im.make_tuple( - im.plus(im.tuple_get(0, "acc"), "a"), - im.plus(im.tuple_get(1, "acc"), "b"), - ) - ), - im.make_tuple(0, 0.0), - ) + im.reduce( + im.lambda_("acc", "a", "b")( + im.make_tuple( + im.plus(im.tuple_get(0, "acc"), "a"), + im.plus(im.tuple_get(1, "acc"), "b"), + ) + ), + im.make_tuple(0, 0.0), )(im.ref("la", int_list_type), im.ref("lb", float64_list_type)), ts.TupleType(types=[int_type, float64_type]), ), @@ -138,58 +149,60 @@ def expression_test_cases(): (im.shift("Ioff", 1)(im.ref("it", it_ijk_type)), it_ijk_type), # as_fieldop ( - im.call( - im.call("as_fieldop")( - "deref", - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + "deref", + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field)), float_i_field, ), ( - im.call( - im.call("as_fieldop")( - im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), - im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.as_fieldop( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), + im.call("unstructured_domain")( + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), + 0, + 1, ), - ) + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), + ), )(im.ref("inp", float_edge_k_field)), float_vertex_k_field, ), ( - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)), ts.TupleType(types=[float_i_field, float_i_field]), ), + ( + im.as_fieldop(im.lambda_("x")(im.deref("x")))( + im.ref("inp", ts.DeferredType(constraint=None)) + ), + ts.DeferredType(constraint=None), + ), # if in field-view scope ( im.if_( False, - im.call( - im.call("as_fieldop")( - im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field), 1.0), - im.call( - im.call("as_fieldop")( - "deref", - im.call("cartesian_domain")( - im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) - ), - ) + im.as_fieldop( + "deref", + im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ), )(im.ref("inp", float_i_field)), ), float_i_field, @@ -208,11 +221,11 @@ def expression_test_cases(): @pytest.mark.parametrize("test_case", expression_test_cases()) def test_expression_type(test_case): mesh = simple_mesh() - offset_provider = {**mesh.offset_provider, "Ioff": IDim, "Joff": JDim, "Koff": KDim} + offset_provider_type = {**mesh.offset_provider_type, "Ioff": IDim, "Joff": JDim, "Koff": KDim} testee, expected_type = test_case result = itir_type_inference.infer( - testee, offset_provider=offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=offset_provider_type, allow_undeclared_symbols=True ) assert result.type == expected_type @@ -221,18 +234,41 @@ def test_adhoc_polymorphism(): func = im.lambda_("a")(im.lambda_("b")(im.make_tuple("a", "b"))) testee = im.call(im.call(func)(im.ref("a_", bool_type)))(im.ref("b_", int_type)) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.type == ts.TupleType(types=[bool_type, int_type]) +def test_binary_lambda(): + func = im.lambda_("a", "b")(im.make_tuple("a", "b")) + testee = im.call(func)(im.ref("a_", bool_type), im.ref("b_", int_type)) + + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + + expected_type = ts.TupleType(types=[bool_type, int_type]) + assert result.type == expected_type + assert result.fun.params[0].type == bool_type + assert result.fun.params[1].type == int_type + assert result.fun.type == ts.FunctionType( + pos_only_args=[bool_type, int_type], + pos_or_kw_args={}, + kw_only_args={}, + returns=expected_type, + ) + + def test_aliased_function(): testee = im.let("f", im.lambda_("x")("x"))(im.call("f")(1)) - result = itir_type_inference.infer(testee, offset_provider={}) + result = itir_type_inference.infer(testee, offset_provider_type={}) assert result.args[0].type == ts.FunctionType( pos_only_args=[int_type], pos_or_kw_args={}, kw_only_args={}, returns=int_type ) + assert result.args[0].params[0].type == int_type assert result.type == int_type @@ -243,7 +279,7 @@ def test_late_offset_axis(): testee = im.call(func)(im.ensure_offset("V2E")) result = itir_type_inference.infer( - testee, offset_provider=mesh.offset_provider, allow_undeclared_symbols=True + testee, offset_provider_type=mesh.offset_provider_type, allow_undeclared_symbols=True ) assert result.type == it_on_e_of_e_type @@ -252,104 +288,81 @@ def test_cast_first_arg_inference(): # since cast_ is a grammar builtin whose return type is given by its second argument it is # easy to forget inferring the types of the first argument and its children. Simply check # if the first argument has a type inferred correctly here. - testee = im.call("cast_")( - im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64" + testee = im.cast_(im.plus(im.literal_from_value(1), im.literal_from_value(2)), "float64") + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) assert result.args[0].type == int_type assert result.type == float64_type -# TODO(tehrengruber): Rewrite tests to use itir.Program def test_cartesian_fencil_definition(): cartesian_domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.ref("deref"), cartesian_domain)(im.ref("inp")), domain=cartesian_domain, - stencil=im.ref("deref"), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ), ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[IDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=float64_type - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_i_field, - inputs=[float_i_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] - ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) + assert result.type == program_type + domain_type = it_ts.DomainType(dims=[IDim]) + assert result.body[0].domain.type == domain_type + assert result.body[0].expr.type == float_i_field + assert result.body[0].target.type == float_i_field def test_unstructured_fencil_definition(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_edge_k_field), im.sym("out", float_vertex_k_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), unstructured_domain + )(im.ref("inp")), domain=unstructured_domain, - stencil=im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ), ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[Vertex, KDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[Vertex, KDim], - defined_dims=[Edge, KDim], - element_type=float64_type, - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_vertex_k_field, - inputs=[float_edge_k_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_edge_k_field, "out": float_vertex_k_field}, closures=[closure_type] + program_type = it_ts.ProgramType( + params={"inp": float_edge_k_field, "out": float_vertex_k_field} ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + assert result.type == program_type + domain_type = it_ts.DomainType(dims=[Vertex, KDim]) + assert result.body[0].domain.type == domain_type + assert result.body[0].expr.type == float_vertex_k_field + assert result.body[0].target.type == float_vertex_k_field def test_function_definition(): @@ -357,72 +370,64 @@ def test_function_definition(): im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[ itir.FunctionDefinition(id="foo", params=[im.sym("it")], expr=im.deref("it")), itir.FunctionDefinition(id="bar", params=[im.sym("it")], expr=im.call("foo")("it")), ], params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( domain=cartesian_domain, - stencil=im.ref("bar"), - output=im.ref("out"), - inputs=[im.ref("inp")], + expr=im.as_fieldop(im.ref("bar"), cartesian_domain)(im.ref("inp")), + target=im.ref("out"), ), ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[IDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=float64_type - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_i_field, - inputs=[float_i_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] - ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) + assert result.type == program_type + assert result.body[0].expr.type == float_i_field + assert result.body[0].target.type == float_i_field def test_fencil_with_nb_field_input(): mesh = simple_mesh() unstructured_domain = im.call("unstructured_domain")( - im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), - im.call("named_range")(itir.AxisLiteral(value="KDim"), 0, 1), + im.call("named_range")( + itir.AxisLiteral(value="Vertex", kind=common.DimensionKind.HORIZONTAL), 0, 1 + ), + im.call("named_range")( + itir.AxisLiteral(value="KDim", kind=common.DimensionKind.VERTICAL), 0, 1 + ), ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_vertex_v2e_field), im.sym("out", float_vertex_k_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( domain=unstructured_domain, - stencil=im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), - output=im.ref("out"), - inputs=[im.ref("inp")], + expr=im.as_fieldop( + im.lambda_("it")(im.reduce("plus", 0.0)(im.deref("it"))), + unstructured_domain, + )(im.ref("inp")), + target=im.ref("out"), ), ], ) - result = itir_type_inference.infer(testee, offset_provider=mesh.offset_provider) + result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - assert result.closures[0].stencil.expr.args[0].type == float64_list_type - assert result.closures[0].stencil.type.returns == float64_type + stencil = result.body[0].expr.fun.args[0] + assert stencil.expr.args[0].type == float64_list_type + assert stencil.type.returns == float64_type def test_program_tuple_setat_short_target(): @@ -437,16 +442,14 @@ def test_program_tuple_setat_short_target(): declarations=[], body=[ itir.SetAt( - expr=im.call( - im.call("as_fieldop")(im.lambda_()(im.make_tuple(1.0, 2.0)), cartesian_domain) - )(), + expr=im.as_fieldop(im.lambda_()(im.make_tuple(1.0, 2.0)), cartesian_domain)(), domain=cartesian_domain, target=im.make_tuple("out"), ) ], ) - result = itir_type_inference.infer(testee, offset_provider={"Ioff": IDim}) + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) assert ( isinstance(result.body[0].expr.type, ts.TupleType) @@ -458,6 +461,33 @@ def test_program_tuple_setat_short_target(): ) +def test_program_setat_without_domain(): + cartesian_domain = im.call("cartesian_domain")( + im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) + ) + + testee = itir.Program( + id="f", + function_definitions=[], + params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(im.lambda_("x")(im.deref("x")))("inp"), + domain=cartesian_domain, + target=im.ref("out", float_i_field), + ) + ], + ) + + result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) + + assert ( + isinstance(result.body[0].expr.type, ts.DeferredType) + and result.body[0].expr.type.constraint == ts.FieldType + ) + + def test_if_stmt(): cartesian_domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) @@ -475,6 +505,39 @@ def test_if_stmt(): false_branch=[], ) - result = itir_type_inference.infer(testee, offset_provider={}, allow_undeclared_symbols=True) + result = itir_type_inference.infer( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) assert result.cond.type == bool_type assert result.true_branch[0].expr.type == float_i_field + + +def test_as_fieldop_without_domain(): + testee = im.as_fieldop(im.lambda_("it")(im.deref(im.shift("IOff", 1)("it"))))( + im.ref("inp", float_i_field) + ) + result = itir_type_inference.infer( + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert result.type == ts.DeferredType(constraint=ts.FieldType) + assert result.fun.args[0].type.pos_only_args[0] == it_ts.IteratorType( + position_dims="unknown", defined_dims=float_i_field.dims, element_type=float_i_field.dtype + ) + + +def test_reinference(): + testee = im.make_tuple(im.ref("inp1", float_i_field), im.ref("inp2", float_i_field)) + result = itir_type_inference.reinfer(copy.deepcopy(testee)) + assert result.type == ts.TupleType(types=[float_i_field, float_i_field]) + + +def test_func_reinference(): + f_type = ts.FunctionType( + pos_only_args=[], + pos_or_kw_args={}, + kw_only_args={}, + returns=float_i_field, + ) + testee = im.call(im.ref("f", f_type))() + result = itir_type_inference.reinfer(copy.deepcopy(testee)) + assert result.type == float_i_field diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py index bcf8b726be..916ae4e578 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_collapse_tuple.py @@ -8,6 +8,8 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple +from gt4py.next.type_system import type_specifications as ts +from next_tests.unit_tests.iterator_tests.test_type_inference import int_type def test_simple_make_tuple_tuple_get(): @@ -17,8 +19,9 @@ def test_simple_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) expected = tuple_of_size_2 @@ -34,8 +37,9 @@ def test_nested_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == tup_of_size2_from_lambda @@ -49,8 +53,9 @@ def test_different_tuples_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -62,8 +67,9 @@ def test_incompatible_order_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -73,8 +79,9 @@ def test_incompatible_size_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == testee # did nothing @@ -84,8 +91,9 @@ def test_merged_with_smaller_outer_size_make_tuple_tuple_get(): actual = CollapseTuple.apply( testee, ignore_tuple_size=True, - flags=CollapseTuple.Flag.COLLAPSE_MAKE_TUPLE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_MAKE_TUPLE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == im.make_tuple("first", "second") @@ -96,8 +104,9 @@ def test_simple_tuple_get_make_tuple(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.COLLAPSE_TUPLE_GET_MAKE_TUPLE, + enabled_transformations=CollapseTuple.Transformation.COLLAPSE_TUPLE_GET_MAKE_TUPLE, allow_undeclared_symbols=True, + within_stencil=False, ) assert expected == actual @@ -108,8 +117,9 @@ def test_propagate_tuple_get(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.PROPAGATE_TUPLE_GET, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TUPLE_GET, allow_undeclared_symbols=True, + within_stencil=False, ) assert expected == actual @@ -118,15 +128,16 @@ def test_letify_make_tuple_elements(): # anything that is not trivial, i.e. a SymRef, works here el1, el2 = im.let("foo", "foo")("foo"), im.let("bar", "bar")("bar") testee = im.make_tuple(el1, el2) - expected = im.let(("_tuple_el_1", el1), ("_tuple_el_2", el2))( - im.make_tuple("_tuple_el_1", "_tuple_el_2") + expected = im.let(("__ct_el_1", el1), ("__ct_el_2", el2))( + im.make_tuple("__ct_el_1", "__ct_el_2") ) actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + enabled_transformations=CollapseTuple.Transformation.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -138,8 +149,9 @@ def test_letify_make_tuple_with_trivial_elements(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + enabled_transformations=CollapseTuple.Transformation.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -151,8 +163,9 @@ def test_inline_trivial_make_tuple(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.INLINE_TRIVIAL_MAKE_TUPLE, + enabled_transformations=CollapseTuple.Transformation.INLINE_TRIVIAL_MAKE_TUPLE, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -169,8 +182,9 @@ def test_propagate_to_if_on_tuples(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -185,9 +199,10 @@ def test_propagate_to_if_on_tuples_with_let(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=True, - flags=CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES - | CollapseTuple.Flag.LETIFY_MAKE_TUPLE_ELEMENTS, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES + | CollapseTuple.Transformation.LETIFY_MAKE_TUPLE_ELEMENTS, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -198,8 +213,9 @@ def test_propagate_nested_lift(): actual = CollapseTuple.apply( testee, remove_letified_make_tuple_elements=False, - flags=CollapseTuple.Flag.PROPAGATE_NESTED_LET, + enabled_transformations=CollapseTuple.Transformation.PROPAGATE_NESTED_LET, allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected @@ -210,6 +226,88 @@ def test_if_on_tuples_with_let(): )(im.tuple_get(0, "val")) expected = im.if_("pred", 1, 3) actual = CollapseTuple.apply( - testee, remove_letified_make_tuple_elements=False, allow_undeclared_symbols=True + testee, + remove_letified_make_tuple_elements=False, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_tuple_get_on_untyped_ref(): + # test pass gracefully handles untyped nodes. + testee = im.tuple_get(0, im.ref("val", ts.DeferredType(constraint=None))) + + actual = CollapseTuple.apply(testee, allow_undeclared_symbols=True, within_stencil=False) + assert actual == testee + + +def test_if_make_tuple_reorder_cps(): + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.make_tuple(2, 1), im.make_tuple(4, 3)) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_nested_if_make_tuple_reorder_cps(): + testee = im.let( + ("t1", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4))), + ("t2", im.if_(False, im.make_tuple(5, 6), im.make_tuple(7, 8))), + )( + im.make_tuple( + im.tuple_get(1, "t1"), + im.tuple_get(0, "t1"), + im.tuple_get(1, "t2"), + im.tuple_get(0, "t2"), + ) + ) + expected = im.if_( + True, + im.if_(False, im.make_tuple(2, 1, 6, 5), im.make_tuple(2, 1, 8, 7)), + im.if_(False, im.make_tuple(4, 3, 6, 5), im.make_tuple(4, 3, 8, 7)), + ) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_if_make_tuple_reorder_cps_nested(): + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.let("c", im.tuple_get(0, "t"))( + im.make_tuple(im.tuple_get(1, "t"), im.tuple_get(0, "t"), "c") + ) + ) + expected = im.if_(True, im.make_tuple(2, 1, 1), im.make_tuple(4, 3, 3)) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, + ) + assert actual == expected + + +def test_if_make_tuple_reorder_cps_external(): + external_ref = im.tuple_get(0, im.ref("external", ts.TupleType(types=[int_type]))) + testee = im.let("t", im.if_(True, im.make_tuple(1, 2), im.make_tuple(3, 4)))( + im.make_tuple(external_ref, im.tuple_get(1, "t"), im.tuple_get(0, "t")) + ) + expected = im.if_(True, im.make_tuple(external_ref, 2, 1), im.make_tuple(external_ref, 4, 3)) + actual = CollapseTuple.apply( + testee, + enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES, + allow_undeclared_symbols=True, + within_stencil=False, ) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py index 0bf8dcb65d..1da2b8cec5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_constant_folding.py @@ -9,54 +9,155 @@ from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.constant_folding import ConstantFolding - -def test_constant_folding_boolean(): - testee = im.not_(im.literal_from_value(True)) - expected = im.literal_from_value(False) - - actual = ConstantFolding.apply(testee) - assert actual == expected +import pytest +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -def test_constant_folding_math_op(): - expected = im.literal_from_value(13) - testee = im.plus( - im.literal_from_value(4), - im.plus( - im.literal_from_value(7), im.minus(im.literal_from_value(7), im.literal_from_value(5)) +def test_cases(): + return ( + # expr, simplified expr + (im.plus(1, 1), 2), + (im.not_(True), False), + (im.plus(4, im.plus(7, im.minus(7, 5))), 13), + (im.if_(True, im.plus(im.ref("a"), 2), im.minus(9, 5)), im.plus("a", 2)), + (im.minimum("a", "a"), "a"), + (im.maximum(1, 2), 2), + # canonicalization + (im.plus("a", 1), im.plus("a", 1)), + (im.plus(1, "a"), im.plus("a", 1)), + # nested plus + (im.plus(im.plus("a", 1), 1), im.plus("a", 2)), + (im.plus(1, im.plus("a", 1)), im.plus("a", 2)), + # nested maximum + (im.maximum(im.maximum("a", 1), 1), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), 1), im.maximum("a", 1)), + (im.maximum("a", im.maximum(1, "a")), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), im.maximum(1, "a")), im.maximum("a", 1)), + (im.maximum(im.maximum(1, "a"), im.maximum("a", 1)), im.maximum("a", 1)), + (im.maximum(im.minimum("a", 1), "a"), im.maximum(im.minimum("a", 1), "a")), + # maximum & plus + (im.maximum(im.plus("a", 1), im.plus("a", 0)), im.plus("a", 1)), + ( + im.maximum(im.plus("a", 1), im.plus(im.plus("a", 1), 0)), + im.plus("a", 1), + ), + (im.maximum("a", im.plus("a", 1)), im.plus("a", 1)), + (im.maximum("a", im.plus("a", im.literal_from_value(-1))), im.ref("a")), + ( + im.plus("a", im.maximum(0, im.literal_from_value(-1))), + im.ref("a"), + ), + # plus & minus + (im.minus(im.plus("a", 1), im.plus(1, 1)), im.minus("a", 1)), + (im.plus(im.minus("a", 1), 2), im.plus("a", 1)), + (im.plus(im.minus(1, "a"), 1), im.minus(2, "a")), + # nested plus + (im.plus(im.plus("a", 1), im.plus(1, 1)), im.plus("a", 3)), + ( + im.plus(im.plus("a", im.literal_from_value(-1)), im.plus("a", 3)), + im.plus(im.minus("a", 1), im.plus("a", 3)), + ), + # maximum & minus + (im.maximum(im.minus("a", 1), "a"), im.ref("a")), + (im.maximum("a", im.minus("a", im.literal_from_value(-1))), im.plus("a", 1)), + ( + im.maximum(im.plus("a", im.literal_from_value(-1)), 1), + im.maximum(im.minus("a", 1), 1), + ), + # minimum & plus & minus + (im.minimum(im.plus("a", 1), "a"), im.ref("a")), + (im.minimum("a", im.plus("a", im.literal_from_value(-1))), im.minus("a", 1)), + (im.minimum(im.minus("a", 1), "a"), im.minus("a", 1)), + (im.minimum("a", im.minus("a", im.literal_from_value(-1))), im.ref("a")), + # nested maximum + (im.maximum("a", im.maximum("b", "a")), im.maximum("b", "a")), + # maximum & plus on complicated expr (tuple_get) + ( + im.maximum( + im.plus(im.tuple_get(1, "a"), 1), + im.maximum(im.tuple_get(1, "a"), im.plus(im.tuple_get(1, "a"), 1)), + ), + im.plus(im.tuple_get(1, "a"), 1), + ), + # nested maximum & plus + ( + im.maximum(im.maximum(im.plus(1, "a"), 1), im.plus(1, "a")), + im.maximum(im.plus("a", 1), 1), + ), + # sanity check that no strange things happen + # complex tests + ( + # 1 - max(max(1, max(1, sym), min(1, sym), sym), 1 + (min(-1, 2) + max(-1, 1 - sym))) + im.minus( + 1, + im.maximum( + im.maximum( + im.maximum(1, im.maximum(1, "a")), + im.maximum(im.maximum(1, "a"), "a"), + ), + im.plus( + 1, + im.plus( + im.minimum(im.literal_from_value(-1), 2), + im.maximum(im.literal_from_value(-1), im.minus(1, "a")), + ), + ), + ), + ), + # 1 - maximum(maximum(sym, 1), maximum(1 - sym, -1)) + im.minus( + 1, + im.maximum( + im.maximum("a", 1), + im.maximum(im.minus(1, "a"), im.literal_from_value(-1)), + ), + ), + ), + ( + # maximum(sym, 1 + sym) + (maximum(1, maximum(1, sym)) + (sym - 1 + (1 + (sym + 1) + 1))) - 2 + im.minus( + im.plus( + im.maximum("a", im.plus(1, "a")), + im.plus( + im.maximum(1, im.maximum(1, "a")), + im.plus(im.minus("a", 1), im.plus(im.plus(1, im.plus("a", 1)), 1)), + ), + ), + 2, + ), + # sym + 1 + (maximum(sym, 1) + (sym - 1 + (sym + 3))) - 2 + im.minus( + im.plus( + im.plus("a", 1), + im.plus( + im.maximum("a", 1), + im.plus(im.minus("a", 1), im.plus("a", 3)), + ), + ), + 2, + ), + ), + ( + # minimum(1 - sym, 1 + sym) + (maximum(maximum(1 - sym, 1 + sym), 1 - sym) + maximum(1 - sym, 1 - sym)) + im.plus( + im.minimum(im.minus(1, "a"), im.plus(1, "a")), + im.plus( + im.maximum(im.maximum(im.minus(1, "a"), im.plus(1, "a")), im.minus(1, "a")), + im.maximum(im.minus(1, "a"), im.minus(1, "a")), + ), + ), + # minimum(1 - sym, sym + 1) + (maximum(1 - sym, sym + 1) + (1 - sym)) + im.plus( + im.minimum(im.minus(1, "a"), im.plus("a", 1)), + im.plus(im.maximum(im.minus(1, "a"), im.plus("a", 1)), im.minus(1, "a")), + ), ), ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_if(): - expected = im.call("plus")("a", 2) - testee = im.if_( - im.literal_from_value(True), - im.plus(im.ref("a"), im.literal_from_value(2)), - im.minus(im.literal_from_value(9), im.literal_from_value(5)), - ) - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_minimum(): - testee = im.call("minimum")("a", "a") - expected = im.ref("a") - actual = ConstantFolding.apply(testee) - assert actual == expected - - -def test_constant_folding_literal(): - testee = im.plus(im.literal_from_value(1), im.literal_from_value(2)) - expected = im.literal_from_value(3) - actual = ConstantFolding.apply(testee) - assert actual == expected -def test_constant_folding_literal_maximum(): - testee = im.call("maximum")(im.literal_from_value(1), im.literal_from_value(2)) - expected = im.literal_from_value(2) +@pytest.mark.parametrize("test_case", test_cases()) +def test_constant_folding(test_case): + testee, expected = test_case actual = ConstantFolding.apply(testee) - assert actual == expected + assert actual == im.ensure_expr(expected) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py index 78f95da8ca..3909c6f26a 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_cse.py @@ -21,23 +21,15 @@ @pytest.fixture -def offset_provider(request): +def offset_provider_type(request): return {"I": common.Dimension("I", kind=common.DimensionKind.HORIZONTAL)} def test_trivial(): - common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")]) - testee = ir.FunCall(fun=ir.SymRef(id="plus"), args=[common, common]) - expected = ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="_cs_1")], - expr=ir.FunCall( - fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="_cs_1"), ir.SymRef(id="_cs_1")] - ), - ), - args=[common], - ) - actual = CSE.apply(testee, is_local_view=True) + common = im.plus("x", "y") + testee = im.plus(common, common) + expected = im.let("_cs_1", common)(im.plus("_cs_1", "_cs_1")) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -45,7 +37,7 @@ def test_lambda_capture(): common = ir.FunCall(fun=ir.SymRef(id="plus"), args=[ir.SymRef(id="x"), ir.SymRef(id="y")]) testee = ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id="x")], expr=common), args=[common]) expected = testee - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -53,7 +45,7 @@ def test_lambda_no_capture(): common = im.plus("x", "y") testee = im.call(im.lambda_("z")(im.plus("x", "y")))(im.plus("x", "y")) expected = im.let("_cs_1", common)("_cs_1") - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -65,7 +57,7 @@ def common_expr(): testee = im.call(im.lambda_("x", "y")(common_expr()))(common_expr(), common_expr()) # (λ(_cs_1) → _cs_1 + _cs_1)(x + y) expected = im.let("_cs_1", common_expr())(im.plus("_cs_1", "_cs_1")) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -74,18 +66,12 @@ def common_expr(): return im.plus("x", "x") # λ(x) → (λ(y) → y + (x + x + (x + x)))(z) - testee = im.lambda_("x")( - im.call(im.lambda_("y")(im.plus("y", im.plus(common_expr(), common_expr()))))("z") - ) - # λ(x) → (λ(_cs_1) → (λ(y) → y + (_cs_1 + _cs_1))(z))(x + x) + testee = im.lambda_("x")(im.let("y", "z")(im.plus("y", im.plus(common_expr(), common_expr())))) + # λ(x) → (λ(_cs_1) → z + (_cs_1 + _cs_1))(x + x) expected = im.lambda_("x")( - im.call( - im.lambda_("_cs_1")( - im.call(im.lambda_("y")(im.plus("y", im.plus("_cs_1", "_cs_1"))))("z") - ) - )(common_expr()) + im.let("_cs_1", common_expr())(im.plus("z", im.plus("_cs_1", "_cs_1"))) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -99,7 +85,7 @@ def common_expr(): ) # (λ(_cs_1) → _cs_1(2) + _cs_1(3))(λ(a) → a + 1) expected = im.let("_cs_1", common_expr())(im.plus(im.call("_cs_1")(2), im.call("_cs_1")(3))) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -115,7 +101,7 @@ def common_expr(): expected = im.let("_cs_1", common_expr())( im.let("_cs_2", im.call("_cs_1")(2))(im.plus("_cs_2", "_cs_2")) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected @@ -139,17 +125,17 @@ def common_expr(): ) ) ) - actual = CSE.apply(testee, is_local_view=True) + actual = CSE.apply(testee, within_stencil=True) assert actual == expected -def test_if_can_deref_no_extraction(offset_provider): +def test_if_can_deref_no_extraction(offset_provider_type): # Test that a subexpression only occurring in one branch of an `if_` is not moved outside the # if statement. A case using `can_deref` is used here as it is common. # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it) else 1 testee = im.if_( - im.call("can_deref")(im.shift("I", 1)("it")), + im.can_deref(im.shift("I", 1)("it")), im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))), # use something more involved where a subexpression can still be eliminated im.literal("1", "int32"), @@ -157,38 +143,38 @@ def test_if_can_deref_no_extraction(offset_provider): # (λ(_cs_1) → if can_deref(_cs_1) then (λ(_cs_2) → _cs_2 + _cs_2)(·_cs_1) else 1)(⟪Iₒ, 1ₒ⟫(it)) expected = im.let("_cs_1", im.shift("I", 1)("it"))( im.if_( - im.call("can_deref")("_cs_1"), + im.can_deref("_cs_1"), im.let("_cs_2", im.deref("_cs_1"))(im.plus("_cs_2", "_cs_2")), im.literal("1", "int32"), ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_can_deref_eligible_extraction(offset_provider): +def test_if_can_deref_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in both branches of an `if_` is moved outside the # if statement. A case using `can_deref` is used here as it is common. # if can_deref(⟪Iₒ, 1ₒ⟫(it)) then ·⟪Iₒ, 1ₒ⟫(it) else ·⟪Iₒ, 1ₒ⟫(it) + ·⟪Iₒ, 1ₒ⟫(it) testee = im.if_( - im.call("can_deref")(im.shift("I", 1)("it")), + im.can_deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it")), im.plus(im.deref(im.shift("I", 1)("it")), im.deref(im.shift("I", 1)("it"))), ) # (λ(_cs_3) → (λ(_cs_1) → if can_deref(_cs_3) then _cs_1 else _cs_1 + _cs_1)(·_cs_3))(⟪Iₒ, 1ₒ⟫(it)) expected = im.let("_cs_3", im.shift("I", 1)("it"))( im.let("_cs_1", im.deref("_cs_3"))( - im.if_(im.call("can_deref")("_cs_3"), "_cs_1", im.plus("_cs_1", "_cs_1")) + im.if_(im.can_deref("_cs_3"), "_cs_1", im.plus("_cs_1", "_cs_1")) ) ) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected -def test_if_eligible_extraction(offset_provider): +def test_if_eligible_extraction(offset_provider_type): # Test that a subexpression only occurring in the condition of an `if_` is moved outside the # if statement. @@ -197,7 +183,7 @@ def test_if_eligible_extraction(offset_provider): # (λ(_cs_1) → if _cs_1 ∧ _cs_1 then c else d)(a ∧ b) expected = im.let("_cs_1", im.and_("a", "b"))(im.if_(im.and_("_cs_1", "_cs_1"), "c", "d")) - actual = CSE.apply(testee, offset_provider=offset_provider, is_local_view=True) + actual = CSE.apply(testee, offset_provider_type=offset_provider_type, within_stencil=True) assert actual == expected @@ -274,7 +260,7 @@ def test_no_extraction_outside_asfieldop(): identity_fieldop(im.ref("a", field_type)), identity_fieldop(im.ref("b", field_type)) ) - actual = CSE.apply(testee, is_local_view=False) + actual = CSE.apply(testee, within_stencil=False) assert actual == testee @@ -295,5 +281,15 @@ def test_field_extraction_outside_asfieldop(): # ) expected = im.let("_cs_1", identity_fieldop(field))(plus_fieldop("_cs_1", "_cs_1")) - actual = CSE.apply(testee, is_local_view=False) + actual = CSE.apply(testee, within_stencil=False) + assert actual == expected + + +def test_scalar_extraction_inside_as_fieldop(): + common = im.plus(1, 2) + + testee = im.as_fieldop(im.lambda_()(im.plus(common, common)))() + expected = im.as_fieldop(im.lambda_()(im.let("_cs_1", common)(im.plus("_cs_1", "_cs_1"))))() + + actual = CSE.apply(testee, within_stencil=False) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 79456e4d85..86cc8a6773 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,20 +8,24 @@ # TODO(SF-N): test scan operator -import pytest +from typing import Iterable, Literal, Optional, Union + import numpy as np -from typing import Iterable, Optional, Literal, Union +import pytest from gt4py import eve -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im +from gt4py.next import common, constructors, utils +from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.iterator.transforms.global_tmps import SymbolicDomain -from gt4py.next.common import Dimension -from gt4py.next import common, NeighborTableOffsetProvider -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next import utils +from gt4py.next.type_system import type_specifications as ts + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) @@ -29,6 +33,7 @@ KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) Vertex = common.Dimension(value="Vertex", kind=common.DimensionKind.HORIZONTAL) Edge = common.Dimension(value="Edge", kind=common.DimensionKind.HORIZONTAL) +E2VDim = common.Dimension(value="E2V", kind=common.DimensionKind.LOCAL) @pytest.fixture @@ -39,11 +44,10 @@ def offset_provider(): @pytest.fixture def unstructured_offset_provider(): return { - "E2V": NeighborTableOffsetProvider( - np.array([[0, 1]], dtype=np.int32), - Edge, - Vertex, - 2, + "E2V": constructors.as_connectivity( + domain={Edge: 1, E2VDim: 2}, + codomain=Vertex, + data=np.array([[0, 1]], dtype=np.int32), ) } @@ -72,7 +76,7 @@ def setup_test_as_fieldop( def run_test_program( testee: itir.Program, expected: itir.Program, offset_provider: common.OffsetProvider ) -> None: - actual_program = infer_domain.infer_program(testee, offset_provider) + actual_program = infer_domain.infer_program(testee, offset_provider=offset_provider) folded_program = constant_fold_domain_exprs(actual_program) assert folded_program == expected @@ -84,9 +88,15 @@ def run_test_expr( domain: itir.FunCall, expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, + symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ): actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, + domain_utils.SymbolicDomain.from_expr(domain), + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -96,10 +106,8 @@ def run_test_expr( def canonicalize_domain(d): if isinstance(d, dict): return im.domain(grid_type, d) - elif isinstance(d, itir.FunCall): + elif isinstance(d, (itir.FunCall, infer_domain.DomainAccessDescriptor)): return d - elif d is None: - return None raise AssertionError() expected_domains = {ref: canonicalize_domain(d) for ref, d in expected_domains.items()} @@ -120,10 +128,12 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( - domains: infer_domain.ACCESSED_DOMAINS, -) -> infer_domain.ACCESSED_DOMAINS: - def fold_domain(domain: SymbolicDomain | None): - if domain is None: + domains: infer_domain.AccessedDomains, +) -> infer_domain.AccessedDomains: + def fold_domain( + domain: domain_utils.SymbolicDomain | Literal[infer_domain.DomainAccessDescriptor.NEVER], + ): + if isinstance(domain, infer_domain.DomainAccessDescriptor): return domain return constant_fold_domain_exprs(domain.as_expr()) @@ -134,7 +144,7 @@ def translate_domain( domain: itir.FunCall, shifts: dict[str, tuple[itir.Expr, itir.Expr]], offset_provider: common.OffsetProvider, -) -> SymbolicDomain: +) -> domain_utils.SymbolicDomain: shift_tuples = [ ( im.ensure_offset(d), @@ -145,7 +155,9 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] - translated_domain_expr = SymbolicDomain.from_expr(domain).translate(shift_list, offset_provider) + translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate( + shift_list, offset_provider=offset_provider + ) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -330,7 +342,7 @@ def test_nested_stencils(offset_provider): "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) folded_call = constant_fold_domain_exprs(actual_call) @@ -374,7 +386,7 @@ def test_nested_stencils_n_times(offset_provider, iterations): } actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -387,7 +399,10 @@ def test_unused_input(offset_provider): stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_domains = {"in_field1": {IDim: (0, 11)}, "in_field2": None} + expected_domains = { + "in_field1": {IDim: (0, 11)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } testee, expected = setup_test_as_fieldop( stencil, domain, @@ -399,7 +414,7 @@ def test_let_unused_field(offset_provider): testee = im.let("a", "c")("b") domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.let("a", "c")("b") - expected_domains = {"b": {IDim: (0, 11)}, "c": None} + expected_domains = {"b": {IDim: (0, 11)}, "c": infer_domain.DomainAccessDescriptor.NEVER} run_test_expr(testee, expected, domain, expected_domains, offset_provider) @@ -500,7 +515,7 @@ def test_cond(offset_provider): testee = im.if_(cond, field_1, field_2) - domain = im.domain(common.GridType.CARTESIAN, {"IDim": (0, 11)}) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) domain_tmp = translate_domain(domain, {"Ioff": -1}, offset_provider) expected_domains_dict = {"in_field1": {IDim: (0, 12)}, "in_field2": {IDim: (-2, 12)}} expected_tmp2 = im.as_fieldop(tmp_stencil2, domain_tmp)( @@ -512,7 +527,7 @@ def test_cond(offset_provider): expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -569,7 +584,7 @@ def test_let(offset_provider): expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} actual_call2, actual_domains2 = infer_domain.infer_expr( - testee2, SymbolicDomain.from_expr(domain), offset_provider + testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains2 = constant_fold_accessed_domains(actual_domains2) folded_call2 = constant_fold_domain_exprs(actual_call2) @@ -789,8 +804,11 @@ def test_make_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), - offset_provider, + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), + offset_provider=offset_provider, ) assert expected == actual @@ -802,13 +820,13 @@ def test_tuple_get_1_make_tuple(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) expected_domains = { - "a": None, + "a": infer_domain.DomainAccessDescriptor.NEVER, "b": im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}), - "c": None, + "c": infer_domain.DomainAccessDescriptor.NEVER, } actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -820,12 +838,15 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) - expected_domains = {"a": None, "b": domain1, "c": domain2} + expected_domains = {"a": infer_domain.DomainAccessDescriptor.NEVER, "b": domain1, "c": domain2} actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), - offset_provider, + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), + offset_provider=offset_provider, ) assert expected == actual @@ -836,12 +857,18 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) - expected_domains = {"b": None, "c": None, "d": (None, domain)} + expected_domains = { + "b": infer_domain.DomainAccessDescriptor.NEVER, + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": (infer_domain.DomainAccessDescriptor.NEVER, domain), + } actual, actual_domains = infer_domain.infer_expr( testee, - SymbolicDomain.from_expr(im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)})), - offset_provider, + domain_utils.SymbolicDomain.from_expr( + im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + ), + offset_provider=offset_provider, ) assert expected == actual @@ -852,12 +879,16 @@ def test_tuple_get_let_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) - expected_domains = {"c": None, "d": domain, "b": None} + expected_domains = { + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": domain, + "b": infer_domain.DomainAccessDescriptor.NEVER, + } actual, actual_domains = infer_domain.infer_expr( testee, - SymbolicDomain.from_expr(domain), - offset_provider, + domain_utils.SymbolicDomain.from_expr(domain), + offset_provider=offset_provider, ) assert expected == actual @@ -877,12 +908,15 @@ def test_nested_make_tuple(offset_provider): testee, ( ( - SymbolicDomain.from_expr(domain1), - (SymbolicDomain.from_expr(domain2_1), SymbolicDomain.from_expr(domain2_2)), + domain_utils.SymbolicDomain.from_expr(domain1), + ( + domain_utils.SymbolicDomain.from_expr(domain2_1), + domain_utils.SymbolicDomain.from_expr(domain2_2), + ), ), - SymbolicDomain.from_expr(domain3), + domain_utils.SymbolicDomain.from_expr(domain3), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -893,10 +927,10 @@ def test_tuple_get_1(offset_provider): testee = im.tuple_get(1, im.ref("a")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.ref("a")) - expected_domains = {"a": (None, domain)} + expected_domains = {"a": (infer_domain.DomainAccessDescriptor.NEVER, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -912,8 +946,11 @@ def test_domain_tuple(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), - offset_provider, + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), + offset_provider=offset_provider, ) assert expected == actual @@ -929,7 +966,7 @@ def test_as_fieldop_tuple_get(offset_provider): expected_domains = {"a": (domain, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -945,8 +982,11 @@ def test_make_tuple_2tuple_get(offset_provider): actual, actual_domains = infer_domain.infer_expr( testee, - (SymbolicDomain.from_expr(domain1), SymbolicDomain.from_expr(domain2)), - offset_provider, + ( + domain_utils.SymbolicDomain.from_expr(domain1), + domain_utils.SymbolicDomain.from_expr(domain2), + ), + offset_provider=offset_provider, ) assert expected == actual @@ -963,7 +1003,7 @@ def test_make_tuple_non_tuple_domain(offset_provider): expected_domains = {"in_field1": domain, "in_field2": domain} actual, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -971,15 +1011,85 @@ def test_make_tuple_non_tuple_domain(offset_provider): def test_arithmetic_builtin(offset_provider): - testee = im.plus(im.ref("in_field1"), im.ref("in_field2")) + testee = im.plus(im.ref("alpha"), im.ref("beta")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected = im.plus(im.ref("in_field1"), im.ref("in_field2")) + expected = im.plus(im.ref("alpha"), im.ref("beta")) expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( - testee, SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) assert folded_call == expected assert actual_domains == expected_domains + + +def test_scan(offset_provider): + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) + testee = im.as_fieldop( + im.scan(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0) + )("a") + expected = im.as_fieldop( + im.scan(im.lambda_("init", "it")(im.deref(im.shift("Ioff", 1)("it"))), True, 0.0), + domain, + )("a") + + run_test_expr( + testee, + expected, + domain, + {"a": im.domain(common.GridType.CARTESIAN, {IDim: (1, 12)})}, + offset_provider, + ) + + +def test_symbolic_domain_sizes(unstructured_offset_provider): + stencil = im.lambda_("arg0")(im.deref(im.shift("E2V", 1)("arg0"))) + domain = im.domain(common.GridType.UNSTRUCTURED, {Edge: (0, 1)}) + symbolic_domain_sizes = {"Vertex": "num_vertices"} + + testee, expected = setup_test_as_fieldop( + stencil, + domain, + ) + run_test_expr( + testee, + expected, + domain, + {"in_field1": {Vertex: (0, im.ref("num_vertices"))}}, + unstructured_offset_provider, + symbolic_domain_sizes, + ) + + +def test_unknown_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref(im.shift("Ioff", im.deref("arg1"))("arg0"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": infer_domain.DomainAccessDescriptor.UNKNOWN, + "in_field2": {IDim: (0, 10)}, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain_tuple(offset_provider): + testee = im.tuple_get(0, im.make_tuple("in_field1", "in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + run_test_expr(testee, testee, domain, expected_domains, offset_provider) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py new file mode 100644 index 0000000000..14aebd032c --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_fuse_as_fieldop.py @@ -0,0 +1,381 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import copy +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im, domain_utils +from gt4py.next.iterator.transforms import fuse_as_fieldop, collapse_tuple +from gt4py.next.type_system import type_specifications as ts + + +IDim = gtx.Dimension("IDim") +JDim = gtx.Dimension("JDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def _with_domain_annex(node: itir.Expr, domain: itir.Expr): + node = copy.deepcopy(node) + node.annex.domain = domain_utils.SymbolicDomain.from_expr(domain) + return node + + +def test_trivial(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.op_as_fieldop("plus", d)( + im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + im.ref("inp3", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2", "inp3")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp3")) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type), im.ref("inp3", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_trivial_literal(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)(im.op_as_fieldop("multiplies", d)(1, 2), 3) + expected = im.as_fieldop(im.lambda_()(im.plus(im.multiplies_(1, 2), 3)), d)() + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_trivial_same_arg_twice(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.op_as_fieldop("plus", d)( + # note: inp1 occurs twice here + im.op_as_fieldop("multiplies", d)(im.ref("inp1", field_type), im.ref("inp1", field_type)), + im.ref("inp2", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp1")), im.deref("inp2")) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_tuple_arg(): + d = im.domain("cartesian_domain", {}) + testee = im.op_as_fieldop("plus", d)( + im.op_as_fieldop(im.lambda_("t")(im.plus(im.tuple_get(0, "t"), im.tuple_get(1, "t"))), d)( + im.make_tuple(1, 2) + ), + 3, + ) + expected = im.as_fieldop( + im.lambda_()( + im.plus( + im.let("t", im.make_tuple(1, 2))( + im.plus(im.tuple_get(0, "t"), im.tuple_get(1, "t")) + ), + 3, + ) + ), + d, + )() + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_symref_used_twice(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.as_fieldop(im.lambda_("a", "b")(im.plus(im.deref("a"), im.deref("b"))), d)( + im.as_fieldop(im.lambda_("c", "d")(im.multiplies_(im.deref("c"), im.deref("d"))), d)( + im.ref("inp1", field_type), im.ref("inp2", field_type) + ), + im.ref("inp1", field_type), + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.plus(im.multiplies_(im.deref("inp1"), im.deref("inp2")), im.deref("inp1")) + ), + d, + )("inp1", "inp2") + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_no_inline(): + d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) + d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) + testee = im.as_fieldop( + im.lambda_("a")( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))) + ), + d1, + )(im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type))) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert actual == testee + + +def test_staged_inlining(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.let( + "tmp", im.op_as_fieldop("plus", d)(im.ref("a", field_type), im.ref("b", field_type)) + )( + im.op_as_fieldop("plus", d)( + im.op_as_fieldop(im.lambda_("a")(im.plus("a", 1)), d)("tmp"), + im.op_as_fieldop(im.lambda_("a")(im.plus("a", 2)), d)("tmp"), + ) + ) + expected = im.as_fieldop( + im.lambda_("a", "b")( + im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("a"), im.deref("b"))))( + im.plus(im.plus(im.call("_icdlv_1")(), 1), im.plus(im.call("_icdlv_1")(), 2)) + ) + ), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_make_tuple_fusion_trivial(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.as_fieldop("deref", d)(im.ref("a", field_type)), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call `{v[0], v[1]}(actual)` + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_symref(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + _with_domain_annex(im.ref("b", field_type), d), + ) + expected = im.as_fieldop( + im.lambda_("a", "b")(im.make_tuple(im.deref("a"), im.deref("b"))), + d, + )(im.ref("a", field_type), im.ref("b", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_symref_same_ref(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + _with_domain_annex(im.ref("a", field_type), d), + ) + expected = im.as_fieldop( + im.lambda_("a")(im.make_tuple(im.deref("a"), im.deref("a"))), + d, + )(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_nested(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.make_tuple( + _with_domain_annex(im.ref("a", field_type), d), + im.make_tuple( + _with_domain_annex(im.ref("b", field_type), d), + _with_domain_annex(im.ref("c", field_type), d), + ), + ) + expected = im.as_fieldop( + im.lambda_("a", "b", "c")( + im.make_tuple(im.deref("a"), im.make_tuple(im.deref("b"), im.deref("c"))) + ), + d, + )(im.ref("a", field_type), im.ref("b", field_type), im.ref("c", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + # simplify to remove unnecessary make_tuple call + actual_simplified = collapse_tuple.CollapseTuple.apply( + actual, within_stencil=False, allow_undeclared_symbols=True + ) + assert actual_simplified == expected + + +def test_make_tuple_fusion_different_domains(): + d1 = im.domain("cartesian_domain", {IDim: (0, 1)}) + d2 = im.domain("cartesian_domain", {JDim: (0, 1)}) + field_i_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + field_j_type = ts.FieldType(dims=[JDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + testee = im.make_tuple( + im.as_fieldop("deref", d1)(im.ref("a", field_i_type)), + im.as_fieldop("deref", d2)(im.ref("b", field_j_type)), + im.as_fieldop("deref", d1)(im.ref("c", field_i_type)), + im.as_fieldop("deref", d2)(im.ref("d", field_j_type)), + ) + expected = im.let( + ( + "__fasfop_1", + im.as_fieldop(im.lambda_("a", "c")(im.make_tuple(im.deref("a"), im.deref("c"))), d1)( + "a", "c" + ), + ), + ( + "__fasfop_2", + im.as_fieldop(im.lambda_("b", "d")(im.make_tuple(im.deref("b"), im.deref("d"))), d2)( + "b", "d" + ), + ), + )( + im.make_tuple( + im.tuple_get(0, "__fasfop_1"), + im.tuple_get(0, "__fasfop_2"), + im.tuple_get(1, "__fasfop_1"), + im.tuple_get(1, "__fasfop_2"), + ) + ) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_partial_inline(): + d1 = im.domain("cartesian_domain", {IDim: (1, 2)}) + d2 = im.domain("cartesian_domain", {IDim: (0, 3)}) + testee = im.as_fieldop( + # first argument read at multiple locations -> not inlined + # second argument only read at a single location -> inlined + im.lambda_("a", "b")( + im.plus( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), + im.deref("b"), + ) + ), + d1, + )( + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + ) + expected = im.as_fieldop( + im.lambda_("a", "inp1", "inp2")( + im.plus( + im.plus(im.deref(im.shift("IOff", 1)("a")), im.deref(im.shift("IOff", -1)("a"))), + im.plus(im.deref("inp1"), im.deref("inp2")), + ) + ), + d1, + )( + im.op_as_fieldop("plus", d2)(im.ref("inp1", field_type), im.ref("inp2", field_type)), + "inp1", + "inp2", + ) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={"IOff": IDim}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_chained_fusion(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.let( + "a", im.op_as_fieldop("plus", d)(im.ref("inp1", field_type), im.ref("inp2", field_type)) + )( + im.op_as_fieldop("plus", d)( + im.as_fieldop("deref", d)(im.ref("a", field_type)), + im.as_fieldop("deref", d)(im.ref("a", field_type)), + ) + ) + expected = im.as_fieldop( + im.lambda_("inp1", "inp2")( + im.let("_icdlv_1", im.lambda_()(im.plus(im.deref("inp1"), im.deref("inp2"))))( + im.plus(im.call("_icdlv_1")(), im.call("_icdlv_1")()) + ) + ), + d, + )(im.ref("inp1", field_type), im.ref("inp2", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_inline_as_fieldop_with_list_dtype(): + list_field_type = ts.FieldType( + dims=[IDim], dtype=ts.ListType(element_type=ts.ScalarType(kind=ts.ScalarKind.INT32)) + ) + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + testee = im.as_fieldop( + im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d + )(im.as_fieldop("deref")(im.ref("inp", list_field_type))) + expected = im.as_fieldop( + im.lambda_("inp")(im.call(im.call("reduce")("plus", 0))(im.deref("inp"))), d + )(im.ref("inp", list_field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_inline_into_scan(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + scan = im.call("scan")(im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0) + testee = im.as_fieldop(scan, d)(im.as_fieldop("deref")(im.ref("a", field_type))) + expected = im.as_fieldop(scan, d)(im.ref("a", field_type)) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == expected + + +def test_no_inline_into_scan(): + d = im.domain("cartesian_domain", {IDim: (0, 1)}) + scan_stencil = im.call("scan")( + im.lambda_("state", "a")(im.plus("state", im.deref("a"))), True, 0 + ) + scan = im.as_fieldop(scan_stencil, d)(im.ref("a", field_type)) + testee = im.as_fieldop(im.lambda_("arg")(im.deref("arg")), d)(scan) + actual = fuse_as_fieldop.FuseAsFieldOp.apply( + testee, offset_provider_type={}, allow_undeclared_symbols=True + ) + assert actual == testee diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index ffb5447684..52d77e5fda 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -6,464 +6,219 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -# TODO(tehrengruber): add integration tests for temporaries starting from manually written -# itir. Currently we only test temporaries from frontend code which makes testing changes -# to anything related to temporaries tedious. -import copy +from typing import Optional -import gt4py.next as gtx -from gt4py.eve.utils import UIDs from gt4py.next import common -from gt4py.next.iterator import ir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.iterator.transforms.global_tmps import ( - AUTO_DOMAIN, - FencilWithTemporaries, - SimpleTemporaryExtractionHeuristics, - collect_tmps_info, - split_closures, - update_domains, -) +from gt4py.next.iterator.transforms import global_tmps, infer_domain +from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_specifications as ts IDim = common.Dimension(value="IDim") JDim = common.Dimension(value="JDim") KDim = common.Dimension(value="KDim", kind=common.DimensionKind.VERTICAL) -index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, ir.INTEGER_INDEX_BUILTIN.upper())) +index_type = ts.ScalarType(kind=getattr(ts.ScalarKind, builtins.INTEGER_INDEX_BUILTIN.upper())) float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) i_field_type = ts.FieldType(dims=[IDim], dtype=float_type) index_field_type_factory = lambda dim: ts.FieldType(dims=[dim], dtype=index_type) -def test_split_closures(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", +def program_factory( + params: list[itir.Sym], + body: list[itir.SetAt], + declarations: Optional[list[itir.Temporary]] = None, +) -> itir.Program: + return itir.Program( + id="testee", function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("baz_inp")( - im.deref( - im.lift( - im.lambda_("bar_inp")( - im.deref( - im.lift(im.lambda_("foo_inp")(im.deref("foo_inp")))("bar_inp") - ) - ) - )("baz_inp") - ) - ), - output=im.ref("out"), - inputs=[im.ref("inp")], - ) - ], + params=params, + declarations=declarations or [], + body=body, ) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_tmp_2", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), - output=im.ref("_tmp_2"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("bar_inp", "_tmp_2")(im.deref("_tmp_2")), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp"), im.ref("_tmp_2")], - ), - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("baz_inp", "_tmp_1")(im.deref("_tmp_1")), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], - ), - ], - ) - actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [ - ir.Temporary(id="_tmp_1", dtype=float_type), - ir.Temporary(id="_tmp_2", dtype=float_type), - ] - assert actual.fencil == expected - -def test_split_closures_simple_heuristics(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("foo")( - im.let("lifted_it", im.lift(im.lambda_("bar")(im.deref("bar")))("foo"))( - im.plus(im.deref("lifted_it"), im.deref(im.shift("I", 1)("lifted_it"))) - ) - ), - output=im.ref("out"), - inputs=[im.ref("inp")], +def test_trivial(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)(im.as_fieldop("deref", domain)("inp")), + domain=domain, ) ], ) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("d", i_field_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("bar")(im.deref("bar")), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp")], + expected = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), expr=im.as_fieldop("deref", domain)("inp"), domain=domain ), - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("foo", "_tmp_1")( - im.plus(im.deref("_tmp_1"), im.deref(im.shift("I", 1)("_tmp_1"))) - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], + itir.SetAt( + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain ), ], ) - actual = split_closures( - testee, - extraction_heuristics=SimpleTemporaryExtractionHeuristics, - offset_provider={"I": IDim}, - ) - assert actual.tmps == [ir.Temporary(id="_tmp_1", dtype=float_type)] - assert actual.fencil == expected + actual = global_tmps.create_global_tmps(testee, offset_provider) + assert actual == expected -def test_split_closures_lifted_scan(): - UIDs.reset_sequence() - testee = ir.FencilDefinition( - id="f", - function_definitions=[], +def test_trivial_let(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], - closures=[ - ir.StencilClosure( - domain=im.call("cartesian_domain")(), - stencil=im.lambda_("a")( - im.call( - im.call("scan")( - im.lambda_("carry", "b")(im.plus("carry", im.deref("b"))), - True, - im.literal_from_value(0.0), - ) - )( - im.lift( - im.call("scan")( - im.lambda_("carry", "c")(im.plus("carry", im.deref("c"))), - False, - im.literal_from_value(0.0), - ) - )("a") - ) + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.let("tmp", im.as_fieldop("deref", domain)("inp"))( + im.as_fieldop("deref", domain)("tmp") ), - output=im.ref("out"), - inputs=[im.ref("inp")], + domain=domain, ) ], ) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) - expected = ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_tmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.call("scan")( - im.lambda_("carry", "c")(im.plus("carry", im.deref("c"))), - False, - im.literal_from_value(0.0), - ), - output=im.ref("_tmp_1"), - inputs=[im.ref("inp")], + expected = program_factory( + params=[im.sym("inp", i_field_type), im.sym("out", i_field_type)], + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], + body=[ + itir.SetAt( + target=im.ref("__tmp_1"), expr=im.as_fieldop("deref", domain)("inp"), domain=domain ), - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=im.lambda_("a", "_tmp_1")( - im.call( - im.call("scan")( - im.lambda_("carry", "b")(im.plus("carry", im.deref("b"))), - True, - im.literal_from_value(0.0), - ) - )("_tmp_1") - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_tmp_1")], + itir.SetAt( + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain ), ], ) - actual = split_closures(testee, offset_provider={}) - assert actual.tmps == [ir.Temporary(id="_tmp_1", dtype=float_type)] - assert actual.fencil == expected + actual = global_tmps.create_global_tmps(testee, offset_provider) + assert actual == expected -def test_update_cartesian_domains(): - testee = FencilWithTemporaries( - fencil=ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - im.sym("i", index_type), - im.sym("j", index_type), - im.sym("k", index_type), - im.sym("inp", i_field_type), - im.sym("out", i_field_type), - im.sym("_gtmp_0", i_field_type), - im.sym("_gtmp_1", i_field_type), - im.sym("_gtmp_auto_domain", ts.DeferredType(constraint=None)), - ], - closures=[ - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.lambda_("foo_inp")(im.deref("foo_inp")), - output=im.ref("_gtmp_1"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=AUTO_DOMAIN, - stencil=im.ref("deref"), - output=im.ref("_gtmp_0"), - inputs=[im.ref("_gtmp_1")], - ), - ir.StencilClosure( - domain=im.call("cartesian_domain")( - *( - im.call("named_range")( - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ) - for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) - ) - ), - stencil=im.lambda_("baz_inp", "_lift_2")(im.deref(im.shift("I", 1)("_lift_2"))), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_gtmp_0")], +def test_top_level_if(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.if_( + True, + im.as_fieldop("deref", domain)("inp1"), + im.as_fieldop("deref", domain)("inp2"), ), - ], - ), - params=[im.sym("i"), im.sym("j"), im.sym("k"), im.sym("inp"), im.sym("out")], - tmps=[ir.Temporary(id="_gtmp_0"), ir.Temporary(id="_gtmp_1")], - ) - expected = copy.deepcopy(testee) - assert expected.fencil.params.pop() == im.sym("_gtmp_auto_domain") - expected.fencil.closures[0].domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.plus( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal("1", ir.INTEGER_INDEX_BUILTIN), - ), - im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), - ], - ) - ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], + domain=domain, ) - for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - expected.fencil.closures[1].domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.plus( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal("1", ir.INTEGER_INDEX_BUILTIN), - ), - im.plus(im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)), + testee = type_inference.infer(testee, offset_provider_type=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) + + expected = program_factory( + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], + declarations=[], + body=[ + itir.IfStmt( + cond=im.literal_from_value(True), + true_branch=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("inp1"), + domain=domain, + ) ], - ) - ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), + false_branch=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)("inp2"), + domain=domain, + ) ], ) - for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None) + + actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected -def test_collect_tmps_info(): - tmp_domain = ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value="IDim"), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - ir.FunCall( - fun=im.ref("plus"), - args=[im.ref("i"), im.literal("1", ir.INTEGER_INDEX_BUILTIN)], - ), - ], - ) - ] - + [ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], +def test_nested_if(): + domain = im.domain("cartesian_domain", {IDim: (0, 1)}) + offset_provider = {} + testee = program_factory( + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), + ], + body=[ + itir.SetAt( + target=im.ref("out"), + expr=im.as_fieldop("deref", domain)( + im.if_( + True, + im.as_fieldop("deref", domain)("inp1"), + im.as_fieldop("deref", domain)("inp2"), + ) + ), + domain=domain, ) - for a, s in (("JDim", "j"), ("KDim", "k")) ], ) + testee = type_inference.infer(testee, offset_provider_type=offset_provider) + testee = infer_domain.infer_program(testee, offset_provider=offset_provider) - i = im.sym("i", index_type) - j = im.sym("j", index_type) - k = im.sym("k", index_type) - inp = im.sym("inp", i_field_type) - out = im.sym("out", i_field_type) - - testee = FencilWithTemporaries( - fencil=ir.FencilDefinition( - id="f", - function_definitions=[], - params=[ - i, - j, - k, - inp, - out, - im.sym("_gtmp_0", i_field_type), - im.sym("_gtmp_1", i_field_type), - ], - closures=[ - ir.StencilClosure( - domain=tmp_domain, - stencil=ir.Lambda( - params=[ir.Sym(id="foo_inp")], - expr=ir.FunCall(fun=im.ref("deref"), args=[im.ref("foo_inp")]), - ), - output=im.ref("_gtmp_1"), - inputs=[im.ref("inp")], - ), - ir.StencilClosure( - domain=tmp_domain, - stencil=im.ref("deref"), - output=im.ref("_gtmp_0"), - inputs=[im.ref("_gtmp_1")], - ), - ir.StencilClosure( - domain=ir.FunCall( - fun=im.ref("cartesian_domain"), - args=[ - ir.FunCall( - fun=im.ref("named_range"), - args=[ - ir.AxisLiteral(value=a), - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.ref(s), - ], - ) - for a, s in (("IDim", "i"), ("JDim", "j"), ("KDim", "k")) - ], - ), - stencil=ir.Lambda( - params=[ir.Sym(id="baz_inp"), ir.Sym(id="_lift_2")], - expr=ir.FunCall( - fun=im.ref("deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=im.ref("shift"), - args=[ - ir.OffsetLiteral(value="I"), - ir.OffsetLiteral(value=1), - ], - ), - args=[im.ref("_lift_2")], - ) - ], - ), - ), - output=im.ref("out"), - inputs=[im.ref("inp"), im.ref("_gtmp_0")], - ), - ], - ), - params=[i, j, k, inp, out], - tmps=[ - ir.Temporary(id="_gtmp_0", dtype=float_type), - ir.Temporary(id="_gtmp_1", dtype=float_type), + expected = program_factory( + params=[ + im.sym("inp1", i_field_type), + im.sym("inp2", i_field_type), + im.sym("out", i_field_type), ], - ) - expected = FencilWithTemporaries( - fencil=testee.fencil, - params=testee.params, - tmps=[ - ir.Temporary(id="_gtmp_0", domain=tmp_domain, dtype=float_type), - ir.Temporary(id="_gtmp_1", domain=tmp_domain, dtype=float_type), + declarations=[itir.Temporary(id="__tmp_1", domain=domain, dtype=float_type)], + body=[ + itir.IfStmt( + cond=im.literal_from_value(True), + true_branch=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", domain)("inp1"), + domain=domain, + ) + ], + false_branch=[ + itir.SetAt( + target=im.ref("__tmp_1"), + expr=im.as_fieldop("deref", domain)("inp2"), + domain=domain, + ) + ], + ), + itir.SetAt( + target=im.ref("out"), expr=im.as_fieldop("deref", domain)("__tmp_1"), domain=domain + ), ], ) - actual = collect_tmps_info(testee, offset_provider={"I": IDim, "J": JDim, "K": KDim}) + + actual = global_tmps.create_global_tmps(testee, offset_provider) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py index 6cc2f7cd28..2caa887803 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_center_deref_lift_vars.py @@ -6,20 +6,33 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import cse from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars +field_type = ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -def wrap_in_program(expr: itir.Expr) -> itir.Program: + +def wrap_in_program(expr: itir.Expr, *, arg_dtypes=None) -> itir.Program: + if arg_dtypes is None: + arg_dtypes = [ts.ScalarKind.FLOAT64] + arg_types = [ts.FieldType(dims=[], dtype=ts.ScalarType(kind=dtype)) for dtype in arg_dtypes] + indices = [i for i in range(1, len(arg_dtypes) + 1)] if len(arg_dtypes) > 1 else [""] return itir.Program( id="f", function_definitions=[], - params=[im.sym("d"), im.sym("inp"), im.sym("out")], + params=[ + *(im.sym(f"inp{i}", type_) for i, type_ in zip(indices, arg_types)), + im.sym("out", field_type), + ], declarations=[], body=[ itir.SetAt( - expr=im.as_fieldop(im.lambda_("it")(expr))(im.ref("inp")), + expr=im.as_fieldop(im.lambda_(*(f"it{i}" for i in indices))(expr))( + *(im.ref(f"inp{i}") for i in indices) + ), domain=im.call("cartesian_domain")(), target=im.ref("out"), ) @@ -34,7 +47,7 @@ def unwrap_from_program(program: itir.Program) -> itir.Expr: def test_simple(): testee = im.let("var", im.lift("deref")("it"))(im.deref("var")) - expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))())(·it)" + expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))())(λ() → ·it)" actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert str(actual) == expected @@ -42,7 +55,7 @@ def test_simple(): def test_double_deref(): testee = im.let("var", im.lift("deref")("it"))(im.plus(im.deref("var"), im.deref("var"))) - expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1))() + ·(↑(λ() → _icdlv_1))())(·it)" + expected = "(λ(_icdlv_1) → ·(↑(λ() → _icdlv_1()))() + ·(↑(λ() → _icdlv_1()))())(λ() → ·it)" actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert str(actual) == expected @@ -62,3 +75,18 @@ def test_deref_at_multiple_pos(): actual = unwrap_from_program(InlineCenterDerefLiftVars.apply(wrap_in_program(testee))) assert testee == actual + + +def test_bc(): + # we also check that the common subexpression is able to extract the inlined value, such + # that it is only evaluated once + testee = im.let("var", im.lift("deref")("it2"))( + im.if_(im.deref("it1"), im.literal_from_value(0), im.plus(im.deref("var"), im.deref("var"))) + ) + expected = "(λ(_icdlv_1) → if ·it1 then 0 else (λ(_cs_1) → _cs_1 + _cs_1)(·(↑(λ() → _icdlv_1()))()))(λ() → ·it2)" + + actual = InlineCenterDerefLiftVars.apply( + wrap_in_program(testee, arg_dtypes=[ts.ScalarKind.BOOL, ts.ScalarKind.FLOAT64]) + ) + simplified = unwrap_from_program(cse.CommonSubexpressionElimination.apply(actual)) + assert str(simplified) == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index e45281734b..c10d48ad06 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -8,6 +8,7 @@ import pytest +from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas @@ -39,6 +40,21 @@ ), im.multiplies_(im.plus(2, 1), im.plus("x", "x")), ), + ( + # ensure opcount preserving option works whether `itir.SymRef` has a type or not + "typed_ref", + im.let("a", im.call("opaque")())( + im.plus(im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None)) + ), + { + True: im.let("a", im.call("opaque")())( + im.plus( # stays as is + im.ref("a", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)), im.ref("a", None) + ) + ), + False: im.plus(im.call("opaque")(), im.call("opaque")()), + }, + ), ] @@ -68,3 +84,10 @@ def test_inline_lambda_args(): ) inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True) assert inlined == expected + + +def test_type_preservation(): + testee = im.let("a", "b")("a") + testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + inlined = InlineLambdas.apply(testee) + assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py index f81ca5a666..957e7ffe63 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lifts.py @@ -27,15 +27,15 @@ def inline_lift_test_data(): ), ( # can_deref(lift(f)(args...)) -> and(can_deref(arg[0]), and(can_deref(arg[1]), ...)) - im.call("can_deref")(im.lift("f")("arg1", "arg2")), - im.and_(im.call("can_deref")("arg1"), im.call("can_deref")("arg2")), + im.can_deref(im.lift("f")("arg1", "arg2")), + im.and_(im.can_deref("arg1"), im.can_deref("arg2")), ), ( # can_deref(shift(...)(lift(f)(args...)) -> and(can_deref(shift(...)(arg[0])), and(can_deref(shift(...)(arg[1])), ...)) - im.call("can_deref")(im.shift("I", 1)(im.lift("f")("arg1", "arg2"))), + im.can_deref(im.shift("I", 1)(im.lift("f")("arg1", "arg2"))), im.and_( - im.call("can_deref")(im.shift("I", 1)("arg1")), - im.call("can_deref")(im.shift("I", 1)("arg2")), + im.can_deref(im.shift("I", 1)("arg1")), + im.can_deref(im.shift("I", 1)("arg2")), ), ), ( diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py new file mode 100644 index 0000000000..3e655b71f4 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_scalar.py @@ -0,0 +1,47 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +import pytest + +from gt4py.next import common +from gt4py.next.iterator import ir as itir +from gt4py.next.type_system import type_specifications as ts +from gt4py.next.iterator.transforms import inline_scalar +from gt4py.next.iterator.ir_utils import ir_makers as im + +TDim = common.Dimension(value="TDim") +int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) + + +def program_factory(expr: itir.Expr) -> itir.Program: + return itir.Program( + id="testee", + function_definitions=[], + params=[im.sym("out", ts.FieldType(dims=[TDim], dtype=int_type))], + declarations=[], + body=[ + itir.SetAt( + expr=expr, + target=im.ref("out"), + domain=im.domain(common.GridType.CARTESIAN, {TDim: (0, 1)}), + ) + ], + ) + + +def test_simple(): + testee = program_factory(im.let("a", 1)(im.op_as_fieldop("plus")("a", "a"))) + expected = program_factory(im.op_as_fieldop("plus")(1, 1)) + actual = inline_scalar.InlineScalar.apply(testee, offset_provider_type={}) + assert actual == expected + + +def test_fo_inline_only(): + scalar_expr = im.let("a", 1)(im.plus("a", "a")) + testee = program_factory(im.as_fieldop(im.lambda_()(scalar_expr))()) + actual = inline_scalar.InlineScalar.apply(testee, offset_provider_type={}) + assert actual == testee diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py new file mode 100644 index 0000000000..b1a18ddab8 --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_casts.py @@ -0,0 +1,42 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from gt4py import next as gtx +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.prune_casts import PruneCasts +from gt4py.next.iterator.type_system import inference as type_inference +from gt4py.next.type_system import type_specifications as ts + + +def test_prune_casts_simple(): + x_ref = im.ref("x", ts.ScalarType(kind=ts.ScalarKind.FLOAT32)) + y_ref = im.ref("y", ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + testee = im.plus(im.cast_(x_ref, "float64"), im.cast_(y_ref, "float64")) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) + + expected = im.plus(im.cast_(x_ref, "float64"), y_ref) + actual = PruneCasts.apply(testee) + assert actual == expected + + +def test_prune_casts_fieldop(): + IDim = gtx.Dimension("IDim") + x_ref = im.ref("x", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) + y_ref = im.ref("y", ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64))) + testee = im.op_as_fieldop("plus")( + im.cast_as_fieldop("float64")(x_ref), + im.cast_as_fieldop("float64")(y_ref), + ) + testee = type_inference.infer(testee, offset_provider_type={}, allow_undeclared_symbols=True) + + expected = im.op_as_fieldop("plus")( + im.cast_as_fieldop("float64")(x_ref), + y_ref, + ) + actual = PruneCasts.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py deleted file mode 100644 index 407ccad924..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py +++ /dev/null @@ -1,68 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs - - -def test_simple(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected - - -def test_shadowing(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py index 0c118ff6dc..c162860c7c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py @@ -6,28 +6,23 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass -from typing import Optional -from gt4py import eve from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.symbol_ref_utils import ( - collect_symbol_refs, - get_user_defined_symbols, -) +from gt4py.next.iterator.transforms.symbol_ref_utils import get_user_defined_symbols def test_get_user_defined_symbols(): - ir = itir.FencilDefinition( + domain = itir.FunCall(fun=itir.SymRef(id="cartesian_domain"), args=[]) + ir = itir.Program( id="foo", function_definitions=[], params=[itir.Sym(id="target_symbol")], - closures=[ - itir.StencilClosure( - domain=itir.FunCall(fun=itir.SymRef(id="cartesian_domain"), args=[]), - stencil=itir.SymRef(id="deref"), - output=itir.SymRef(id="target_symbol"), - inputs=[], + declarations=[], + body=[ + itir.SetAt( + expr=itir.Lambda(params=[itir.Sym(id="foo")], expr=itir.SymRef(id="foo")), + domain=domain, + target=itir.SymRef(id="target_symbol"), ) ], ) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 1cf662e221..dd7a8f4d43 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -44,7 +44,7 @@ def test_neighbors(): def test_reduce(): # λ(inp) → reduce(plus, 0.)(·inp) - testee = im.lambda_("inp")(im.call(im.call("reduce")("plus", 0.0))(im.deref("inp"))) + testee = im.lambda_("inp")(im.reduce("plus", 0.0)(im.deref("inp"))) expected = [{()}] actual = TraceShifts.trace_stencil(testee) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py index 09ed204a91..2415a42267 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_unroll_reduce.py @@ -11,11 +11,20 @@ import pytest from gt4py.eve.utils import UIDs +from gt4py.next import common from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce, _get_partial_offset_tags -from next_tests.unit_tests.conftest import DummyConnectivity + +def dummy_connectivity_type(max_neighbors: int, has_skip_values: bool): + return common.NeighborConnectivityType( + domain=[common.Dimension("dummy_origin"), common.Dimension("dummy_neighbor")], + codomain=common.Dimension("dummy_codomain"), + skip_value=common._DEFAULT_SKIP_VALUE if has_skip_values else None, + dtype=None, + max_neighbors=max_neighbors, + ) @pytest.fixture(params=[True, False]) @@ -26,93 +35,33 @@ def has_skip_values(request): @pytest.fixture def basic_reduction(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="x")], - ) - ], - ) + return im.reduce("foo", 0.0)(im.neighbors("Dim", "x")) @pytest.fixture def reduction_with_shift_on_second_arg(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.SymRef(id="x"), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="y")], - ), - ], - ) + return im.reduce("foo", 0.0)("x", im.neighbors("Dim", "y")) @pytest.fixture def reduction_with_incompatible_shifts(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="x")], - ), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim2"), ir.SymRef(id="y")], - ), - ], - ) + return im.reduce("foo", 0.0)(im.neighbors("Dim", "x"), im.neighbors("Dim2", "y")) @pytest.fixture def reduction_with_irrelevant_full_shift(): UIDs.reset_sequence() - return ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="foo"), im.literal("0.0", "float64")], - ), - args=[ - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ - ir.OffsetLiteral(value="Dim"), - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ - ir.OffsetLiteral(value="IrrelevantDim"), - ir.OffsetLiteral(value="0"), - ], - ), - args=[ir.SymRef(id="x")], - ), - ], - ), - ir.FunCall( - fun=ir.SymRef(id="neighbors"), - args=[ir.OffsetLiteral(value="Dim"), ir.SymRef(id="y")], - ), - ], + return im.reduce("foo", 0.0)( + im.neighbors("Dim", im.shift("IrrelevantDim", 0)("x")), im.neighbors("Dim", "y") ) -# TODO add a test with lift +@pytest.fixture +def reduction_if(): + UIDs.reset_sequence() + return im.reduce("foo", 0.0)(im.if_(True, im.neighbors("Dim", "x"), "y")) @pytest.mark.parametrize( @@ -121,99 +70,109 @@ def reduction_with_irrelevant_full_shift(): "basic_reduction", "reduction_with_irrelevant_full_shift", "reduction_with_shift_on_second_arg", + "reduction_if", ], ) def test_get_partial_offsets(reduction, request): - offset_provider = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} + offset_provider_type = {"Dim": SimpleNamespace(max_neighbors=3, has_skip_values=False)} partial_offsets = _get_partial_offset_tags(request.getfixturevalue(reduction).args) assert set(partial_offsets) == {"Dim"} def _expected(red, dim, max_neighbors, has_skip_values, shifted_arg=0): - acc = ir.SymRef(id="_acc_1") - offset = ir.SymRef(id="_i_2") - step = ir.SymRef(id="_step_3") + acc, offset, step = "_acc_1", "_i_2", "_step_3" red_fun, red_init = red.fun.args - elements = [ir.FunCall(fun=ir.SymRef(id="list_get"), args=[offset, arg]) for arg in red.args] + elements = [im.list_get(offset, arg) for arg in red.args] - step_expr = ir.FunCall(fun=red_fun, args=[acc] + elements) + step_expr = im.call(red_fun)(acc, *elements) if has_skip_values: neighbors_offset = red.args[shifted_arg].args[0] neighbors_it = red.args[shifted_arg].args[1] - can_deref = ir.FunCall( - fun=ir.SymRef(id="can_deref"), - args=[ - ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="shift"), args=[neighbors_offset, offset]), - args=[neighbors_it], - ) - ], - ) - step_expr = ir.FunCall(fun=ir.SymRef(id="if_"), args=[can_deref, step_expr, acc]) - step_fun = ir.Lambda(params=[ir.Sym(id=acc.id), ir.Sym(id=offset.id)], expr=step_expr) + can_deref = im.can_deref(im.shift(neighbors_offset, offset)(neighbors_it)) + + step_expr = im.if_(can_deref, step_expr, acc) + step_fun = im.lambda_(acc, offset)(step_expr) step_app = red_init for i in range(max_neighbors): - step_app = ir.FunCall(fun=step, args=[step_app, ir.OffsetLiteral(value=i)]) + step_app = im.call(step)(step_app, ir.OffsetLiteral(value=i)) - return ir.FunCall(fun=ir.Lambda(params=[ir.Sym(id=step.id)], expr=step_app), args=[step_fun]) + return im.let(step, step_fun)(step_app) def test_basic(basic_reduction, has_skip_values): expected = _expected(basic_reduction, "Dim", 3, has_skip_values) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=3, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(basic_reduction, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply(basic_reduction, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_shift_on_second_arg(reduction_with_shift_on_second_arg, has_skip_values): expected = _expected(reduction_with_shift_on_second_arg, "Dim", 1, has_skip_values, 1) - offset_provider = {"Dim": DummyConnectivity(max_neighbors=1, has_skip_values=has_skip_values)} - actual = UnrollReduce.apply(reduction_with_shift_on_second_arg, offset_provider=offset_provider) + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=1, has_skip_values=has_skip_values) + } + actual = UnrollReduce.apply( + reduction_with_shift_on_second_arg, offset_provider_type=offset_provider_type + ) + assert actual == expected + + +def test_reduction_with_if(reduction_if): + expected = _expected(reduction_if, "Dim", 2, False) + + offset_provider_type = {"Dim": dummy_connectivity_type(max_neighbors=2, has_skip_values=False)} + actual = UnrollReduce.apply(reduction_if, offset_provider_type=offset_provider_type) assert actual == expected def test_reduction_with_irrelevant_full_shift(reduction_with_irrelevant_full_shift): expected = _expected(reduction_with_irrelevant_full_shift, "Dim", 3, False) - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "IrrelevantDim": DummyConnectivity( + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "IrrelevantDim": dummy_connectivity_type( max_neighbors=1, has_skip_values=True ), # different max_neighbors and skip value to trigger error } actual = UnrollReduce.apply( - reduction_with_irrelevant_full_shift, offset_provider=offset_provider + reduction_with_irrelevant_full_shift, offset_provider_type=offset_provider_type ) assert actual == expected @pytest.mark.parametrize( - "offset_provider", + "offset_provider_type", [ { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=3, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=3, has_skip_values=True), }, { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=True), + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=True), }, ], ) -def test_reduction_with_incompatible_shifts(reduction_with_incompatible_shifts, offset_provider): - offset_provider = { - "Dim": DummyConnectivity(max_neighbors=3, has_skip_values=False), - "Dim2": DummyConnectivity(max_neighbors=2, has_skip_values=False), +def test_reduction_with_incompatible_shifts( + reduction_with_incompatible_shifts, offset_provider_type +): + offset_provider_type = { + "Dim": dummy_connectivity_type(max_neighbors=3, has_skip_values=False), + "Dim2": dummy_connectivity_type(max_neighbors=2, has_skip_values=False), } with pytest.raises(RuntimeError, match="incompatible"): - UnrollReduce.apply(reduction_with_incompatible_shifts, offset_provider=offset_provider) + UnrollReduce.apply( + reduction_with_incompatible_shifts, offset_provider_type=offset_provider_type + ) diff --git a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py index b1e051c82b..a25732649a 100644 --- a/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py +++ b/tests/next_tests/unit_tests/otf_tests/binding_tests/test_cpp_interface.py @@ -33,10 +33,10 @@ def test_render_function_declaration_scalar(function_scalar_example): expected = format_source( "cpp", """\ - decltype(auto) example(double a, std::int64_t b) { +decltype(auto) example(double a, std::int64_t b) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected @@ -60,14 +60,14 @@ def function_buffer_example(): interface.Parameter( name="a_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ), interface.Parameter( name="b_buf", type_=ts.FieldType( - dims=[gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.INT64) + dims=[gtx.Dimension("bar")], dtype=ts.ScalarType(ts.ScalarKind.INT64) ), ), ], @@ -81,11 +81,11 @@ def test_render_function_declaration_buffer(function_buffer_example): expected = format_source( "cpp", """\ - template - decltype(auto) example(ArgT0 &&a_buf, ArgT1 &&b_buf) { +template + decltype(auto) example(ArgT0&& a_buf, ArgT1&& b_buf) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected @@ -111,11 +111,11 @@ def function_tuple_example(): type_=ts.TupleType( types=[ ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ts.FieldType( - dims=[gtx.Dimension("foo"), gtx.Dimension("bar")], + dims=[gtx.Dimension("bar"), gtx.Dimension("foo")], dtype=ts.ScalarType(ts.ScalarKind.FLOAT64), ), ] @@ -132,11 +132,11 @@ def test_render_function_declaration_tuple(function_tuple_example): expected = format_source( "cpp", """\ - template - decltype(auto) example(ArgT0 &&a_buf) { +template + decltype(auto) example(ArgT0&& a_buf) { return; }\ - """, +""", style="LLVM", ) assert rendered == expected diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index e3e0ee474f..e7053d3317 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -6,19 +6,30 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import copy + +import diskcache import numpy as np import pytest import gt4py.next as gtx -from gt4py.next.iterator import ir as itir +from gt4py.next.iterator import builtins, ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import arguments, languages, stages from gt4py.next.program_processors.codegens.gtfn import gtfn_module +from gt4py.next.program_processors.runners import gtfn from gt4py.next.type_system import type_translation +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import cartesian_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + KDim, + exec_alloc_descriptor, +) + @pytest.fixture -def fencil_example(): +def program_example(): IDim = gtx.Dimension("I") params = [gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), np.float32(3.14)] param_types = [type_translation.from_value(param) for param in params] @@ -30,13 +41,13 @@ def fencil_example(): fun=itir.SymRef(id="named_range"), args=[ itir.AxisLiteral(value="I"), - im.literal("0", itir.INTEGER_INDEX_BUILTIN), - im.literal("10", itir.INTEGER_INDEX_BUILTIN), + im.literal("0", builtins.INTEGER_INDEX_BUILTIN), + im.literal("10", builtins.INTEGER_INDEX_BUILTIN), ], ) ], ) - fencil = itir.FencilDefinition( + program = itir.Program( id="example", params=[im.sym(name, type_) for name, type_ in zip(("buf", "sc"), param_types)], function_definitions=[ @@ -46,20 +57,22 @@ def fencil_example(): expr=im.literal("1", "float32"), ) ], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.as_fieldop(itir.SymRef(id="stencil"), domain)( + itir.SymRef(id="buf"), itir.SymRef(id="sc") + ), domain=domain, - stencil=itir.SymRef(id="stencil"), - output=itir.SymRef(id="buf"), - inputs=[itir.SymRef(id="buf"), itir.SymRef(id="sc")], + target=itir.SymRef(id="buf"), ) ], ) - return fencil, params + return program, params -def test_codegen(fencil_example): - fencil, parameters = fencil_example +def test_codegen(program_example): + fencil, parameters = program_example module = gtfn_module.translate_program_cpu( stages.CompilableProgram( data=fencil, @@ -71,3 +84,103 @@ def test_codegen(fencil_example): assert module.entry_point.name == fencil.id assert any(d.name == "gridtools_cpu" for d in module.library_deps) assert module.language is languages.CPP + + +def test_hash_and_diskcache(program_example, tmp_path): + fencil, parameters = program_example + compilable_program = stages.CompilableProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete_no_size( + *parameters, **{"offset_provider": {}} + ), + ) + hash = stages.fingerprint_compilable_program(compilable_program) + + with diskcache.Cache(tmp_path) as cache: + cache[hash] = compilable_program + + # check content of cash file + with diskcache.Cache(tmp_path) as reopened_cache: + assert hash in reopened_cache + compilable_program_from_cache = reopened_cache[hash] + assert compilable_program == compilable_program_from_cache + del reopened_cache[hash] # delete data + + # hash creation is deterministic + assert hash == stages.fingerprint_compilable_program(compilable_program) + assert hash == stages.fingerprint_compilable_program(compilable_program_from_cache) + + # hash is different if program changes + altered_program_id = copy.deepcopy(compilable_program) + altered_program_id.data.id = "example2" + assert stages.fingerprint_compilable_program( + compilable_program + ) != stages.fingerprint_compilable_program(altered_program_id) + + altered_program_offset_provider = copy.deepcopy(compilable_program) + object.__setattr__(altered_program_offset_provider.args, "offset_provider", {"Koff": KDim}) + assert stages.fingerprint_compilable_program( + compilable_program + ) != stages.fingerprint_compilable_program(altered_program_offset_provider) + + altered_program_column_axis = copy.deepcopy(compilable_program) + object.__setattr__(altered_program_column_axis.args, "column_axis", KDim) + assert stages.fingerprint_compilable_program( + compilable_program + ) != stages.fingerprint_compilable_program(altered_program_column_axis) + + +def test_gtfn_file_cache(program_example): + fencil, parameters = program_example + compilable_program = stages.CompilableProgram( + data=fencil, + args=arguments.CompileTimeArgs.from_concrete_no_size( + *parameters, **{"offset_provider": {}} + ), + ) + cached_gtfn_translation_step = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=True + ).executor.step.translation + + bare_gtfn_translation_step = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=False + ).executor.step.translation + + cache_key = stages.fingerprint_compilable_program(compilable_program) + + # ensure the actual cached step in the backend generates the cache item for the test + if cache_key in (translation_cache := cached_gtfn_translation_step.cache): + del translation_cache[cache_key] + cached_gtfn_translation_step(compilable_program) + assert bare_gtfn_translation_step(compilable_program) == cached_gtfn_translation_step( + compilable_program + ) + + assert cache_key in cached_gtfn_translation_step.cache + assert ( + bare_gtfn_translation_step(compilable_program) + == cached_gtfn_translation_step.cache[cache_key] + ) + + +# TODO(egparedes): we should switch to use the cached backend by default and then remove this test +def test_gtfn_file_cache_whole_workflow(cartesian_case): + if cartesian_case.backend != gtfn.run_gtfn: + pytest.skip("Skipping backend.") + cartesian_case.backend = gtfn.GTFNBackendFactory( + gpu=False, cached=True, otf_workflow__cached_translation=True + ) + + @gtx.field_operator + def testee(a: cases.IJKField) -> cases.IJKField: + field_tuple = (a, a) + field_0 = field_tuple[0] + field_1 = field_tuple[1] + return field_0 + + # first call: this generates the cache file + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) + # clearing the OTFCompileWorkflow cache such that the OTFCompileWorkflow step is executed again + object.__setattr__(cartesian_case.backend.executor, "cache", {}) + # second call: the cache file is used + cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py index 1a86f7b0f8..50e8fa43f0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_itir_to_gtfn_ir.py @@ -21,7 +21,7 @@ def test_funcall_to_op(): ) actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual @@ -32,7 +32,7 @@ def test_unapplied_funcall_to_function_object(): expected = gtfn_ir.SymRef(id="plus") actual = it2gtfn.GTFN_lowering( - grid_type=gtx.GridType.CARTESIAN, offset_provider={}, column_axis=None + grid_type=gtx.GridType.CARTESIAN, offset_provider_type={}, column_axis=None ).visit(testee) assert expected == actual @@ -47,7 +47,7 @@ def test_get_domains(): declarations=[], body=[ itir.SetAt( - expr=im.call(im.call("as_fieldop")("deref"))(), + expr=im.as_fieldop("deref")(), domain=domain, target=itir.SymRef(id="bar"), ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/__init__.py index 9fa07e46e9..1cdf0f0591 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/__init__.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/__init__.py @@ -9,4 +9,5 @@ import pytest -pytestmark = pytest.mark.requires_dace +#: Attribute defining package-level marks used by a custom pytest hook. +package_pytestmarks = [pytest.mark.requires_dace] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py index 329b2814d2..64ec757f16 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace.py @@ -20,18 +20,11 @@ from gt4py.next.ffront.fbuiltins import where from next_tests.integration_tests import cases -from next_tests.integration_tests.cases import ( - E2V, - cartesian_case, - unstructured_case, -) +from next_tests.integration_tests.cases import E2V, cartesian_case, unstructured_case # noqa: F401 from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( - exec_alloc_descriptor, - mesh_descriptor, + exec_alloc_descriptor, # noqa: F401 + mesh_descriptor, # noqa: F401 ) -from unittest.mock import patch - -from . import pytestmark dace = pytest.importorskip("dace") @@ -151,14 +144,14 @@ def test_dace_fastcall_with_connectivity(unstructured_case, monkeypatch): # check that test connectivities are allocated on host memory # this is an assumption to test that fast_call cannot be used for gpu tests - assert isinstance(connectivity_E2V.table, np.ndarray) + assert isinstance(connectivity_E2V.ndarray, np.ndarray) @gtx.field_operator def testee(a: cases.VField) -> cases.EField: return a(E2V[0]) (a,), kwfields = cases.get_default_data(unstructured_case, testee) - numpy_ref = lambda a: a[connectivity_E2V.table[:, 0]] + numpy_ref = lambda a: a[connectivity_E2V.ndarray[:, 0]] mock_fast_call, mock_construct_args = make_mocks(monkeypatch) @@ -182,7 +175,7 @@ def verify_testee(offset_provider): offset_provider = unstructured_case.offset_provider else: assert gtx.allocators.is_field_allocator_for( - unstructured_case.backend.allocator, gtx.allocators.CUPY_DEVICE + unstructured_case.backend.allocator, core_defs.CUPY_DEVICE_TYPE ) import cupy as cp @@ -191,15 +184,14 @@ def verify_testee(offset_provider): # to gpu memory at each program call (see `dace_backend._ensure_is_on_device`), # therefore fast_call cannot be used (unless cupy reuses the same cupy array # from the its memory pool, but this behavior is random and unpredictable). - # Here we copy the connectivity to gpu memory, and resuse the same cupy array + # Here we copy the connectivity to gpu memory, and reuse the same cupy array # on multiple program calls, in order to ensure that fast_call is used. offset_provider = { - "E2V": gtx.NeighborTableOffsetProvider( - table=cp.asarray(connectivity_E2V.table), - origin_axis=connectivity_E2V.origin_axis, - neighbor_axis=connectivity_E2V.neighbor_axis, - max_neighbors=connectivity_E2V.max_neighbors, - has_skip_values=connectivity_E2V.has_skip_values, + "E2V": gtx.as_connectivity( + domain=connectivity_E2V.domain, + codomain=connectivity_E2V.codomain, + data=cp.asarray(connectivity_E2V.ndarray), + skip_value=connectivity_E2V.skip_value, ) } diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py new file mode 100644 index 0000000000..eec68a6486 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_utils.py @@ -0,0 +1,21 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""Test utility functions of the dace backend module.""" + +import pytest + +dace = pytest.importorskip("dace") + +from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils + + +def test_safe_replace_symbolic(): + assert gtx_dace_utils.safe_replace_symbolic( + dace.symbolic.pystr_to_symbolic("x*x + y"), symbol_mapping={"x": "y", "y": "x"} + ) == dace.symbolic.pystr_to_symbolic("y*y + x") diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index e819cdcd8c..030aa9b131 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -12,8 +12,8 @@ Note: this test module covers the fieldview flavour of ITIR. """ -import copy import functools +from typing import Any, Callable import numpy as np import pytest @@ -21,6 +21,7 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import infer_domain from gt4py.next.type_system import type_specifications as ts from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( @@ -34,37 +35,31 @@ skip_value_mesh, ) -from . import pytestmark -dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace_fieldview") +dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace") N = 10 -IFTYPE = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -CFTYPE = ts.FieldType(dims=[Cell], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -EFTYPE = ts.FieldType(dims=[Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) -VFTYPE = ts.FieldType(dims=[Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) +FLOAT_TYPE = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) +IFTYPE = ts.FieldType(dims=[IDim], dtype=FLOAT_TYPE) +CFTYPE = ts.FieldType(dims=[Cell], dtype=FLOAT_TYPE) +EFTYPE = ts.FieldType(dims=[Edge], dtype=FLOAT_TYPE) +VFTYPE = ts.FieldType(dims=[Vertex], dtype=FLOAT_TYPE) V2E_FTYPE = ts.FieldType(dims=[Vertex, V2EDim], dtype=EFTYPE.dtype) CARTESIAN_OFFSETS = { - "IDim": IDim, + IDim.value: IDim, } SIMPLE_MESH: MeshDescriptor = simple_mesh() -SIMPLE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SIMPLE_MESH.offset_provider | CARTESIAN_OFFSETS -) SKIP_VALUE_MESH: MeshDescriptor = skip_value_mesh() -SKIP_VALUE_MESH_OFFSET_PROVIDER: dict[str, gtx_common.Connectivity | gtx_common.Dimension] = ( - SKIP_VALUE_MESH.offset_provider | CARTESIAN_OFFSETS -) SIZE_TYPE = ts.ScalarType(ts.ScalarKind.INT32) FSYMBOLS = dict( - __w_size_0=N, + __w_0_range_1=N, __w_stride_0=1, - __x_size_0=N, + __x_0_range_1=N, __x_stride_0=1, - __y_size_0=N, + __y_0_range_1=N, __y_stride_0=1, - __z_size_0=N, + __z_0_range_1=N, __z_stride_0=1, size=N, ) @@ -75,31 +70,39 @@ def make_mesh_symbols(mesh: MeshDescriptor): ncells=mesh.num_cells, nedges=mesh.num_edges, nvertices=mesh.num_vertices, - __cells_size_0=mesh.num_cells, + __cells_0_range_1=mesh.num_cells, __cells_stride_0=1, - __edges_size_0=mesh.num_edges, + __edges_0_range_1=mesh.num_edges, __edges_stride_0=1, - __vertices_size_0=mesh.num_vertices, + __vertices_0_range_1=mesh.num_vertices, __vertices_stride_0=1, - __connectivity_C2E_size_0=mesh.num_cells, - __connectivity_C2E_size_1=mesh.offset_provider["C2E"].max_neighbors, - __connectivity_C2E_stride_0=mesh.offset_provider["C2E"].max_neighbors, + __connectivity_C2E_0_range_1=mesh.num_cells, + __connectivity_C2E_size_1=mesh.offset_provider_type["C2E"].max_neighbors, + __connectivity_C2E_stride_0=mesh.offset_provider_type["C2E"].max_neighbors, __connectivity_C2E_stride_1=1, - __connectivity_C2V_size_0=mesh.num_cells, - __connectivity_C2V_size_1=mesh.offset_provider["C2V"].max_neighbors, - __connectivity_C2V_stride_0=mesh.offset_provider["C2V"].max_neighbors, + __connectivity_C2V_0_range_1=mesh.num_cells, + __connectivity_C2V_size_1=mesh.offset_provider_type["C2V"].max_neighbors, + __connectivity_C2V_stride_0=mesh.offset_provider_type["C2V"].max_neighbors, __connectivity_C2V_stride_1=1, - __connectivity_E2V_size_0=mesh.num_edges, - __connectivity_E2V_size_1=mesh.offset_provider["E2V"].max_neighbors, - __connectivity_E2V_stride_0=mesh.offset_provider["E2V"].max_neighbors, + __connectivity_E2V_0_range_1=mesh.num_edges, + __connectivity_E2V_size_1=mesh.offset_provider_type["E2V"].max_neighbors, + __connectivity_E2V_stride_0=mesh.offset_provider_type["E2V"].max_neighbors, __connectivity_E2V_stride_1=1, - __connectivity_V2E_size_0=mesh.num_vertices, - __connectivity_V2E_size_1=mesh.offset_provider["V2E"].max_neighbors, - __connectivity_V2E_stride_0=mesh.offset_provider["V2E"].max_neighbors, + __connectivity_V2E_0_range_1=mesh.num_vertices, + __connectivity_V2E_size_1=mesh.offset_provider_type["V2E"].max_neighbors, + __connectivity_V2E_stride_0=mesh.offset_provider_type["V2E"].max_neighbors, __connectivity_V2E_stride_1=1, ) +def build_dace_sdfg( + ir: gtir.Program, offset_provider_type: gtx_common.OffsetProviderType +) -> Callable[..., Any]: + return dace_backend.build_sdfg_from_gtir( + ir, offset_provider_type, disable_field_origin_on_program_arguments=True + ) + + def test_gtir_broadcast(): val = np.random.rand() domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) @@ -122,7 +125,7 @@ def test_gtir_broadcast(): a = np.empty(N, dtype=np.float64) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, **FSYMBOLS) np.testing.assert_array_equal(a, val) @@ -145,9 +148,7 @@ def test_gtir_cast(): body=[ gtir.SetAt( expr=im.op_as_fieldop("eq", domain)( - im.as_fieldop( - im.lambda_("a")(im.call("cast_")(im.deref("a"), "float32")), domain - )("x"), + im.cast_as_fieldop("float32", domain)("x"), "y", ), domain=domain, @@ -160,7 +161,7 @@ def test_gtir_cast(): b = a.astype(np.float32) c = np.empty_like(a, dtype=np.bool_) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) np.testing.assert_array_equal(c, True) @@ -188,7 +189,7 @@ def test_gtir_copy_self(): a = np.random.rand(N) ref = a.copy() - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, **FSYMBOLS) assert np.allclose(a, ref) @@ -219,7 +220,7 @@ def test_gtir_tuple_swap(): b = np.random.rand(N) ref = (a.copy(), b.copy()) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, **FSYMBOLS) assert np.allclose(a, ref[1]) @@ -258,19 +259,20 @@ def test_gtir_tuple_args(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) x_fields = (a, a, b) - x_symbols = dict( - __x_0_size_0=FSYMBOLS["__x_size_0"], - __x_0_stride_0=FSYMBOLS["__x_stride_0"], - __x_1_0_size_0=FSYMBOLS["__x_size_0"], - __x_1_0_stride_0=FSYMBOLS["__x_stride_0"], - __x_1_1_size_0=FSYMBOLS["__y_size_0"], - __x_1_1_stride_0=FSYMBOLS["__y_stride_0"], - ) - sdfg(*x_fields, c, **FSYMBOLS, **x_symbols) + tuple_symbols = { + "__x_0_0_range_1": N, + "__x_0_stride_0": 1, + "__x_1_0_0_range_1": N, + "__x_1_0_stride_0": 1, + "__x_1_1_0_range_1": N, + "__x_1_1_stride_0": 1, + } + + sdfg(*x_fields, c, **FSYMBOLS, **tuple_symbols) assert np.allclose(c, a * 2 + b) @@ -309,12 +311,97 @@ def test_gtir_tuple_expr(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, a * 2 + b) +def test_gtir_tuple_broadcast_scalar(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + testee = gtir.Program( + id="gtir_tuple_broadcast_scalar", + function_definitions=[], + params=[ + gtir.Sym( + id="x", + type=ts.TupleType(types=[FLOAT_TYPE, ts.TupleType(types=[FLOAT_TYPE, FLOAT_TYPE])]), + ), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.as_fieldop("deref", domain)( + im.plus( + im.tuple_get(0, "x"), + im.plus( + im.multiplies_( + im.tuple_get( + 0, + im.tuple_get(1, "x"), + ), + 2.0, + ), + im.multiplies_( + im.tuple_get( + 1, + im.tuple_get(1, "x"), + ), + 3.0, + ), + ), + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand() + b = np.random.rand() + c = np.random.rand() + d = np.empty(N, dtype=type(a)) + + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + + x_fields = (a, b, c) + + sdfg(*x_fields, d, **FSYMBOLS) + assert np.allclose(d, a + 2 * b + 3 * c) + + +def test_gtir_zero_dim_fields(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + testee = gtir.Program( + id="gtir_zero_dim_fields", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[], dtype=IFTYPE.dtype)), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.as_fieldop("deref", domain)("x"), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.asarray(np.random.rand()) + b = np.empty(N) + + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + + sdfg(a.item(), b, **FSYMBOLS) + assert np.allclose(a, b) + + def test_gtir_tuple_return(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) testee = gtir.Program( @@ -343,19 +430,20 @@ def test_gtir_tuple_return(): a = np.random.rand(N) b = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) - z_symbols = dict( - __z_0_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_0_1_size_0=FSYMBOLS["__x_size_0"], - __z_0_1_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - ) - sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) + tuple_symbols = { + "__z_0_0_0_range_1": N, + "__z_0_0_stride_0": 1, + "__z_0_1_0_range_1": N, + "__z_0_1_stride_0": 1, + "__z_1_0_range_1": N, + "__z_1_stride_0": 1, + } + + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a + b) assert np.allclose(z_fields[1], a) assert np.allclose(z_fields[2], b) @@ -385,7 +473,7 @@ def test_gtir_tuple_target(): b = np.empty_like(a) ref = a.copy() - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, **FSYMBOLS) assert np.allclose(a, ref + 1) @@ -417,7 +505,7 @@ def test_gtir_update(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) a = np.random.rand(N) ref = a - 1.0 @@ -451,7 +539,7 @@ def test_gtir_sum2(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, (a + b)) @@ -480,7 +568,7 @@ def test_gtir_sum2_sym(): a = np.random.rand(N) b = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) sdfg(a, b, **FSYMBOLS) assert np.allclose(b, (a + a)) @@ -522,7 +610,7 @@ def test_gtir_sum3(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) d = np.empty_like(a) @@ -551,7 +639,7 @@ def test_gtir_cond(): expr=im.op_as_fieldop("plus", domain)( "x", im.if_( - im.greater(gtir.SymRef(id="s1"), gtir.SymRef(id="s2")), + im.greater("s1", "s2"), im.op_as_fieldop("plus", domain)("y", "scalar"), im.op_as_fieldop("plus", domain)("w", "scalar"), ), @@ -566,7 +654,7 @@ def test_gtir_cond(): b = np.random.rand(N) c = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) for s1, s2 in [(1, 2), (2, 1)]: d = np.empty_like(a) @@ -593,7 +681,7 @@ def test_gtir_cond_with_tuple_return(): expr=im.tuple_get( 0, im.if_( - gtir.SymRef(id="pred"), + "pred", im.make_tuple(im.make_tuple("x", "y"), "w"), im.make_tuple(im.make_tuple("y", "x"), "w"), ), @@ -608,18 +696,18 @@ def test_gtir_cond_with_tuple_return(): b = np.random.rand(N) c = np.random.rand(N) - z_symbols = dict( - __z_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], - ) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + tuple_symbols = { + "__z_0_0_range_1": N, + "__z_0_stride_0": 1, + "__z_1_0_range_1": N, + "__z_1_stride_0": 1, + } for s in [False, True]: z_fields = (np.empty_like(a), np.empty_like(a)) - sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS, **z_symbols) + sdfg(a, b, c, *z_fields, pred=np.bool_(s), **FSYMBOLS, **tuple_symbols) assert np.allclose(z_fields[0], a if s else b) assert np.allclose(z_fields[1], b if s else a) @@ -640,10 +728,10 @@ def test_gtir_cond_nested(): body=[ gtir.SetAt( expr=im.if_( - gtir.SymRef(id="pred_1"), + "pred_1", im.op_as_fieldop("plus", domain)("x", 1.0), im.if_( - gtir.SymRef(id="pred_2"), + "pred_2", im.op_as_fieldop("plus", domain)("x", 2.0), im.op_as_fieldop("plus", domain)("x", 3.0), ), @@ -656,7 +744,7 @@ def test_gtir_cond_nested(): a = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) for s1 in [False, True]: for s2 in [False, True]: @@ -680,13 +768,13 @@ def test_gtir_cartesian_shift_left(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -694,13 +782,15 @@ def test_gtir_cartesian_shift_left(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -709,14 +799,14 @@ def test_gtir_cartesian_shift_left(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -760,11 +850,9 @@ def test_gtir_cartesian_shift_left(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) - FSYMBOLS_tmp = FSYMBOLS.copy() - FSYMBOLS_tmp["__x_offset_stride_0"] = 1 - sdfg(a, a_offset, b, **FSYMBOLS_tmp) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_0_range_1=N, __x_offset_stride_0=1) assert np.allclose(a[OFFSET:] + DELTA, b[:-OFFSET]) @@ -775,13 +863,13 @@ def test_gtir_cartesian_shift_right(): # cartesian shift with literal integer offset stencil1_inlined = im.as_fieldop( - im.lambda_("a")(im.plus(im.deref(im.shift("IDim", -OFFSET)("a")), DELTA)), + im.lambda_("a")(im.plus(im.deref(im.shift(IDim.value, -OFFSET)("a")), DELTA)), domain, )("x") # fieldview flavor of same stencil, in which a temporary field is initialized with the `DELTA` constant value stencil1_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a")(im.deref(im.shift("IDim", -OFFSET)("a"))), + im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))), domain, )("x"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -789,13 +877,15 @@ def test_gtir_cartesian_shift_right(): # use dynamic offset retrieved from field stencil2_inlined = im.as_fieldop( - im.lambda_("a", "off")(im.plus(im.deref(im.shift("IDim", im.deref("off"))("a")), DELTA)), + im.lambda_("a", "off")( + im.plus(im.deref(im.shift(IDim.value, im.deref("off"))("a")), DELTA) + ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil2_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )("x", "x_offset"), im.as_fieldop(im.lambda_()(DELTA), domain)(), @@ -804,14 +894,14 @@ def test_gtir_cartesian_shift_right(): # use the result of an arithmetic field operation as dynamic offset stencil3_inlined = im.as_fieldop( im.lambda_("a", "off")( - im.plus(im.deref(im.shift("IDim", im.plus(im.deref("off"), 0))("a")), DELTA) + im.plus(im.deref(im.shift(IDim.value, im.plus(im.deref("off"), 0))("a")), DELTA) ), domain, )("x", "x_offset") # fieldview flavor of same stencil stencil3_fieldview = im.op_as_fieldop("plus", domain)( im.as_fieldop( - im.lambda_("a", "off")(im.deref(im.shift("IDim", im.deref("off"))("a"))), + im.lambda_("a", "off")(im.deref(im.shift(IDim.value, im.deref("off"))("a"))), domain, )( "x", @@ -855,9 +945,9 @@ def test_gtir_cartesian_shift_right(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) - sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_stride_0=1) + sdfg(a, a_offset, b, **FSYMBOLS, __x_offset_0_range_1=N, __x_offset_stride_0=1) assert np.allclose(a[:-OFFSET] + DELTA, b[OFFSET:]) @@ -954,19 +1044,19 @@ def test_gtir_connectivity_shift(): im.op_as_fieldop("plus", edge_domain)("e2v_offset", 0), ) - CE_FTYPE = ts.FieldType(dims=[Cell, Edge], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) - EV_FTYPE = ts.FieldType(dims=[Edge, Vertex], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)) + CE_FTYPE = ts.FieldType(dims=[Cell, Edge], dtype=FLOAT_TYPE) + EV_FTYPE = ts.FieldType(dims=[Edge, Vertex], dtype=FLOAT_TYPE) CELL_OFFSET_FTYPE = ts.FieldType(dims=[Cell], dtype=SIZE_TYPE) EDGE_OFFSET_FTYPE = ts.FieldType(dims=[Edge], dtype=SIZE_TYPE) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) ev = np.random.rand(SIMPLE_MESH.num_edges, SIMPLE_MESH.num_vertices) - ref = ev[connectivity_C2E.table[:, C2E_neighbor_idx], :][ - :, connectivity_E2V.table[:, E2V_neighbor_idx] + ref = ev[connectivity_C2E.ndarray[:, C2E_neighbor_idx], :][ + :, connectivity_E2V.ndarray[:, E2V_neighbor_idx] ] for i, stencil in enumerate( @@ -994,7 +1084,7 @@ def test_gtir_connectivity_shift(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) ce = np.empty([SIMPLE_MESH.num_cells, SIMPLE_MESH.num_edges]) @@ -1003,19 +1093,21 @@ def test_gtir_connectivity_shift(): ev, c2e_offset=np.full(SIMPLE_MESH.num_cells, C2E_neighbor_idx, dtype=np.int32), e2v_offset=np.full(SIMPLE_MESH.num_edges, E2V_neighbor_idx, dtype=np.int32), - connectivity_C2E=connectivity_C2E.table, - connectivity_E2V=connectivity_E2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_E2V=connectivity_E2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __ce_field_size_0=SIMPLE_MESH.num_cells, + __ce_field_0_range_1=SIMPLE_MESH.num_cells, __ce_field_size_1=SIMPLE_MESH.num_edges, __ce_field_stride_0=SIMPLE_MESH.num_edges, __ce_field_stride_1=1, - __ev_field_size_0=SIMPLE_MESH.num_edges, + __ev_field_0_range_1=SIMPLE_MESH.num_edges, __ev_field_size_1=SIMPLE_MESH.num_vertices, __ev_field_stride_0=SIMPLE_MESH.num_vertices, __ev_field_stride_1=1, + __c2e_offset_0_range_1=SIMPLE_MESH.num_cells, __c2e_offset_stride_0=1, + __e2v_offset_0_range_1=SIMPLE_MESH.num_edges, __e2v_offset_stride_0=1, ) assert np.allclose(ce, ref) @@ -1053,15 +1145,17 @@ def test_gtir_connectivity_shift_chain(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) - connectivity_E2V = SIMPLE_MESH_OFFSET_PROVIDER["E2V"] + connectivity_E2V = SIMPLE_MESH.offset_provider["E2V"] assert isinstance(connectivity_E2V, gtx_common.NeighborTable) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) - ref = e[connectivity_V2E.table[connectivity_E2V.table[:, E2V_neighbor_idx], V2E_neighbor_idx]] + ref = e[ + connectivity_V2E.ndarray[connectivity_E2V.ndarray[:, E2V_neighbor_idx], V2E_neighbor_idx] + ] # new empty output field e_out = np.empty_like(e) @@ -1069,19 +1163,17 @@ def test_gtir_connectivity_shift_chain(): sdfg( e, e_out, - connectivity_E2V=connectivity_E2V.table, - connectivity_V2E=connectivity_V2E.table, + connectivity_E2V=connectivity_E2V.ndarray, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __edges_out_size_0=SIMPLE_MESH.num_edges, + __edges_out_0_range_1=SIMPLE_MESH.num_edges, __edges_out_stride_0=1, ) assert np.allclose(e_out, ref) def test_gtir_neighbors_as_input(): - # FIXME[#1582](edopao): Enable testcase when type inference is working - pytest.skip("Field of lists not fully supported by GTIR type inference") init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) testee = gtir.Program( @@ -1089,49 +1181,54 @@ def test_gtir_neighbors_as_input(): function_definitions=[], params=[ gtir.Sym(id="v2e_field", type=V2E_FTYPE), - gtir.Sym(id="vertex", type=EFTYPE), + gtir.Sym(id="edges", type=EFTYPE), + gtir.Sym(id="vertices", type=VFTYPE), gtir.Sym(id="nvertices", type=SIZE_TYPE), ], declarations=[], body=[ gtir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - "it" - ) - ), - vertex_domain, + expr=im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) + ), + vertex_domain, + )( + im.op_as_fieldop(im.map_("plus"), vertex_domain)( + "v2e_field", + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) - )("v2e_field"), + ), domain=vertex_domain, - target=gtir.SymRef(id="vertex"), + target=gtir.SymRef(id="vertices"), ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.max_neighbors) + v2e_field = np.random.rand(SIMPLE_MESH.num_vertices, connectivity_V2E.shape[1]) + e = np.random.rand(SIMPLE_MESH.num_edges) v = np.empty(SIMPLE_MESH.num_vertices, dtype=v2e_field.dtype) v_ref = [ - functools.reduce(lambda x, y: x + y, v2e_neighbors, init_value) - for v2e_neighbors in v2e_field + functools.reduce(lambda x, y: x + y, v2e_values + e[v2e_neighbors], init_value) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( v2e_field, + e, v, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __v2e_field_size_0=SIMPLE_MESH.num_vertices, - __v2e_field_size_1=connectivity_V2E.max_neighbors, - __v2e_field_stride_0=connectivity_V2E.max_neighbors, + __v2e_field_0_range_1=SIMPLE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1144,7 +1241,7 @@ def test_gtir_neighbors_as_output(): gtx_common.GridType.UNSTRUCTURED, ranges={ Vertex: (0, "nvertices"), - V2EDim: (0, SIMPLE_MESH_OFFSET_PROVIDER["V2E"].max_neighbors), + V2EDim: (0, SIMPLE_MESH.offset_provider_type["V2E"].max_neighbors), }, ) vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) @@ -1166,9 +1263,9 @@ def test_gtir_neighbors_as_output(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) @@ -1177,48 +1274,38 @@ def test_gtir_neighbors_as_output(): sdfg( e, v2e_field, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), - __v2e_field_size_0=SIMPLE_MESH.num_vertices, + __v2e_field_0_range_1=SIMPLE_MESH.num_vertices, __v2e_field_size_1=connectivity_V2E.max_neighbors, __v2e_field_stride_0=connectivity_V2E.max_neighbors, __v2e_field_stride_1=1, ) - assert np.allclose(v2e_field, e[connectivity_V2E.table]) + assert np.allclose(v2e_field, e[connectivity_V2E.ndarray]) def test_gtir_reduce(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.neighbors("V2E", "it") - ) - ), - vertex_domain, - ) + stencil_inlined = im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.neighbors("V2E", "it")) + ), + vertex_domain, )("edges") - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) + stencil_fieldview = im.as_fieldop( + im.lambda_("it")(im.reduce("plus", im.literal_from_value(init_value))(im.deref("it"))), + vertex_domain, )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SIMPLE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SIMPLE_MESH.num_edges) v_ref = [ functools.reduce(lambda x, y: x + y, e[v2e_neighbors], init_value) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1239,7 +1326,7 @@ def test_gtir_reduce(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) # new empty output field v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) @@ -1247,7 +1334,7 @@ def test_gtir_reduce(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) @@ -1257,36 +1344,28 @@ def test_gtir_reduce(): def test_gtir_reduce_with_skip_values(): init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) - stencil_inlined = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.neighbors("V2E", "it") - ) - ), - vertex_domain, - ) + stencil_inlined = im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.neighbors("V2E", "it")) + ), + vertex_domain, )("edges") - stencil_fieldview = im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) - ), - vertex_domain, - ) + stencil_fieldview = im.as_fieldop( + im.lambda_("it")(im.reduce("plus", im.literal_from_value(init_value))(im.deref("it"))), + vertex_domain, )(im.as_fieldop_neighbors("V2E", "edges", vertex_domain)) - connectivity_V2E = SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"] + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] assert isinstance(connectivity_V2E, gtx_common.NeighborTable) e = np.random.rand(SKIP_VALUE_MESH.num_edges) v_ref = [ functools.reduce( - lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + lambda x, y: x + y, + [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], + init_value, ) - for v2e_neighbors in connectivity_V2E.table + for v2e_neighbors in connectivity_V2E.ndarray ] for i, stencil in enumerate([stencil_inlined, stencil_fieldview]): @@ -1307,7 +1386,7 @@ def test_gtir_reduce_with_skip_values(): ) ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SKIP_VALUE_MESH_OFFSET_PROVIDER) + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) # new empty output field v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) @@ -1315,7 +1394,7 @@ def test_gtir_reduce_with_skip_values(): sdfg( e, v, - connectivity_V2E=connectivity_V2E.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, **make_mesh_symbols(SKIP_VALUE_MESH), ) @@ -1323,15 +1402,32 @@ def test_gtir_reduce_with_skip_values(): def test_gtir_reduce_dot_product(): - # FIXME[#1582](edopao): Enable testcase when type inference is working - pytest.skip("Field of lists not fully supported as a type in GTIR yet") init_value = np.random.rand() vertex_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Vertex: (0, "nvertices")}) + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) + + v2e_field = np.random.rand(*connectivity_V2E.shape) + e = np.random.rand(SKIP_VALUE_MESH.num_edges) + v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) + v_ref = [ + functools.reduce( + lambda x, y: x + y, + map( + lambda x: 0.0 if x[1] == gtx_common._DEFAULT_SKIP_VALUE else x[0], + zip((e[v2e_neighbors] * v2e_values) + 1.0, v2e_neighbors), + ), + init_value, + ) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field) + ] + testee = gtir.Program( - id="reduce_dot_product", + id=f"reduce_dot_product", function_definitions=[], params=[ + gtir.Sym(id="v2e_field", type=V2E_FTYPE), gtir.Sym(id="edges", type=EFTYPE), gtir.Sym(id="vertices", type=VFTYPE), gtir.Sym(id="nvertices", type=SIZE_TYPE), @@ -1339,20 +1435,19 @@ def test_gtir_reduce_dot_product(): declarations=[], body=[ gtir.SetAt( - expr=im.call( - im.call("as_fieldop")( - im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) + expr=im.as_fieldop( + im.lambda_("it")( + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) + ), + vertex_domain, + )( + im.op_as_fieldop(im.map_("plus"), vertex_domain)( + im.op_as_fieldop(im.map_("multiplies"), vertex_domain)( + im.as_fieldop_neighbors("V2E", "edges", vertex_domain), + "v2e_field", ), - vertex_domain, + im.op_as_fieldop("make_const_list", vertex_domain)(1.0), ) - )( - im.op_as_fieldop("multiplies", vertex_domain)( - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - im.as_fieldop_neighbors("V2E", "edges", vertex_domain), - ), ), domain=vertex_domain, target=gtir.SymRef(id="vertices"), @@ -1360,24 +1455,18 @@ def test_gtir_reduce_dot_product(): ], ) - connectivity_V2E = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] - assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) - - e = np.random.rand(SIMPLE_MESH.num_edges) - v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) - v_ref = [ - reduce(lambda x, y: x + y, e[v2e_neighbors] * e[v2e_neighbors], init_value) - for v2e_neighbors in connectivity_V2E.table - ] + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) sdfg( + v2e_field, e, v, - connectivity_V2E=connectivity_V2E.table, - **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), + connectivity_V2E=connectivity_V2E.ndarray, + **make_mesh_symbols(SKIP_VALUE_MESH), + __v2e_field_0_range_1=SKIP_VALUE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], + __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) @@ -1390,6 +1479,7 @@ def test_gtir_reduce_with_cond_neighbors(): function_definitions=[], params=[ gtir.Sym(id="pred", type=ts.ScalarType(ts.ScalarKind.BOOL)), + gtir.Sym(id="v2e_field", type=V2E_FTYPE), gtir.Sym(id="edges", type=EFTYPE), gtir.Sym(id="vertices", type=VFTYPE), gtir.Sym(id="nvertices", type=SIZE_TYPE), @@ -1399,15 +1489,13 @@ def test_gtir_reduce_with_cond_neighbors(): gtir.SetAt( expr=im.as_fieldop( im.lambda_("it")( - im.call(im.call("reduce")("plus", im.literal_from_value(init_value)))( - im.deref("it") - ) + im.reduce("plus", im.literal_from_value(init_value))(im.deref("it")) ), vertex_domain, )( im.if_( - gtir.SymRef(id="pred"), - im.as_fieldop_neighbors("V2E_FULL", "edges", vertex_domain), + "pred", + "v2e_field", im.as_fieldop_neighbors("V2E", "edges", vertex_domain), ) ), @@ -1417,53 +1505,134 @@ def test_gtir_reduce_with_cond_neighbors(): ], ) - connectivity_V2E_simple = SIMPLE_MESH_OFFSET_PROVIDER["V2E"] - assert isinstance(connectivity_V2E_simple, gtx_common.NeighborTable) - connectivity_V2E_skip_values = copy.deepcopy(SKIP_VALUE_MESH_OFFSET_PROVIDER["V2E"]) - assert isinstance(connectivity_V2E_skip_values, gtx_common.NeighborTable) - assert SKIP_VALUE_MESH.num_vertices <= SIMPLE_MESH.num_vertices - connectivity_V2E_skip_values.table = np.concatenate( - ( - connectivity_V2E_skip_values.table[:, 0 : connectivity_V2E_simple.max_neighbors], - connectivity_V2E_simple.table[SKIP_VALUE_MESH.num_vertices :, :], - ), - axis=0, - ) - connectivity_V2E_skip_values.max_neighbors = connectivity_V2E_simple.max_neighbors + connectivity_V2E = SKIP_VALUE_MESH.offset_provider["V2E"] + assert isinstance(connectivity_V2E, gtx_common.NeighborTable) - e = np.random.rand(SIMPLE_MESH.num_edges) + v2e_field = np.random.rand(*connectivity_V2E.shape) + e = np.random.rand(SKIP_VALUE_MESH.num_edges) - for use_full in [False, True]: - sdfg = dace_backend.build_sdfg_from_gtir( - testee, - SIMPLE_MESH_OFFSET_PROVIDER | {"V2E_FULL": connectivity_V2E_skip_values}, - ) + for use_sparse in [False, True]: + sdfg = build_dace_sdfg(testee, SKIP_VALUE_MESH.offset_provider_type) - v = np.empty(SIMPLE_MESH.num_vertices, dtype=e.dtype) + v = np.empty(SKIP_VALUE_MESH.num_vertices, dtype=e.dtype) v_ref = [ functools.reduce( - lambda x, y: x + y, [e[i] if i != -1 else 0.0 for i in v2e_neighbors], init_value + lambda x, y: x + y, + [ + v if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 + for i, v in zip(v2e_neighbors, v2e_values, strict=True) + ], + init_value, ) - for v2e_neighbors in ( - connectivity_V2E_simple.table if use_full else connectivity_V2E_skip_values.table + if use_sparse + else functools.reduce( + lambda x, y: x + y, + [e[i] if i != gtx_common._DEFAULT_SKIP_VALUE else 0.0 for i in v2e_neighbors], + init_value, ) + for v2e_neighbors, v2e_values in zip(connectivity_V2E.ndarray, v2e_field, strict=True) ] sdfg( - np.bool_(use_full), + np.bool_(use_sparse), + v2e_field, e, v, - connectivity_V2E=connectivity_V2E_skip_values.table, - connectivity_V2E_FULL=connectivity_V2E_simple.table, + connectivity_V2E=connectivity_V2E.ndarray, **FSYMBOLS, - **make_mesh_symbols(SIMPLE_MESH), - __connectivity_V2E_FULL_size_0=SIMPLE_MESH.num_edges, - __connectivity_V2E_FULL_size_1=connectivity_V2E_skip_values.max_neighbors, - __connectivity_V2E_FULL_stride_0=connectivity_V2E_skip_values.max_neighbors, - __connectivity_V2E_FULL_stride_1=1, + **make_mesh_symbols(SKIP_VALUE_MESH), + __v2e_field_0_range_1=SKIP_VALUE_MESH.num_vertices, + __v2e_field_size_1=connectivity_V2E.shape[1], + __v2e_field_stride_0=connectivity_V2E.shape[1], + __v2e_field_stride_1=1, ) assert np.allclose(v, v_ref) +def test_gtir_symbolic_domain(): + MARGIN = 2 + assert MARGIN < N + OFFSET = 1000 * 1000 * 1000 + domain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + left_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.minus(MARGIN, OFFSET), im.minus(im.minus("size", MARGIN), OFFSET))}, + ) + right_domain = im.domain( + gtx_common.GridType.CARTESIAN, + ranges={IDim: (im.plus(MARGIN, OFFSET), im.plus(im.plus("size", MARGIN), OFFSET))}, + ) + shift_left_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, OFFSET)("a"))) + shift_right_stencil = im.lambda_("a")(im.deref(im.shift(IDim.value, -OFFSET)("a"))) + testee = gtir.Program( + id="symbolic_domain", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let( + "xᐞ1", + im.op_as_fieldop("multiplies", left_domain)( + 4.0, + im.as_fieldop( + shift_left_stencil, + left_domain, + )("x"), + ), + )( + im.let( + "xᐞ2", + im.op_as_fieldop("multiplies", right_domain)( + 3.0, + im.as_fieldop( + shift_right_stencil, + right_domain, + )("x"), + ), + )( + im.let( + "xᐞ3", + im.as_fieldop( + shift_right_stencil, + domain, + )("xᐞ1"), + )( + im.let( + "xᐞ4", + im.as_fieldop( + shift_left_stencil, + domain, + )("xᐞ2"), + )( + im.let("xᐞ5", im.op_as_fieldop("plus", domain)("xᐞ3", "xᐞ4"))( + im.op_as_fieldop("plus", domain)("xᐞ5", "x") + ) + ) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand(N) + b = np.random.rand(N) + ref = np.concatenate((b[0:MARGIN], a[MARGIN : N - MARGIN] * 8, b[N - MARGIN : N])) + + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + + sdfg(a, b, **FSYMBOLS) + assert np.allclose(b, ref) + + def test_gtir_let_lambda(): domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) subdomain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) @@ -1478,20 +1647,20 @@ def test_gtir_let_lambda(): declarations=[], body=[ gtir.SetAt( - # `x1` is a let-lambda expression representing `x * 3` - # `x2` is a let-lambda expression representing `x * 4` - # - note that the let-symbol `x2` is used twice, in a nested let-expression, to test aliasing of the symbol - # `x3` is a let-lambda expression simply accessing `x` field symref - expr=im.let("x1", im.op_as_fieldop("multiplies", subdomain)(3.0, "x"))( + # `xᐞ1` is a let-lambda expression representing `x * 3` + # `xᐞ2` is a let-lambda expression representing `x * 4` + # - note that the let-symbol `xᐞ2` is used twice, in a nested let-expression, to test aliasing of the symbol + # `xᐞ3` is a let-lambda expression simply accessing `x` field symref + expr=im.let("xᐞ1", im.op_as_fieldop("multiplies", subdomain)(3.0, "x"))( im.let( - "x2", - im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.op_as_fieldop("plus", subdomain)("x2", "x2") + "xᐞ2", + im.let("xᐞ2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( + im.op_as_fieldop("plus", subdomain)("xᐞ2", "xᐞ2") ), )( - im.let("x3", "x")( + im.let("xᐞ3", "x")( im.op_as_fieldop("plus", subdomain)( - "x1", im.op_as_fieldop("plus", subdomain)("x2", "x3") + "xᐞ1", im.op_as_fieldop("plus", subdomain)("xᐞ2", "xᐞ3") ) ) ) @@ -1506,20 +1675,55 @@ def test_gtir_let_lambda(): b = np.random.rand(N) ref = np.concatenate((b[0:1], a[1 : N - 1] * 8, b[N - 1 : N])) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) sdfg(a, b, **FSYMBOLS) assert np.allclose(b, ref) +def test_gtir_let_lambda_scalar_expression(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + testee = gtir.Program( + id="let_lambda_scalar_expression", + function_definitions=[], + params=[ + gtir.Sym(id="a", type=IFTYPE.dtype), + gtir.Sym(id="b", type=IFTYPE.dtype), + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("tmp", im.multiplies_("a", "b"))( + im.op_as_fieldop("multiplies", domain)("x", im.multiplies_("tmp", "tmp")) + ), + domain=domain, + target=gtir.SymRef(id="y"), + ) + ], + ) + + a = np.random.rand() + b = np.random.rand() + c = np.random.rand(N) + d = np.empty_like(c) + + sdfg = build_dace_sdfg(testee, {}) + + sdfg(a, b, c, d, **FSYMBOLS) + assert np.allclose(d, (a * a * b * b * c)) + + def test_gtir_let_lambda_with_connectivity(): C2E_neighbor_idx = 1 C2V_neighbor_idx = 2 cell_domain = im.domain(gtx_common.GridType.UNSTRUCTURED, ranges={Cell: (0, "ncells")}) - connectivity_C2E = SIMPLE_MESH_OFFSET_PROVIDER["C2E"] + connectivity_C2E = SIMPLE_MESH.offset_provider["C2E"] assert isinstance(connectivity_C2E, gtx_common.NeighborTable) - connectivity_C2V = SIMPLE_MESH_OFFSET_PROVIDER["C2V"] + connectivity_C2V = SIMPLE_MESH.offset_provider["C2V"] assert isinstance(connectivity_C2V, gtx_common.NeighborTable) testee = gtir.Program( @@ -1555,22 +1759,22 @@ def test_gtir_let_lambda_with_connectivity(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, SIMPLE_MESH_OFFSET_PROVIDER) + sdfg = build_dace_sdfg(testee, SIMPLE_MESH.offset_provider_type) e = np.random.rand(SIMPLE_MESH.num_edges) v = np.random.rand(SIMPLE_MESH.num_vertices) c = np.empty(SIMPLE_MESH.num_cells) ref = ( - e[connectivity_C2E.table[:, C2E_neighbor_idx]] - + v[connectivity_C2V.table[:, C2V_neighbor_idx]] + e[connectivity_C2E.ndarray[:, C2E_neighbor_idx]] + + v[connectivity_C2V.ndarray[:, C2V_neighbor_idx]] ) sdfg( cells=c, edges=e, vertices=v, - connectivity_C2E=connectivity_C2E.table, - connectivity_C2V=connectivity_C2V.table, + connectivity_C2E=connectivity_C2E.ndarray, + connectivity_C2V=connectivity_C2V.ndarray, **FSYMBOLS, **make_mesh_symbols(SIMPLE_MESH), ) @@ -1593,11 +1797,7 @@ def test_gtir_let_lambda_with_cond(): gtir.SetAt( expr=im.let("x1", "x")( im.let("x2", im.op_as_fieldop("multiplies", domain)(2.0, "x"))( - im.if_( - gtir.SymRef(id="pred"), - im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x1"), - im.as_fieldop(im.lambda_("a")(im.deref("a")), domain)("x2"), - ) + im.if_("pred", "x1", "x2") ) ), domain=domain, @@ -1606,7 +1806,7 @@ def test_gtir_let_lambda_with_cond(): ], ) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) + sdfg = build_dace_sdfg(testee, {}) a = np.random.rand(N) for s in [False, True]: @@ -1615,10 +1815,10 @@ def test_gtir_let_lambda_with_cond(): assert np.allclose(b, a if s else a * 2) -def test_gtir_let_lambda_with_tuple(): - domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) +def test_gtir_let_lambda_with_tuple1(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (1, im.minus("size", 1))}) testee = gtir.Program( - id="let_lambda_with_tuple", + id="let_lambda_with_tuple1", function_definitions=[], params=[ gtir.Sym(id="x", type=IFTYPE), @@ -1644,19 +1844,72 @@ def test_gtir_let_lambda_with_tuple(): a = np.random.rand(N) b = np.random.rand(N) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + + z_fields = (np.zeros_like(a), np.zeros_like(a)) + a_ref = np.concatenate((z_fields[0][:1], a[1 : N - 1], z_fields[0][N - 1 :])) + b_ref = np.concatenate((z_fields[1][:1], b[1 : N - 1], z_fields[1][N - 1 :])) + + tuple_symbols = { + "__z_0_0_range_1": N, + "__z_0_stride_0": 1, + "__z_1_0_range_1": N, + "__z_1_stride_0": 1, + } - z_fields = (np.empty_like(a), np.empty_like(a)) - z_symbols = dict( - __z_0_size_0=FSYMBOLS["__x_size_0"], - __z_0_stride_0=FSYMBOLS["__x_stride_0"], - __z_1_size_0=FSYMBOLS["__x_size_0"], - __z_1_stride_0=FSYMBOLS["__x_stride_0"], + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) + assert np.allclose(z_fields[0], a_ref) + assert np.allclose(z_fields[1], b_ref) + + +def test_gtir_let_lambda_with_tuple2(): + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + val = np.random.rand() + testee = gtir.Program( + id="let_lambda_with_tuple2", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=IFTYPE), + gtir.Sym(id="y", type=IFTYPE), + gtir.Sym(id="z", type=ts.TupleType(types=[IFTYPE, IFTYPE, IFTYPE])), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("s", im.as_fieldop("deref", domain)(val))( + im.let("t", im.make_tuple("x", "y"))( + im.let("p", im.op_as_fieldop("plus", domain)("x", "y"))( + im.make_tuple("p", "s", im.tuple_get(1, "t")) + ) + ) + ), + domain=domain, + target=gtir.SymRef(id="z"), + ) + ], ) - sdfg(a, b, *z_fields, **FSYMBOLS, **z_symbols) - assert np.allclose(z_fields[0], a) - assert np.allclose(z_fields[1], b) + a = np.random.rand(N) + b = np.random.rand(N) + + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + + z_fields = (np.empty_like(a), np.empty_like(a), np.empty_like(a)) + + tuple_symbols = { + "__z_0_0_range_1": N, + "__z_0_stride_0": 1, + "__z_1_0_range_1": N, + "__z_1_stride_0": 1, + "__z_2_0_range_1": N, + "__z_2_stride_0": 1, + } + + sdfg(a, b, *z_fields, **FSYMBOLS, **tuple_symbols) + assert np.allclose(z_fields[0], a + b) + assert np.allclose(z_fields[1], val) + assert np.allclose(z_fields[2], b) def test_gtir_if_scalars(): @@ -1677,15 +1930,15 @@ def test_gtir_if_scalars(): body=[ gtir.SetAt( expr=im.let("f", im.tuple_get(0, "x"))( - im.let("y", im.tuple_get(1, "x"))( - im.let("y_0", im.tuple_get(0, "y"))( - im.let("y_1", im.tuple_get(1, "y"))( + im.let("g", im.tuple_get(1, "x"))( + im.let("y_0", im.tuple_get(0, "g"))( + im.let("y_1", im.tuple_get(1, "g"))( im.op_as_fieldop("plus", domain)( "f", im.if_( "pred", - im.call("cast_")("y_0", "float64"), - im.call("cast_")("y_1", "float64"), + im.cast_("y_0", "float64"), + im.cast_("y_1", "float64"), ), ) ) @@ -1703,14 +1956,19 @@ def test_gtir_if_scalars(): d1 = np.random.randint(0, 1000) d2 = np.random.randint(0, 1000) - sdfg = dace_backend.build_sdfg_from_gtir(testee, {}) - x_symbols = dict( - __x_0_size_0=FSYMBOLS["__x_size_0"], - __x_0_stride_0=FSYMBOLS["__x_stride_0"], - ) + sdfg = build_dace_sdfg(testee, {}) + + tuple_symbols = { + "__x_0_0_range_1": N, + "__x_0_stride_0": 1, + "__x_1_0_0_range_1": N, + "__x_1_0_stride_0": 1, + "__x_1_1_0_range_1": N, + "__x_1_1_stride_0": 1, + } for s in [False, True]: - sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS, **x_symbols) + sdfg(x_0=a, x_1_0=d1, x_1_1=d2, z=b, pred=np.bool_(s), **FSYMBOLS, **tuple_symbols) assert np.allclose(b, (a + d1 if s else a + d2)) @@ -1741,7 +1999,53 @@ def test_gtir_if_values(): b = np.random.rand(N) c = np.empty_like(a) - sdfg = dace_backend.build_sdfg_from_gtir(testee, CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) sdfg(a, b, c, **FSYMBOLS) assert np.allclose(c, np.where(a < b, a, b)) + + +def test_gtir_index(): + MARGIN = 2 + assert MARGIN < N + domain = im.domain(gtx_common.GridType.CARTESIAN, ranges={IDim: (0, "size")}) + subdomain = im.domain( + gtx_common.GridType.CARTESIAN, ranges={IDim: (MARGIN, im.minus("size", MARGIN))} + ) + + testee = gtir.Program( + id="gtir_index", + function_definitions=[], + params=[ + gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), + gtir.Sym(id="size", type=SIZE_TYPE), + ], + declarations=[], + body=[ + gtir.SetAt( + expr=im.let("i", im.index(IDim))( + im.op_as_fieldop("plus", domain)( + "i", + im.as_fieldop( + im.lambda_("a")(im.deref(im.shift(IDim.value, 1)("a"))), subdomain + )("i"), + ) + ), + domain=subdomain, + target=gtir.SymRef(id="x"), + ) + ], + ) + + v = np.zeros(N, dtype=np.int32) + + # we need to run domain inference in order to add the domain annex information to the index node. + testee = infer_domain.infer_program(testee, offset_provider=CARTESIAN_OFFSETS) + sdfg = build_dace_sdfg(testee, CARTESIAN_OFFSETS) + + ref = np.concatenate( + (v[:MARGIN], np.arange(MARGIN, N - MARGIN, dtype=np.int32), v[N - MARGIN :]) + ) + + sdfg(v, **FSYMBOLS) + np.allclose(v, ref) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py index 6c3b1060b6..a576665ee3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/__init__.py @@ -9,4 +9,5 @@ import pytest -pytestmark = [pytest.mark.requires_dace, pytest.mark.usefixtures("set_dace_settings")] +#: Attribute defining package-level marks used by a custom pytest hook. +package_pytestmarks = [pytest.mark.usefixtures("common_dace_config")] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py index e85ef6ad1f..c3455c37cc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py @@ -11,8 +11,8 @@ import pytest -@pytest.fixture() -def set_dace_settings() -> Generator[None, None, None]: +@pytest.fixture(autouse=True) +def common_dace_config() -> Generator[None, None, None]: """Sets the common DaCe settings for the tests. The function will modify the following settings: @@ -24,6 +24,6 @@ def set_dace_settings() -> Generator[None, None, None]: import dace with dace.config.temporary_config(): - dace.Config.set("optimizer", "match_exception", value=False) + dace.Config.set("optimizer", "match_exception", value=True) dace.Config.set("compiler", "allow_view_arguments", value=True) yield diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py new file mode 100644 index 0000000000..8177ea9ae7 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py @@ -0,0 +1,142 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + + +def test_constant_substitution(): + sdfg, nsdfg = _make_sdfg() + + # Ensure that `One` is present. + assert len(sdfg.symbols) == 2 + assert len(nsdfg.sdfg.symbols) == 2 + assert len(nsdfg.symbol_mapping) == 2 + assert "One" in sdfg.symbols + assert "One" in nsdfg.sdfg.symbols + assert "One" in nsdfg.symbol_mapping + assert "One" == str(nsdfg.symbol_mapping["One"]) + assert all(str(desc.strides[1]) == "One" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[1]) == "One" for desc in nsdfg.sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" in sdfg.used_symbols(True) + + # Now replace `One` with 1 + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, {"One": 1}) + + assert len(sdfg.symbols) == 1 + assert len(nsdfg.sdfg.symbols) == 1 + assert len(nsdfg.symbol_mapping) == 1 + assert "One" not in sdfg.symbols + assert "One" not in nsdfg.sdfg.symbols + assert "One" not in nsdfg.symbol_mapping + assert all(desc.strides[1] == 1 and len(desc.strides) == 2 for desc in sdfg.arrays.values()) + assert all( + desc.strides[1] == 1 and len(desc.strides) == 2 for desc in nsdfg.sdfg.arrays.values() + ) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" not in sdfg.used_symbols(True) + + +def _make_nested_sdfg() -> dace.SDFG: + sdfg = dace.SDFG("nested") + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + One = dace.symbol(sdfg.add_symbol("One", dace.int32)) + for name in "ABC": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, One), + transient=False, + ) + state = sdfg.add_state(is_start_block=True) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("B[__i0, __i1]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("C[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_sdfg() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: + sdfg = dace.SDFG("outer_sdfg") + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + One = dace.symbol(sdfg.add_symbol("One", dace.int32)) + for name in "ABCD": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, One), + transient=False, + ) + sdfg.arrays["C"].transient = True + + first_state: dace.SDFGState = sdfg.add_state(is_start_block=True) + nested_sdfg: dace.SDFG = _make_nested_sdfg() + nsdfg = first_state.add_nested_sdfg( + nested_sdfg, + parent=sdfg, + inputs={"A", "B"}, + outputs={"C"}, + symbol_mapping={"One": "One", "N": "N"}, + ) + first_state.add_edge( + first_state.add_access("A"), + None, + nsdfg, + "A", + dace.Memlet("A[0:N, 0:N]"), + ) + first_state.add_edge( + first_state.add_access("B"), + None, + nsdfg, + "B", + dace.Memlet("B[0:N, 0:N]"), + ) + first_state.add_edge( + nsdfg, + "C", + first_state.add_access("C"), + None, + dace.Memlet("C[0:N, 0:N]"), + ) + + second_state: dace.SDFGState = sdfg.add_state_after(first_state) + second_state.add_mapped_tasklet( + "outer_computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("C[__i0, __i1]"), + }, + code="__out = __in0 * __in1", + outputs={"__out": dace.Memlet("D[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg, nsdfg diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py new file mode 100644 index 0000000000..8251352e49 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py @@ -0,0 +1,553 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pytest +import copy +import numpy as np + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + + +def _make_simple_linear_chain_sdfg() -> dace.SDFG: + """Creates a simple linear chain. + + All intermediates have the same size. + """ + sdfg = dace.SDFG(util.unique_name("simple_linear_chain_sdfg")) + + for name in ["a", "b", "c", "d", "e"]: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["a"].transient = False + sdfg.arrays["e"].transient = False + + state = sdfg.add_state(is_start_block=True) + b, c, d, e = (state.add_access(name) for name in "bcde") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in", + outputs={"__out": dace.Memlet("b[__i]")}, + output_nodes={b}, + external_edges=True, + ) + state.add_nedge(b, c, dace.Memlet("b[0:10] -> [0:10]")) + state.add_nedge(c, d, dace.Memlet("c[0:10] -> [0:10]")) + state.add_nedge(d, e, dace.Memlet("d[0:10] -> [0:10]")) + sdfg.validate() + return sdfg + + +def _make_diff_sizes_linear_chain_sdfg() -> ( + tuple[dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.Tasklet] +): + """Creates a linear chain of copies. + + The main differences compared to the SDFG made by `_make_simple_linear_chain_sdfg()` + is that here the intermediate arrays have different sizes, that become bigger. + It essentially checks the adjusting of the memlet subset during copying. + + The function returns a tuple with the following content. + - The SDFG that was generated. + - The SDFG state. + - The AccessNode that is used as final output, refers to `e`. + - The Tasklet that is within the Map. + """ + sdfg = dace.SDFG(util.unique_name("diff_size_linear_chain_sdfg")) + + array_size_increment = 10 + array_size = 10 + for name in ["a", "b", "c", "d", "e"]: + sdfg.add_array( + name, + shape=(array_size,), + dtype=dace.float64, + transient=True, + ) + array_size += array_size_increment + sdfg.arrays["a"].transient = False + sdfg.arrays["e"].transient = False + assert sdfg.arrays["e"].shape[0] == 50 + + state = sdfg.add_state(is_start_block=True) + b, c, d, e = (state.add_access(name) for name in "bcde") + + tasklet, _, _ = state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in", + outputs={"__out": dace.Memlet("b[__i + 3]")}, + output_nodes={b}, + external_edges=True, + ) + state.add_nedge(b, c, dace.Memlet("b[0:20] -> [10:30]")) + state.add_nedge(c, d, dace.Memlet("c[0:30] -> [2:32]")) + state.add_nedge(d, e, dace.Memlet("d[0:40] -> [3:43]")) + sdfg.validate() + return sdfg, state, e, tasklet + + +def _make_multi_stage_reduction_sdfg() -> dace.SDFG: + """Creates an SDFG that has a two stage copy reduction.""" + sdfg = dace.SDFG(util.unique_name("multi_stage_reduction")) + state: dace.SDFGState = sdfg.add_state(is_start_block=True) + + # This is the size of the arrays, if not mentioned here, then its size is 10. + array_sizes: dict[str, int] = {"d": 20, "f": 40, "o1": 40} + def_array_size = 10 + + array_names: list[str] = ["i1", "i2", "i3", "i4", "a", "b", "c", "d", "e", "f", "o1"] + for name in array_names: + sdfg.add_array( + name, + shape=(array_sizes.get(name, def_array_size),), + dtype=dace.float64, + transient=(len(name) == 1), + ) + + a, b, c, d, e, f = (state.add_access(name) for name in "abcdef") + + state.add_mapped_tasklet( + "comp_i1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("i1[__i]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("a[__i]")}, + output_nodes={a}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp_i2", + map_ranges={"__j": "0:10"}, + inputs={"__in": dace.Memlet("i2[__j]")}, + code="__out = __in + 2.", + outputs={"__out": dace.Memlet("b[__j]")}, + output_nodes={b}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp_i3", + map_ranges={"__k": "0:10"}, + inputs={"__in": dace.Memlet("i3[__k]")}, + code="__out = __in + 3.", + outputs={"__out": dace.Memlet("c[__k]")}, + output_nodes={c}, + external_edges=True, + ) + + state.add_nedge(state.add_access("i4"), e, dace.Memlet("i4[0:10] -> [0:10]")) + + state.add_nedge(b, d, dace.Memlet("b[0:10] -> [0:10]")) + state.add_nedge(c, d, dace.Memlet("c[0:10] -> [10:20]")) + + state.add_nedge(a, f, dace.Memlet("a[0:10] -> [0:10]")) + state.add_nedge(d, f, dace.Memlet("d[0:20] -> [10:30]")) + state.add_nedge(e, f, dace.Memlet("e[0:10] -> [30:40]")) + + state.add_nedge(f, state.add_access("o1"), dace.Memlet("f[0:40] -> [0:40]")) + + sdfg.validate() + return sdfg + + +def _make_not_fully_copied() -> dace.SDFG: + """ + Make an SDFG where two intermediate array is not fully copied. Thus the + transformation only applies once, when `d` is removed. + """ + sdfg = dace.SDFG(util.unique_name("not_fully_copied_intermediate")) + + for name in ["a", "b", "c", "d", "e"]: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["a"].transient = False + sdfg.arrays["e"].transient = False + + state = sdfg.add_state(is_start_block=True) + b, c, d, e = (state.add_access(name) for name in "bcde") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in", + outputs={"__out": dace.Memlet("b[__i]")}, + output_nodes={b}, + external_edges=True, + ) + state.add_nedge(b, c, dace.Memlet("b[2:10] -> [0:8]")) + state.add_nedge(c, d, dace.Memlet("c[0:8] -> [0:8]")) + state.add_nedge(d, e, dace.Memlet("d[0:10] -> [0:10]")) + sdfg.validate() + return sdfg + + +def _make_possible_cyclic_sdfg() -> dace.SDFG: + """ + If the transformation would remove `a1` then it would create a cycle. Thus the + transformation should not apply. + """ + sdfg = dace.SDFG(util.unique_name("possible_cyclic_sdfg")) + + anames = ["i1", "a1", "a2", "o1"] + for name in anames: + sdfg.add_array( + name, + shape=((30,) if name in ["o1", "a2"] else (10,)), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["a1"].transient = True + sdfg.arrays["a2"].transient = True + + state = sdfg.add_state(is_start_block=True) + i1, a1, a2, o1 = (state.add_access(name) for name in anames) + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("i1[__i]")}, + code="__out = __in + 1", + outputs={"__out": dace.Memlet("a2[__i]")}, + input_nodes={i1}, + output_nodes={a2}, + external_edges=True, + ) + + state.add_nedge(i1, a1, dace.Memlet("i1[0:10] -> [0:10]")) + state.add_nedge(a1, a2, dace.Memlet("a1[0:10] -> [10:20]")) + + state.add_mapped_tasklet( + "comp2", + map_ranges={"__j": "0:10"}, + inputs={"__in": dace.Memlet("a1[__j]")}, + code="__out = __in + 2.0", + outputs={"__out": dace.Memlet("a2[__j + 20]")}, + input_nodes={a1}, + output_nodes={a2}, + external_edges=True, + ) + + state.add_nedge(a2, o1, dace.Memlet("a2[0:30] -> [0:30]")) + + sdfg.validate() + return sdfg + + +def _make_linear_chain_with_nested_sdfg_sdfg() -> tuple[dace.SDFG, dace.SDFG]: + """ + The structure is very similar than `_make_diff_sizes_linear_chain_sdfg()`, with + the main difference that the Map is inside a NestedSDFG. + """ + + def make_inner_sdfg() -> dace.SDFG: + inner_sdfg = dace.SDFG("inner_sdfg") + inner_state = inner_sdfg.add_state(is_start_block=True) + for name in ["i0", "o0"]: + inner_sdfg.add_array(name=name, shape=(10, 10), dtype=dace.float64, transient=False) + inner_state.add_mapped_tasklet( + "inner_comp", + map_ranges={ + "__i0": "0:10", + "__i1": "0:10", + }, + inputs={"__in": dace.Memlet("i0[__i0, __i1]")}, + code="__out = __in + 10.", + outputs={"__out": dace.Memlet("o0[__i0, __i1]")}, + external_edges=True, + ) + inner_sdfg.validate() + return inner_sdfg + + inner_sdfg = make_inner_sdfg() + + sdfg = dace.SDFG(util.unique_name("linear_chain_with_nested_sdfg")) + state = sdfg.add_state(is_start_block=True) + + array_size_increment = 10 + array_size = 10 + for name in ["a", "b", "c", "d", "e"]: + sdfg.add_array( + name, + shape=(array_size, array_size), + dtype=dace.float64, + transient=True, + ) + if name != "a": + array_size += array_size_increment + assert sdfg.arrays["a"].shape == sdfg.arrays["b"].shape + assert sdfg.arrays["e"].shape == (40, 40) + sdfg.arrays["a"].transient = False + sdfg.arrays["e"].transient = False + a, b, c, d, e = (state.add_access(name) for name in "abcde") + + nsdfg = state.add_nested_sdfg( + inner_sdfg, + parent=sdfg, + inputs={"i0"}, + outputs={"o0"}, + symbol_mapping={}, + ) + + state.add_edge(a, None, nsdfg, "i0", sdfg.make_array_memlet("a")) + state.add_edge(nsdfg, "o0", b, None, sdfg.make_array_memlet("b")) + + state.add_nedge(b, c, dace.Memlet("b[0:10, 0:10] -> [5:15, 3:13]")) + state.add_nedge(c, d, dace.Memlet("c[0:20, 0:20] -> [2:22, 6:26]")) + state.add_nedge(d, e, dace.Memlet("d[0:30, 0:30] -> [1:31, 8:38]")) + sdfg.validate() + return sdfg, inner_sdfg + + +def _make_a1_has_output_sdfg() -> dace.SDFG: + """Here `a1` has an output degree of 2, one to `a2` and one to another output.""" + sdfg = dace.SDFG(util.unique_name("a1_has_an_additional_output_sdfg")) + state = sdfg.add_state(is_start_block=True) + + # All other arrays have a size of 10. + anames = ["i1", "i2", "i3", "a1", "a2", "o1", "o2"] + def_array_size = 10 + asizes = {"a1": 20, "a2": 30, "o2": 30} + for name in anames: + sdfg.add_array( + name=name, + shape=(asizes.get(name, def_array_size),), + dtype=dace.float64, + transient=name[0] == "a", + ) + a1, a2 = (state.add_access("a1"), state.add_access("a2")) + + state.add_nedge(state.add_access("i1"), a1, dace.Memlet("i1[0:10] -> [0:10]")) + state.add_nedge(state.add_access("i2"), a1, dace.Memlet("i2[0:10] -> [10:20]")) + + state.add_nedge(a1, state.add_access("o1"), dace.Memlet("a1[5:15] -> [0:10]")) + + state.add_nedge(state.add_access("i3"), a2, dace.Memlet("i3[0:10] -> [0:10]")) + state.add_nedge(a1, a2, dace.Memlet("a1[0:20] -> [10:30]")) + + state.add_nedge(a2, state.add_access("o2"), dace.Memlet("a2[0:30] -> [0:30]")) + + sdfg.validate() + return sdfg + + +def test_simple_linear_chain(): + sdfg = _make_simple_linear_chain_sdfg() + + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 2 + assert not any(ac.desc(sdfg).transient for ac in acnodes) + assert nb_applies == 3 + + +def test_diff_size_linear_chain(): + sdfg, state, output, tasklet = _make_diff_sizes_linear_chain_sdfg() + + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 2 + assert not any(ac.desc(sdfg).transient for ac in acnodes) + assert nb_applies == 3 + assert output in acnodes + assert state.in_degree(output) == 1 + assert state.out_degree(output) == 0 + + # Look if the subsets were correctly adapted, for that we look at the output + # AccessNode and the tasklet inside the map. + output_memlet: dace.Memlet = next(iter(state.in_edges(output))).data + assert output_memlet.dst_subset.min_element()[0] == 18 + assert output_memlet.dst_subset.max_element()[0] == 27 + + tasklet_memlet: dace.Memlet = next(iter(state.out_edges(tasklet))).data + assert str(tasklet_memlet.subset[0][0] - 18).strip() == "__i" + + +def test_multi_stage_reduction(): + sdfg = _make_multi_stage_reduction_sdfg() + + # Make the input + ref = { + "i1": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i2": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i3": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i4": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "o1": np.zeros(40, dtype=np.float64), + } + res = copy.deepcopy(ref) + + # Generate the reference solution. + csdfg_ref = sdfg.compile() + csdfg_ref(**ref) + + # Apply the transformation. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Run the processed SDFG + csdfg_res = sdfg.compile() + csdfg_res(**res) + + # Perform all the checks. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 5 + assert not any(ac.desc(sdfg).transient for ac in acnodes) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()) + + +def test_not_fully_copied(): + sdfg = _make_not_fully_copied() + + # Apply the transformation. + # It will only remove `d` all the others are retained, because they are not read + # correctly, i.e. fully. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Perform all the checks. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 4 + assert nb_applies == 1 + assert "d" not in acnodes + + +def test_possible_cyclic_sdfg(): + sdfg = _make_possible_cyclic_sdfg() + + # Apply the transformation. + # It will not remove `a1`, because it it would and replace it with `a2` then + # the resulting SDFG is cyclic. It will, however, replace `a2` with `o1`. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Perform all the checks. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 3 + assert nb_applies == 1 + assert "o1" not in acnodes + + +def test_a1_additional_output(): + sdfg = _make_a1_has_output_sdfg() + + # Make the input + ref = { + "i1": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i2": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "i3": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "o1": np.zeros(10, dtype=np.float64), + "o2": np.zeros(30, dtype=np.float64), + } + res = copy.deepcopy(ref) + + csdfg_ref = sdfg.compile() + csdfg_ref(**ref) + + # Apply the transformation. + # The transformation removes `a1` and `a2`. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Perform the tests. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert len(acnodes) == 5 + assert nb_applies == 2 + assert not any(acnode.data.startswith("a") for acnode in acnodes) + + # Now run the SDFG, which is essentially to check if the subsets were handled + # correctly. This is especially important for `o1` which is composed of both + # `i1` and `i2`. + csdfg_res = sdfg.compile() + csdfg_res(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()) + + +def test_linear_chain_with_nested_sdfg(): + sdfg, inner_sdfg = _make_linear_chain_with_nested_sdfg_sdfg() + + # Ensure that the SDFG was constructed in the correct way. + assert inner_sdfg.arrays["i0"].strides == sdfg.arrays["a"].strides + assert inner_sdfg.arrays["o0"].strides == sdfg.arrays["b"].strides + assert inner_sdfg.arrays["i0"].shape == inner_sdfg.arrays["o0"].shape + assert inner_sdfg.arrays["i0"].shape == sdfg.arrays["a"].shape + + def ref_comp(a, e): + def inner_ref(i0, o0): + for i in range(10): + for j in range(10): + o0[i, j] = i0[i, j] + 10 + + b, c, d = np.zeros_like(a), np.zeros((20, 20)), np.zeros((30, 30)) + inner_ref(i0=a, o0=b) + c[5:15, 3:13] = b + d[2:22, 6:26] = c + e[1:31, 8:38] = d + + # Make the input + ref = { + "a": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + "e": np.zeros((40, 40), dtype=np.float64), + } + res = copy.deepcopy(ref) + + ref_comp(**ref) + + # Apply the transformation. + # It should remove all non transient arrays. + nb_applies = gtx_transformations.gt_remove_copy_chain(sdfg, validate_all=True) + + # Perform the tests. + acnodes: list[dace_nodes.AccessNode] = util.count_nodes( + sdfg, dace_nodes.AccessNode, return_nodes=True + ) + assert {ac.data for ac in acnodes} == {"a", "e"} + assert util.count_nodes(sdfg, dace_nodes.NestedSDFG) == 1 + + # The shapes should be the same as before. + assert inner_sdfg.arrays["i0"].shape == inner_sdfg.arrays["o0"].shape + assert inner_sdfg.arrays["i0"].shape == sdfg.arrays["a"].shape + + # The strides of `i0` should also be the same as before, but the strides + # of `o0` should now be the same as `e`. + assert inner_sdfg.arrays["i0"].strides == sdfg.arrays["a"].strides + assert inner_sdfg.arrays["o0"].strides == sdfg.arrays["e"].strides + + # Now run the transformed SDFG to see if the same output is generated. + csdfg_res = sdfg.compile() + csdfg_res(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py new file mode 100644 index 0000000000..88786ee0e3 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py @@ -0,0 +1,239 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace import data as dace_data + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _create_sdfg_double_read_part_1( + sdfg: dace.SDFG, + state: dace.SDFGState, + me: dace.nodes.MapEntry, + mx: dace.nodes.MapExit, + A_in: dace.nodes.AccessNode, + nb: int, +) -> dace.nodes.Tasklet: + tskl = state.add_tasklet( + name=f"tasklet_1", inputs={"__in1"}, outputs={"__out"}, code="__out = __in1 + 1.0" + ) + + state.add_edge(A_in, None, me, f"IN_{nb}", dace.Memlet("A[0:10]")) + state.add_edge(me, f"OUT_{nb}", tskl, "__in1", dace.Memlet("A[__i0]")) + me.add_in_connector(f"IN_{nb}") + me.add_out_connector(f"OUT_{nb}") + + state.add_edge(tskl, "__out", mx, f"IN_{nb}", dace.Memlet("A[__i0]")) + state.add_edge(mx, f"OUT_{nb}", state.add_access("A"), None, dace.Memlet("A[0:10]")) + mx.add_in_connector(f"IN_{nb}") + mx.add_out_connector(f"OUT_{nb}") + + +def _create_sdfg_double_read_part_2( + sdfg: dace.SDFG, + state: dace.SDFGState, + me: dace.nodes.MapEntry, + mx: dace.nodes.MapExit, + A_in: dace.nodes.AccessNode, + nb: int, +) -> dace.nodes.Tasklet: + tskl = state.add_tasklet( + name=f"tasklet_2", inputs={"__in1"}, outputs={"__out"}, code="__out = __in1 + 3.0" + ) + + state.add_edge(A_in, None, me, f"IN_{nb}", dace.Memlet("A[0:10]")) + state.add_edge(me, f"OUT_{nb}", tskl, "__in1", dace.Memlet("A[__i0]")) + me.add_in_connector(f"IN_{nb}") + me.add_out_connector(f"OUT_{nb}") + + state.add_edge(tskl, "__out", mx, f"IN_{nb}", dace.Memlet("B[__i0]")) + state.add_edge(mx, f"OUT_{nb}", state.add_access("B"), None, dace.Memlet("B[0:10]")) + mx.add_in_connector(f"IN_{nb}") + mx.add_out_connector(f"OUT_{nb}") + + +def _create_sdfg_double_read( + version: int, +) -> tuple[dace.SDFG]: + sdfg = dace.SDFG(util.unique_name(f"double_read_version_{version}")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + A_in = state.add_access("A") + me, mx = state.add_map("map", ndrange={"__i0": "0:10"}) + + if version == 0: + _create_sdfg_double_read_part_1(sdfg, state, me, mx, A_in, 0) + _create_sdfg_double_read_part_2(sdfg, state, me, mx, A_in, 1) + elif version == 1: + _create_sdfg_double_read_part_1(sdfg, state, me, mx, A_in, 1) + _create_sdfg_double_read_part_2(sdfg, state, me, mx, A_in, 0) + else: + raise ValueError(f"Does not know version {version}") + sdfg.validate() + return sdfg + + +def test_local_double_buffering_double_read_sdfg(): + sdfg0 = _create_sdfg_double_read(0) + sdfg1 = _create_sdfg_double_read(1) + args0 = {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in "AB"} + args1 = copy.deepcopy(args0) + + count0 = gtx_transformations.gt_create_local_double_buffering(sdfg0) + assert count0 == 1 + + count1 = gtx_transformations.gt_create_local_double_buffering(sdfg1) + assert count1 == 1 + + sdfg0(**args0) + sdfg1(**args1) + for name in args0: + assert np.allclose(args0[name], args1[name]), f"Failed verification in '{name}'." + + +def test_local_double_buffering_no_connection(): + """There is no direct connection between read and write.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_connection")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + A_in, B, A_out = (state.add_access(name) for name in "ABA") + + comp_tskl, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0]")}, + input_nodes={A_in}, + output_nodes={B}, + external_edges=True, + ) + + fill_tasklet = state.add_tasklet( + name="fill_tasklet", + inputs=set(), + code="__out = 2.", + outputs={"__out"}, + ) + state.add_nedge(me, fill_tasklet, dace.Memlet()) + state.add_edge(fill_tasklet, "__out", mx, "IN_1", dace.Memlet("A[__i0]")) + state.add_edge(mx, "OUT_1", A_out, None, dace.Memlet("A[0:10]")) + mx.add_in_connector("IN_1") + mx.add_out_connector("OUT_1") + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 1 + + # Ensure that a second application of the transformation does not run again. + count_again = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count_again == 0 + + # Find the newly created access node. + comp_tasklet_producers = [in_edge.src for in_edge in state.in_edges(comp_tskl)] + assert len(comp_tasklet_producers) == 1 + new_double_buffer = comp_tasklet_producers[0] + assert isinstance(new_double_buffer, dace_nodes.AccessNode) + assert not any(new_double_buffer.data == name for name in "AB") + assert isinstance(new_double_buffer.desc(sdfg), dace_data.Scalar) + assert new_double_buffer.desc(sdfg).transient + + # The newly created access node, must have an empty Memlet to the fill tasklet. + read_dependencies = [ + out_edge.dst for out_edge in state.out_edges(new_double_buffer) if out_edge.data.is_empty() + ] + assert len(read_dependencies) == 1 + assert read_dependencies[0] is fill_tasklet + + res = {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in "AB"} + ref = {"A": np.full_like(res["A"], 2.0), "B": res["A"] + 10.0} + sdfg(**res) + for name in res: + assert np.allclose(res[name], ref[name]), f"Failed verification in '{name}'." + + +def test_local_double_buffering_no_apply(): + """Here it does not apply, because are all distinct.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0]")}, + external_edges=True, + ) + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 0 + + +def test_local_double_buffering_already_buffered(): + """It is already buffered.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + state = sdfg.add_state(is_start_block=True) + sdfg.add_array( + "A", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + tsklt, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("A[__i0]")}, + external_edges=True, + ) + + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + tmp = state.add_access("tmp") + me_to_tskl_edge = next(iter(state.out_edges(me))) + + state.add_edge(me, me_to_tskl_edge.src_conn, tmp, None, dace.Memlet("A[__i0]")) + state.add_edge(tmp, None, tsklt, "__in1", dace.Memlet("tmp[0]")) + state.remove_edge(me_to_tskl_edge) + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py new file mode 100644 index 0000000000..8befcf0610 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -0,0 +1,389 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pytest +import numpy as np + +dace = pytest.importorskip("dace") + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + + +def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_sdfg")) + + for name in ["a", "b", "tmp"]: + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + sdfg.arrays["tmp"].transient = True + sdfg.arrays["b"].shape = (100, 100) + + state1: dace.SDFGState = sdfg.add_state(is_start_block=True) + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i1": "0:10", "__i2": "0:10"}, + inputs={"__in": dace.Memlet("a[__i1, __i2]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("tmp[__i1, __i2]")}, + external_edges=True, + ) + + state2 = sdfg.add_state_after(state1) + state2_tskl = state2.add_tasklet( + name="empty_blocker_tasklet", + inputs={}, + code="pass", + outputs={"__out"}, + side_effects=True, + ) + state2.add_edge( + state2_tskl, + "__out", + state2.add_access("a"), + None, + dace.Memlet("a[0, 0]"), + ) + + state3 = sdfg.add_state_after(state2) + state3.add_edge( + state3.add_access("tmp"), + None, + state3.add_access("b"), + None, + dace.Memlet("tmp[0:10, 0:10] -> [11:21, 22:32]"), + ) + sdfg.validate() + assert sdfg.number_of_nodes() == 3 + + return sdfg, state1, state3 + + +def test_distributed_buffer_remover(): + sdfg, state1, state3 = _mk_distributed_buffer_sdfg() + assert state1.number_of_nodes() == 5 + assert not any(dnode.data == "b" for dnode in state1.data_nodes()) + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][state3] == {"tmp"} + + # Because the final state has now become empty + assert sdfg.number_of_nodes() == 3 + assert state1.number_of_nodes() == 6 + assert any(dnode.data == "b" for dnode in state1.data_nodes()) + assert any(dnode.data == "tmp" for dnode in state1.data_nodes()) + + +def _make_distributed_buffer_global_memory_data_race_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race")) + arr_names = ["a", "b", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + state1.add_nedge(a_state1, state1.add_access("b"), dace.Memlet("a[0:10, 0:10]")) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_race(): + """Tests if the transformation realized that it would create a data race. + + If the transformation would apply, then `a` is read twice, once from two + different branches, whose order of execution is indeterminate. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_race_sdfg() + assert state2.number_of_nodes() == 2 + + sdfg.simplify() + assert sdfg.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert "DistributedBufferRelocator" not in res[sdfg] + assert state2.number_of_nodes() == 2 + + +def _make_distributed_buffer_global_memory_data_race_sdfg2() -> ( + tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] +): + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race2_sdfg")) + arr_names = ["a", "b", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in - 10", + outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + external_edges=True, + ) + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state1, state2 + + +def test_distributed_buffer_global_memory_data_race2(): + """Tests if the transformation realized that it would create a data race. + + Similar situation but now there are two different subgraphs. This is needed + because it is another branch that checks it. + """ + sdfg, state1, state2 = _make_distributed_buffer_global_memory_data_race_sdfg2() + assert state1.number_of_nodes() == 10 + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert "DistributedBufferRelocator" not in res[sdfg] + assert state1.number_of_nodes() == 10 + assert state2.number_of_nodes() == 2 + + +def _make_distributed_buffer_global_memory_data_no_rance() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance_sdfg")) + arr_names = ["a", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_no_rance(): + """Transformation applies if there is no data race. + + According to ADR18, pointwise dependencies are fine. This tests checks if the + checks for the read-write conflicts are not too strong. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_no_rance() + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} + assert state2.number_of_nodes() == 0 + + +def _make_distributed_buffer_global_memory_data_no_rance2() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance2_sdfg")) + arr_names = ["a", "t"] + for name in arr_names: + sdfg.add_array( + name=name, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + + state1 = sdfg.add_state(is_start_block=True) + state2 = sdfg.add_state_after(state1) + + a_state1 = state1.add_access("a") + state1.add_mapped_tasklet( + "computation1", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("a[__i0, __i1]")}, + output_nodes={a_state1}, + external_edges=True, + ) + state1.add_mapped_tasklet( + "computation2", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0, __i1]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("t[__i0, __i1]")}, + input_nodes={a_state1}, + external_edges=True, + ) + + state2.add_nedge(state2.add_access("t"), state2.add_access("a"), dace.Memlet("t[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state2 + + +def test_distributed_buffer_global_memory_data_no_rance2(): + """Transformation applies if there is no data race. + + These dependency is fine, because the access nodes are in a clear serial order. + """ + sdfg, state2 = _make_distributed_buffer_global_memory_data_no_rance2() + assert state2.number_of_nodes() == 2 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][state2] == {"t"} + assert state2.number_of_nodes() == 0 + + +def _make_distributed_buffer_non_sink_temporary_sdfg() -> ( + tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] +): + sdfg = dace.SDFG(util.unique_name("distributed_buffer_non_sink_temporary_sdfg")) + state = sdfg.add_state(is_start_block=True) + wb_state = sdfg.add_state_after(state) + + names = ["a", "b", "c", "t1", "t2"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t1"].transient = True + sdfg.arrays["t2"].transient = True + t1 = state.add_access("t1") + + state.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in1": dace.Memlet("a[__i]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("t1[__i]")}, + output_nodes={t1}, + external_edges=True, + ) + state.add_mapped_tasklet( + "comp2", + map_ranges={"__i": "0:10"}, + inputs={"__in1": dace.Memlet("t1[__i]")}, + code="__out = __in1 / 2.0", + outputs={"__out": dace.Memlet("t2[__i]")}, + input_nodes={t1}, + external_edges=True, + ) + + wb_state.add_nedge(wb_state.add_access("t1"), wb_state.add_access("b"), dace.Memlet("t1[0:10]")) + wb_state.add_nedge(wb_state.add_access("t2"), wb_state.add_access("b"), dace.Memlet("t2[0:10]")) + + sdfg.validate() + return sdfg, state, wb_state + + +def test_distributed_buffer_non_sink_temporary(): + """Tests the transformation if one of the temporaries is not a sink node. + + Note that the SDFG has two temporaries, `t1` is not a sink node and `t2` is + a sink node. + """ + sdfg, state, wb_state = _make_distributed_buffer_non_sink_temporary_sdfg() + assert wb_state.number_of_nodes() == 4 + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][wb_state] == {"t1", "t2"} + assert wb_state.number_of_nodes() == 0 + + +def _make_distributed_buffer_conditional_block_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG("distributed_buffer_conditional_block_sdfg") + + for name in ["a", "b", "c", "t"]: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["t"].transient = True + sdfg.add_symbol("cond", dace.bool_) + + # create states inside the nested SDFG for the if-branches + if_region = dace.sdfg.state.ConditionalBlock("if") + sdfg.add_node(if_region) + entry_state = sdfg.add_state("entry", is_start_block=True) + sdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + tstate.add_nedge(tstate.add_access("a"), tstate.add_access("t"), dace.Memlet("a[0:10]")) + if_region.add_branch(dace.sdfg.state.CodeBlock("cond"), then_body) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + fstate.add_nedge(fstate.add_access("b"), fstate.add_access("t"), dace.Memlet("b[0:10]")) + if_region.add_branch(dace.sdfg.state.CodeBlock("not (cond)"), else_body) + + wb_state = sdfg.add_state_after(if_region) + wb_state.add_nedge(wb_state.add_access("t"), wb_state.add_access("c"), dace.Memlet("t[0:10]")) + sdfg.validate() + return sdfg, wb_state + + +def test_distributed_buffer_conditional_block(): + sdfg, wb_state = _make_distributed_buffer_conditional_block_sdfg() + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res[sdfg]["DistributedBufferRelocator"][wb_state] == {"t"} diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 30266d71d1..350fa807a1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -14,21 +14,22 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview.transformations import ( +from gt4py.next.program_processors.runners.dace.transformations import ( gpu_utils as gtx_dace_fieldview_gpu_utils, ) -from . import pytestmark + from . import util def _get_trivial_gpu_promotable( tasklet_code: str, + trivial_map_range: str = "0", ) -> tuple[dace.SDFG, dace_nodes.MapEntry, dace_nodes.MapEntry]: - """Returns an SDFG that is suitable to test the `TrivialGPUMapPromoter` promoter. + """Returns an SDFG that is suitable to test the `TrivialGPUMapElimination` promoter. The first map is a trivial map (`Map[__trival_gpu_it=0]`) containing a Tasklet, - that does not have an output, but writes a scalar value into `tmp` (output + that does not have an input, but writes a scalar value into `tmp` (output connector `__out`), the body of this Tasklet can be controlled through the `tasklet_code` argument. The second map (`Map[__i0=0:N]`) contains a Tasklet that computes the sum of its @@ -41,6 +42,7 @@ def _get_trivial_gpu_promotable( Args: tasklet_code: The body of the Tasklet inside the trivial map. + trivial_map_range: Range of the trivial map, defaults to `"0"`. """ sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) @@ -57,11 +59,11 @@ def _get_trivial_gpu_promotable( _, trivial_map_entry, _ = state.add_mapped_tasklet( "trivail_top_tasklet", - map_ranges={"__trivial_gpu_it": "0"}, + map_ranges={"__trivial_gpu_it": trivial_map_range}, inputs={}, code=tasklet_code, outputs={"__out": dace.Memlet("tmp[0]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, schedule=schedule, ) @@ -74,15 +76,15 @@ def _get_trivial_gpu_promotable( }, code="__out = __in0 + __in1", outputs={"__out": dace.Memlet("b[__i0]")}, - input_nodes={"a": a, "tmp": tmp}, - output_nodes={"b": b}, + input_nodes={a, tmp}, + output_nodes={b}, external_edges=True, schedule=schedule, ) return sdfg, trivial_map_entry, second_map_entry -def test_trivial_gpu_map_promoter(): +def test_trivial_gpu_map_promoter_1(): """Tests if the GPU map promoter works. By using a body such as `__out = 3.0`, the transformation will apply. @@ -92,15 +94,15 @@ def test_trivial_gpu_map_promoter(): org_second_map_ranges = copy.deepcopy(second_map_entry.map.range) nb_runs = sdfg.apply_transformations_once_everywhere( - gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + gtx_dace_fieldview_gpu_utils.TrivialGPUMapElimination(do_not_fuse=True), validate=True, validate_all=True, ) assert ( nb_runs == 1 - ), f"Expected that 'TrivialGPUMapPromoter' applies once but it applied {nb_runs}." + ), f"Expected that 'TrivialGPUMapElimination' applies once but it applied {nb_runs}." trivial_map_params = trivial_map_entry.map.params - trivial_map_ranges = trivial_map_ranges.map.range + trivial_map_ranges = trivial_map_entry.map.range second_map_params = second_map_entry.map.params second_map_ranges = second_map_entry.map.range @@ -119,32 +121,82 @@ def test_trivial_gpu_map_promoter(): assert sdfg.is_valid() -def test_trivial_gpu_map_promoter(): +def test_trivial_gpu_map_promoter_2(): """Test if the GPU promoter does not fuse a special trivial map. By using a body such as `__out = __trivial_gpu_it` inside the - Tasklet's body, the map parameter is now used, and thus can not be fused. + Tasklet's body, the map parameter must now be replaced inside + the Tasklet's body. """ sdfg, trivial_map_entry, second_map_entry = _get_trivial_gpu_promotable( - "__out = __trivial_gpu_it" + tasklet_code="__out = __trivial_gpu_it", + trivial_map_range="2", + ) + state: dace.SDFGStae = sdfg.nodes()[0] + trivial_tasklet: dace_nodes.Tasklet = next( + iter( + out_edge.dst + for out_edge in state.out_edges(trivial_map_entry) + if isinstance(out_edge.dst, dace_nodes.Tasklet) + ) ) - org_trivial_map_params = list(trivial_map_entry.map.params) - org_second_map_params = list(second_map_entry.map.params) nb_runs = sdfg.apply_transformations_once_everywhere( - gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + gtx_dace_fieldview_gpu_utils.TrivialGPUMapElimination(do_not_fuse=True), validate=True, validate_all=True, ) - assert ( - nb_runs == 0 - ), f"Expected that 'TrivialGPUMapPromoter' does not apply but it applied {nb_runs}." - trivial_map_params = trivial_map_entry.map.params - second_map_params = second_map_entry.map.params - assert ( - trivial_map_params == org_trivial_map_params - ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." - assert ( - second_map_params == org_second_map_params - ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." - assert sdfg.is_valid() + assert nb_runs == 1 + + expected_trivial_code = "__out = 2" + assert trivial_tasklet.code == expected_trivial_code + + +def test_set_gpu_properties(): + """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()`.""" + sdfg = dace.SDFG("gpu_properties_test") + state = sdfg.add_state(is_start_block=True) + + map_entries: dict[int, dace_nodes.MapEntry] = {} + for dim in [1, 2, 3]: + shape = (10,) * dim + sdfg.add_array( + f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + sdfg.add_array( + f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + _, me, _ = state.add_mapped_tasklet( + f"map_{dim}", + map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, + inputs={"__in": dace.Memlet(f"A_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + code="__out = math.cos(__in)", + outputs={"__out": dace.Memlet(f"B_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + external_edges=True, + ) + map_entries[dim] = me + + sdfg.apply_gpu_transformations() + sdfg.validate() + + gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize( + sdfg=sdfg, + block_size=(10, "11", 12), + launch_factor_2d=2, + block_size_2d=(2, 2, 2), + launch_bounds_3d=200, + ) + + map1, map2, map3 = (map_entries[d].map for d in [1, 2, 3]) + + assert len(map1.params) == 1 + assert map1.gpu_block_size == [10, 1, 1] + assert map1.gpu_launch_bounds == "0" + + assert len(map2.params) == 2 + assert map2.gpu_block_size == [2, 2, 1] + assert map2.gpu_launch_bounds == "8" + + assert len(map3.params) == 3 + assert map3.gpu_block_size == [10, 11, 12] + assert map3.gpu_launch_bounds == "200" diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index c1e0ddd2f6..3b41da6336 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -7,6 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +from enum import Enum from typing import Callable import numpy as np @@ -16,11 +17,11 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) -from . import pytestmark + from . import util @@ -29,7 +30,7 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np The k blocking transformation can be applied to the SDFG, however no node can be taken out. This is because how it is constructed. However, applying - some simplistic transformations this can be done. + some simplistic transformations will enable the transformation. """ sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) @@ -136,6 +137,83 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n return sdfg, lambda a, b: (a + (2 * b.reshape((-1, 1)) + 3)) +def _get_sdfg_with_empty_memlet( + first_tasklet_independent: bool, + only_empty_memlets: bool, +) -> tuple[ + dace.SDFG, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.Tasklet +]: + """Generates an SDFG with an empty tasklet. + + The map contains two (serial) tasklets, connected through an access node. + The first tasklet has an empty memlet that connects it to the map entry. + Depending on `first_tasklet_independent` the tasklet is either independent + or not. The second tasklet has an additional in connector that accesses an array. + + If `only_empty_memlets` is given then the second memlet will only depend + on the input of the first tasklet. However, since it is connected to the + map exit, it will be classified as dependent. + + Returns: + The function returns the SDFG, the map entry and the first tasklet (that + is either dependent or independent), the access node between the tasklets + and the second tasklet that is always dependent. + """ + sdfg = dace.SDFG(util.unique_name("empty_memlet_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("M", dace.int32) + sdfg.add_array("b", ("N", "M"), dace.float64, transient=False) + b = state.add_access("b") + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + tmp = state.add_access("tmp") + + if not only_empty_memlets: + sdfg.add_array("a", ("N", "M"), dace.float64, transient=False) + a = state.add_access("a") + + # This is the first tasklet. + task1 = state.add_tasklet( + "task1", + inputs={}, + outputs={"__out0"}, + code="__out0 = 1.0" if first_tasklet_independent else "__out0 = j", + ) + + if only_empty_memlets: + task2 = state.add_tasklet( + "task2", inputs={"__in0"}, outputs={"__out0"}, code="__out0 = __in0 + 1.0" + ) + else: + task2 = state.add_tasklet( + "task2", inputs={"__in0", "__in1"}, outputs={"__out0"}, code="__out0 = __in0 + __in1" + ) + + # Now create the map + mentry, mexit = state.add_map("map", ndrange={"i": "0:N", "j": "0:M"}) + + if not only_empty_memlets: + state.add_edge(a, None, mentry, "IN_a", dace.Memlet("a[0:N, 0:M]")) + state.add_edge(mentry, "OUT_a", task2, "__in1", dace.Memlet("a[i, j]")) + + state.add_edge(task2, "__out0", mexit, "IN_b", dace.Memlet("b[i, j]")) + state.add_edge(mexit, "OUT_b", b, None, dace.Memlet("b[0:N, 0:M]")) + + state.add_edge(mentry, None, task1, None, dace.Memlet()) + state.add_edge(task1, "__out0", tmp, None, dace.Memlet("tmp[0]")) + state.add_edge(tmp, None, task2, "__in0", dace.Memlet("tmp[0]")) + + if not only_empty_memlets: + mentry.add_in_connector("IN_a") + mentry.add_out_connector("OUT_a") + mexit.add_in_connector("IN_b") + mexit.add_out_connector("OUT_b") + + sdfg.validate() + + return sdfg, mentry, task1, tmp, task2 + + def test_only_dependent(): """Just applying the transformation to the SDFG. @@ -152,11 +230,12 @@ def test_only_dependent(): ref = reff(a, b) # Apply the transformation - sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) + assert count == 1 assert len(sdfg.states()) == 1 state = sdfg.states()[0] @@ -216,11 +295,12 @@ def test_intermediate_access_node(): assert np.allclose(ref, c) # Apply the transformation. - sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) + assert count == 1 # Inspect if the SDFG was modified correctly. # We only inspect `tmp` which now has to be between the two maps. @@ -254,12 +334,12 @@ def test_chained_access() -> None: c[:] = 0 # Apply the transformation. - ret = sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) - assert ret == 1, f"Expected that the transformation was applied 1 time, but it was {ret}." + assert count == 1 # Now run the SDFG to see if it is still the same sdfg(a=a, b=b, c=c, M=M, N=N) @@ -305,3 +385,610 @@ def test_chained_access() -> None: assert isinstance(inner_tasklet, dace_nodes.Tasklet) assert inner_tasklet not in first_level_tasklets + + +def test_direct_map_exit_connection() -> dace.SDFG: + """Generates a SDFG with a mapped independent tasklet connected to the map exit. + + Because the tasklet is connected to the map exit it can not be independent. + """ + sdfg = dace.SDFG(util.unique_name("mapped_tasklet_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_array("a", (10,), dace.float64, transient=False) + sdfg.add_array("b", (10, 30), dace.float64, transient=False) + tsklt, me, mx = state.add_mapped_tasklet( + name="comp", + map_ranges=dict(i=f"0:10", j=f"0:30"), + inputs=dict(__in0=dace.Memlet("a[i]")), + outputs=dict(__out=dace.Memlet("b[i, j]")), + code="__out = __in0 + 1", + external_edges=True, + ) + + assert all(out_edge.dst is tsklt for out_edge in state.out_edges(me)) + assert all(in_edge.src is tsklt for in_edge in state.in_edges(mx)) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + assert all(isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(me)) + assert all(isinstance(in_edge.src, dace_nodes.MapExit) for in_edge in state.in_edges(mx)) + + +def test_empty_memlet_1(): + sdfg, mentry, itask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=True, + only_empty_memlets=False, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[itask] is mentry + assert scope_dict[tmp] is mentry + assert scope_dict[task2] is not mentry + assert scope_dict[task2] is not None + assert all( + isinstance(in_edge.src, dace_nodes.MapEntry) and in_edge.src is not mentry + for in_edge in state.in_edges(task2) + ) + + +def test_empty_memlet_2(): + sdfg, mentry, dtask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=False, + only_empty_memlets=False, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + # Find the inner map entry + assert all( + isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(mentry) + ) + inner_mentry = next(iter(state.out_edges(mentry))).dst + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[inner_mentry] is mentry + assert scope_dict[dtask] is inner_mentry + assert scope_dict[tmp] is inner_mentry + assert scope_dict[task2] is inner_mentry + + +def test_empty_memlet_3(): + # This is the only interesting case with only empty memlet. + sdfg, mentry, dtask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=False, + only_empty_memlets=True, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + # The top map only has a single output, which is the empty edge, that is holding + # the inner map entry in the scope. + assert all(out_edge.data.is_empty() for out_edge in state.out_edges(mentry)) + assert state.in_degree(mentry) == 0 + assert state.out_degree(mentry) == 1 + assert all( + isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(mentry) + ) + + inner_mentry = next(iter(state.out_edges(mentry))).dst + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[inner_mentry] is mentry + assert scope_dict[dtask] is inner_mentry + assert scope_dict[tmp] is inner_mentry + assert scope_dict[task2] is inner_mentry + + +class IndependentPart(Enum): + NONE = 0 + TASKLET = 1 + NESTED_SDFG = 2 + + +def _make_loop_blocking_sdfg_with_inner_map( + add_independent_part: IndependentPart, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: + """ + Generate the SDFGs with an inner map. + + The SDFG has an inner map that is classified as dependent. If + `add_independent_part` is `True` then the SDFG has a part that is independent. + Note that everything is read from a single connector. + + Return: + The function will return the SDFG, the state and the map entry for the outer + and inner map. + """ + sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_map")) + state = sdfg.add_state(is_start_block=True) + + for name in "AB": + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + + me_out, mx_out = state.add_map("outer_map", ndrange={"__i0": "0:10"}) + me_in, mx_in = state.add_map("inner_map", ndrange={"__i1": "0:10"}) + A, B = (state.add_access(name) for name in "AB") + tskl = state.add_tasklet( + "computation", inputs={"__in1", "__in2"}, outputs={"__out"}, code="__out = __in1 + __in2" + ) + + # construct the inner map of the map. + state.add_edge(A, None, me_out, "IN_A", dace.Memlet("A[0:10, 0:10]")) + me_out.add_in_connector("IN_A") + state.add_edge(me_out, "OUT_A", me_in, "IN_A", dace.Memlet("A[__i0, 0:10]")) + me_out.add_out_connector("OUT_A") + me_in.add_in_connector("IN_A") + state.add_edge(me_in, "OUT_A", tskl, "__in1", dace.Memlet("A[__i0, __i1]")) + me_in.add_out_connector("OUT_A") + + state.add_edge(me_out, "OUT_A", me_in, "IN_A1", dace.Memlet("A[__i0, 0:10]")) + me_in.add_in_connector("IN_A1") + state.add_edge(me_in, "OUT_A1", tskl, "__in2", dace.Memlet("A[__i0, 9 - __i1]")) + me_in.add_out_connector("OUT_A1") + + state.add_edge(tskl, "__out", mx_in, "IN_B", dace.Memlet("B[__i0, __i1]")) + mx_in.add_in_connector("IN_B") + state.add_edge(mx_in, "OUT_B", mx_out, "IN_B", dace.Memlet("B[__i0, 0:10]")) + mx_in.add_out_connector("OUT_B") + mx_out.add_in_connector("IN_B") + state.add_edge(mx_out, "OUT_B", B, None, dace.Memlet("B[0:10, 0:10]")) + mx_out.add_out_connector("OUT_B") + + # If requested add a part that is independent, i.e. is before the inner loop + if add_independent_part != IndependentPart.NONE: + sdfg.add_array("C", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + sdfg.add_scalar("tmp2", dtype=dace.float64, transient=True) + tmp, tmp2, C = (state.add_access(name) for name in ("tmp", "tmp2", "C")) + state.add_edge(tmp, None, tmp2, None, dace.Memlet("tmp2[0]")) + state.add_edge(tmp2, None, mx_out, "IN_tmp", dace.Memlet("C[__i0]")) + mx_out.add_in_connector("IN_tmp") + state.add_edge(mx_out, "OUT_tmp", C, None, dace.Memlet("C[0:10]")) + mx_out.add_out_connector("OUT_tmp") + match add_independent_part: + case IndependentPart.TASKLET: + tskli = state.add_tasklet( + "independent_comp", + inputs={"__field"}, + outputs={"__out"}, + code="__out = __field[1, 1]", + ) + state.add_edge(me_out, "OUT_A", tskli, "__field", dace.Memlet("A[0:10, 0:10]")) + state.add_edge(tskli, "__out", tmp, None, dace.Memlet("tmp[0]")) + case IndependentPart.NESTED_SDFG: + nsdfg_sym, nsdfg_inp, nsdfg_out = ("S", "I", "V") + nsdfg = _make_conditional_block_sdfg( + "independent_comp", nsdfg_sym, nsdfg_inp, nsdfg_out + ) + nsdfg_node = state.add_nested_sdfg( + nsdfg, + sdfg, + inputs={nsdfg_inp}, + outputs={nsdfg_out}, + symbol_mapping={nsdfg_sym: 0}, + ) + state.add_edge(me_out, "OUT_A", nsdfg_node, nsdfg_inp, dace.Memlet("A[1, 1]")) + state.add_edge(nsdfg_node, nsdfg_out, tmp, None, dace.Memlet("tmp[0]")) + case _: + raise NotImplementedError() + + sdfg.validate() + return sdfg, state, me_out, me_in + + +def test_loop_blocking_inner_map(): + """ + Tests with an inner map, without an independent part. + """ + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map( + IndependentPart.NONE + ) + assert all(oedge.dst is inner_map for oedge in state.out_edges(outer_map)) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + validate=True, + validate_all=True, + ) + assert count == 1 + assert all( + oedge.dst is not inner_map and isinstance(oedge.dst, dace_nodes.MapEntry) + for oedge in state.out_edges(outer_map) + ) + inner_blocking_map: dace_nodes.MapEntry = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.MapEntry) + ) + assert inner_blocking_map is not inner_map + + assert all(oedge.dst is inner_map for oedge in state.out_edges(inner_blocking_map)) + + +@pytest.mark.parametrize("independent_part", [IndependentPart.TASKLET, IndependentPart.NESTED_SDFG]) +def test_loop_blocking_inner_map_with_independent_part(independent_part): + """ + Tests with an inner map with an independent part. + """ + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(independent_part) + + # Find the parts that are independent. + independent_node: dace_nodes.Tasklet | dace_nodes.NestedSDFG = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, (dace_nodes.Tasklet, dace_nodes.NestedSDFG)) + ) + assert independent_node.label == "independent_comp" + i_access_node: dace_nodes.AccessNode = next( + oedge.dst for oedge in state.out_edges(independent_node) + ) + assert i_access_node.data == "tmp" + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + validate=True, + validate_all=True, + ) + assert count == 1 + inner_blocking_map: dace_nodes.MapEntry = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.MapEntry) + ) + assert inner_blocking_map is not inner_map + + assert all( + oedge.dst in {inner_blocking_map, independent_node} for oedge in state.out_edges(outer_map) + ) + assert state.scope_dict()[i_access_node] is outer_map + assert all(oedge.dst is inner_blocking_map for oedge in state.out_edges(i_access_node)) + + +def _make_mixed_memlet_sdfg( + tskl1_independent: bool, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.Tasklet]: + """ + Generates the SDFGs for the mixed Memlet tests. + + The SDFG that is generated has the following structure: + - `tsklt2`, is always dependent, it has an incoming connection from the + map entry, and an incoming, but empty, connection with `tskl1`. + - `tskl1` is connected to the map entry, depending on `tskl1_independent` + it is independent or dependent, it has an empty connection to `tskl2`, + thus it is sequenced before. + - Both have connection to other nodes down stream, but they are dependent. + + Returns: + A tuple containing the following objects. + - The SDFG. + - The SDFG state. + - The outer map entry node. + - `tskl1`. + - `tskl2`. + """ + sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + state = sdfg.add_state(is_start_block=True) + names_array = ["A", "B", "C"] + names_scalar = ["tmp1", "tmp2"] + for aname in names_array: + sdfg.add_array( + aname, + shape=((10,) if aname == "A" else (10, 10)), + dtype=dace.float64, + transient=False, + ) + for sname in names_scalar: + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=True, + ) + A, B, C, tmp1, tmp2 = (state.add_access(name) for name in names_array + names_scalar) + + me, mx = state.add_map("outer_map", ndrange={"i": "0:10", "j": "0:10"}) + tskl1 = state.add_tasklet( + "tskl1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1" if tskl1_independent else "__out = __in1 + j", + ) + tskl2 = state.add_tasklet( + "tskl2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 10.0", + ) + tskl3 = state.add_tasklet( + "tskl3", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code="__out = __in1 + __in2", + ) + + state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:10]")) + me.add_in_connector("IN_A") + state.add_edge(me, "OUT_A", tskl1, "__in1", dace.Memlet("A[i]")) + me.add_out_connector("OUT_A") + state.add_edge(tskl1, "__out", tmp1, None, dace.Memlet("tmp1[0]")) + + state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:10, 0:10]")) + me.add_in_connector("IN_B") + state.add_edge(me, "OUT_B", tskl2, "__in1", dace.Memlet("B[i, j]")) + me.add_out_connector("OUT_B") + state.add_edge(tskl2, "__out", tmp2, None, dace.Memlet("tmp2[0]")) + + # Add the empty Memlet that sequences `tskl1` before `tskl2`. + state.add_edge(tskl1, None, tskl2, None, dace.Memlet()) + + state.add_edge(tmp1, None, tskl3, "__in1", dace.Memlet("tmp1[0]")) + state.add_edge(tmp2, None, tskl3, "__in2", dace.Memlet("tmp2[0]")) + state.add_edge(tskl3, "__out", mx, "IN_C", dace.Memlet("C[i, j]")) + mx.add_in_connector("IN_C") + state.add_edge(mx, "OUT_C", C, None, dace.Memlet("C[0:10, 0:10]")) + mx.add_out_connector("OUT_C") + sdfg.validate() + + return (sdfg, state, me, tskl1, tskl2) + + +def _apply_and_run_mixed_memlet_sdfg(sdfg: dace.SDFG) -> None: + ref = { + "A": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "B": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + "C": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + } + res = copy.deepcopy(ref) + sdfg(**ref) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1, f"Expected one application, but git {count}" + sdfg(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref) + + +def _make_conditional_block_sdfg(sdfg_label: str, sym: str, inp: str, out: str): + sdfg = dace.SDFG(sdfg_label) + for data in [inp, out]: + sdfg.add_scalar(data, dtype=dace.float64) + + if_region = dace.sdfg.state.ConditionalBlock("if") + sdfg.add_node(if_region) + entry_state = sdfg.add_state("entry", is_start_block=True) + sdfg.add_edge(entry_state, if_region, dace.InterstateEdge()) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"{sym} % 2 == 0"), then_body) + tskli = tstate.add_tasklet("write_0", inputs={"inp"}, outputs={"val"}, code=f"val = inp + 0") + tstate.add_edge(tstate.add_access(inp), None, tskli, "inp", dace.Memlet(f"{inp}[0]")) + tstate.add_edge(tskli, "val", tstate.add_access(out), None, dace.Memlet(f"{out}[0]")) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"{sym} % 2 != 0"), else_body) + tskli = fstate.add_tasklet("write_1", inputs={"inp"}, outputs={"val"}, code=f"val = inp + 1") + fstate.add_edge(fstate.add_access(inp), None, tskli, "inp", dace.Memlet(f"{inp}[0]")) + fstate.add_edge(tskli, "val", fstate.add_access(out), None, dace.Memlet(f"{out}[0]")) + + return sdfg + + +def test_loop_blocking_mixed_memlets_1(): + sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(True) + mx = state.exit_node(me) + + _apply_and_run_mixed_memlet_sdfg(sdfg) + scope_dict = state.scope_dict() + + # Ensure that `tskl1` is independent. + assert scope_dict[tskl1] is me + + # The output of `tskl1`, which is `tmp1` should also be classified as independent. + tmp1 = next(iter(edge.dst for edge in state.out_edges(tskl1) if not edge.data.is_empty())) + assert scope_dict[tmp1] is me + assert isinstance(tmp1, dace_nodes.AccessNode) + assert tmp1.data == "tmp1" + + # Find the inner map. + inner_map_entry: dace_nodes.MapEntry = scope_dict[tskl2] + assert inner_map_entry is not me and isinstance(inner_map_entry, dace_nodes.MapEntry) + inner_map_exit: dace_nodes.MapExit = state.exit_node(inner_map_entry) + + outer_scope = {tskl1, tmp1, inner_map_entry, inner_map_exit, mx} + for node in state.nodes(): + if scope_dict[node] is None: + assert (node is me) or ( + isinstance(node, dace_nodes.AccessNode) and node.data in {"A", "B", "C"} + ) + elif scope_dict[node] is me: + assert node in outer_scope + else: + assert ( + (node is inner_map_exit) + or (isinstance(node, dace_nodes.AccessNode) and node.data == "tmp2") + or (isinstance(node, dace_nodes.Tasklet) and node.label in {"tskl2", "tskl3"}) + ) + + +def test_loop_blocking_mixed_memlets_2(): + sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(False) + mx = state.exit_node(me) + + _apply_and_run_mixed_memlet_sdfg(sdfg) + scope_dict = state.scope_dict() + + # Because `tskl1` is now dependent, everything is now dependent. + inner_map_entry = scope_dict[tskl1] + assert isinstance(inner_map_entry, dace_nodes.MapEntry) + assert inner_map_entry is not me + + for node in state.nodes(): + if scope_dict[node] is None: + assert (node is me) or ( + isinstance(node, dace_nodes.AccessNode) and node.data in {"A", "B", "C"} + ) + elif scope_dict[node] is me: + assert isinstance(node, dace_nodes.MapEntry) or (node is mx) + else: + assert scope_dict[node] is inner_map_entry + + +def test_loop_blocking_no_independent_nodes(): + import dace + + sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + state = sdfg.add_state(is_start_block=True) + names = ["A", "B", "C"] + for aname in names: + sdfg.add_array( + aname, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + A = state.add_access("A") + _, me, mx = state.add_mapped_tasklet( + "fully_dependent_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + external_edges=True, + input_nodes={A}, + ) + nsdfg_sym, nsdfg_inp, nsdfg_out = ("S", "I", "V") + nsdfg = _make_conditional_block_sdfg("dependent_component", nsdfg_sym, nsdfg_inp, nsdfg_out) + nsdfg_node = state.add_nested_sdfg( + nsdfg, sdfg, inputs={nsdfg_inp}, outputs={nsdfg_out}, symbol_mapping={nsdfg_sym: "__i1"} + ) + state.add_memlet_path(A, me, nsdfg_node, dst_conn=nsdfg_inp, memlet=dace.Memlet("A[1,1]")) + state.add_memlet_path( + nsdfg_node, + mx, + state.add_access("C"), + src_conn=nsdfg_out, + memlet=dace.Memlet("C[__i0, __i1]"), + ) + sdfg.validate() + + # Because there is nothing that is independent the transformation will + # not apply if `require_independent_nodes` is enabled. + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameter="__i1", + require_independent_nodes=True, + ), + validate=True, + validate_all=True, + ) + assert count == 0 + + # But it will apply once this requirement is lifted. + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameter="__i1", + require_independent_nodes=False, + ), + validate=True, + validate_all=True, + ) + assert count == 1 + + +def _make_only_last_two_elements_sdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("B", dace.int32) + sdfg.add_symbol("M", dace.int32) + + for name in "acb": + sdfg.add_array( + name, + shape=(20, 10), + dtype=dace.float64, + ) + + state.add_mapped_tasklet( + "computation", + map_ranges={"i": "B:N", "k": "(M-2):M"}, + inputs={ + "__in1": dace.Memlet("a[i, k]"), + "__in2": dace.Memlet("b[i, k]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("c[i, k]")}, + external_edges=True, + ) + sdfg.validate() + + return sdfg + + +def test_only_last_two_elements_sdfg(): + sdfg = _make_only_last_two_elements_sdfg() + + def ref_comp(a, b, c, B, N, M): + for i in range(B, N): + for k in range(M - 2, M): + c[i, k] = a[i, k] + b[i, k] + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=1, + blocking_parameter="k", + require_independent_nodes=False, + ), + validate=True, + validate_all=True, + ) + assert count == 1 + + ref = { + "a": np.array(np.random.rand(20, 10), dtype=np.float64), + "b": np.array(np.random.rand(20, 10), dtype=np.float64), + "c": np.zeros((20, 10), dtype=np.float64), + "B": 0, + "N": 20, + "M": 6, + } + res = copy.deepcopy(ref) + + ref_comp(**ref) + sdfg(**res) + + assert np.allclose(ref["c"], res["c"]) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py new file mode 100644 index 0000000000..d8cf8e33f8 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py @@ -0,0 +1,74 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_transients_persistent_inner_access_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG(util.unique_name("transients_persistent_inner_access_sdfg")) + state = sdfg.add_state(is_start_block=True) + + for name in "abc": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays["b"].transient = True + + me: dace_nodes.MapEntry + mx: dace_nodes.MapExit + me, mx = state.add_map("comp", ndrange={"__i0": "0:10"}) + a, b, c = (state.add_access(name) for name in "abc") + tsklt: dace_nodes.Tasklet = state.add_tasklet( + "tsklt", + inputs={"__in"}, + code="__out = __in + 1.0", + outputs={"__out"}, + ) + + me.add_in_connector("IN_A") + state.add_edge(a, None, me, "IN_A", dace.Memlet("a[0:10]")) + + me.add_out_connector("OUT_A") + state.add_edge(me, "OUT_A", b, None, dace.Memlet("a[__i0] -> [__i0]")) + + state.add_edge(b, None, tsklt, "__in", dace.Memlet("b[__i0]")) + + mx.add_in_connector("IN_C") + state.add_edge(tsklt, "__out", mx, "IN_C", dace.Memlet("c[__i0]")) + + mx.add_out_connector("OUT_C") + state.add_edge(mx, "OUT_C", c, None, dace.Memlet("c[0:10]")) + sdfg.validate() + return sdfg, state + + +def test_make_transients_persistent_inner_access(): + sdfg, state = _make_transients_persistent_inner_access_sdfg() + assert sdfg.arrays["b"].lifetime is dace.dtypes.AllocationLifetime.Scope + + # Because `b`, the only transient, is used inside a map scope, it is not selected, + # although in this situation it would be possible. + change_report: dict[int, set[str]] = gtx_transformations.gt_make_transients_persistent( + sdfg, device=dace.DeviceType.CPU + ) + assert len(change_report) == 1 + assert change_report[sdfg.cfg_id] == set() diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py new file mode 100644 index 0000000000..f2c31a7188 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -0,0 +1,349 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_test_sdfg( + output_name: str = "G", + input_name: str = "G", + tmp_name: str = "T", + array_size: int | str = 10, + tmp_size: int | str | None = None, + map_range: tuple[int | str, int | str] | None = None, + tmp_to_glob_memlet: str | None = None, + in_offset: str | None = None, + out_offset: str | None = None, +) -> dace.SDFG: + if isinstance(array_size, str): + array_size = sdfg.add_symbol(array_size, dace.int32, find_new_name=True) + if tmp_size is None: + tmp_size = array_size + if map_range is None: + map_range = (0, array_size) + if tmp_to_glob_memlet is None: + tmp_to_glob_memlet = f"{tmp_name}[0:{array_size}] -> [0:{array_size}]" + elif tmp_to_glob_memlet[0] == "[": + tmp_to_glob_memlet = tmp_name + tmp_to_glob_memlet + if in_offset is None: + in_offset = "0" + if out_offset is None: + out_offset = in_offset + + sdfg = dace.SDFG(util.unique_name("map_buffer")) + state = sdfg.add_state(is_start_block=True) + names = {input_name, tmp_name, output_name} + for name in names: + sdfg.add_array( + name, + shape=((array_size,) if name != tmp_name else (tmp_size,)), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays[tmp_name].transient = True + + input_ac = state.add_access(input_name) + tmp_ac = state.add_access(tmp_name) + output_ac = state.add_access(output_name) + + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": f"{map_range[0]}:{map_range[1]}"}, + inputs={"__in1": dace.Memlet(data=input_ac.data, subset=f"__i0 + {in_offset}")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet(data=tmp_ac.data, subset=f"__i0 + {out_offset}")}, + input_nodes={input_ac}, + output_nodes={tmp_ac}, + external_edges=True, + ) + state.add_edge( + tmp_ac, + None, + output_ac, + None, + dace.Memlet(tmp_to_glob_memlet), + ) + sdfg.validate() + return sdfg + + +def _perform_test( + sdfg: dace.SDFG, + xform: gtx_transformations.GT4PyMapBufferElimination, + exp_count: int, + array_size: int = 10, +) -> None: + ref = { + name: np.array(np.random.rand(array_size), dtype=np.float64, copy=True) + for name, desc in sdfg.arrays.items() + if not desc.transient + } + if "array_size" in sdfg.symbols: + ref["array_size"] = array_size + + res = copy.deepcopy(ref) + sdfg(**ref) + + count = sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) + assert count == exp_count, f"Expected {exp_count} applications, but got {count}" + + if count == 0: + return + + sdfg(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()), f"Failed for '{name}'." + + +def test_map_buffer_elimination_simple(): + sdfg = _make_test_sdfg() + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=True), + exp_count=1, + ) + + +def test_map_buffer_elimination_simple_2(): + sdfg = _make_test_sdfg() + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=0, + ) + + +def test_map_buffer_elimination_simple_3(): + sdfg = _make_test_sdfg(input_name="A", output_name="O") + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_1(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + tmp_to_glob_memlet="[2:8] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_2(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [0:6]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_3(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_4(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[1:7] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=0, + ) + + +def test_map_buffer_elimination_offset_5(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + tmp_size=6, + in_offset="0", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_not_apply(): + """Indirect accessing, because of this the double buffer is needed.""" + sdfg = dace.SDFG(util.unique_name("map_buffer")) + state = sdfg.add_state(is_start_block=True) + + names = ["A", "tmp", "idx"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.int32 if name == "tmp" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + tmp = state.add_access("tmp") + state.add_mapped_tasklet( + "indirect_accessing", + map_ranges={"__i0": "0:10"}, + inputs={ + "__field": dace.Memlet("A[0:10]"), + "__idx": dace.Memlet("idx[__i0]"), + }, + code="__out = __field[__idx]", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + output_nodes={tmp}, + external_edges=True, + ) + state.add_nedge(tmp, state.add_access("A"), dace.Memlet("tmp[0:10] -> [0:10]")) + + # TODO(phimuell): Update the transformation such that we can specify + # `assume_pointwise=True` and the test would still pass. + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 0 + + +def test_map_buffer_elimination_with_nested_sdfgs(): + """ + After removing a transient connected to a nested SDFG node, ensure that the strides + are propagated to the arrays in nested SDFG. + """ + + stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] + + # top-level sdfg + sdfg = dace.SDFG(util.unique_name("map_buffer")) + inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) + out, out_desc = sdfg.add_array( + "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) + ) + tmp, _ = sdfg.add_temp_transient_like(out_desc) + state = sdfg.add_state() + tmp_node = state.add_access(tmp) + + nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) + out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) + tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) + state1 = nsdfg1.add_state() + tmp1_node = state1.add_access(tmp1) + + nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) + out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) + tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) + state2 = nsdfg2.add_state() + tmp2_node = state2.add_access(tmp2) + + state2.add_mapped_tasklet( + "broadcast2", + map_ranges={"__i": "0:10"}, + code="__oval = __ival + 1.0", + inputs={ + "__ival": dace.Memlet(f"{inp2}[__i]"), + }, + outputs={ + "__oval": dace.Memlet(f"{tmp2}[__i]"), + }, + output_nodes={tmp2_node}, + external_edges=True, + ) + state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc)) + + nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"}) + me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"}) + state1.add_memlet_path( + state1.add_access(inp1), + me1, + nsdfg2_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp1, inp1_desc), + ) + state1.add_memlet_path( + nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]") + ) + state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc)) + + nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"}) + me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"}) + state.add_memlet_path( + state.add_access(inp), + me, + nsdfg1_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp, inp_desc), + ) + state.add_memlet_path( + nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]") + ) + state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc)) + + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 3 + assert out1_desc.strides == out_desc.strides[1:] + assert out2_desc.strides == out_desc.strides[2:] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py index c9d467ba80..ecf5a4762b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py @@ -17,11 +17,13 @@ from dace.sdfg import nodes as dace_nodes from dace.transformation import dataflow as dace_dataflow -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) -from . import pytestmark +import dace + + from . import util @@ -58,17 +60,17 @@ def _make_serial_sdfg_1( inputs={"__in0": dace.Memlet("a[__i0, __i1]")}, code="__out = __in0 + 1.0", outputs={"__out": dace.Memlet("tmp[__i0, __i1]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, ) state.add_mapped_tasklet( name="second_computation", - map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp": tmp}, - inputs={"__in0": dace.Memlet("tmp[__i0, __i1]")}, + map_ranges=[("__i4", f"0:{N}"), ("__i6", f"0:{N}")], + input_nodes={tmp}, + inputs={"__in0": dace.Memlet("tmp[__i4, __i6]")}, code="__out = __in0 + 3.0", - outputs={"__out": dace.Memlet("b[__i0, __i1]")}, + outputs={"__out": dace.Memlet("b[__i4, __i6]")}, external_edges=True, ) @@ -118,17 +120,14 @@ def _make_serial_sdfg_2( "__out0": dace.Memlet("tmp_1[__i0, __i1]"), "__out1": dace.Memlet("tmp_2[__i0, __i1]"), }, - output_nodes={ - "tmp_1": tmp_1, - "tmp_2": tmp_2, - }, + output_nodes={tmp_1, tmp_2}, external_edges=True, ) state.add_mapped_tasklet( name="first_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp_1": tmp_1}, + input_nodes={tmp_1}, inputs={"__in0": dace.Memlet("tmp_1[__i0, __i1]")}, code="__out = __in0 + 3.0", outputs={"__out": dace.Memlet("b[__i0, __i1]")}, @@ -136,11 +135,11 @@ def _make_serial_sdfg_2( ) state.add_mapped_tasklet( name="second_computation", - map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp_2": tmp_2}, - inputs={"__in0": dace.Memlet("tmp_2[__i0, __i1]")}, + map_ranges=[("__i3", f"0:{N}"), ("__i6", f"0:{N}")], + input_nodes={tmp_2}, + inputs={"__in0": dace.Memlet("tmp_2[__i3, __i6]")}, code="__out = __in0 - 3.0", - outputs={"__out": dace.Memlet("c[__i0, __i1]")}, + outputs={"__out": dace.Memlet("c[__i3, __i6]")}, external_edges=True, ) @@ -194,45 +193,93 @@ def _make_serial_sdfg_3( }, code="__out = __in0 + __in1", outputs={"__out": dace.Memlet("tmp[__i0]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, ) state.add_mapped_tasklet( name="indirect_access", - map_ranges=[("__i0", f"0:{N_output}")], - input_nodes={"tmp": tmp}, + map_ranges=[("__i1", f"0:{N_output}")], + input_nodes={tmp}, inputs={ - "__index": dace.Memlet("idx[__i0]"), + "__index": dace.Memlet("idx[__i1]"), "__array": dace.Memlet.simple("tmp", subset_str=f"0:{N_input}", num_accesses=1), }, code="__out = __array[__index]", - outputs={"__out": dace.Memlet("c[__i0]")}, + outputs={"__out": dace.Memlet("c[__i1]")}, external_edges=True, ) return sdfg +def _make_parallel_sdfg_1( + single_input_node: bool, +) -> tuple[dace.SDFG, dace.SDFGState]: + """Make a parallel SDFG. + + The maps access both the same Data but uses different AccessNodes for that. + If `single_input_node` is `True` then there will only one AccessNode for `a` + be created, otherwise each map has its own. + """ + sdfg = dace.SDFG(util.unique_name("parallel_sdfg_1")) + state = sdfg.add_state(is_start_block=True) + + for name in "abc": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + a1, b, c = (state.add_access(name) for name in "abc") + a2 = a1 if single_input_node else state.add_access("a") + + state.add_mapped_tasklet( + "map1", + map_ranges={"__i0": "0:10"}, + inputs={"__in": dace.Memlet("a[__i0]")}, + code="__out = __in + 10.", + outputs={"__out": dace.Memlet("b[__i0]")}, + input_nodes={a1}, + output_nodes={b}, + external_edges=True, + ) + state.add_mapped_tasklet( + "map2", + map_ranges={"__i1": "0:10"}, + inputs={"__in": dace.Memlet("a[__i1]")}, + code="__out = __in - 10.", + outputs={"__out": dace.Memlet("c[__i1]")}, + input_nodes={a2}, + output_nodes={c}, + external_edges=True, + ) + sdfg.validate() + + return sdfg, state + + def test_exclusive_itermediate(): """Tests if the exclusive intermediate branch works.""" N = 10 sdfg = _make_serial_sdfg_1(N) # Now apply the optimizations. - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" not in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b"] ] assert len(intermediate_nodes) == 1 @@ -257,19 +304,19 @@ def test_shared_itermediate(): sdfg.arrays["tmp"].transient = False # Now apply the optimizations. - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b", "tmp"] ] assert len(intermediate_nodes) == 1 @@ -291,21 +338,21 @@ def test_pure_output_node(): """Tests the path of a pure intermediate.""" N = 10 sdfg = _make_serial_sdfg_2(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 3 # The first fusion will only bring it down to two maps. sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -327,17 +374,17 @@ def test_array_intermediate(): """ N = 10 sdfg = _make_serial_sdfg_1(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations_repeated([dace_dataflow.MapExpansion]) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 4 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 4 # Now perform the fusion sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(only_toplevel_maps=True), + gtx_transformations.MapFusionSerial(only_toplevel_maps=True), validate=True, validate_all=True, ) - map_entries = util._count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) + map_entries = util.count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) scope = next(iter(sdfg.states())).scope_dict() assert len(map_entries) == 3 @@ -349,7 +396,7 @@ def test_array_intermediate(): # Find the access node that is the new intermediate node. inner_access_nodes: list[dace_nodes.AccessNode] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if scope[node] is not None ] assert len(inner_access_nodes) == 1 @@ -374,7 +421,7 @@ def test_interstate_transient(): """ N = 10 sdfg = _make_serial_sdfg_2(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 3 assert sdfg.number_of_nodes() == 1 # Now add the new state and the new output. @@ -393,15 +440,15 @@ def test_interstate_transient(): # Now apply the transformation sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) assert "tmp_1" in sdfg.arrays assert "tmp_2" not in sdfg.arrays assert sdfg.number_of_nodes() == 2 - assert util._count_nodes(head_state, dace_nodes.MapEntry) == 1 - assert util._count_nodes(new_state, dace_nodes.MapEntry) == 1 + assert util.count_nodes(head_state, dace_nodes.MapEntry) == 1 + assert util.count_nodes(new_state, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -430,7 +477,7 @@ def test_indirect_access(): c = np.empty(N_output) idx = np.random.randint(low=0, high=N_input, size=N_output, dtype=np.int32) sdfg = _make_serial_sdfg_3(N_input=N_input, N_output=N_output) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 def _ref(a, b, idx): tmp = a + b @@ -443,11 +490,11 @@ def _ref(a, b, idx): # Now "apply" the transformation sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 c[:] = -1.0 sdfg(a=a, b=b, idx=idx, c=c) @@ -455,5 +502,107 @@ def _ref(a, b, idx): def test_indirect_access_2(): - # TODO(phimuell): Index should be computed and that map should be fusable. - pass + """Indirect accesses, with non point wise input dependencies. + + Because `a` is used as input and output and `a` is indirectly accessed + the access to `a` can not be point wise so, fusing is not possible. + """ + sdfg = dace.SDFG(util.unique_name("indirect_access_sdfg_2")) + state = sdfg.add_state(is_start_block=True) + + names = ["a", "b", "idx", "tmp"] + + for name in names: + sdfg.add_array( + name=name, + shape=(10,), + dtype=dace.int32 if name == "idx" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + a_in, b, idx, tmp, a_out = (state.add_access(name) for name in (names + ["a"])) + + state.add_mapped_tasklet( + "indirect_access", + map_ranges={"__i0": "0:10"}, + inputs={ + "__idx": dace.Memlet("idx[__i0]"), + "__field": dace.Memlet("a[0:10]", volume=1), + }, + code="__out = __field[__idx]", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + input_nodes={a_in, idx}, + output_nodes={tmp}, + external_edges=True, + ) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={ + "__in1": dace.Memlet("tmp[__i0]"), + "__in2": dace.Memlet("b[__i0]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("a[__i0]")}, + input_nodes={tmp, b}, + output_nodes={a_out}, + external_edges=True, + ) + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.MapFusionSerial(), + validate=True, + validate_all=True, + ) + assert count == 0 + + +def test_parallel_1(): + sdfg, state = _make_parallel_sdfg_1(single_input_node=False) + assert util.count_nodes(state, dace_nodes.AccessNode) == 4 + assert util.count_nodes(state, dace_nodes.MapEntry) == 2 + + # Because we request a common ancestor it will not apply. + # NOTE: We might have to change that if the implementation changes. + nb_applies = sdfg.apply_transformations_repeated( + [gtx_transformations.MapFusionParallel(only_if_common_ancestor=True)] + ) + assert nb_applies == 0 + + # If we do not restrict common ancestor then it will work. + nb_applies = sdfg.apply_transformations_repeated( + [gtx_transformations.MapFusionParallel(only_if_common_ancestor=False)] + ) + + assert nb_applies == 1 + assert util.count_nodes(state, dace_nodes.AccessNode) == 4 + assert util.count_nodes(state, dace_nodes.MapEntry) == 1 + + +def test_parallel_2(): + sdfg, state = _make_parallel_sdfg_1(single_input_node=True) + assert util.count_nodes(state, dace_nodes.AccessNode) == 3 + assert util.count_nodes(state, dace_nodes.MapEntry) == 2 + + nb_applies = sdfg.apply_transformations_repeated([gtx_transformations.MapFusionParallel()]) + + assert nb_applies == 1 + assert util.count_nodes(state, dace_nodes.AccessNode) == 3 + assert util.count_nodes(state, dace_nodes.MapEntry) == 1 + + +def test_parallel_3(): + """Test that the parallel map fusion does not apply for serial maps.""" + sdfg = _make_serial_sdfg_1(20) + assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 + + # Because the maps are fully serial, parallel map fusion should never apply. + nb_applies = sdfg.apply_transformations_repeated( + [gtx_transformations.MapFusionParallel(only_if_common_ancestor=False)] + ) + assert nb_applies == 0 + assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py new file mode 100644 index 0000000000..d82127f6f3 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -0,0 +1,100 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import pytest +import numpy as np + +dace = pytest.importorskip("dace") + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + + +def _perform_reorder_test( + sdfg: dace.SDFG, + leading_dim: list[str], + expected_order: list[str], +) -> None: + """Performs the reorder transformation and test it. + + If `expected_order` is the empty list, then the transformation should not apply. + """ + map_entries: list[dace.nodes.MapEntry] = util.count_nodes(sdfg, dace.nodes.MapEntry, True) + assert len(map_entries) == 1 + map_entry: dace.nodes.MapEntry = map_entries[0] + old_map_params = map_entry.map.params.copy() + + apply_count = sdfg.apply_transformations_repeated( + gtx_transformations.MapIterationOrder( + leading_dims=leading_dim, + ), + validate=True, + validate_all=True, + ) + new_map_params = map_entry.map.params.copy() + + if len(expected_order) == 0: + assert ( + apply_count == 0 + ), f"Expected that the transformation was not applied. New map order: {map_entry.map.params}" + return + else: + assert ( + apply_count > 0 + ), f"Expected that the transformation was applied. Old map order: {map_entry.map.params}; Expected order: {expected_order}" + assert len(expected_order) == len(new_map_params) + + assert ( + expected_order == new_map_params + ), f"Expected map order {expected_order} but got {new_map_params} instead." + + +def _make_test_sdfg(map_params: list[str]) -> dace.SDFG: + """Generate an SDFG for the test.""" + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + state: dace.SDFGState = sdfg.add_state("state", is_start_block=True) + dim = len(map_params) + for aname in ["a", "b"]: + sdfg.add_array(aname, shape=((4,) * dim), dtype=dace.float64, transient=False) + + state.add_mapped_tasklet( + "mapped_tasklet", + map_ranges=[(map_param, "0:4") for map_param in map_params], + inputs={"__in": dace.Memlet("a[" + ",".join(map_params) + "]")}, + code="__out = __in + 1", + outputs={"__out": dace.Memlet("b[" + ",".join(map_params) + "]")}, + external_edges=True, + ) + sdfg.validate() + + return sdfg + + +def test_map_order_1(): + sdfg = _make_test_sdfg(["EDim", "KDim", "VDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "VDim", "EDim"]) + + +def test_map_order_2(): + sdfg = _make_test_sdfg(["VDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "VDim"]) + + +def test_map_order_3(): + sdfg = _make_test_sdfg(["EDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "EDim"]) + + +def test_map_order_4(): + sdfg = _make_test_sdfg(["CDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], []) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py new file mode 100644 index 0000000000..7718977d53 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py @@ -0,0 +1,164 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_movable_tasklet( + outer_tasklet_code: str, +) -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.MapEntry +]: + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + + sdfg.add_scalar("outer_scalar", dtype=dace.float64, transient=True) + for name in "AB": + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + A, B, outer_scalar = (state.add_access(name) for name in ["A", "B", "outer_scalar"]) + + outer_tasklet = state.add_tasklet( + name="outer_tasklet", + inputs=set(), + outputs={"__out"}, + code=f"__out = {outer_tasklet_code}", + ) + state.add_edge(outer_tasklet, "__out", outer_scalar, None, dace.Memlet("outer_scalar[0]")) + + _, me, _ = state.add_mapped_tasklet( + "map", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("outer_scalar[0]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + external_edges=True, + input_nodes={outer_scalar, A}, + output_nodes={B}, + ) + sdfg.validate() + + return sdfg, state, outer_tasklet, outer_scalar, me + + +def test_move_tasklet_inside_trivial_memlet_tree(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="1.2", + ) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 1.2 + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_non_trivial_memlet_tree(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="1.2", + ) + # By expanding the maps, we the memlet tree is no longer trivial. + sdfg.apply_transformations_repeated(dace_dataflow.MapExpansion) + assert util.count_nodes(state, dace_nodes.MapEntry) == 2 + me = None + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 1.2 + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_two_inner_connector(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="32.2", + ) + mapped_tasklet = next( + iter(e.dst for e in state.out_edges(me) if isinstance(e.dst, dace_nodes.Tasklet)) + ) + + state.add_edge( + me, + f"OUT_{outer_scalar.data}", + mapped_tasklet, + "__in2", + dace.Memlet(f"{outer_scalar.data}[0]"), + ) + mapped_tasklet.add_in_connector("__in2") + mapped_tasklet.code.as_string = "__out = __in0 + __in1 + __in2" + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 2 * (32.2) + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_outer_scalar_used_outside(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="22.6", + ) + sdfg.add_array("C", shape=(1,), dtype=dace.float64, transient=False) + state.add_edge(outer_scalar, None, state.add_access("C"), None, dace.Memlet("C[0]")) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + C = np.array(np.random.rand(1), dtype=np.float64, copy=True) + ref_C = 22.6 + ref_B = A + ref_C + + csdfg = sdfg.compile() + csdfg(A=A, B=B, C=C) + assert np.allclose(B, ref_B) + assert np.allclose(C, ref_C) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py new file mode 100644 index 0000000000..2eba2ce51c --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py @@ -0,0 +1,514 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +from typing import Optional + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace.transformation import pass_pipeline as dace_ppl + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def apply_distributed_self_copy_elimination( + sdfg: dace.SDFG, +) -> Optional[dict[dace.SDFG, set[str]]]: + return gtx_transformations.gt_multi_state_global_self_copy_elimination(sdfg=sdfg, validate=True) + + +def _make_not_apply_because_of_write_to_g_sdfg() -> dace.SDFG: + """This SDFG is not eligible, because there is a write to `G`.""" + sdfg = dace.SDFG(util.unique_name("not_apply_because_of_write_to_g_sdfg")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the `T` array. + sdfg.add_array(name="t", shape=(5,), dtype=dace.float64, transient=True) + + # This is an unrelated array that is used as output. + sdfg.add_array( + name="b", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + state1 = sdfg.add_state(is_start_block=True) + state1.add_nedge(state1.add_access("a"), state1.add_access("t"), dace.Memlet("a[0:5] -> [0:5]")) + + state2 = sdfg.add_state_after(state1) + state2.add_mapped_tasklet( + "make_a_non_applicable", + map_ranges={"__i": "3:8"}, + inputs={}, + code="__out = 10.", + outputs={"__out": dace.Memlet("a[__i]")}, + external_edges=True, + ) + + state3 = sdfg.add_state_after(state2) + a3 = state3.add_access("a") + state3.add_nedge(state3.add_access("t"), a3, dace.Memlet("t[0:5] -> [0:5]")) + state3.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={a3}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_eligible_sdfg_1() -> dace.SDFG: + """This SDFG is very similar to the one generated by `_make_not_apply_because_of_write_to_g_sdfg()`. + + The main difference is that there is no mutating write to `a` and thus the + transformation applies. + """ + sdfg = dace.SDFG(util.unique_name("_make_eligible_sdfg_1")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the `T` array. + sdfg.add_array(name="t", shape=(5,), dtype=dace.float64, transient=True) + + # These are some unrelated arrays that is used as output. + sdfg.add_array(name="b", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_array(name="c", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + state1.add_nedge(state1.add_access("a"), state1.add_access("t"), dace.Memlet("a[0:5] -> [0:5]")) + + state2 = sdfg.add_state_after(state1) + state2.add_mapped_tasklet( + "comp1", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("b[__i]")}, + external_edges=True, + ) + + state3 = sdfg.add_state_after(state2) + a3 = state3.add_access("a") + state3.add_nedge(state3.add_access("t"), a3, dace.Memlet("t[0:5] -> [0:5]")) + state3.add_mapped_tasklet( + "comp2", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("c[__i]")}, + input_nodes={a3}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_multiple_temporaries_sdfg1() -> dace.SDFG: + """Generates an SDFG in which `G` is saved into different temporaries.""" + sdfg = dace.SDFG(util.unique_name("multiple_temporaries")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the first `T` array. + sdfg.add_array(name="t1", shape=(5,), dtype=dace.float64, transient=True) + # This is the second `T` array. + sdfg.add_array(name="t2", shape=(5,), dtype=dace.float64, transient=True) + + # This are some unrelated array that is used as output. + sdfg.add_array(name="b", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + a1 = state1.add_access("a") + state1.add_nedge(a1, state1.add_access("t1"), dace.Memlet("a[0:5] -> [0:5]")) + state1.add_nedge(a1, state1.add_access("t2"), dace.Memlet("a[5:10] -> [0:5]")) + + state2 = sdfg.add_state_after(state1) + a2 = state2.add_access("a") + + state2.add_nedge(state2.add_access("t1"), a2, dace.Memlet("t1[0:5] -> [0:5]")) + state2.add_nedge(state2.add_access("t2"), a2, dace.Memlet("t2[0:5] -> [5:10]")) + + state2.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={a2}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_multiple_temporaries_sdfg2() -> dace.SDFG: + """Generates an SDFG where there are multiple `T` used. + + The main difference between the SDFG produced by this function and the one + generated by `_make_multiple_temporaries_sdfg()` is that the temporaries + are used sequentially. + """ + sdfg = dace.SDFG(util.unique_name("multiple_temporaries_sequential")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the first `T` array. + sdfg.add_array(name="t1", shape=(5,), dtype=dace.float64, transient=True) + # This is the second `T` array. + sdfg.add_array(name="t2", shape=(5,), dtype=dace.float64, transient=True) + + # This are some unrelated array that is used as output. + sdfg.add_array(name="b", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_array(name="c", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + state1.add_nedge( + state1.add_access("a"), state1.add_access("t1"), dace.Memlet("a[0:5] -> [0:5]") + ) + + state2 = sdfg.add_state_after(state1) + a2 = state2.add_access("a") + + state2.add_nedge(state2.add_access("t1"), a2, dace.Memlet("t1[0:5] -> [0:5]")) + + state2.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={a2}, + external_edges=True, + ) + + # This essentially repeats the same thing as above again, but this time with `t2`. + state3 = sdfg.add_state_after(state2) + state3.add_nedge( + state3.add_access("a"), state3.add_access("t2"), dace.Memlet("a[5:10] -> [0:5]") + ) + + state4 = sdfg.add_state_after(state3) + a4 = state4.add_access("a") + state4.add_nedge(state4.add_access("t2"), a4, dace.Memlet("t2[0:5] -> [5:10]")) + state4.add_mapped_tasklet( + "comp2", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in - 1.", + outputs={"__out": dace.Memlet("c[__i]")}, + input_nodes={a4}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_multiple_temporaries_sdfg_keep_one_1() -> dace.SDFG: + """ + The generated SDFG is very similar to `_make_multiple_temporaries_sdfg1()` except + that `t1` can not be removed because it is used to generate `c`. + """ + sdfg = _make_multiple_temporaries_sdfg1() + + sdfg.add_array("c", shape=(5,), dtype=dace.float64, transient=False) + + state = sdfg.add_state_after( + next(iter(state for state in sdfg.states() if sdfg.out_degree(state) == 0)) + ) + state.add_mapped_tasklet( + "comp_that_needs_t1", + map_ranges={"__j": "0:5"}, + inputs={"__in": dace.Memlet("t1[__j]")}, + code="__out = __in + 4.0", + outputs={"__out": dace.Memlet("c[__j]")}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_multiple_temporaries_sdfg_keep_one_2() -> dace.SDFG: + """ + The generated SDFG is very similar to `_make_multiple_temporaries_sdfg2()` except + that `t1` can not be removed because it is used to generate `d`. + """ + sdfg = _make_multiple_temporaries_sdfg2() + + sdfg.add_array("d", shape=(5,), dtype=dace.float64, transient=False) + + state = sdfg.add_state_after( + next(iter(state for state in sdfg.states() if sdfg.out_degree(state) == 0)) + ) + state.add_mapped_tasklet( + "comp_that_needs_t1", + map_ranges={"__j": "0:5"}, + inputs={"__in": dace.Memlet("t1[__j]")}, + code="__out = __in + 4.0", + outputs={"__out": dace.Memlet("d[__j]")}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_non_eligible_because_of_pseudo_temporary() -> dace.SDFG: + """Generates an SDFG that that defines `T` from two souces, which is not handled. + + Note that in this particular case it would be possible, but we do not support it. + """ + sdfg = dace.SDFG(util.unique_name("multiple_temporaries_sequential")) + + # This is the `G` array. + sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) + # This is the `T` array. + sdfg.add_array(name="t", shape=(10,), dtype=dace.float64, transient=True) + + # This is the array that also writes to `T` and thus makes it inapplicable. + sdfg.add_array(name="pg", shape=(10,), dtype=dace.float64, transient=True) + + # This are some unrelated array that is used as output. + sdfg.add_array(name="b", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + t1 = state1.add_access("t") + state1.add_nedge(state1.add_access("a"), t1, dace.Memlet("a[0:5] -> [0:5]")) + state1.add_nedge(state1.add_access("pg"), t1, dace.Memlet("pg[0:5] -> [5:10]")) + + state2 = sdfg.add_state_after(state1) + a2 = state2.add_access("a") + state2.add_nedge(state2.add_access("t"), a2, dace.Memlet("t[0:5] -> [0:5]")) + state2.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={a2}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_wb_single_state_sdfg() -> dace.SDFG: + """Generates an SDFG with the pattern `(G) -> (T) -> (G)` which is not handled. + + This pattern is handled by the `SingleStateGlobalSelfCopyElimination` transformation. + """ + sdfg = dace.SDFG(util.unique_name("single_state_write_back_sdfg")) + + sdfg.add_array("g", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_array("t", shape=(10,), dtype=dace.float64, transient=True) + sdfg.add_array("b", shape=(10,), dtype=dace.float64, transient=False) + + state1 = sdfg.add_state(is_start_block=True) + t1 = state1.add_access("t") + state1.add_nedge(state1.add_access("g"), t1, dace.Memlet("g[0:10] -> [0:10]")) + g1 = state1.add_access("g") + state1.add_nedge(t1, g1, dace.Memlet("t[0:10] -> [0:10]")) + + # return sdfg + + state1.add_mapped_tasklet( + "comp", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("g[__i]")}, + code="__out = __in + 1.0", + outputs={"__out": dace.Memlet("b[__i]")}, + input_nodes={g1}, + external_edges=True, + ) + + sdfg.validate() + return sdfg + + +def _make_non_eligible_sdfg_with_branches(): + """Creates an SDFG with two different definitions of `T`.""" + sdfg = dace.SDFG(util.unique_name("non_eligible_sdfg_with_branches_sdfg")) + + # This is the `G` array, it is also used as output. + sdfg.add_array("a", shape=(10,), dtype=dace.float64, transient=False) + # This is the (possible) `T` array. + sdfg.add_array("t", shape=(10,), dtype=dace.float64, transient=True) + + # This is an additional array that serves as input. In one case it defines `t`. + sdfg.add_array("b", shape=(10,), dtype=dace.float64, transient=False) + # This is the condition on which we switch. + sdfg.add_scalar("cond", dtype=dace.bool, transient=False) + + # This is the init state. + state1 = sdfg.add_state(is_start_block=True) + + # This is the state where `T` is not defined in terms of `G`. + stateT = sdfg.add_state(is_start_block=False) + sdfg.add_edge(state1, stateT, dace.InterstateEdge(condition="cond == True")) + stateT.add_nedge( + stateT.add_access("b"), stateT.add_access("t"), dace.Memlet("b[0:10] -> [0:10]") + ) + + # This is the state where `T` is defined in terms of `G`. + stateF = sdfg.add_state(is_start_block=False) + sdfg.add_edge(state1, stateF, dace.InterstateEdge(condition="cond != True")) + stateF.add_nedge( + stateF.add_access("a"), stateF.add_access("t"), dace.Memlet("a[0:10] -> [0:10]") + ) + + # Now the write back state, where `T` is written back into `G`. + stateWB = sdfg.add_state(is_start_block=False) + stateWB.add_nedge( + stateWB.add_access("t"), stateWB.add_access("a"), dace.Memlet("t[0:10] -> [0:10]") + ) + + sdfg.add_edge(stateF, stateWB, dace.InterstateEdge()) + sdfg.add_edge(stateT, stateWB, dace.InterstateEdge()) + + sdfg.validate() + return sdfg + + +def test_not_apply_because_of_write_to_g(): + sdfg = _make_not_apply_because_of_write_to_g_sdfg() + old_hash = sdfg.hash_sdfg() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + nb_access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode) + + assert res is None + assert nb_access_nodes_initially == nb_access_nodes_after + assert old_hash == sdfg.hash_sdfg() + + +def test_eligible_sdfg_1(): + sdfg = _make_eligible_sdfg_1() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t"} + assert nb_access_nodes_initially == len(access_nodes_after) + 3 + assert not any(an.data == "t" for an in access_nodes_after) + + +def test_multiple_temporaries(): + sdfg = _make_multiple_temporaries_sdfg1() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t1", "t2"} + assert not any(an.data.startswith("t") for an in access_nodes_after) + assert nb_access_nodes_initially == len(access_nodes_after) + 5 + + +def test_multiple_temporaries_2(): + sdfg = _make_multiple_temporaries_sdfg2() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t1", "t2"} + assert not any(an.data.startswith("t") for an in access_nodes_after) + assert nb_access_nodes_initially == len(access_nodes_after) + 6 + + +def test_multiple_temporaries_keep_one_1(): + sdfg = _make_multiple_temporaries_sdfg_keep_one_1() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + # NOTE: The transformation will not only remove the `(t2) -> (a)` write in the + # second block, but also the `(t1) -> (a)` write, this is because it was + # concluded that this was a noops write. This might be a bit unintuitive + # considering that `t1` is used in the third state. However, this is why the + # `(a) -> (t1)` write in the first state is maintained. + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + start_block_nodes = util.count_nodes(sdfg.start_block, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t2"} + assert not any(an.data == "t2" for an in access_nodes_after) + assert sum(an.data == "t1" for an in access_nodes_after) == 2 + assert nb_access_nodes_initially == len(access_nodes_after) + 3 + assert len(start_block_nodes) == 2 + assert {nb.data for nb in start_block_nodes} == {"a", "t1"} + + +def test_multiple_temporaries_keep_one_2(): + sdfg = _make_multiple_temporaries_sdfg_keep_one_2() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode, return_nodes=True) + + assert res == {"a", "t2"} + assert not any(an.data == "t2" for an in access_nodes_after) + assert sum(an.data == "t1" for an in access_nodes_after) == 2 + assert nb_access_nodes_initially == len(access_nodes_after) + 4 + + +def test_pseudo_temporaries(): + sdfg = _make_non_eligible_because_of_pseudo_temporary() + old_hash = sdfg.hash_sdfg() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + nb_access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode) + + assert res is None + assert nb_access_nodes_initially == nb_access_nodes_after + assert old_hash == sdfg.hash_sdfg() + + +def test_single_state(): + sdfg = _make_wb_single_state_sdfg() + old_hash = sdfg.hash_sdfg() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + nb_access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode) + + assert res is None + assert nb_access_nodes_initially == nb_access_nodes_after + assert old_hash == sdfg.hash_sdfg() + + +def test_branches(): + sdfg = _make_non_eligible_sdfg_with_branches() + old_hash = sdfg.hash_sdfg() + nb_access_nodes_initially = util.count_nodes(sdfg, dace_nodes.AccessNode) + + res = apply_distributed_self_copy_elimination(sdfg) + nb_access_nodes_after = util.count_nodes(sdfg, dace_nodes.AccessNode) + + assert res is None + assert nb_access_nodes_initially == nb_access_nodes_after + assert old_hash == sdfg.hash_sdfg() diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py index 96584b8273..3bd0ed2dc3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py @@ -12,11 +12,11 @@ dace = pytest.importorskip("dace") from dace.sdfg import nodes as dace_nodes -from gt4py.next.program_processors.runners.dace_fieldview import ( +from gt4py.next.program_processors.runners.dace import ( transformations as gtx_transformations, ) -from . import pytestmark + from . import util @@ -68,7 +68,7 @@ def test_serial_map_promotion(): external_edges=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 1 assert len(map_entry_1d.map.range) == 1 assert len(map_entry_2d.map.params) == 2 @@ -83,7 +83,7 @@ def test_serial_map_promotion(): validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 2 assert len(map_entry_1d.map.range) == 2 assert len(map_entry_2d.map.params) == 2 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py new file mode 100644 index 0000000000..2264102182 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py @@ -0,0 +1,146 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + + +def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + """Generates an SDFG that contains the self copying pattern.""" + sdfg = dace.SDFG(util.unique_name("self_copy_sdfg")) + state = sdfg.add_state(is_start_block=True) + + for name in "GT": + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["G"].transient = False + g_read, tmp_node, g_write = (state.add_access(name) for name in "GTG") + + state.add_nedge(g_read, tmp_node, dace.Memlet("G[0:10, 0:10]")) + state.add_nedge(tmp_node, g_write, dace.Memlet("G[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state + + +def test_global_self_copy_elimination_only_pattern(): + """Contains only the pattern -> Total elimination.""" + sdfg, state = _make_self_copy_sdfg() + assert sdfg.number_of_nodes() == 1 + assert state.number_of_nodes() == 3 + assert util.count_nodes(state, dace_nodes.AccessNode) == 3 + assert state.number_of_edges() == 2 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.SingleStateGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 1 + assert ( + state.number_of_nodes() == 0 + ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + + +def test_global_self_copy_elimination_g_downstream(): + """`G` is read downstream. + + Since we ignore reads to `G` downstream, this will not influence the + transformation. + """ + sdfg, state1 = _make_self_copy_sdfg() + + # Add a read to `G` downstream. + state2 = sdfg.add_state_after(state1) + sdfg.add_array( + "output", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state2.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("G[__i0, __i1]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("output[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + assert state2.number_of_nodes() == 5 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.SingleStateGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 2 + assert ( + state1.number_of_nodes() == 0 + ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + assert state2.number_of_nodes() == 5 + assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 + assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 + + +def test_global_self_copy_elimination_tmp_downstream(): + """`T` is read downstream. + + Because `T` is read downstream, the read to `G` will be retained, but the write + will be removed. + """ + sdfg, state1 = _make_self_copy_sdfg() + + # Add a read to `G` downstream. + state2 = sdfg.add_state_after(state1) + sdfg.add_array( + "output", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state2.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("T[__i0, __i1]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("output[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + assert state2.number_of_nodes() == 5 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.SingleStateGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 2 + assert state1.number_of_nodes() == 2 + assert util.count_nodes(state1, dace_nodes.AccessNode) == 2 + assert all(state1.degree(node) == 1 for node in state1.nodes()) + assert next(iter(state1.source_nodes())).data == "G" + assert next(iter(state1.sink_nodes())).data == "T" + + assert state2.number_of_nodes() == 5 + assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 + assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py new file mode 100644 index 0000000000..c89fe566c0 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -0,0 +1,637 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace import symbolic as dace_symbolic +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_strides_propagation_level3_sdfg() -> dace.SDFG: + """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" + sdfg = dace.SDFG(util.unique_name("level3")) + state = sdfg.add_state(is_start_block=True) + names = ["a3", "c3"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL3", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a3[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("c3[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + """Generates the level 2 SDFG (nested) SDFG for `test_strides_propagation()`. + + The function returns the level 2 SDFG and the NestedSDFG node that contains + the level 3 SDFG. + """ + sdfg = dace.SDFG(util.unique_name("level2")) + state = sdfg.add_state(is_start_block=True) + names = ["a2", "a2_alias", "b2", "c2"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL2_1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "compL2_2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("c2[__i0]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet("a2_alias[__i0]")}, + external_edges=True, + ) + + # This is the nested SDFG we have here. + sdfg_level3 = _make_strides_propagation_level3_sdfg() + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level3, + parent=sdfg, + inputs={"a3"}, + outputs={"c3"}, + symbol_mapping={s3: s3 for s3 in sdfg_level3.free_symbols}, + ) + + state.add_edge(state.add_access("a2"), None, nsdfg, "a3", dace.Memlet("a2[0:10]")) + state.add_edge(nsdfg, "c3", state.add_access("c2"), None, dace.Memlet("c2[0:10]")) + sdfg.validate() + + return sdfg, nsdfg + + +def _make_strides_propagation_level1_sdfg() -> ( + tuple[dace.SDFG, dace_nodes.NestedSDFG, dace_nodes.NestedSDFG] +): + """Generates the level 1 SDFG (top) SDFG for `test_strides_propagation()`. + + Note that the SDFG is valid, but will be indeterminate. The only point of + this SDFG is to have a lot of different situations that have to be handled + for renaming. + + Returns: + A tuple of length three, with the following members: + - The top level SDFG. + - The NestedSDFG node that contains the level 2 SDFG (member of the top level SDFG). + - The NestedSDFG node that contains the lebel 3 SDFG (member of the level 2 SDFG). + """ + + sdfg = dace.SDFG(util.unique_name("level1")) + state = sdfg.add_state(is_start_block=True) + names = ["a1", "b1", "c1"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + sdfg_level2, nsdfg_level3 = _make_strides_propagation_level2_sdfg() + + nsdfg_level2: dace_nodes.NestedSDFG = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg, + inputs={"a2", "c2"}, + outputs={"a2_alias", "b2", "c2"}, + symbol_mapping={s: s for s in sdfg_level2.free_symbols}, + ) + + for inner_name in nsdfg_level2.in_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + state.add_access(outer_name), + None, + nsdfg_level2, + inner_name, + dace.Memlet(f"{outer_name}[0:10]"), + ) + for inner_name in nsdfg_level2.out_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + nsdfg_level2, + inner_name, + state.add_access(outer_name), + None, + dace.Memlet(f"{outer_name}[0:10]"), + ) + + sdfg.validate() + + return sdfg, nsdfg_level2, nsdfg_level3 + + +def test_strides_propagation_use_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=False) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=False) + sdfg_level1.validate() + + # Because `ignore_symbol_mapping=False` the strides of the data descriptor should + # not have changed. But the `symbol_mapping` has been updated for `a` and `b`. + # However, the symbols will only point one level above. + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + + if aname.startswith("c"): + target_symbol = f"{aname}_stride" + else: + target_symbol = f"{aname[0]}{level - 1}_stride" + + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=False) + sdfg_level1.validate() + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + target_symbol = f"{aname[0]}{level-1}_stride" + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + +def test_strides_propagation_ignore_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # After the propagation `a` and `b` should use the same stride (the one that + # it has on level 1, but `c` should still be level depending. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + original_stride = f"{aname}_stride" + if aname.startswith("c"): + exp_stride = f"{aname}_stride" + else: + exp_stride = f"{aname[0]}1_stride" + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) + sdfg_level1.validate() + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname[0]}1_stride" + original_stride = f"{aname}_stride" + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + # The symbol mapping must should not be updated. + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride + + +def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + array_names = ["a2", "b2"] + for name in array_names: + stride_sym = dace.symbol(f"{name}_stride", dtype=dace.uint64) + sdfg.add_symbol(stride_sym.name, stride_sym.dtype) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + array_names = ["a1", "b1"] + for name in array_names: + stride_sym1 = dace.symbol(f"{name}_1stride", dtype=dace.uint64) + stride_sym2 = dace.symbol(f"{name}_2stride", dtype=dace.int64) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + sdfg_level1.add_symbol(stride_sym2.name, stride_sym2.dtype) + stride_sym = stride_sym1 * stride_sym2 + sdfg_level1.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_dependent_symbol_nsdfg() + + for sym, sym_dtype in sdfg_level2.symbols.items(): + sdfg_level1.add_symbol(sym, sym_dtype) + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_dependent_symbol(): + sdfg_level1, nsdfg_level2 = _make_strides_propagation_dependent_symbol_sdfg() + sym1_dtype = dace.uint64 + sym2_dtype = dace.int64 + + # Ensure that the special symbols are not already present inside the nested SDFG. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in {fs.name for fs in adesc.strides[0].free_symbols} + assert sym not in nsdfg_level2.symbol_mapping + assert sym not in nsdfg_level2.sdfg.symbols + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + + # Now propagate `a1` and `b1`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # Now we check if the update has worked. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + adesc2 = nsdfg_level2.sdfg.arrays[aname.replace("1", "2")] + assert adesc2.strides == adesc.strides + + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in nsdfg_level2.symbol_mapping + assert nsdfg_level2.symbol_mapping[sym].name == sym + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + assert sym in nsdfg_level2.sdfg.symbols + assert nsdfg_level2.sdfg.symbols[sym] == dtype + + +def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + # NOTE: Both arrays have the same symbols used for strides. + array_names = ["a2", "b2"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=(stride_sym0, stride_sym1), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={ + "__i0": "0:10", + "__i1": "0:10", + }, + inputs={"__in1": dace.Memlet("a2[__i0, __i1]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + # NOTE: Both arrays use the same symbols as strides. + # Furthermore, they are the same as in the nested SDFG, i.e. they are shared. + array_names = ["a1", "b1"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg_level1.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=( + stride_sym0, + stride_sym1, + ), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_shared_symbols_nsdfg() + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10, 0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10, 0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_shared_symbols_sdfg(): + """Tests what happens if symbols are (unintentionally) shred between descriptor. + + This test looks rather artificial, but it is actually quite likely. Because + transients will most likely have the same shape and if the strides are not + set explicitly, which is the case, the strides will also be related to their + shape. This test explores the situation, where we can, for whatever reason, + only propagate the strides of one such data descriptor. + + Note: + If `ignore_symbol_mapping` is `False` then this test will fail. + This is because the `symbol_mapping` of the NestedSDFG will act on the + whole SDFG. Thus it will not only change the strides of `b` but as an + unintended side effect also the strides of `a`. + """ + + def ref(a1, b1): + for i in range(10): + for j in range(10): + b1[i, j] = a1[i, j] + 10.0 + + sdfg_level1, nsdfg_level2 = _make_strides_propagation_shared_symbols_sdfg() + + res_args = { + "a1": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + "b1": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + # Now we change the strides of `b1`, and then we propagate the new strides + # into the nested SDFG. We want to keep (for whatever reasons) strides of `a1`. + stride_b1_sym0 = dace.symbol(f"__b1_stride_0", dtype=dace.uint64) + stride_b1_sym1 = dace.symbol(f"__b1_stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_b1_sym0.name, stride_b1_sym0.dtype) + sdfg_level1.add_symbol(stride_b1_sym1.name, stride_b1_sym1.dtype) + + desc_b1 = sdfg_level1.arrays["b1"] + desc_b1.set_shape((10, 10), (stride_b1_sym0, stride_b1_sym1)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of( + sdfg=sdfg_level1, + data_name="b1", + ) + + # Now we have to prepare the call arguments, i.e. adding the strides + itemsize = res_args["b1"].itemsize + res_args.update( + { + "__b1_stride_0": res_args["b1"].strides[0] // itemsize, + "__b1_stride_1": res_args["b1"].strides[1] // itemsize, + "__stride_0": res_args["a1"].strides[0] // itemsize, + "__stride_1": res_args["a1"].strides[1] // itemsize, + } + ) + ref(**ref_args) + sdfg_level1(**res_args) + assert np.allclose(ref_args["b1"], res_args["b1"]) + + +def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_stride_1_nsdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64) + b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64) + stride_syms = {"a": a_stride_sym, "b": b_stride_sym} + + for name in ["a", "b"]: + sdfg_level1.add_array( + name, + shape=(10, 1), + strides=(stride_syms[name], 1), + dtype=dace.float64, + transient=False, + ) + + state.add_mapped_tasklet( + "computation", + map_ranges={"__i": "0:10"}, + inputs={"__in": dace.Memlet("a[__i, 0]")}, + code="__out = __in + 10", + outputs={"__out": dace.Memlet("b[__i, 0]")}, + external_edges=True, + ) + sdfg_level1.validate() + return sdfg_level1 + + +def _make_strides_propagation_stride_1_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg = dace.SDFG(util.unique_name("strides_propagation_stride_1_sdfg")) + state = sdfg.add_state(is_start_block=True) + + a_stride_sym = dace.symbol("a_stride", dtype=dace.uint64) + b_stride_sym = dace.symbol("b_stride", dtype=dace.uint64) + stride_syms = {"a": a_stride_sym, "b": b_stride_sym} + + for name in ["a", "b"]: + sdfg.add_array( + name, + shape=(10, 10), + strides=(stride_syms[name], 1), + dtype=dace.float64, + transient=False, + ) + + # Now get the nested SDFG. + sdfg_level1 = _make_strides_propagation_stride_1_nsdfg() + + nsdfg = state.add_nested_sdfg( + parent=sdfg, + sdfg=sdfg_level1, + inputs={"a"}, + outputs={"b"}, + symbol_mapping=None, + ) + + state.add_edge(state.add_access("a"), None, nsdfg, "a", dace.Memlet("a[0:10, 3]")) + state.add_edge(nsdfg, "b", state.add_access("b"), None, dace.Memlet("b[0:10, 2]")) + sdfg.validate() + return sdfg, nsdfg + + +def test_strides_propagation_stride_1(): + def ref(a, b): + for i in range(10): + b[i, 2] = a[i, 3] + 10.0 + + sdfg, nsdfg = _make_strides_propagation_stride_1_sdfg() + + outer_desc_a = sdfg.arrays["a"] + inner_desc_a = nsdfg.sdfg.arrays["a"] + assert outer_desc_a.strides == inner_desc_a.strides + + # Now switch the strides of `a` on the top level. + # Essentially going from `C` to FORTRAN order. + stride_outer_a_0, stride_outer_a_1 = outer_desc_a.strides + outer_desc_a.set_shape(outer_desc_a.shape, (stride_outer_a_1, stride_outer_a_0)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of(sdfg=sdfg, data_name="a") + + # Because of the propagation it must now been changed to `(1, 1)` on the inside. + assert inner_desc_a.strides == (1, 1) + + res_args = { + "a": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + "b": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + sdfg(**res_args, a_stride=10, b_stride=10) + ref(**ref_args) + assert np.allclose(ref_args["b"], res_args["b"]) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py index ac88f4fef8..b82cecee98 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py @@ -14,7 +14,7 @@ @overload -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: Literal[False], @@ -22,14 +22,14 @@ def _count_nodes( @overload -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: Literal[True], ) -> list[dace_nodes.Node]: ... -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: bool = False, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py index ab86dda16b..3d82dd8ee5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/test_gtfn.py @@ -50,20 +50,6 @@ def test_backend_factory_trait_cached(): assert cached_version.name == "run_gtfn_cpu_cached" -def test_backend_factory_trait_temporaries(): - inline_version = gtfn.GTFNBackendFactory(cached=False) - temps_version = gtfn.GTFNBackendFactory(cached=False, use_temporaries=True) - - assert inline_version.executor.translation.lift_mode is None - assert temps_version.executor.translation.lift_mode is transforms.LiftMode.USE_TEMPORARIES - - assert inline_version.executor.translation.temporary_extraction_heuristics is None - assert ( - temps_version.executor.translation.temporary_extraction_heuristics - is global_tmps.SimpleTemporaryExtractionHeuristics - ) - - def test_backend_factory_build_cache_config(monkeypatch): monkeypatch.setattr(config, "BUILD_CACHE_LIFETIME", config.BuildCacheLifetime.SESSION) session_version = gtfn.GTFNBackendFactory() diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 8f46fc7ce1..09ca44aaac 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -10,7 +10,9 @@ from typing import Optional, Pattern import pytest +import re +from gt4py import next as gtx import gt4py.next.common as common from gt4py.next.common import ( Dimension, @@ -25,7 +27,11 @@ unit_range, ) - +C2E = Dimension("C2E", kind=DimensionKind.LOCAL) +V2E = Dimension("V2E", kind=DimensionKind.LOCAL) +E2V = Dimension("E2V", kind=DimensionKind.LOCAL) +E2C = Dimension("E2C", kind=DimensionKind.LOCAL) +E2C2V = Dimension("E2C2V", kind=DimensionKind.LOCAL) ECDim = Dimension("ECDim") IDim = Dimension("IDim") JDim = Dimension("JDim") @@ -324,16 +330,6 @@ def test_domain_intersection_different_dimensions(a_domain, second_domain, expec assert result_domain == expected -def test_domain_intersection_reversed_dimensions(a_domain): - domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) - - with pytest.raises( - ValueError, - match="Dimensions can not be promoted. The following dimensions appear in contradicting order: IDim, JDim.", - ): - a_domain & domain2 - - @pytest.mark.parametrize( "index, expected", [ @@ -571,27 +567,29 @@ def dimension_promotion_cases() -> ( ): raw_list = [ # list of list of dimensions, expected result, expected error message - ([["I", "J"], ["I"]], ["I", "J"], None), - ([["I", "J"], ["J"]], ["I", "J"], None), - ([["I", "J"], ["J", "K"]], ["I", "J", "K"], None), + ([[IDim, JDim], [IDim]], [IDim, JDim], None), + ([[JDim], [IDim, JDim]], [IDim, JDim], None), + ([[JDim, KDim], [IDim, JDim]], [IDim, JDim, KDim], None), ( - [["I", "J"], ["J", "I"]], + [[IDim, JDim], [JDim, IDim]], None, - r"The following dimensions appear in contradicting order: I, J.", + "Dimensions 'JDim[horizontal], IDim[horizontal]' are not ordered correctly, expected 'IDim[horizontal], JDim[horizontal]'.", ), + ([[JDim, KDim], [IDim, KDim]], [IDim, JDim, KDim], None), ( - [["I", "K"], ["J", "K"]], + [[KDim, JDim], [IDim, KDim]], None, - r"Could not determine order of the following dimensions: I, J", + "Dimensions 'KDim[vertical], JDim[horizontal]' are not ordered correctly, expected 'JDim[horizontal], KDim[vertical]'.", ), + ( + [[JDim, V2E], [IDim, KDim, E2C2V]], + None, + "There are more than one dimension with DimensionKind 'LOCAL'.", + ), + ([[JDim, V2E], [IDim, KDim]], [IDim, JDim, KDim, V2E], None), ] - # transform dimension names into Dimension objects return [ - ( - [[Dimension(el) for el in arg] for arg in args], - [Dimension(el) for el in result] if result else result, - msg, - ) + ([[el for el in arg] for arg in args], [el for el in result] if result else result, msg) for args, result, msg in raw_list ] @@ -608,7 +606,7 @@ def test_dimension_promotion( with pytest.raises(Exception) as exc_info: promote_dims(*dim_list) - assert exc_info.match(expected_error_msg) + assert exc_info.match(re.escape(expected_error_msg)) class TestCartesianConnectivity: diff --git a/tests/next_tests/unit_tests/test_constructors.py b/tests/next_tests/unit_tests/test_constructors.py index 6e9dfa3d64..0998ab8eab 100644 --- a/tests/next_tests/unit_tests/test_constructors.py +++ b/tests/next_tests/unit_tests/test_constructors.py @@ -11,10 +11,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs -from gt4py.next import allocators as next_allocators, common, float32 -from gt4py.next.program_processors.runners import roundtrip - -from next_tests.integration_tests import cases +from gt4py.next import allocators as next_allocators, common I = gtx.Dimension("I") @@ -154,3 +151,12 @@ def test_field_wrong_origin(): @pytest.mark.xfail(reason="aligned_index not supported yet") def test_aligned_index(): gtx.as_field([I], np.random.rand(sizes[I]).astype(gtx.float32), aligned_index=[I, 0]) + + +@pytest.mark.parametrize( + "data, skip_value", + [([0, 1, 2], None), ([0, 1, common._DEFAULT_SKIP_VALUE], common._DEFAULT_SKIP_VALUE)], +) +def test_as_connectivity(nd_array_implementation, data, skip_value): + testee = gtx.as_connectivity([I], J, nd_array_implementation.array(data)) + assert testee.skip_value is skip_value diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py b/tests/next_tests/unit_tests/test_type_system.py similarity index 50% rename from tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py rename to tests/next_tests/unit_tests/test_type_system.py index 5352724827..69ff54b711 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_type_deduction.py +++ b/tests/next_tests/unit_tests/test_type_system.py @@ -11,28 +11,13 @@ import pytest -import gt4py.next.ffront.type_specifications from gt4py.next import ( Dimension, DimensionKind, - Field, - FieldOffset, - astype, - broadcast, - common, - errors, - float32, - float64, - int32, - int64, - neighbor_sum, - where, ) -from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.experimental import as_offset -from gt4py.next.ffront.func_to_foast import FieldOperatorParser from gt4py.next.type_system import type_info, type_specifications as ts - +from gt4py.next.ffront import type_specifications as ts_ffront +from gt4py.next.iterator.type_system import type_specifications as ts_it TDim = Dimension("TDim") # Meaningless dimension, used for tests. @@ -107,7 +92,7 @@ def callable_type_info_cases(): unary_tuple_arg_func_type = ts.FunctionType( pos_only_args=[tuple_type], pos_or_kw_args={}, kw_only_args={}, returns=ts.VoidType() ) - fieldop_type = gt4py.next.ffront.type_specifications.FieldOperatorType( + fieldop_type = ts_ffront.FieldOperatorType( definition=ts.FunctionType( pos_only_args=[field_type, float_type], pos_or_kw_args={}, @@ -115,7 +100,7 @@ def callable_type_info_cases(): returns=field_type, ) ) - scanop_type = gt4py.next.ffront.type_specifications.ScanOperatorType( + scanop_type = ts_ffront.ScanOperatorType( axis=KDim, definition=ts.FunctionType( pos_only_args=[], @@ -124,7 +109,7 @@ def callable_type_info_cases(): returns=float_type, ), ) - tuple_scanop_type = gt4py.next.ffront.type_specifications.ScanOperatorType( + tuple_scanop_type = ts_ffront.ScanOperatorType( axis=KDim, definition=ts.FunctionType( pos_only_args=[], @@ -320,10 +305,7 @@ def callable_type_info_cases(): ts.FieldType(dims=[KDim], dtype=int_type), ], {}, - [ - r"Dimensions can not be promoted. Could not determine order of the " - r"following dimensions: J, K." - ], + [], ts.FieldType(dims=[IDim, JDim, KDim], dtype=float_type), ), ( @@ -367,6 +349,22 @@ def callable_type_info_cases(): ], ts.FieldType(dims=[IDim, JDim, KDim], dtype=float_type), ), + ( + ts.FunctionType( + pos_only_args=[ + ts_it.IteratorType( + position_dims="unknown", defined_dims=[], element_type=float_type + ), + ], + pos_or_kw_args={}, + kw_only_args={}, + returns=ts.VoidType(), + ), + [ts_it.IteratorType(position_dims=[IDim], defined_dims=[], element_type=float_type)], + {}, + [], + ts.VoidType(), + ), ] @@ -408,381 +406,3 @@ def test_return_type( accepts_args = type_info.accepts_args(func_type, with_args=args, with_kwargs=kwargs) if accepts_args: assert type_info.return_type(func_type, with_args=args, with_kwargs=kwargs) == return_type - - -def test_unpack_assign(): - def unpack_explicit_tuple( - a: Field[[TDim], float64], b: Field[[TDim], float64] - ) -> tuple[Field[[TDim], float64], Field[[TDim], float64]]: - tmp_a, tmp_b = (a, b) - return tmp_a, tmp_b - - parsed = FieldOperatorParser.apply_to_function(unpack_explicit_tuple) - - assert parsed.body.annex.symtable[ssa.unique_name("tmp_a", 0)].type == ts.FieldType( - dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) - ) - assert parsed.body.annex.symtable[ssa.unique_name("tmp_b", 0)].type == ts.FieldType( - dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None) - ) - - -def test_assign_tuple(): - def temp_tuple(a: Field[[TDim], float64], b: Field[[TDim], int64]): - tmp = a, b - return tmp - - parsed = FieldOperatorParser.apply_to_function(temp_tuple) - - assert parsed.body.annex.symtable[ssa.unique_name("tmp", 0)].type == ts.TupleType( - types=[ - ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64, shape=None)), - ts.FieldType(dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64, shape=None)), - ] - ) - - -def test_adding_bool(): - """Expect an error when using arithmetic on bools.""" - - def add_bools(a: Field[[TDim], bool], b: Field[[TDim], bool]): - return a + b - - with pytest.raises( - errors.DSLError, match=(r"Type 'Field\[\[TDim\], bool\]' can not be used in operator '\+'.") - ): - _ = FieldOperatorParser.apply_to_function(add_bools) - - -def test_binop_nonmatching_dims(): - """Binary operations can only work when both fields have the same dimensions.""" - X = Dimension("X") - Y = Dimension("Y") - - def nonmatching(a: Field[[X], float64], b: Field[[Y], float64]): - return a + b - - with pytest.raises( - errors.DSLError, - match=( - r"Could not promote 'Field\[\[X], float64\]' and 'Field\[\[Y\], float64\]' to common type in call to +." - ), - ): - _ = FieldOperatorParser.apply_to_function(nonmatching) - - -def test_bitopping_float(): - def float_bitop(a: Field[[TDim], float], b: Field[[TDim], float]): - return a & b - - with pytest.raises( - errors.DSLError, - match=(r"Type 'Field\[\[TDim\], float64\]' can not be used in operator '\&'."), - ): - _ = FieldOperatorParser.apply_to_function(float_bitop) - - -def test_signing_bool(): - def sign_bool(a: Field[[TDim], bool]): - return -a - - with pytest.raises( - errors.DSLError, - match=r"Incompatible type for unary operator '\-': 'Field\[\[TDim\], bool\]'.", - ): - _ = FieldOperatorParser.apply_to_function(sign_bool) - - -def test_notting_int(): - def not_int(a: Field[[TDim], int64]): - return not a - - with pytest.raises( - errors.DSLError, - match=r"Incompatible type for unary operator 'not': 'Field\[\[TDim\], int64\]'.", - ): - _ = FieldOperatorParser.apply_to_function(not_int) - - -@pytest.fixture -def premap_setup(): - X = Dimension("X") - Y = Dimension("Y") - Y2XDim = Dimension("Y2X", kind=DimensionKind.LOCAL) - Y2X = FieldOffset("Y2X", source=X, target=(Y, Y2XDim)) - return X, Y, Y2XDim, Y2X - - -def test_premap(premap_setup): - X, Y, Y2XDim, Y2X = premap_setup - - def premap_fo(bar: Field[[X], int64]) -> Field[[Y], int64]: - return bar(Y2X[0]) - - parsed = FieldOperatorParser.apply_to_function(premap_fo) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) - ) - - -def test_premap_nbfield(premap_setup): - X, Y, Y2XDim, Y2X = premap_setup - - def premap_fo(bar: Field[[X], int64]) -> Field[[Y, Y2XDim], int64]: - return bar(Y2X) - - parsed = FieldOperatorParser.apply_to_function(premap_fo) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[Y, Y2XDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64) - ) - - -def test_premap_reduce(premap_setup): - X, Y, Y2XDim, Y2X = premap_setup - - def premap_fo(bar: Field[[X], int32]) -> Field[[Y], int32]: - return 2 * neighbor_sum(bar(Y2X), axis=Y2XDim) - - parsed = FieldOperatorParser.apply_to_function(premap_fo) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) - ) - - -def test_premap_reduce_sparse(premap_setup): - X, Y, Y2XDim, Y2X = premap_setup - - def premap_fo(bar: Field[[Y, Y2XDim], int32]) -> Field[[Y], int32]: - return 5 * neighbor_sum(bar, axis=Y2XDim) - - parsed = FieldOperatorParser.apply_to_function(premap_fo) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[Y], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32) - ) - - -def test_mismatched_literals(): - def mismatched_lit() -> Field[[TDim], "float32"]: - return float32("1.0") + float64("1.0") - - with pytest.raises( - errors.DSLError, - match=(r"Could not promote 'float32' and 'float64' to common type in call to +."), - ): - _ = FieldOperatorParser.apply_to_function(mismatched_lit) - - -def test_broadcast_multi_dim(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - CDim = Dimension("CDim") - - def simple_broadcast(a: Field[[ADim], float64]): - return broadcast(a, (ADim, BDim, CDim)) - - parsed = FieldOperatorParser.apply_to_function(simple_broadcast) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[ADim, BDim, CDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - ) - - -def test_broadcast_disjoint(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - CDim = Dimension("CDim") - - def disjoint_broadcast(a: Field[[ADim], float64]): - return broadcast(a, (BDim, CDim)) - - with pytest.raises(errors.DSLError, match=r"expected broadcast dimension\(s\) \'.*\' missing"): - _ = FieldOperatorParser.apply_to_function(disjoint_broadcast) - - -def test_broadcast_badtype(): - ADim = Dimension("ADim") - BDim = "BDim" - CDim = Dimension("CDim") - - def badtype_broadcast(a: Field[[ADim], float64]): - return broadcast(a, (BDim, CDim)) - - with pytest.raises( - errors.DSLError, match=r"expected all broadcast dimensions to be of type 'Dimension'." - ): - _ = FieldOperatorParser.apply_to_function(badtype_broadcast) - - -def test_where_dim(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - - def simple_where(a: Field[[ADim], bool], b: Field[[ADim, BDim], float64]): - return where(a, b, 9.0) - - parsed = FieldOperatorParser.apply_to_function(simple_where) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[ADim, BDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - ) - - -def test_where_broadcast_dim(): - ADim = Dimension("ADim") - - def simple_where(a: Field[[ADim], bool]): - return where(a, 5.0, 9.0) - - parsed = FieldOperatorParser.apply_to_function(simple_where) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - ) - - -def test_where_tuple_dim(): - ADim = Dimension("ADim") - - def tuple_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): - return where(a, ((5.0, 9.0), (b, 6.0)), ((8.0, b), (5.0, 9.0))) - - parsed = FieldOperatorParser.apply_to_function(tuple_where) - - assert parsed.body.stmts[0].value.type == ts.TupleType( - types=[ - ts.TupleType( - types=[ - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ] - ), - ts.TupleType( - types=[ - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ] - ), - ] - ) - - -def test_where_bad_dim(): - ADim = Dimension("ADim") - - def bad_dim_where(a: Field[[ADim], bool], b: Field[[ADim], float64]): - return where(a, ((5.0, 9.0), (b, 6.0)), b) - - with pytest.raises(errors.DSLError, match=r"Return arguments need to be of same type"): - _ = FieldOperatorParser.apply_to_function(bad_dim_where) - - -def test_where_mixed_dims(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - - def tuple_where_mix_dims( - a: Field[[ADim], bool], b: Field[[ADim], float64], c: Field[[ADim, BDim], float64] - ): - return where(a, ((c, 9.0), (b, 6.0)), ((8.0, b), (5.0, 9.0))) - - parsed = FieldOperatorParser.apply_to_function(tuple_where_mix_dims) - - assert parsed.body.stmts[0].value.type == ts.TupleType( - types=[ - ts.TupleType( - types=[ - ts.FieldType( - dims=[ADim, BDim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - ), - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ] - ), - ts.TupleType( - types=[ - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ts.FieldType(dims=[ADim], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)), - ] - ), - ] - ) - - -def test_astype_dtype(): - def simple_astype(a: Field[[TDim], float64]): - return astype(a, bool) - - parsed = FieldOperatorParser.apply_to_function(simple_astype) - - assert parsed.body.stmts[0].value.type == ts.FieldType( - dims=[TDim], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) - ) - - -def test_astype_wrong_dtype(): - def simple_astype(a: Field[[TDim], float64]): - # we just use broadcast here, but anything with type function is fine - return astype(a, broadcast) - - with pytest.raises( - errors.DSLError, - match=r"Invalid call to 'astype': second argument must be a scalar type, got.", - ): - _ = FieldOperatorParser.apply_to_function(simple_astype) - - -def test_astype_wrong_value_type(): - def simple_astype(a: Field[[TDim], float64]): - # we just use broadcast here but anything that is not a field, scalar or tuple thereof works - return astype(broadcast, bool) - - with pytest.raises(errors.DSLError) as exc_info: - _ = FieldOperatorParser.apply_to_function(simple_astype) - - assert ( - re.search("Expected 1st argument to be of type", exc_info.value.__cause__.args[0]) - is not None - ) - - -def test_mod_floats(): - def modulo_floats(inp: Field[[TDim], float]): - return inp % 3.0 - - with pytest.raises(errors.DSLError, match=r"Type 'float64' can not be used in operator '%'"): - _ = FieldOperatorParser.apply_to_function(modulo_floats) - - -def test_undefined_symbols(): - def return_undefined(): - return undefined_symbol - - with pytest.raises(errors.DSLError, match="Undeclared symbol"): - _ = FieldOperatorParser.apply_to_function(return_undefined) - - -def test_as_offset_dim(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - Boff = FieldOffset("Boff", source=BDim, target=(BDim,)) - - def as_offset_dim(a: Field[[ADim, BDim], float], b: Field[[ADim], int]): - return a(as_offset(Boff, b)) - - with pytest.raises(errors.DSLError, match=f"not in list of offset field dimensions"): - _ = FieldOperatorParser.apply_to_function(as_offset_dim) - - -def test_as_offset_dtype(): - ADim = Dimension("ADim") - BDim = Dimension("BDim") - Boff = FieldOffset("Boff", source=BDim, target=(BDim,)) - - def as_offset_dtype(a: Field[[ADim, BDim], float], b: Field[[BDim], float]): - return a(as_offset(Boff, b)) - - with pytest.raises(errors.DSLError, match=f"expected integer for offset field dtype"): - _ = FieldOperatorParser.apply_to_function(as_offset_dtype) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 8da0e45810..0000000000 --- a/tox.ini +++ /dev/null @@ -1,199 +0,0 @@ -[tox] -requires = - tox>=4.2 - virtualenv>20.2 -envlist = - cartesian-py{310}-{internal,dace}-{cpu} - eve-py{310} - next-py{310}-{nomesh,atlas}-{cpu} - storage-py{310}-{internal,dace}-{cpu} - # docs -labels = - test-cartesian-cpu = cartesian-py38-internal-cpu, cartesian-internal-py39-cpu, \ - cartesian-internal-py310-cpu, cartesian-py311-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, cartesian-py311-dace-cpu - test-eve-cpu = eve-py38, eve-py39, eve-py310, eve-py311 - test-next-cpu = next-py310-nomesh-cpu, next-py311-nomesh-cpu, next-py310-atlas-cpu, next-py311-atlas-cpu - test-storage-cpu = storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu - test-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, cartesian-py311-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, cartesian-py311-dace-cpu, \ - eve-py38, eve-py39, eve-py310, eve-py311, \ - next-py310-nomesh-cpu, next-py311-nomesh-cpu, next-py310-atlas-cpu, next-py311-atlas-cpu, \ - storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu - -[testenv] -deps = -r {tox_root}{/}{env:ENV_REQUIREMENTS_FILE:requirements-dev.txt} -constrain_package_deps = true -use_frozen_constraints = true -extras = - testing - formatting - dace: dace - cuda: cuda - cuda11x: cuda11x - cuda12x: cuda12x -package = wheel -wheel_build_env = .pkg -pass_env = CUDAARCHS, NUM_PROCESSES, GT4PY_* -set_env = - PYTEST_ADDOPTS = --color=auto --instafail - PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning,ignore:Field View Program:UserWarning} - -# -- Primary tests -- -[testenv:cartesian-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] -description = Run 'gt4py.cartesian' tests -pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE -allowlist_externals = - make - gcc - g++ - ldd - rm -commands = - python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "\ - internal: not requires_dace \ - dace: requires_dace \ - cpu: and not requires_gpu \ - {cuda,cuda11x,cuda12x}: and requires_gpu \ - " {posargs} tests{/}cartesian_tests - python -m pytest --doctest-modules --doctest-ignore-import-errors src{/}gt4py{/}cartesian -# commands_pre = -# rm -Rf tests/_reports/coverage* -# commands_post = -# coverage json --rcfile=setup.cfg -# coverage html --rcfile=setup.cfg --show-contexts - -[testenv:eve-py{38,39,310,311}] -description = Run 'gt4py.eve' tests -commands = - python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests - python -m pytest --doctest-modules src{/}gt4py{/}eve - -[testenv:next-py{310,311}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] -description = Run 'gt4py.next' tests -pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH -deps = - -r {tox_root}{/}requirements-dev.txt - atlas: atlas4py -set_env = - {[testenv]set_env} - PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} -commands = - python -m pytest --suppress-no-test-exit-code --cache-clear -v -n {env:NUM_PROCESSES:1} -m "\ - nomesh: not requires_atlas \ - atlas: requires_atlas \ - cpu: and not requires_gpu \ - {cuda,cuda11x,cuda12x}: and requires_gpu \ - " {posargs} tests{/}next_tests - pytest --doctest-modules src{/}gt4py{/}next - -[testenv:storage-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] -description = Run 'gt4py.storage' tests -commands = - python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "\ - cpu: not requires_gpu \ - {cuda,cuda11x,cuda12x}: requires_gpu \ - " {posargs} tests{/}storage_tests - # pytest doctest-modules {posargs} src{/}gt4py{/}storage - -# -- Secondary tests -- -[testenv:notebooks-py{310,311}] -description = Run notebooks -commands_pre = - jupytext docs/user/next/QuickstartGuide.md --to .ipynb - jupytext docs/user/next/advanced/*.md --to .ipynb -commands = - python -m pytest --nbmake docs/user/next/workshop/slides -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake docs/user/next/workshop/exercises -k 'solutions' -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake docs/user/next/QuickstartGuide.ipynb -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake docs/user/next/advanced -v -n {env:NUM_PROCESSES:1} - python -m pytest --nbmake examples -v -n {env:NUM_PROCESSES:1} - -# -- Other artefacts -- -[testenv:dev-py{38,39,310,311}{-atlas,}] -description = Initialize development environment for gt4py -deps = - -r {tox_root}{/}requirements-dev.txt - atlas: atlas4py -package = editable-legacy # => use_develop = True -set_env = - {[testenv]set_env} - PIP_EXTRA_INDEX_URL = {env:PIP_EXTRA_INDEX_URL:https://test.pypi.org/simple/} - -# [testenv:diagrams] -# install_command = echo {packages} -# skip_install = true -# allowlist_externals = -# /bin/bash -# make -# gcc -# g++ -# ldd -# rm -# plantuml -# git -# echo -# changedir = docs/development/ADRs -# commands = -# plantuml ./*.md -tsvg -o _static -# git add _static -# commands_post = - -[testenv:requirements-{base,py38,py39,py310,py311}] -description = - base: Update pinned development requirements - py38: Update requirements for testing a specific python version - py39: Update requirements for testing a specific python version - py310: Update requirements for testing a specific python version - py311: Update requirements for testing a specific python version -base_python = - base: py38 - py38: py38 - py39: py39 - py310: py310 - py311: py311 -deps = - cogapp>=3.3 - packaging>=20.0 - pip-tools>=6.10 -package = skip -set_env = - CUSTOM_COMPILE_COMMAND = "tox run -e requirements-base" -allowlist_externals = - mv -commands = - -mv constraints.txt constraints.txt.old - -mv requirements-dev.txt requirements-dev.old - # Run cog to update requirements files from pyproject - cog -r -P min-requirements-test.txt min-extra-requirements-test.txt - # Generate constraints file removing extras - # (extras are not supported by pip in constraints files) - pip-compile -r --resolver=backtracking \ - --annotation-style line \ - --build-isolation \ - --strip-extras \ - --allow-unsafe \ - --extra dace \ - --extra formatting \ - --extra jax-cpu \ - --extra performance \ - --extra testing \ - -o constraints.txt \ - pyproject.toml requirements-dev.in - # Generate actual requirements file - # (compiling from scratch again to print actual package sources) - pip-compile --resolver=backtracking \ - --annotation-style line \ - --build-isolation \ - --allow-unsafe \ - --extra dace \ - --extra formatting \ - --extra jax-cpu \ - --extra testing \ - -c constraints.txt \ - -o requirements-dev.txt \ - pyproject.toml requirements-dev.in - # Run cog to update .pre-commit-config.yaml with new versions - base: cog -r -P .pre-commit-config.yaml diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000000..dbcb32411d --- /dev/null +++ b/uv.lock @@ -0,0 +1,3494 @@ +version = 1 +requires-python = ">=3.10, <3.12" +resolution-markers = [ + "python_full_version >= '3.11'", + "python_full_version < '3.11'", +] +conflicts = [[ + { package = "gt4py", extra = "cuda11" }, + { package = "gt4py", extra = "jax-cuda12" }, + { package = "gt4py", extra = "rocm4-3" }, + { package = "gt4py", extra = "rocm5-0" }, +], [ + { package = "gt4py", extra = "dace" }, + { package = "gt4py", extra = "dace-next" }, +], [ + { package = "gt4py", extra = "all" }, + { package = "gt4py", extra = "dace-next" }, +]] + +[[package]] +name = "aenum" +version = "3.1.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/f8/33e75863394f42e429bb553e05fda7c59763f0fd6848de847a25b3fbccf6/aenum-3.1.15.tar.gz", hash = "sha256:8cbd76cd18c4f870ff39b24284d3ea028fbe8731a58df3aa581e434c575b9559", size = 134730 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/fa/ca0c66b388624ba9dbbf35aab3a9f326bfdf5e56a7237fe8f1b600da6864/aenum-3.1.15-py3-none-any.whl", hash = "sha256:e0dfaeea4c2bd362144b87377e2c61d91958c5ed0b4daf89cb6f45ae23af6288", size = 137633 }, +] + +[[package]] +name = "alabaster" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/f8/d9c74d0daf3f742840fd818d69cfae176fa332022fd44e3469487d5a9420/alabaster-1.0.0.tar.gz", hash = "sha256:c00dca57bca26fa62a6d7d0a9fcce65f3e026e9bfe33e9c538fd3fbb2144fd9e", size = 24210 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/b3/6b4067be973ae96ba0d615946e314c5ae35f9f993eca561b356540bb0c2b/alabaster-1.0.0-py3-none-any.whl", hash = "sha256:fc6786402dc3fcb2de3cabd5fe455a2db534b371124f1f21de8731783dec828b", size = 13929 }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643 }, +] + +[[package]] +name = "anyio" +version = "4.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/73/199a98fc2dae33535d6b8e8e6ec01f8c1d76c9adb096c6b7d64823038cde/anyio-4.8.0.tar.gz", hash = "sha256:1d9fe889df5212298c0c0723fa20479d1b94883a2df44bd3897aa91083316f7a", size = 181126 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 }, +] + +[[package]] +name = "apeye" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apeye-core" }, + { name = "domdf-python-tools" }, + { name = "platformdirs" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4f/6b/cc65e31843d7bfda8313a9dc0c77a21e8580b782adca53c7cb3e511fe023/apeye-1.4.1.tar.gz", hash = "sha256:14ea542fad689e3bfdbda2189a354a4908e90aee4bf84c15ab75d68453d76a36", size = 99219 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/7b/2d63664777b3e831ac1b1d8df5bbf0b7c8bee48e57115896080890527b1b/apeye-1.4.1-py3-none-any.whl", hash = "sha256:44e58a9104ec189bf42e76b3a7fe91e2b2879d96d48e9a77e5e32ff699c9204e", size = 107989 }, +] + +[[package]] +name = "apeye-core" +version = "1.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "domdf-python-tools" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e5/4c/4f108cfd06923bd897bf992a6ecb6fb122646ee7af94d7f9a64abd071d4c/apeye_core-1.1.5.tar.gz", hash = "sha256:5de72ed3d00cc9b20fea55e54b7ab8f5ef8500eb33a5368bc162a5585e238a55", size = 96511 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/9f/fa9971d2a0c6fef64c87ba362a493a4f230eff4ea8dfb9f4c7cbdf71892e/apeye_core-1.1.5-py3-none-any.whl", hash = "sha256:dc27a93f8c9e246b3b238c5ea51edf6115ab2618ef029b9f2d9a190ec8228fbf", size = 99286 }, +] + +[[package]] +name = "appnope" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/5d/752690df9ef5b76e169e68d6a129fa6d08a7100ca7f754c89495db3c6019/appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee", size = 4170 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 }, +] + +[[package]] +name = "argcomplete" +version = "3.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/be/6c23d80cb966fb8f83fb1ebfb988351ae6b0554d0c3a613ee4531c026597/argcomplete-3.5.3.tar.gz", hash = "sha256:c12bf50eded8aebb298c7b7da7a5ff3ee24dffd9f5281867dfe1424b58c55392", size = 72999 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/08/2a4db06ec3d203124c967fc89295e85a202e5cbbcdc08fd6a64b65217d1e/argcomplete-3.5.3-py3-none-any.whl", hash = "sha256:2ab2c4a215c59fd6caaff41a869480a23e8f6a5f910b266c1808037f4e375b61", size = 43569 }, +] + +[[package]] +name = "asttokens" +version = "2.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/45/1d/f03bcb60c4a3212e15f99a56085d93093a497718adf828d050b9d675da81/asttokens-2.4.1.tar.gz", hash = "sha256:b03869718ba9a6eb027e134bfdf69f38a236d681c83c160d510768af11254ba0", size = 62284 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/86/4736ac618d82a20d87d2f92ae19441ebc7ac9e7a581d7e58bbe79233b24a/asttokens-2.4.1-py2.py3-none-any.whl", hash = "sha256:051ed49c3dcae8913ea7cd08e46a606dba30b79993209636c4875bc1d637bc24", size = 27764 }, +] + +[[package]] +name = "astunparse" +version = "1.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, + { name = "wheel" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/af/4182184d3c338792894f34a62672919db7ca008c89abee9b564dd34d8029/astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872", size = 18290 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8", size = 12732 }, +] + +[[package]] +name = "atlas4py" +version = "0.35.1.dev15" +source = { registry = "https://test.pypi.org/simple/" } +sdist = { url = "https://test-files.pythonhosted.org/packages/59/e4/48ede747be846f80b30d6303d732f96ca44ee9858504140db5222d2345bb/atlas4py-0.35.1.dev15.tar.gz", hash = "sha256:3c4274261d99a03ffd14a23dfb9ee9265ce79d8db7887751f4fbf1a315091664", size = 15079 } +wheels = [ + { url = "https://test-files.pythonhosted.org/packages/7a/47/0d1f8f7ba596a60bef920638724dfcc76f4edbfdb6bb79932b7e12ec45fc/atlas4py-0.35.1.dev15-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:244ae7f016d28ad04f8e9071de34192c1f8a58fd075477e327c4528cad8daacf", size = 6040572 }, + { url = "https://test-files.pythonhosted.org/packages/5d/f5/2b5645ec670b4088816ca7089fae06c6d72f0a4c301ef186ec8ac8e715fd/atlas4py-0.35.1.dev15-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:12b8f63df17e22d0ddc8310ced42e4db903e73b167c4f261f180cc2c011888ca", size = 5752419 }, + { url = "https://test-files.pythonhosted.org/packages/41/c3/03f3f061d28865f307c7916a0b82b8d37efeddb6cd4085aa687718341aee/atlas4py-0.35.1.dev15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cadd6e5de2e0771d129b6242cbe0bd9268bed16d37ee3cc65b97a7de19a67933", size = 5251334 }, + { url = "https://test-files.pythonhosted.org/packages/22/fe/32d912deb54d7e9eaecc652b813b86925616be358222e069ded6e3bea8c6/atlas4py-0.35.1.dev15-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:0cacea024adb384aacb5da5a5a233e23cf8563e4f357e9687eeac0d9c4c9a4d8", size = 6041915 }, + { url = "https://test-files.pythonhosted.org/packages/ef/94/e85cc3588d836e58974f7be1b362ce321f5989ae8c355a75faee5b09f131/atlas4py-0.35.1.dev15-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:d5f3147e8ad52b890ffc4d92b51d6fd2b34bb39b89e09d6d4d5d7fec9f48aa0f", size = 5753565 }, + { url = "https://test-files.pythonhosted.org/packages/36/5e/71c7c054ae756f7cd5a984a44edad85ca20f4a0364ccc10052363314a9f2/atlas4py-0.35.1.dev15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c13f4a4a88dbe0eb056920d57eafa3e0f1e9fc117bd3c8773cfebca945ed8d76", size = 5253094 }, +] + +[[package]] +name = "attrs" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/49/7c/fdf464bcc51d23881d110abd74b512a42b3d5d376a55a831b44c603ae17f/attrs-25.1.0.tar.gz", hash = "sha256:1c97078a80c814273a76b2a298a932eb681c87415c11dee0a6921de7f1b02c3e", size = 810562 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/30/d4986a882011f9df997a55e6becd864812ccfcd821d64aac8570ee39f719/attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a", size = 63152 }, +] + +[[package]] +name = "autodocsumm" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/03/96/92afe8a7912b327c01f0a8b6408c9556ee13b1aba5b98d587ac7327ff32d/autodocsumm-0.2.14.tar.gz", hash = "sha256:2839a9d4facc3c4eccd306c08695540911042b46eeafcdc3203e6d0bab40bc77", size = 46357 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/bc/3f66af9beb683728e06ca08797e4e9d3e44f432f339718cae3ba856a9cad/autodocsumm-0.2.14-py3-none-any.whl", hash = "sha256:3bad8717fc5190802c60392a7ab04b9f3c97aa9efa8b3780b3d81d615bfe5dc0", size = 14640 }, +] + +[[package]] +name = "babel" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 }, +] + +[[package]] +name = "beautifulsoup4" +version = "4.13.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "soupsieve" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f0/3c/adaf39ce1fb4afdd21b611e3d530b183bb7759c9b673d60db0e347fd4439/beautifulsoup4-4.13.3.tar.gz", hash = "sha256:1bd32405dacc920b42b83ba01644747ed77456a65760e285fbc47633ceddaf8b", size = 619516 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/49/6abb616eb3cbab6a7cca303dc02fdf3836de2e0b834bf966a7f5271a34d8/beautifulsoup4-4.13.3-py3-none-any.whl", hash = "sha256:99045d7d3f08f91f0d656bc9b7efbae189426cd913d830294a15eefa0ea4df16", size = 186015 }, +] + +[[package]] +name = "black" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "mypy-extensions" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "platformdirs" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/3b/4ba3f93ac8d90410423fdd31d7541ada9bcee1df32fb90d26de41ed40e1d/black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32", size = 1629419 }, + { url = "https://files.pythonhosted.org/packages/b4/02/0bde0485146a8a5e694daed47561785e8b77a0466ccc1f3e485d5ef2925e/black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da", size = 1461080 }, + { url = "https://files.pythonhosted.org/packages/52/0e/abdf75183c830eaca7589144ff96d49bce73d7ec6ad12ef62185cc0f79a2/black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7", size = 1766886 }, + { url = "https://files.pythonhosted.org/packages/dc/a6/97d8bb65b1d8a41f8a6736222ba0a334db7b7b77b8023ab4568288f23973/black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9", size = 1419404 }, + { url = "https://files.pythonhosted.org/packages/7e/4f/87f596aca05c3ce5b94b8663dbfe242a12843caaa82dd3f85f1ffdc3f177/black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0", size = 1614372 }, + { url = "https://files.pythonhosted.org/packages/e7/d0/2c34c36190b741c59c901e56ab7f6e54dad8df05a6272a9747ecef7c6036/black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299", size = 1442865 }, + { url = "https://files.pythonhosted.org/packages/21/d4/7518c72262468430ead45cf22bd86c883a6448b9eb43672765d69a8f1248/black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096", size = 1749699 }, + { url = "https://files.pythonhosted.org/packages/58/db/4f5beb989b547f79096e035c4981ceb36ac2b552d0ac5f2620e941501c99/black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2", size = 1428028 }, + { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646 }, +] + +[[package]] +name = "boltons" +version = "25.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/63/54/71a94d8e02da9a865587fb3fff100cb0fc7aa9f4d5ed9ed3a591216ddcc7/boltons-25.0.0.tar.gz", hash = "sha256:e110fbdc30b7b9868cb604e3f71d4722dd8f4dcb4a5ddd06028ba8f1ab0b5ace", size = 246294 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/7f/0e961cf3908bc4c1c3e027de2794f867c6c89fb4916fc7dba295a0e80a2d/boltons-25.0.0-py3-none-any.whl", hash = "sha256:dc9fb38bf28985715497d1b54d00b62ea866eca3938938ea9043e254a3a6ca62", size = 194210 }, +] + +[[package]] +name = "bracex" +version = "2.5.post1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/6c/57418c4404cd22fe6275b8301ca2b46a8cdaa8157938017a9ae0b3edf363/bracex-2.5.post1.tar.gz", hash = "sha256:12c50952415bfa773d2d9ccb8e79651b8cdb1f31a42f6091b804f6ba2b4a66b6", size = 26641 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/02/8db98cdc1a58e0abd6716d5e63244658e6e63513c65f469f34b6f1053fd0/bracex-2.5.post1-py3-none-any.whl", hash = "sha256:13e5732fec27828d6af308628285ad358047cec36801598368cb28bc631dbaf6", size = 11558 }, +] + +[[package]] +name = "bump-my-version" +version = "0.32.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "questionary" }, + { name = "rich" }, + { name = "rich-click" }, + { name = "tomlkit" }, + { name = "wcmatch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e7/8b/72f0cd91ca6e296b71b05d39fcfbcf365eebaa5679a863ce7bb4d9d8aad7/bump_my_version-0.32.0.tar.gz", hash = "sha256:e8d964d13ba3ab6c090a872d0b5094ecf8df7ae8052b09288ace00fc6647df27", size = 1028515 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/67/92853455bb91f09cb1bb9d3a4993b2e5fda80d6c44c727eb93993dc1cc60/bump_my_version-0.32.0-py3-none-any.whl", hash = "sha256:7c807110bdd8ecc845019e68a050ff378d836effb116440ba7f4a8ad59652b63", size = 57572 }, +] + +[[package]] +name = "cachecontrol" +version = "0.14.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "msgpack" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/a4/3390ac4dfa1773f661c8780368018230e8207ec4fd3800d2c0c3adee4456/cachecontrol-0.14.2.tar.gz", hash = "sha256:7d47d19f866409b98ff6025b6a0fca8e4c791fb31abbd95f622093894ce903a2", size = 28832 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/63/baffb44ca6876e7b5fc8fe17b24a7c07bf479d604a592182db9af26ea366/cachecontrol-0.14.2-py3-none-any.whl", hash = "sha256:ebad2091bf12d0d200dfc2464330db638c5deb41d546f6d7aca079e87290f3b0", size = 21780 }, +] + +[package.optional-dependencies] +filecache = [ + { name = "filelock" }, +] + +[[package]] +name = "cached-property" +version = "2.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/76/4b/3d870836119dbe9a5e3c9a61af8cc1a8b69d75aea564572e385882d5aefb/cached_property-2.0.1.tar.gz", hash = "sha256:484d617105e3ee0e4f1f58725e72a8ef9e93deee462222dbd51cd91230897641", size = 10574 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/0e/7d8225aab3bc1a0f5811f8e1b557aa034ac04bdf641925b30d3caf586b28/cached_property-2.0.1-py3-none-any.whl", hash = "sha256:f617d70ab1100b7bcf6e42228f9ddcb78c676ffa167278d9f730d1c2fba69ccb", size = 7428 }, +] + +[[package]] +name = "cattrs" +version = "24.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/65/af6d57da2cb32c076319b7489ae0958f746949d407109e3ccf4d115f147c/cattrs-24.1.2.tar.gz", hash = "sha256:8028cfe1ff5382df59dd36474a86e02d817b06eaf8af84555441bac915d2ef85", size = 426462 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/d5/867e75361fc45f6de75fe277dd085627a9db5ebb511a87f27dc1396b5351/cattrs-24.1.2-py3-none-any.whl", hash = "sha256:67c7495b760168d931a10233f979b28dc04daf853b30752246f4f8471c6d68d0", size = 66446 }, +] + +[[package]] +name = "certifi" +version = "2025.1.31" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/ab/c9f1e32b7b1bf505bf26f0ef697775960db7932abeb7b516de930ba2705f/certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651", size = 167577 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/fc/bce832fd4fd99766c04d1ee0eead6b0ec6486fb100ae5e74c1d91292b982/certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe", size = 166393 }, +] + +[[package]] +name = "cffi" +version = "1.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pycparser" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/07/f44ca684db4e4f08a3fdc6eeb9a0d15dc6883efc7b8c90357fdbf74e186c/cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14", size = 182191 }, + { url = "https://files.pythonhosted.org/packages/08/fd/cc2fedbd887223f9f5d170c96e57cbf655df9831a6546c1727ae13fa977a/cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67", size = 178592 }, + { url = "https://files.pythonhosted.org/packages/de/cc/4635c320081c78d6ffc2cab0a76025b691a91204f4aa317d568ff9280a2d/cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382", size = 426024 }, + { url = "https://files.pythonhosted.org/packages/b6/7b/3b2b250f3aab91abe5f8a51ada1b717935fdaec53f790ad4100fe2ec64d1/cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702", size = 448188 }, + { url = "https://files.pythonhosted.org/packages/d3/48/1b9283ebbf0ec065148d8de05d647a986c5f22586b18120020452fff8f5d/cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3", size = 455571 }, + { url = "https://files.pythonhosted.org/packages/40/87/3b8452525437b40f39ca7ff70276679772ee7e8b394934ff60e63b7b090c/cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6", size = 436687 }, + { url = "https://files.pythonhosted.org/packages/8d/fb/4da72871d177d63649ac449aec2e8a29efe0274035880c7af59101ca2232/cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17", size = 446211 }, + { url = "https://files.pythonhosted.org/packages/ab/a0/62f00bcb411332106c02b663b26f3545a9ef136f80d5df746c05878f8c4b/cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8", size = 461325 }, + { url = "https://files.pythonhosted.org/packages/36/83/76127035ed2e7e27b0787604d99da630ac3123bfb02d8e80c633f218a11d/cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e", size = 438784 }, + { url = "https://files.pythonhosted.org/packages/21/81/a6cd025db2f08ac88b901b745c163d884641909641f9b826e8cb87645942/cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be", size = 461564 }, + { url = "https://files.pythonhosted.org/packages/f8/fe/4d41c2f200c4a457933dbd98d3cf4e911870877bd94d9656cc0fcb390681/cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c", size = 171804 }, + { url = "https://files.pythonhosted.org/packages/d1/b6/0b0f5ab93b0df4acc49cae758c81fe4e5ef26c3ae2e10cc69249dfd8b3ab/cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15", size = 181299 }, + { url = "https://files.pythonhosted.org/packages/6b/f4/927e3a8899e52a27fa57a48607ff7dc91a9ebe97399b357b85a0c7892e00/cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401", size = 182264 }, + { url = "https://files.pythonhosted.org/packages/6c/f5/6c3a8efe5f503175aaddcbea6ad0d2c96dad6f5abb205750d1b3df44ef29/cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf", size = 178651 }, + { url = "https://files.pythonhosted.org/packages/94/dd/a3f0118e688d1b1a57553da23b16bdade96d2f9bcda4d32e7d2838047ff7/cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", size = 445259 }, + { url = "https://files.pythonhosted.org/packages/2e/ea/70ce63780f096e16ce8588efe039d3c4f91deb1dc01e9c73a287939c79a6/cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", size = 469200 }, + { url = "https://files.pythonhosted.org/packages/1c/a0/a4fa9f4f781bda074c3ddd57a572b060fa0df7655d2a4247bbe277200146/cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", size = 477235 }, + { url = "https://files.pythonhosted.org/packages/62/12/ce8710b5b8affbcdd5c6e367217c242524ad17a02fe5beec3ee339f69f85/cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6", size = 459721 }, + { url = "https://files.pythonhosted.org/packages/ff/6b/d45873c5e0242196f042d555526f92aa9e0c32355a1be1ff8c27f077fd37/cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d", size = 467242 }, + { url = "https://files.pythonhosted.org/packages/1a/52/d9a0e523a572fbccf2955f5abe883cfa8bcc570d7faeee06336fbd50c9fc/cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", size = 477999 }, + { url = "https://files.pythonhosted.org/packages/44/74/f2a2460684a1a2d00ca799ad880d54652841a780c4c97b87754f660c7603/cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", size = 454242 }, + { url = "https://files.pythonhosted.org/packages/f8/4a/34599cac7dfcd888ff54e801afe06a19c17787dfd94495ab0c8d35fe99fb/cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b", size = 478604 }, + { url = "https://files.pythonhosted.org/packages/34/33/e1b8a1ba29025adbdcda5fb3a36f94c03d771c1b7b12f726ff7fef2ebe36/cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655", size = 171727 }, + { url = "https://files.pythonhosted.org/packages/3d/97/50228be003bb2802627d28ec0627837ac0bf35c90cf769812056f235b2d1/cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0", size = 181400 }, +] + +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/16/b0/572805e227f01586461c80e0fd25d65a2115599cc9dad142fee4b747c357/charset_normalizer-3.4.1.tar.gz", hash = "sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3", size = 123188 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/58/5580c1716040bc89206c77d8f74418caf82ce519aae06450393ca73475d1/charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de", size = 198013 }, + { url = "https://files.pythonhosted.org/packages/d0/11/00341177ae71c6f5159a08168bcb98c6e6d196d372c94511f9f6c9afe0c6/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176", size = 141285 }, + { url = "https://files.pythonhosted.org/packages/01/09/11d684ea5819e5a8f5100fb0b38cf8d02b514746607934134d31233e02c8/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037", size = 151449 }, + { url = "https://files.pythonhosted.org/packages/08/06/9f5a12939db324d905dc1f70591ae7d7898d030d7662f0d426e2286f68c9/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f", size = 143892 }, + { url = "https://files.pythonhosted.org/packages/93/62/5e89cdfe04584cb7f4d36003ffa2936681b03ecc0754f8e969c2becb7e24/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a", size = 146123 }, + { url = "https://files.pythonhosted.org/packages/a9/ac/ab729a15c516da2ab70a05f8722ecfccc3f04ed7a18e45c75bbbaa347d61/charset_normalizer-3.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a", size = 147943 }, + { url = "https://files.pythonhosted.org/packages/03/d2/3f392f23f042615689456e9a274640c1d2e5dd1d52de36ab8f7955f8f050/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247", size = 142063 }, + { url = "https://files.pythonhosted.org/packages/f2/e3/e20aae5e1039a2cd9b08d9205f52142329f887f8cf70da3650326670bddf/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408", size = 150578 }, + { url = "https://files.pythonhosted.org/packages/8d/af/779ad72a4da0aed925e1139d458adc486e61076d7ecdcc09e610ea8678db/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb", size = 153629 }, + { url = "https://files.pythonhosted.org/packages/c2/b6/7aa450b278e7aa92cf7732140bfd8be21f5f29d5bf334ae987c945276639/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d", size = 150778 }, + { url = "https://files.pythonhosted.org/packages/39/f4/d9f4f712d0951dcbfd42920d3db81b00dd23b6ab520419626f4023334056/charset_normalizer-3.4.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807", size = 146453 }, + { url = "https://files.pythonhosted.org/packages/49/2b/999d0314e4ee0cff3cb83e6bc9aeddd397eeed693edb4facb901eb8fbb69/charset_normalizer-3.4.1-cp310-cp310-win32.whl", hash = "sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f", size = 95479 }, + { url = "https://files.pythonhosted.org/packages/2d/ce/3cbed41cff67e455a386fb5e5dd8906cdda2ed92fbc6297921f2e4419309/charset_normalizer-3.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f", size = 102790 }, + { url = "https://files.pythonhosted.org/packages/72/80/41ef5d5a7935d2d3a773e3eaebf0a9350542f2cab4eac59a7a4741fbbbbe/charset_normalizer-3.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125", size = 194995 }, + { url = "https://files.pythonhosted.org/packages/7a/28/0b9fefa7b8b080ec492110af6d88aa3dea91c464b17d53474b6e9ba5d2c5/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1", size = 139471 }, + { url = "https://files.pythonhosted.org/packages/71/64/d24ab1a997efb06402e3fc07317e94da358e2585165930d9d59ad45fcae2/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3", size = 149831 }, + { url = "https://files.pythonhosted.org/packages/37/ed/be39e5258e198655240db5e19e0b11379163ad7070962d6b0c87ed2c4d39/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd", size = 142335 }, + { url = "https://files.pythonhosted.org/packages/88/83/489e9504711fa05d8dde1574996408026bdbdbd938f23be67deebb5eca92/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00", size = 143862 }, + { url = "https://files.pythonhosted.org/packages/c6/c7/32da20821cf387b759ad24627a9aca289d2822de929b8a41b6241767b461/charset_normalizer-3.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12", size = 145673 }, + { url = "https://files.pythonhosted.org/packages/68/85/f4288e96039abdd5aeb5c546fa20a37b50da71b5cf01e75e87f16cd43304/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77", size = 140211 }, + { url = "https://files.pythonhosted.org/packages/28/a3/a42e70d03cbdabc18997baf4f0227c73591a08041c149e710045c281f97b/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146", size = 148039 }, + { url = "https://files.pythonhosted.org/packages/85/e4/65699e8ab3014ecbe6f5c71d1a55d810fb716bbfd74f6283d5c2aa87febf/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd", size = 151939 }, + { url = "https://files.pythonhosted.org/packages/b1/82/8e9fe624cc5374193de6860aba3ea8070f584c8565ee77c168ec13274bd2/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6", size = 149075 }, + { url = "https://files.pythonhosted.org/packages/3d/7b/82865ba54c765560c8433f65e8acb9217cb839a9e32b42af4aa8e945870f/charset_normalizer-3.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8", size = 144340 }, + { url = "https://files.pythonhosted.org/packages/b5/b6/9674a4b7d4d99a0d2df9b215da766ee682718f88055751e1e5e753c82db0/charset_normalizer-3.4.1-cp311-cp311-win32.whl", hash = "sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b", size = 95205 }, + { url = "https://files.pythonhosted.org/packages/1e/ab/45b180e175de4402dcf7547e4fb617283bae54ce35c27930a6f35b6bef15/charset_normalizer-3.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76", size = 102441 }, + { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767 }, +] + +[[package]] +name = "clang-format" +version = "19.1.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ee/71d017fe603c06b83d6720df6b3f6f07f03abf330f39beee3fee2a067c56/clang_format-19.1.7.tar.gz", hash = "sha256:bd6fc5272a41034a7844149203461d1f311bece9ed100d22eb3eebd952a25f49", size = 11122 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/c3/2f1c53bc298c1740d0c9f8dc2d9b7030be4826b6f2aa8a04f07ef25a3d9b/clang_format-19.1.7-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:a09f34d2c89d176581858ff718c327eebc14eb6415c176dab4af5bfd8582a999", size = 1428184 }, + { url = "https://files.pythonhosted.org/packages/8e/9d/7c246a3d08105de305553d14971ed6c16cde06d20ab12d6ce7f243cf66f0/clang_format-19.1.7-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:776f89c7b056c498c0e256485bc031cbf514aaebe71e929ed54e50c478524b65", size = 1398224 }, + { url = "https://files.pythonhosted.org/packages/b1/7d/002aa5571351ee7f00f87aae5104cdd30cad1a46f25936226f7d2aed06bf/clang_format-19.1.7-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dac394c83a9233ab6707f66e1cdbd950f8b014b58604142a5b6f7998bf0bcc8c", size = 1730962 }, + { url = "https://files.pythonhosted.org/packages/1c/fe/24b7c13af432e609d65dc32c47c61f0a6c3b80d78eb7b3df37daf0395c56/clang_format-19.1.7-py2.py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bbd4f94d929edf6d8d81e990dfaafc22bb10deaefcb2762150a136f281b01c00", size = 1908820 }, + { url = "https://files.pythonhosted.org/packages/7d/a8/86595ffd6ea0bf3a3013aad94e3d55be32ef987567781eddf4621e316d09/clang_format-19.1.7-py2.py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bdcda63fffdbe2aac23b54d46408a6283ad16676a5230a95b3ed49eacd99129b", size = 2622838 }, + { url = "https://files.pythonhosted.org/packages/48/d1/731ebf78c5d5cc043c20b0755c89239350b8e75ac5d667b99689e8110bc7/clang_format-19.1.7-py2.py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c13a5802da986b1400afbee97162c29f841890ab9e20a0be7ede18189219f5f1", size = 1723352 }, + { url = "https://files.pythonhosted.org/packages/3c/e7/0e526915a3a4a23100cc721c24226a192fa0385d394019d06920dc83fe6c/clang_format-19.1.7-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f4906fb463dd2033032978f56962caab268c9428a384126b9400543eb667f11c", size = 1740347 }, + { url = "https://files.pythonhosted.org/packages/52/04/ed8e2af6b3e29655a858b3aad145f3f0539df0dd1c77815b95f578260bd3/clang_format-19.1.7-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ffca915c09aed9137f8c649ad7521bd5ce690c939121db1ba54af2ba63ac8374", size = 2675802 }, + { url = "https://files.pythonhosted.org/packages/9a/ab/7874a6f45c167f4cc4d02f517b85d14b6b5fa8412f6e9c7482588d00fccb/clang_format-19.1.7-py2.py3-none-musllinux_1_2_i686.whl", hash = "sha256:fc011dc7bbe3ac8a32e0caa37ab8ba6c1639ceef6ecd04feea8d37360fc175e4", size = 2977872 }, + { url = "https://files.pythonhosted.org/packages/46/b5/c87b6c46eb7e9d0f07e2bd56cd0a62bf7e679f146b4e1447110cfae4bd01/clang_format-19.1.7-py2.py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:afdfb11584f5a6f15127a7061673a7ea12a0393fe9ee8d2ed84e74bb191ffc3b", size = 3125795 }, + { url = "https://files.pythonhosted.org/packages/22/3e/7ea08aba446c1e838367d3c0e13eb3d2e482b23e099a25149d4f7f6b8c75/clang_format-19.1.7-py2.py3-none-musllinux_1_2_s390x.whl", hash = "sha256:6ce81d5b08e0169dc52037d3ff1802eafcaf86c281ceb8b38b8359ba7b6b7bdc", size = 3069663 }, + { url = "https://files.pythonhosted.org/packages/f5/f9/6ce7fe8ff52ded01d02a568358f2ddf993347e44202b6506b039a583b7ed/clang_format-19.1.7-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d27ac1a5a8783c9271d41cd5851766ca547ea003efa4e3764f880f319b2d3ed3", size = 2763172 }, + { url = "https://files.pythonhosted.org/packages/82/fa/77fe5636bb6b6252918bf129226a248506af218a2256deece3a9d95af850/clang_format-19.1.7-py2.py3-none-win32.whl", hash = "sha256:5dfde0be33f038114af89efb917144c2f766f8b7f3a3d3e4cb9c25f76d71ef81", size = 1243262 }, + { url = "https://files.pythonhosted.org/packages/e4/32/0b44f3582b9df0b8f90266ef43975e37ec8ad52bae4f85b71552f264d5a2/clang_format-19.1.7-py2.py3-none-win_amd64.whl", hash = "sha256:3e3c75fbdf8827bbb7277226b3057fc3785dabe7284d3a9d15fceb250f68f529", size = 1441132 }, +] + +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + +[[package]] +name = "cmake" +version = "3.31.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/50/cb/3a327fa784a5dbaf838b135cb1729f43535c52d83bbf02191fb8a0cb118e/cmake-3.31.4.tar.gz", hash = "sha256:a6ac2242e0b16ad7d94c9f8572d6f232e6169747be50e5cdf497f206c4819ce1", size = 34278 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/db/50efa1d3e29cb2a6e8e143e522e52698b3fc08f4b56100fb35f97a70af79/cmake-3.31.4-py3-none-macosx_10_10_universal2.whl", hash = "sha256:fc048b4b70facd16699a43c737f6782b4eff56e8e6093090db5979532d9db0f6", size = 47198138 }, + { url = "https://files.pythonhosted.org/packages/c7/76/ccb8764761c739ef16bd8957a16ecbda01b03c2d7d241c376bfca6bf2822/cmake-3.31.4-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:2a37be93534df04513f0845492d71bc80899c3f87b77e3b01c95aff1a7fc9bde", size = 27556485 }, + { url = "https://files.pythonhosted.org/packages/ad/8e/888e2944655d7fa1ea5af46b60883a0e7847bbf9fb7ecc321c8e5f0a1394/cmake-3.31.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c9f5f8289c5e7bd2ed654cbac164021fa7723064fee0443a2f0068bc08413d81", size = 26808834 }, + { url = "https://files.pythonhosted.org/packages/59/f4/0b2b1430a441c3c09ee102bf8c5d9ec1dc11d002ff4affef15c656f37ce9/cmake-3.31.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:926d91cae2ba7d2f3df857d0fc066bdac4f3904bf5c95e99b60435e85aabedb4", size = 27140820 }, + { url = "https://files.pythonhosted.org/packages/d1/f9/a274b4e36e457d8e99db1038cc31a6c391bf3bc26230c2dc9caf37499753/cmake-3.31.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:929a8d8d289d69e43784661748ddd08933ce1ec5db8f9bcfce6ee817a48f8787", size = 28868269 }, + { url = "https://files.pythonhosted.org/packages/9b/35/8da1ffa00a3f3853881aa5025cdf11c744303013df70c8716155b83825d3/cmake-3.31.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b463efdf5b92f3b290235aa9f8da092b3dac19b7636c563fd156022dab580649", size = 30732267 }, + { url = "https://files.pythonhosted.org/packages/79/48/bb8485687f5a64d52ac68cfcb02e9b8e46a9e107f380c54d484b6632c87e/cmake-3.31.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:225d9a643b0b60ffce0399ff0cabd7a4820e0dbcb794e97d3aacfcf7c0589ae6", size = 26908885 }, + { url = "https://files.pythonhosted.org/packages/e5/9e/2594d7fa8b263296497bf044469b4ab4797c51675ea629f9672011cdfe09/cmake-3.31.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89143a5e2a5916061f2cfc5012e9fe6281aaf7c0dae7930bdc68d105d22ddc39", size = 27784555 }, + { url = "https://files.pythonhosted.org/packages/95/16/5b1989f1d2287b05cd68792c0a48b721c060f728506d719fcf0e3b80ceb2/cmake-3.31.4-py3-none-manylinux_2_31_armv7l.whl", hash = "sha256:f96127bf663168accd29d5a50ee68ea80f26bcd37f96c7a14ef2378781f19936", size = 24965366 }, + { url = "https://files.pythonhosted.org/packages/5a/4c/289fb0986c6ff63583383eca0c9479147f362330938856a9b5201c84cee8/cmake-3.31.4-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:25c5094394f0cee21130b5678e5b4552f72470e266df6d6fb1d5c505100f0eaa", size = 27824887 }, + { url = "https://files.pythonhosted.org/packages/3c/f3/d45ba2b5bb54f4ef615a6a24cf6258600eec790a9d5017c9584107b445b9/cmake-3.31.4-py3-none-musllinux_1_1_i686.whl", hash = "sha256:466c9295af440bb4a47cc5e1af10576cf2227620528afd0fd0b3effa1d513b49", size = 31368421 }, + { url = "https://files.pythonhosted.org/packages/34/3d/f6b712241ede5fb8e32c13e119c06e142f3f12ead1656721b1f67756106b/cmake-3.31.4-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:f6af3b83a1b1fc1d990d18b6a566ee9c95c0393f986c6df15f2505dda8ad1bcc", size = 32074545 }, + { url = "https://files.pythonhosted.org/packages/f0/23/48cd0404d7238d703a4cd4d7434eeaf12e8fbe68160d52f1489f55f582df/cmake-3.31.4-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:23781e17563693a68b0cef85749746894b8a61488e56e96fc6649b73652e8236", size = 27946950 }, + { url = "https://files.pythonhosted.org/packages/21/03/014d9710bccf5a7e04c6f6ee27bfaba1220e79ee145d7b95f84e7843729b/cmake-3.31.4-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:838a388b559137f3654d8cf30f62bbdec10f8d1c3624f0d289614d33cdf4fba1", size = 29473412 }, + { url = "https://files.pythonhosted.org/packages/23/de/5a8142732f0a52dedac2887e0c105c9bbb449e517ade500e56bf2af520d1/cmake-3.31.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a6a3b0b9557f41c955a6b25c94205f2ca9c3a46edca809ad87507c5ef6bc4274", size = 32971081 }, + { url = "https://files.pythonhosted.org/packages/a5/a1/50c11f0b110986c753592f025970094030b25748df126abe8e38265be722/cmake-3.31.4-py3-none-win32.whl", hash = "sha256:d378c9e58eac906bddafd673c7571262dcd5a9946bb1e8f9e3902572a8fa95ca", size = 33351393 }, + { url = "https://files.pythonhosted.org/packages/0c/7f/331d181b6b1b8942ec5fad23e98fff85218485f29f62f6bc60663d424df8/cmake-3.31.4-py3-none-win_amd64.whl", hash = "sha256:20be7cdb41903edf85e8a498c4beff8d6854acbb087abfb07c362c738bdf0018", size = 36496715 }, + { url = "https://files.pythonhosted.org/packages/65/26/11a78723364716004928b7bea7d96cf2c72dc3abfaa7c163159110fcb649/cmake-3.31.4-py3-none-win_arm64.whl", hash = "sha256:9479a9255197c49e135df039d8484c69aa63158a06ae9c2d0eb939da2f0f7dff", size = 35559239 }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + +[[package]] +name = "colorlog" +version = "6.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/7a/359f4d5df2353f26172b3cc39ea32daa39af8de522205f512f458923e677/colorlog-6.9.0.tar.gz", hash = "sha256:bfba54a1b93b94f54e1f4fe48395725a3d92fd2a4af702f6bd70946bdc0c6ac2", size = 16624 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/51/9b208e85196941db2f0654ad0357ca6388ab3ed67efdbfc799f35d1f83aa/colorlog-6.9.0-py3-none-any.whl", hash = "sha256:5906e71acd67cb07a71e779c47c4bcb45fb8c2993eebe9e5adcd6a6f1b283eff", size = 11424 }, +] + +[[package]] +name = "comm" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/a8/fb783cb0abe2b5fded9f55e5703015cdf1c9c85b3669087c538dd15a6a86/comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e", size = 6210 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/75/49e5bfe642f71f272236b5b2d2691cf915a7283cc0ceda56357b61daa538/comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3", size = 7180 }, +] + +[[package]] +name = "contourpy" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/c2/fc7193cc5383637ff390a712e88e4ded0452c9fbcf84abe3de5ea3df1866/contourpy-1.3.1.tar.gz", hash = "sha256:dfd97abd83335045a913e3bcc4a09c0ceadbe66580cf573fe961f4a825efa699", size = 13465753 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/a3/80937fe3efe0edacf67c9a20b955139a1a622730042c1ea991956f2704ad/contourpy-1.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a045f341a77b77e1c5de31e74e966537bba9f3c4099b35bf4c2e3939dd54cdab", size = 268466 }, + { url = "https://files.pythonhosted.org/packages/82/1d/e3eaebb4aa2d7311528c048350ca8e99cdacfafd99da87bc0a5f8d81f2c2/contourpy-1.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:500360b77259914f7805af7462e41f9cb7ca92ad38e9f94d6c8641b089338124", size = 253314 }, + { url = "https://files.pythonhosted.org/packages/de/f3/d796b22d1a2b587acc8100ba8c07fb7b5e17fde265a7bb05ab967f4c935a/contourpy-1.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b2f926efda994cdf3c8d3fdb40b9962f86edbc4457e739277b961eced3d0b4c1", size = 312003 }, + { url = "https://files.pythonhosted.org/packages/bf/f5/0e67902bc4394daee8daa39c81d4f00b50e063ee1a46cb3938cc65585d36/contourpy-1.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:adce39d67c0edf383647a3a007de0a45fd1b08dedaa5318404f1a73059c2512b", size = 351896 }, + { url = "https://files.pythonhosted.org/packages/1f/d6/e766395723f6256d45d6e67c13bb638dd1fa9dc10ef912dc7dd3dcfc19de/contourpy-1.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abbb49fb7dac584e5abc6636b7b2a7227111c4f771005853e7d25176daaf8453", size = 320814 }, + { url = "https://files.pythonhosted.org/packages/a9/57/86c500d63b3e26e5b73a28b8291a67c5608d4aa87ebd17bd15bb33c178bc/contourpy-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0cffcbede75c059f535725c1680dfb17b6ba8753f0c74b14e6a9c68c29d7ea3", size = 324969 }, + { url = "https://files.pythonhosted.org/packages/b8/62/bb146d1289d6b3450bccc4642e7f4413b92ebffd9bf2e91b0404323704a7/contourpy-1.3.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ab29962927945d89d9b293eabd0d59aea28d887d4f3be6c22deaefbb938a7277", size = 1265162 }, + { url = "https://files.pythonhosted.org/packages/18/04/9f7d132ce49a212c8e767042cc80ae390f728060d2eea47058f55b9eff1c/contourpy-1.3.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:974d8145f8ca354498005b5b981165b74a195abfae9a8129df3e56771961d595", size = 1324328 }, + { url = "https://files.pythonhosted.org/packages/46/23/196813901be3f97c83ababdab1382e13e0edc0bb4e7b49a7bff15fcf754e/contourpy-1.3.1-cp310-cp310-win32.whl", hash = "sha256:ac4578ac281983f63b400f7fe6c101bedc10651650eef012be1ccffcbacf3697", size = 173861 }, + { url = "https://files.pythonhosted.org/packages/e0/82/c372be3fc000a3b2005061ca623a0d1ecd2eaafb10d9e883a2fc8566e951/contourpy-1.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:174e758c66bbc1c8576992cec9599ce8b6672b741b5d336b5c74e35ac382b18e", size = 218566 }, + { url = "https://files.pythonhosted.org/packages/12/bb/11250d2906ee2e8b466b5f93e6b19d525f3e0254ac8b445b56e618527718/contourpy-1.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3e8b974d8db2c5610fb4e76307e265de0edb655ae8169e8b21f41807ccbeec4b", size = 269555 }, + { url = "https://files.pythonhosted.org/packages/67/71/1e6e95aee21a500415f5d2dbf037bf4567529b6a4e986594d7026ec5ae90/contourpy-1.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:20914c8c973f41456337652a6eeca26d2148aa96dd7ac323b74516988bea89fc", size = 254549 }, + { url = "https://files.pythonhosted.org/packages/31/2c/b88986e8d79ac45efe9d8801ae341525f38e087449b6c2f2e6050468a42c/contourpy-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d40d37c1c3a4961b4619dd9d77b12124a453cc3d02bb31a07d58ef684d3d86", size = 313000 }, + { url = "https://files.pythonhosted.org/packages/c4/18/65280989b151fcf33a8352f992eff71e61b968bef7432fbfde3a364f0730/contourpy-1.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:113231fe3825ebf6f15eaa8bc1f5b0ddc19d42b733345eae0934cb291beb88b6", size = 352925 }, + { url = "https://files.pythonhosted.org/packages/f5/c7/5fd0146c93220dbfe1a2e0f98969293b86ca9bc041d6c90c0e065f4619ad/contourpy-1.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4dbbc03a40f916a8420e420d63e96a1258d3d1b58cbdfd8d1f07b49fcbd38e85", size = 323693 }, + { url = "https://files.pythonhosted.org/packages/85/fc/7fa5d17daf77306840a4e84668a48ddff09e6bc09ba4e37e85ffc8e4faa3/contourpy-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a04ecd68acbd77fa2d39723ceca4c3197cb2969633836ced1bea14e219d077c", size = 326184 }, + { url = "https://files.pythonhosted.org/packages/ef/e7/104065c8270c7397c9571620d3ab880558957216f2b5ebb7e040f85eeb22/contourpy-1.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c414fc1ed8ee1dbd5da626cf3710c6013d3d27456651d156711fa24f24bd1291", size = 1268031 }, + { url = "https://files.pythonhosted.org/packages/e2/4a/c788d0bdbf32c8113c2354493ed291f924d4793c4a2e85b69e737a21a658/contourpy-1.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:31c1b55c1f34f80557d3830d3dd93ba722ce7e33a0b472cba0ec3b6535684d8f", size = 1325995 }, + { url = "https://files.pythonhosted.org/packages/a6/e6/a2f351a90d955f8b0564caf1ebe4b1451a3f01f83e5e3a414055a5b8bccb/contourpy-1.3.1-cp311-cp311-win32.whl", hash = "sha256:f611e628ef06670df83fce17805c344710ca5cde01edfdc72751311da8585375", size = 174396 }, + { url = "https://files.pythonhosted.org/packages/a8/7e/cd93cab453720a5d6cb75588cc17dcdc08fc3484b9de98b885924ff61900/contourpy-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:b2bdca22a27e35f16794cf585832e542123296b4687f9fd96822db6bae17bfc9", size = 219787 }, + { url = "https://files.pythonhosted.org/packages/3e/4f/e56862e64b52b55b5ddcff4090085521fc228ceb09a88390a2b103dccd1b/contourpy-1.3.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b457d6430833cee8e4b8e9b6f07aa1c161e5e0d52e118dc102c8f9bd7dd060d6", size = 265605 }, + { url = "https://files.pythonhosted.org/packages/b0/2e/52bfeeaa4541889f23d8eadc6386b442ee2470bd3cff9baa67deb2dd5c57/contourpy-1.3.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb76c1a154b83991a3cbbf0dfeb26ec2833ad56f95540b442c73950af2013750", size = 315040 }, + { url = "https://files.pythonhosted.org/packages/52/94/86bfae441707205634d80392e873295652fc313dfd93c233c52c4dc07874/contourpy-1.3.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:44a29502ca9c7b5ba389e620d44f2fbe792b1fb5734e8b931ad307071ec58c53", size = 218221 }, +] + +[[package]] +name = "coverage" +version = "7.6.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/84/ba/ac14d281f80aab516275012e8875991bb06203957aa1e19950139238d658/coverage-7.6.10.tar.gz", hash = "sha256:7fb105327c8f8f0682e29843e2ff96af9dcbe5bab8eeb4b398c6a33a16d80a23", size = 803868 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/12/2a2a923edf4ddabdffed7ad6da50d96a5c126dae7b80a33df7310e329a1e/coverage-7.6.10-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5c912978f7fbf47ef99cec50c4401340436d200d41d714c7a4766f377c5b7b78", size = 207982 }, + { url = "https://files.pythonhosted.org/packages/ca/49/6985dbca9c7be3f3cb62a2e6e492a0c88b65bf40579e16c71ae9c33c6b23/coverage-7.6.10-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a01ec4af7dfeb96ff0078ad9a48810bb0cc8abcb0115180c6013a6b26237626c", size = 208414 }, + { url = "https://files.pythonhosted.org/packages/35/93/287e8f1d1ed2646f4e0b2605d14616c9a8a2697d0d1b453815eb5c6cebdb/coverage-7.6.10-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3b204c11e2b2d883946fe1d97f89403aa1811df28ce0447439178cc7463448a", size = 236860 }, + { url = "https://files.pythonhosted.org/packages/de/e1/cfdb5627a03567a10031acc629b75d45a4ca1616e54f7133ca1fa366050a/coverage-7.6.10-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32ee6d8491fcfc82652a37109f69dee9a830e9379166cb73c16d8dc5c2915165", size = 234758 }, + { url = "https://files.pythonhosted.org/packages/6d/85/fc0de2bcda3f97c2ee9fe8568f7d48f7279e91068958e5b2cc19e0e5f600/coverage-7.6.10-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675cefc4c06e3b4c876b85bfb7c59c5e2218167bbd4da5075cbe3b5790a28988", size = 235920 }, + { url = "https://files.pythonhosted.org/packages/79/73/ef4ea0105531506a6f4cf4ba571a214b14a884630b567ed65b3d9c1975e1/coverage-7.6.10-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f4f620668dbc6f5e909a0946a877310fb3d57aea8198bde792aae369ee1c23b5", size = 234986 }, + { url = "https://files.pythonhosted.org/packages/c6/4d/75afcfe4432e2ad0405c6f27adeb109ff8976c5e636af8604f94f29fa3fc/coverage-7.6.10-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:4eea95ef275de7abaef630c9b2c002ffbc01918b726a39f5a4353916ec72d2f3", size = 233446 }, + { url = "https://files.pythonhosted.org/packages/86/5b/efee56a89c16171288cafff022e8af44f8f94075c2d8da563c3935212871/coverage-7.6.10-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e2f0280519e42b0a17550072861e0bc8a80a0870de260f9796157d3fca2733c5", size = 234566 }, + { url = "https://files.pythonhosted.org/packages/f2/db/67770cceb4a64d3198bf2aa49946f411b85ec6b0a9b489e61c8467a4253b/coverage-7.6.10-cp310-cp310-win32.whl", hash = "sha256:bc67deb76bc3717f22e765ab3e07ee9c7a5e26b9019ca19a3b063d9f4b874244", size = 210675 }, + { url = "https://files.pythonhosted.org/packages/8d/27/e8bfc43f5345ec2c27bc8a1fa77cdc5ce9dcf954445e11f14bb70b889d14/coverage-7.6.10-cp310-cp310-win_amd64.whl", hash = "sha256:0f460286cb94036455e703c66988851d970fdfd8acc2a1122ab7f4f904e4029e", size = 211518 }, + { url = "https://files.pythonhosted.org/packages/85/d2/5e175fcf6766cf7501a8541d81778fd2f52f4870100e791f5327fd23270b/coverage-7.6.10-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ea3c8f04b3e4af80e17bab607c386a830ffc2fb88a5484e1df756478cf70d1d3", size = 208088 }, + { url = "https://files.pythonhosted.org/packages/4b/6f/06db4dc8fca33c13b673986e20e466fd936235a6ec1f0045c3853ac1b593/coverage-7.6.10-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:507a20fc863cae1d5720797761b42d2d87a04b3e5aeb682ef3b7332e90598f43", size = 208536 }, + { url = "https://files.pythonhosted.org/packages/0d/62/c6a0cf80318c1c1af376d52df444da3608eafc913b82c84a4600d8349472/coverage-7.6.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d37a84878285b903c0fe21ac8794c6dab58150e9359f1aaebbeddd6412d53132", size = 240474 }, + { url = "https://files.pythonhosted.org/packages/a3/59/750adafc2e57786d2e8739a46b680d4fb0fbc2d57fbcb161290a9f1ecf23/coverage-7.6.10-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a534738b47b0de1995f85f582d983d94031dffb48ab86c95bdf88dc62212142f", size = 237880 }, + { url = "https://files.pythonhosted.org/packages/2c/f8/ef009b3b98e9f7033c19deb40d629354aab1d8b2d7f9cfec284dbedf5096/coverage-7.6.10-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d7a2bf79378d8fb8afaa994f91bfd8215134f8631d27eba3e0e2c13546ce994", size = 239750 }, + { url = "https://files.pythonhosted.org/packages/a6/e2/6622f3b70f5f5b59f705e680dae6db64421af05a5d1e389afd24dae62e5b/coverage-7.6.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6713ba4b4ebc330f3def51df1d5d38fad60b66720948112f114968feb52d3f99", size = 238642 }, + { url = "https://files.pythonhosted.org/packages/2d/10/57ac3f191a3c95c67844099514ff44e6e19b2915cd1c22269fb27f9b17b6/coverage-7.6.10-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ab32947f481f7e8c763fa2c92fd9f44eeb143e7610c4ca9ecd6a36adab4081bd", size = 237266 }, + { url = "https://files.pythonhosted.org/packages/ee/2d/7016f4ad9d553cabcb7333ed78ff9d27248ec4eba8dd21fa488254dff894/coverage-7.6.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:7bbd8c8f1b115b892e34ba66a097b915d3871db7ce0e6b9901f462ff3a975377", size = 238045 }, + { url = "https://files.pythonhosted.org/packages/a7/fe/45af5c82389a71e0cae4546413266d2195c3744849669b0bab4b5f2c75da/coverage-7.6.10-cp311-cp311-win32.whl", hash = "sha256:299e91b274c5c9cdb64cbdf1b3e4a8fe538a7a86acdd08fae52301b28ba297f8", size = 210647 }, + { url = "https://files.pythonhosted.org/packages/db/11/3f8e803a43b79bc534c6a506674da9d614e990e37118b4506faf70d46ed6/coverage-7.6.10-cp311-cp311-win_amd64.whl", hash = "sha256:489a01f94aa581dbd961f306e37d75d4ba16104bbfa2b0edb21d29b73be83609", size = 211508 }, + { url = "https://files.pythonhosted.org/packages/a1/70/de81bfec9ed38a64fc44a77c7665e20ca507fc3265597c28b0d989e4082e/coverage-7.6.10-pp39.pp310-none-any.whl", hash = "sha256:fd34e7b3405f0cc7ab03d54a334c17a9e802897580d964bd8c2001f4b9fd488f", size = 200223 }, +] + +[package.optional-dependencies] +toml = [ + { name = "tomli", marker = "python_full_version <= '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] + +[[package]] +name = "cssutils" +version = "2.11.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "more-itertools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/9f/329d26121fe165be44b1dfff21aa0dc348f04633931f1d20ed6cf448a236/cssutils-2.11.1.tar.gz", hash = "sha256:0563a76513b6af6eebbe788c3bf3d01c920e46b3f90c8416738c5cfc773ff8e2", size = 711657 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/ec/bb273b7208c606890dc36540fe667d06ce840a6f62f9fae7e658fcdc90fb/cssutils-2.11.1-py3-none-any.whl", hash = "sha256:a67bfdfdff4f3867fab43698ec4897c1a828eca5973f4073321b3bccaf1199b1", size = 385747 }, +] + +[[package]] +name = "cupy-cuda11x" +version = "13.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/1b/3afbaea2b78114c82b33ecc9affc79b7d9f4899945940b9b50790c93fd33/cupy_cuda11x-13.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ef854f0c63525d8163ab7af19f503d964de9dde0dd1cf9ea806a6ecb302cdce3", size = 109578634 }, + { url = "https://files.pythonhosted.org/packages/82/94/1da4205249baa861ac848dcbc36208a0b08f2ba2c414634525e53dabf818/cupy_cuda11x-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:54bf12a6663d0471e3e37e62972add348c5263ce803688f48bbfab1b20ebdb02", size = 96619611 }, + { url = "https://files.pythonhosted.org/packages/3f/ef/6924de40b67d4a0176e9c27f1ea9b0c8700935424473afd104cf72b36eb0/cupy_cuda11x-13.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:972d133efa2af80bb8ef321858ffe7cabc3abf8f58bcc4f13541dd497c05077d", size = 76006133 }, + { url = "https://files.pythonhosted.org/packages/4d/2d/9f01f25a81535572050f77ca618a54d8ad08afc13963c9fc57c162931e42/cupy_cuda11x-13.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:766ef1558a3ed967d5f092829bfb99edbcfaf75224925e1fb1a9f531e1e79f36", size = 110899612 }, + { url = "https://files.pythonhosted.org/packages/96/8f/b92bbf066ed86ec9dbeb969a5d6e6b6597bf0bab730f9e8b4c589f7cf198/cupy_cuda11x-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:77a81fa48d1a392b731885555a53cf2febde39cc33db55f2d78ba64b5ef4689b", size = 97172154 }, + { url = "https://files.pythonhosted.org/packages/08/94/113cc947b06b45b950979441a4f12f257b203d9a33796b1dbe6b82a2c36c/cupy_cuda11x-13.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:a8e8b7f7f73677afe2f70c38562f01f82688e43147550b3e192a5a2206e17fe1", size = 75976673 }, +] + +[[package]] +name = "cupy-cuda12x" +version = "13.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/60/dc268d1d9c5fdde4673a463feff5e9c70c59f477e647b54b501f65deef60/cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:674488e990998042cc54d2486d3c37cae80a12ba3787636be5a10b9446dd6914", size = 103601326 }, + { url = "https://files.pythonhosted.org/packages/7a/a9/1e19ecf008011df2935d038f26f721f22f2804c00077fc024f088e0996e6/cupy_cuda12x-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:cf4a2a0864364715881b50012927e88bd7ec1e6f1de3987970870861ae5ed25e", size = 90619949 }, + { url = "https://files.pythonhosted.org/packages/ce/6b/e77e3fc20648d323021f55d4e0fafc5572eff50c37750d6aeae868e110d8/cupy_cuda12x-13.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:7c0dc8c49d271d1c03e49a5d6c8e42e8fee3114b10f269a5ecc387731d693eaa", size = 69594183 }, + { url = "https://files.pythonhosted.org/packages/95/c9/0b88c015e98aad808c18f938267585d79e6211fe08650e0de7132e235e40/cupy_cuda12x-13.3.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:c0cc095b9a3835fd5db66c45ed3c58ecdc5a3bb14e53e1defbfd4a0ce5c8ecdb", size = 104925909 }, + { url = "https://files.pythonhosted.org/packages/8c/1f/596803c35833c01a41da21c6a7bb552f1ed56d807090ddc6727c8f396d7d/cupy_cuda12x-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a0e3bead04e502ebde515f0343444ca3f4f7aed09cbc3a316a946cba97f2ea66", size = 91172049 }, + { url = "https://files.pythonhosted.org/packages/d0/a8/5b5929830d2da94608d8126bafe2c52d69929a197fd8698ac09142c068ba/cupy_cuda12x-13.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:5f11df1149c7219858b27e4c8be92cb4eaf7364c94af6b78c40dffb98050a61f", size = 69564719 }, +] + +[[package]] +name = "cupy-rocm-4-3" +version = "13.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/16/7fd4bc8a8f1a4697f76e52c13f348f284fcc5c37195efd7e4c5d0eb2b15c/cupy_rocm_4_3-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:fc6b93be093bcea8b820baed856b61efc5c8cb09b02ebdc890431655714366ad", size = 41259087 }, + { url = "https://files.pythonhosted.org/packages/2e/ee/e893b0fdc6b347d8d65024442e5baf5ae13ee92c1364152e8f343906793d/cupy_rocm_4_3-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:f5e6886f1750810ddc3d261adf84d98b4d42f1d3cb2be5b7f5da181c8bf1593d", size = 41775360 }, +] + +[[package]] +name = "cupy-rocm-5-0" +version = "13.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastrlock" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-cuda11' and extra != 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace')" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/2e/6e4ecd65f5158808a54ef75d90fc7a884afb55bd405c4a7dbc34bb4a8f96/cupy_rocm_5_0-13.3.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d4c370441f7778b00f3ab80d6f0d669ea0215b6e96bbed9663ecce7ffce83fa9", size = 60056031 }, + { url = "https://files.pythonhosted.org/packages/08/52/8b5b6b32c84616989a2a84f02d9f4ca39d812de9f630276a664f321840bf/cupy_rocm_5_0-13.3.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:00907762735d182737bee317f532dc381337fb8e978bd846acb268df463b2d7b", size = 60576552 }, +] + +[[package]] +name = "cycler" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/95/a3dbbb5028f35eafb79008e7522a75244477d2838f38cbb722248dabc2a8/cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c", size = 7615 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321 }, +] + +[[package]] +name = "cython" +version = "3.0.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/84/4d/b720d6000f4ca77f030bd70f12550820f0766b568e43f11af7f7ad9061aa/cython-3.0.11.tar.gz", hash = "sha256:7146dd2af8682b4ca61331851e6aebce9fe5158e75300343f80c07ca80b1faff", size = 2755544 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/7f/ab5796a0951328d7818b771c36fe7e1a2077cffa28c917d9fa4a642728c3/Cython-3.0.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:44292aae17524abb4b70a25111fe7dec1a0ad718711d47e3786a211d5408fdaa", size = 3100879 }, + { url = "https://files.pythonhosted.org/packages/d8/3b/67480e609537e9fc899864847910ded481b82d033fea1b7fcf85893a2fc4/Cython-3.0.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a75d45fbc20651c1b72e4111149fed3b33d270b0a4fb78328c54d965f28d55e1", size = 3461957 }, + { url = "https://files.pythonhosted.org/packages/f0/89/b1ae45689abecca777f95462781a76e67ff46b55495a481ec5a73a739994/Cython-3.0.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d89a82937ce4037f092e9848a7bbcc65bc8e9fc9aef2bb74f5c15e7d21a73080", size = 3627062 }, + { url = "https://files.pythonhosted.org/packages/44/77/a651da74d5d41c6045bbe0b6990b1515bf4850cd7a8d8580333c90dfce2e/Cython-3.0.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a8ea2e7e2d3bc0d8630dafe6c4a5a89485598ff8a61885b74f8ed882597efd5", size = 3680431 }, + { url = "https://files.pythonhosted.org/packages/59/45/60e7e8db93c3eb8b2af8c64020c1fa502e355f4b762886a24d46e433f395/Cython-3.0.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:cee29846471ce60226b18e931d8c1c66a158db94853e3e79bc2da9bd22345008", size = 3497314 }, + { url = "https://files.pythonhosted.org/packages/f8/0b/6919025958926625319f83523ee7f45e7e7ae516b8054dcff6eb710daf32/Cython-3.0.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:eeb6860b0f4bfa402de8929833fe5370fa34069c7ebacb2d543cb017f21fb891", size = 3709091 }, + { url = "https://files.pythonhosted.org/packages/52/3c/c21b9b9271dfaa46fa2938de730f62fc94b9c2ec25ec400585e372f35dcd/Cython-3.0.11-cp310-cp310-win32.whl", hash = "sha256:3699391125ab344d8d25438074d1097d9ba0fb674d0320599316cfe7cf5f002a", size = 2576110 }, + { url = "https://files.pythonhosted.org/packages/f9/de/19fdd1c7a52e0534bf5f544e0346c15d71d20338dbd013117f763b94613f/Cython-3.0.11-cp310-cp310-win_amd64.whl", hash = "sha256:d02f4ebe15aac7cdacce1a628e556c1983f26d140fd2e0ac5e0a090e605a2d38", size = 2776386 }, + { url = "https://files.pythonhosted.org/packages/f8/73/e55be864199cd674cb3426a052726c205589b1ac66fb0090e7fe793b60b3/Cython-3.0.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:75ba1c70b6deeaffbac123856b8d35f253da13552207aa969078611c197377e4", size = 3113599 }, + { url = "https://files.pythonhosted.org/packages/09/c9/537108d0980beffff55336baaf8b34162ad0f3f33ededcb5db07069bc8ef/Cython-3.0.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af91497dc098718e634d6ec8f91b182aea6bb3690f333fc9a7777bc70abe8810", size = 3441131 }, + { url = "https://files.pythonhosted.org/packages/93/03/e330b241ad8aa12bb9d98b58fb76d4eb7dcbe747479aab5c29fce937b9e7/Cython-3.0.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3999fb52d3328a6a5e8c63122b0a8bd110dfcdb98dda585a3def1426b991cba7", size = 3595065 }, + { url = "https://files.pythonhosted.org/packages/4a/84/a3c40f2c0439d425daa5aa4e3a6fdbbb41341a14a6fd97f94906f528d9a4/Cython-3.0.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d566a4e09b8979be8ab9f843bac0dd216c81f5e5f45661a9b25cd162ed80508c", size = 3641667 }, + { url = "https://files.pythonhosted.org/packages/6d/93/bdb61e0254ed8f1d21a14088a473584ecb1963d68dba5682158aa45c70ef/Cython-3.0.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:46aec30f217bdf096175a1a639203d44ac73a36fe7fa3dd06bd012e8f39eca0f", size = 3503650 }, + { url = "https://files.pythonhosted.org/packages/f8/62/0da548144c71176155ff5355c4cc40fb28b9effe22e830b55cec8072bdf2/Cython-3.0.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ddd1fe25af330f4e003421636746a546474e4ccd8f239f55d2898d80983d20ed", size = 3709662 }, + { url = "https://files.pythonhosted.org/packages/56/d3/d9c9eaf3611a9fe5256266d07b6a5f9069aa84d20d9f6aa5824289513315/Cython-3.0.11-cp311-cp311-win32.whl", hash = "sha256:221de0b48bf387f209003508e602ce839a80463522fc6f583ad3c8d5c890d2c1", size = 2577870 }, + { url = "https://files.pythonhosted.org/packages/fd/10/236fcc0306f85a2db1b8bc147aea714b66a2f27bac4d9e09e5b2c5d5dcca/Cython-3.0.11-cp311-cp311-win_amd64.whl", hash = "sha256:3ff8ac1f0ecd4f505db4ab051e58e4531f5d098b6ac03b91c3b902e8d10c67b3", size = 2785053 }, + { url = "https://files.pythonhosted.org/packages/43/39/bdbec9142bc46605b54d674bf158a78b191c2b75be527c6dcf3e6dfe90b8/Cython-3.0.11-py2.py3-none-any.whl", hash = "sha256:0e25f6425ad4a700d7f77cd468da9161e63658837d1bc34861a9861a4ef6346d", size = 1171267 }, +] + +[[package]] +name = "cytoolz" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "toolz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/f9/3243eed3a6545c2a33a21f74f655e3fcb5d2192613cd3db81a93369eb339/cytoolz-1.0.1.tar.gz", hash = "sha256:89cc3161b89e1bb3ed7636f74ed2e55984fd35516904fc878cae216e42b2c7d6", size = 626652 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d9/f13d66c16cff1fa1cb6c234698029877c456f35f577ef274aba3b86e7c51/cytoolz-1.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cec9af61f71fc3853eb5dca3d42eb07d1f48a4599fa502cbe92adde85f74b042", size = 403515 }, + { url = "https://files.pythonhosted.org/packages/4b/2d/4cdf848a69300c7d44984f2ebbebb3b8576e5449c8dea157298f3bdc4da3/cytoolz-1.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:140bbd649dbda01e91add7642149a5987a7c3ccc251f2263de894b89f50b6608", size = 383936 }, + { url = "https://files.pythonhosted.org/packages/72/a4/ccfdd3f0ed9cc818f734b424261f6018fc61e3ec833bf85225a9aca0d994/cytoolz-1.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e90124bdc42ff58b88cdea1d24a6bc5f776414a314cc4d94f25c88badb3a16d1", size = 1934569 }, + { url = "https://files.pythonhosted.org/packages/50/fc/38d5344fa595683ad10dc819cfc1d8b9d2b3391ccf3e8cb7bab4899a01f5/cytoolz-1.0.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e74801b751e28f7c5cc3ad264c123954a051f546f2fdfe089f5aa7a12ccfa6da", size = 2015129 }, + { url = "https://files.pythonhosted.org/packages/28/29/75261748dc54a20a927f33641f4e9aac674cfc6d3fbd4f332e10d0b37639/cytoolz-1.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:582dad4545ddfb5127494ef23f3fa4855f1673a35d50c66f7638e9fb49805089", size = 2000506 }, + { url = "https://files.pythonhosted.org/packages/00/ae/e4ead004cc2698281d153c4a5388638d67cdb5544d6d6cc1e5b3db2bd2a3/cytoolz-1.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7bd0618e16efe03bd12f19c2a26a27e6e6b75d7105adb7be1cd2a53fa755d8", size = 1957537 }, + { url = "https://files.pythonhosted.org/packages/4a/ff/4f3aa07f4f47701f7f63df60ce0a5669fa09c256c3d4a33503a9414ea5cc/cytoolz-1.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d74cca6acf1c4af58b2e4a89cc565ed61c5e201de2e434748c93e5a0f5c541a5", size = 1863331 }, + { url = "https://files.pythonhosted.org/packages/a2/29/654f57f2a9b8e9765a4ab876765f64f94530b61fc6471a07feea42ece6d4/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:823a3763828d8d457f542b2a45d75d6b4ced5e470b5c7cf2ed66a02f508ed442", size = 1849938 }, + { url = "https://files.pythonhosted.org/packages/bc/7b/11f457db6b291060a98315ab2c7198077d8bddeeebe5f7126d9dad98cc54/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:51633a14e6844c61db1d68c1ffd077cf949f5c99c60ed5f1e265b9e2966f1b52", size = 1852345 }, + { url = "https://files.pythonhosted.org/packages/6b/92/0dccc96ce0323be236d404f5084479b79b747fa0e74e43a270e95868b5f9/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:f3ec9b01c45348f1d0d712507d54c2bfd69c62fbd7c9ef555c9d8298693c2432", size = 1989877 }, + { url = "https://files.pythonhosted.org/packages/a3/c8/1c5203a81200bae51aa8f7b5fad613f695bf1afa03f16251ca23ecb2ef9f/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1855022b712a9c7a5bce354517ab4727a38095f81e2d23d3eabaf1daeb6a3b3c", size = 1994492 }, + { url = "https://files.pythonhosted.org/packages/e2/8a/04bc193c4d7ced8ef6bb62cdcd0bf40b5e5eb26586ed2cfb4433ec7dfd0a/cytoolz-1.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9930f7288c4866a1dc1cc87174f0c6ff4cad1671eb1f6306808aa6c445857d78", size = 1896077 }, + { url = "https://files.pythonhosted.org/packages/21/a5/bee63a58f51d2c74856db66e6119a014464ff8cb1c9387fa4bd2d94e49b0/cytoolz-1.0.1-cp310-cp310-win32.whl", hash = "sha256:a9baad795d72fadc3445ccd0f122abfdbdf94269157e6d6d4835636dad318804", size = 322135 }, + { url = "https://files.pythonhosted.org/packages/e8/16/7abfb1685e8b7f2838264551ee33651748994813f566ac4c3d737dfe90e5/cytoolz-1.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:ad95b386a84e18e1f6136f6d343d2509d4c3aae9f5a536f3dc96808fcc56a8cf", size = 363599 }, + { url = "https://files.pythonhosted.org/packages/dc/ea/8131ae39119820b8867cddc23716fa9f681f2b3bbce6f693e68dfb36b55b/cytoolz-1.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2d958d4f04d9d7018e5c1850790d9d8e68b31c9a2deebca74b903706fdddd2b6", size = 406162 }, + { url = "https://files.pythonhosted.org/packages/26/18/3d9bd4c146f6ea6e51300c242b20cb416966b21d481dac230e1304f1e54b/cytoolz-1.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0f445b8b731fc0ecb1865b8e68a070084eb95d735d04f5b6c851db2daf3048ab", size = 384961 }, + { url = "https://files.pythonhosted.org/packages/e4/73/9034827907c7f85c7c484c9494e905d022fb8174526004e9ef332570349e/cytoolz-1.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f546a96460a7e28eb2ec439f4664fa646c9b3e51c6ebad9a59d3922bbe65e30", size = 2091698 }, + { url = "https://files.pythonhosted.org/packages/74/af/d5c2733b0fde1a08254ff1a8a8d567874040c9eb1606363cfebc0713c73f/cytoolz-1.0.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0317681dd065532d21836f860b0563b199ee716f55d0c1f10de3ce7100c78a3b", size = 2188452 }, + { url = "https://files.pythonhosted.org/packages/6a/bb/77c71fa9c217260b4056a732d754748903423c2cdd82a673d6064741e375/cytoolz-1.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c0ef52febd5a7821a3fd8d10f21d460d1a3d2992f724ba9c91fbd7a96745d41", size = 2174203 }, + { url = "https://files.pythonhosted.org/packages/fc/a9/a5b4a3ff5d22faa1b60293bfe97362e2caf4a830c26d37ab5557f60d04b2/cytoolz-1.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5ebaf419acf2de73b643cf96108702b8aef8e825cf4f63209ceb078d5fbbbfd", size = 2099831 }, + { url = "https://files.pythonhosted.org/packages/35/08/7f6869ea1ff31ce5289a7d58d0e7090acfe7058baa2764473048ff61ea3c/cytoolz-1.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5f7f04eeb4088947585c92d6185a618b25ad4a0f8f66ea30c8db83cf94a425e3", size = 1996744 }, + { url = "https://files.pythonhosted.org/packages/46/b4/9ac424c994b51763fd1bbed62d95f8fba8fa0e45c8c3c583904fdaf8f51d/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f61928803bb501c17914b82d457c6f50fe838b173fb40d39c38d5961185bd6c7", size = 2013733 }, + { url = "https://files.pythonhosted.org/packages/3e/99/03009765c4b87d742d5b5a8670abb56a8c7ede033c2cdaa4be8662d3b001/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:d2960cb4fa01ccb985ad1280db41f90dc97a80b397af970a15d5a5de403c8c61", size = 1994850 }, + { url = "https://files.pythonhosted.org/packages/40/9a/8458af9a5557e177ea42f8cf7e477bede518b0bbef564e28c4151feaa52c/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b2b407cc3e9defa8df5eb46644f6f136586f70ba49eba96f43de67b9a0984fd3", size = 2155352 }, + { url = "https://files.pythonhosted.org/packages/5e/5c/2a701423e001fcbec288b4f3fc2bf67557d114c2388237fc1ae67e1e2686/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:8245f929144d4d3bd7b972c9593300195c6cea246b81b4c46053c48b3f044580", size = 2163515 }, + { url = "https://files.pythonhosted.org/packages/36/16/ee2e06e65d9d533bc05cd52a0b355ba9072fc8f60d77289e529c6d2e3750/cytoolz-1.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e37385db03af65763933befe89fa70faf25301effc3b0485fec1c15d4ce4f052", size = 2054431 }, + { url = "https://files.pythonhosted.org/packages/d8/d5/2fac8315f210fa1bc7106e27c19e1211580aa25bb7fa17dfd79505e5baf2/cytoolz-1.0.1-cp311-cp311-win32.whl", hash = "sha256:50f9c530f83e3e574fc95c264c3350adde8145f4f8fc8099f65f00cc595e5ead", size = 322004 }, + { url = "https://files.pythonhosted.org/packages/a9/9e/0b70b641850a95f9ff90adde9d094a4b1d81ec54dadfd97fec0a2aaf440e/cytoolz-1.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:b7f6b617454b4326af7bd3c7c49b0fc80767f134eb9fd6449917a058d17a0e3c", size = 365358 }, + { url = "https://files.pythonhosted.org/packages/d9/f7/ef2a10daaec5c0f7d781d50758c6187eee484256e356ae8ef178d6c48497/cytoolz-1.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:83d19d55738ad9c60763b94f3f6d3c6e4de979aeb8d76841c1401081e0e58d96", size = 345702 }, + { url = "https://files.pythonhosted.org/packages/c8/14/53c84adddedb67ff1546abb86fea04d26e24298c3ceab8436d20122ed0b9/cytoolz-1.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f112a71fad6ea824578e6393765ce5c054603afe1471a5c753ff6c67fd872d10", size = 385695 }, + { url = "https://files.pythonhosted.org/packages/bd/80/3ae356c5e7b8d7dc7d1adb52f6932fee85cd748ed4e1217c269d2dfd610f/cytoolz-1.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a515df8f8aa6e1eaaf397761a6e4aff2eef73b5f920aedf271416d5471ae5ee", size = 406261 }, + { url = "https://files.pythonhosted.org/packages/0c/31/8e43761ffc82d90bf9cab7e0959712eedcd1e33c211397e143dd42d7af57/cytoolz-1.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92c398e7b7023460bea2edffe5fcd0a76029580f06c3f6938ac3d198b47156f3", size = 397207 }, + { url = "https://files.pythonhosted.org/packages/d1/b9/fe9da37090b6444c65f848a83e390f87d8cb43d6a4df46de1556ad7e5ceb/cytoolz-1.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:3237e56211e03b13df47435b2369f5df281e02b04ad80a948ebd199b7bc10a47", size = 343358 }, +] + +[[package]] +name = "dace" +version = "1.0.0" +source = { git = "https://github.com/spcl/dace?branch=main#5097d6f1a4b6e1dc8e06be6eb4aa585a6c6e04f3" } +resolution-markers = [ + "python_full_version >= '3.11'", + "python_full_version < '3.11'", +] +dependencies = [ + { name = "aenum" }, + { name = "astunparse" }, + { name = "dill" }, + { name = "fparser" }, + { name = "networkx" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" } }, + { name = "packaging" }, + { name = "ply" }, + { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pyyaml" }, + { name = "sympy" }, +] + +[[package]] +name = "dace" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.11'", + "python_full_version < '3.11'", +] +dependencies = [ + { name = "aenum" }, + { name = "astunparse" }, + { name = "dill" }, + { name = "fparser" }, + { name = "networkx" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" } }, + { name = "packaging" }, + { name = "ply" }, + { name = "pyreadline", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "pyyaml" }, + { name = "sympy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b4/67/fb1be2673868ee1f08e9c7bacc0b9b77d2bd5ff17ab47896f20006a2a1a5/dace-1.0.1.tar.gz", hash = "sha256:6f7a5defb082ed4f1a81f857d4268ed2bb606f6d9ea9c28d2831d1151e3a80f7", size = 5801727 } + +[[package]] +name = "debugpy" +version = "1.8.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/68/25/c74e337134edf55c4dfc9af579eccb45af2393c40960e2795a94351e8140/debugpy-1.8.12.tar.gz", hash = "sha256:646530b04f45c830ceae8e491ca1c9320a2d2f0efea3141487c82130aba70dce", size = 1641122 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/19/dd58334c0a1ec07babf80bf29fb8daf1a7ca4c1a3bbe61548e40616ac087/debugpy-1.8.12-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:a2ba7ffe58efeae5b8fad1165357edfe01464f9aef25e814e891ec690e7dd82a", size = 2076091 }, + { url = "https://files.pythonhosted.org/packages/4c/37/bde1737da15f9617d11ab7b8d5267165f1b7dae116b2585a6643e89e1fa2/debugpy-1.8.12-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbbd4149c4fc5e7d508ece083e78c17442ee13b0e69bfa6bd63003e486770f45", size = 3560717 }, + { url = "https://files.pythonhosted.org/packages/d9/ca/bc67f5a36a7de072908bc9e1156c0f0b272a9a2224cf21540ab1ffd71a1f/debugpy-1.8.12-cp310-cp310-win32.whl", hash = "sha256:b202f591204023b3ce62ff9a47baa555dc00bb092219abf5caf0e3718ac20e7c", size = 5180672 }, + { url = "https://files.pythonhosted.org/packages/c1/b9/e899c0a80dfa674dbc992f36f2b1453cd1ee879143cdb455bc04fce999da/debugpy-1.8.12-cp310-cp310-win_amd64.whl", hash = "sha256:9649eced17a98ce816756ce50433b2dd85dfa7bc92ceb60579d68c053f98dff9", size = 5212702 }, + { url = "https://files.pythonhosted.org/packages/af/9f/5b8af282253615296264d4ef62d14a8686f0dcdebb31a669374e22fff0a4/debugpy-1.8.12-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:36f4829839ef0afdfdd208bb54f4c3d0eea86106d719811681a8627ae2e53dd5", size = 2174643 }, + { url = "https://files.pythonhosted.org/packages/ef/31/f9274dcd3b0f9f7d1e60373c3fa4696a585c55acb30729d313bb9d3bcbd1/debugpy-1.8.12-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a28ed481d530e3138553be60991d2d61103ce6da254e51547b79549675f539b7", size = 3133457 }, + { url = "https://files.pythonhosted.org/packages/ab/ca/6ee59e9892e424477e0c76e3798046f1fd1288040b927319c7a7b0baa484/debugpy-1.8.12-cp311-cp311-win32.whl", hash = "sha256:4ad9a94d8f5c9b954e0e3b137cc64ef3f579d0df3c3698fe9c3734ee397e4abb", size = 5106220 }, + { url = "https://files.pythonhosted.org/packages/d5/1a/8ab508ab05ede8a4eae3b139bbc06ea3ca6234f9e8c02713a044f253be5e/debugpy-1.8.12-cp311-cp311-win_amd64.whl", hash = "sha256:4703575b78dd697b294f8c65588dc86874ed787b7348c65da70cfc885efdf1e1", size = 5130481 }, + { url = "https://files.pythonhosted.org/packages/38/c4/5120ad36405c3008f451f94b8f92ef1805b1e516f6ff870f331ccb3c4cc0/debugpy-1.8.12-py2.py3-none-any.whl", hash = "sha256:274b6a2040349b5c9864e475284bce5bb062e63dce368a394b8cc865ae3b00c6", size = 5229490 }, +] + +[[package]] +name = "decorator" +version = "5.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/0c/8d907af351aa16b42caae42f9d6aa37b900c67308052d10fdce809f8d952/decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330", size = 35016 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186", size = 9073 }, +] + +[[package]] +name = "deepdiff" +version = "8.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "orderly-set" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/89/12/207d2ec96a526cf9d04fc2423ff9832e93b665e94b9d7c9b5198903e18a7/deepdiff-8.2.0.tar.gz", hash = "sha256:6ec78f65031485735545ffbe7a61e716c3c2d12ca6416886d5e9291fc76c46c3", size = 432573 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/13/d7dd6b8c297b1d5cfea4f1ebd678e68d90ab04b6613d005c0a7c506d11e1/deepdiff-8.2.0-py3-none-any.whl", hash = "sha256:5091f2cdfd372b1b9f6bfd8065ba323ae31118dc4e42594371b38c8bea3fd0a4", size = 83672 }, +] + +[[package]] +name = "devtools" +version = "0.12.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/75/b78198620640d394bc435c17bb49db18419afdd6cfa3ed8bcfe14034ec80/devtools-0.12.2.tar.gz", hash = "sha256:efceab184cb35e3a11fa8e602cc4fadacaa2e859e920fc6f87bf130b69885507", size = 75005 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/ae/afb1487556e2dc827a17097aac8158a25b433a345386f0e249f6d2694ccb/devtools-0.12.2-py3-none-any.whl", hash = "sha256:c366e3de1df4cdd635f1ad8cbcd3af01a384d7abda71900e68d43b04eb6aaca7", size = 19411 }, +] + +[[package]] +name = "dict2css" +version = "0.3.0.post1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cssutils" }, + { name = "domdf-python-tools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/eb/776eef1f1aa0188c0fc165c3a60b71027539f71f2eedc43ad21b060e9c39/dict2css-0.3.0.post1.tar.gz", hash = "sha256:89c544c21c4ca7472c3fffb9d37d3d926f606329afdb751dc1de67a411b70719", size = 7845 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/47/290daabcf91628f4fc0e17c75a1690b354ba067066cd14407712600e609f/dict2css-0.3.0.post1-py3-none-any.whl", hash = "sha256:f006a6b774c3e31869015122ae82c491fd25e7de4a75607a62aa3e798f837e0d", size = 25647 }, +] + +[[package]] +name = "dill" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/43/86fe3f9e130c4137b0f1b50784dd70a5087b911fe07fa81e53e0c4c47fea/dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c", size = 187000 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/d1/e73b6ad76f0b1fb7f23c35c6d95dbc506a9c8804f43dda8cb5b0fa6331fd/dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a", size = 119418 }, +] + +[[package]] +name = "diskcache" +version = "5.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550 }, +] + +[[package]] +name = "distlib" +version = "0.3.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 }, +] + +[[package]] +name = "docutils" +version = "0.21.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/ed/aefcc8cd0ba62a0560c3c18c33925362d46c6075480bfa4df87b28e169a9/docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f", size = 2204444 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 }, +] + +[[package]] +name = "domdf-python-tools" +version = "3.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "natsort" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6b/78/974e10c583ba9d2302e748c9585313a7f2c7ba00e4f600324f432e38fe68/domdf_python_tools-3.9.0.tar.gz", hash = "sha256:1f8a96971178333a55e083e35610d7688cd7620ad2b99790164e1fc1a3614c18", size = 103792 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/de/e9/7447a88b217650a74927d3444a89507986479a69b83741900eddd34167fe/domdf_python_tools-3.9.0-py3-none-any.whl", hash = "sha256:4e1ef365cbc24627d6d1e90cf7d46d8ab8df967e1237f4a26885f6986c78872e", size = 127106 }, +] + +[[package]] +name = "esbonio" +version = "0.16.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "platformdirs" }, + { name = "pygls" }, + { name = "pyspellchecker" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/c5/0c89af3da1f3133b53f3ba8ae677ed4d4ddff33eec50dbf32c95e01ed2d2/esbonio-0.16.5.tar.gz", hash = "sha256:acab2e16c6cf8f7232fb04e0d48514ce50566516b1f6fcf669ccf2f247e8b10f", size = 145347 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d8/ca/a0296fca375d4324f471bb34d2ce8a585b48fb9eae21cf9abe00913eb899/esbonio-0.16.5-py3-none-any.whl", hash = "sha256:04ba926e3603f7b1fde1abc690b47afd60749b64b1029b6bce8e1de0bb284921", size = 170830 }, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/09/35/2495c4ac46b980e4ca1f6ad6db102322ef3ad2410b79fdde159a4b0f3b92/exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc", size = 28883 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, +] + +[[package]] +name = "execnet" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/ff/b4c0dc78fbe20c3e59c0c7334de0c27eb4001a2b2017999af398bf730817/execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3", size = 166524 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/09/2aea36ff60d16dd8879bdb2f5b3ee0ba8d08cbbdcdfe870e695ce3784385/execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", size = 40612 }, +] + +[[package]] +name = "executing" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, +] + +[[package]] +name = "factory-boy" +version = "3.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "faker" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/3d/8070dde623341401b1c80156583d4c793058fe250450178218bb6e45526c/factory_boy-3.3.1.tar.gz", hash = "sha256:8317aa5289cdfc45f9cae570feb07a6177316c82e34d14df3c2e1f22f26abef0", size = 163924 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/cf/44ec67152f3129d0114c1499dd34f0a0a0faf43d9c2af05bc535746ca482/factory_boy-3.3.1-py2.py3-none-any.whl", hash = "sha256:7b1113c49736e1e9995bc2a18f4dbf2c52cf0f841103517010b1d825712ce3ca", size = 36878 }, +] + +[[package]] +name = "faker" +version = "35.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6c/d9/c5bc5edaeea1a3a5da6e7f93a5c0bdd49e0740d8c4a1e7ea9515fd4da2ed/faker-35.2.0.tar.gz", hash = "sha256:28c24061780f83b45d9cb15a72b8f143b09d276c9ff52eb557744b7a89e8ba19", size = 1874908 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/db/bab82efcf241dabc93ad65cebaf0f2332cb2827b55a5d3a6ef1d52fa2c29/Faker-35.2.0-py3-none-any.whl", hash = "sha256:609abe555761ff31b0e5e16f958696e9b65c9224a7ac612ac96bfc2b8f09fe35", size = 1917786 }, +] + +[[package]] +name = "fastjsonschema" +version = "2.21.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/50/4b769ce1ac4071a1ef6d86b1a3fb56cdc3a37615e8c5519e1af96cdac366/fastjsonschema-2.21.1.tar.gz", hash = "sha256:794d4f0a58f848961ba16af7b9c85a3e88cd360df008c59aac6fc5ae9323b5d4", size = 373939 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/2b/0817a2b257fe88725c25589d89aec060581aabf668707a8d03b2e9e0cb2a/fastjsonschema-2.21.1-py3-none-any.whl", hash = "sha256:c9e5b7e908310918cf494a434eeb31384dd84a98b57a30bcb1f535015b554667", size = 23924 }, +] + +[[package]] +name = "fastrlock" +version = "0.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/73/b1/1c3d635d955f2b4bf34d45abf8f35492e04dbd7804e94ce65d9f928ef3ec/fastrlock-0.8.3.tar.gz", hash = "sha256:4af6734d92eaa3ab4373e6c9a1dd0d5ad1304e172b1521733c6c3b3d73c8fa5d", size = 79327 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/02/3f771177380d8690812d5b2b7736dc6b6c8cd1c317e4572e65f823eede08/fastrlock-0.8.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:cc5fa9166e05409f64a804d5b6d01af670979cdb12cd2594f555cb33cdc155bd", size = 55094 }, + { url = "https://files.pythonhosted.org/packages/be/b4/aae7ed94b8122c325d89eb91336084596cebc505dc629b795fcc9629606d/fastrlock-0.8.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:7a77ebb0a24535ef4f167da2c5ee35d9be1e96ae192137e9dc3ff75b8dfc08a5", size = 48220 }, + { url = "https://files.pythonhosted.org/packages/96/87/9807af47617fdd65c68b0fcd1e714542c1d4d3a1f1381f591f1aa7383a53/fastrlock-0.8.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:d51f7fb0db8dab341b7f03a39a3031678cf4a98b18533b176c533c122bfce47d", size = 49551 }, + { url = "https://files.pythonhosted.org/packages/9d/12/e201634810ac9aee59f93e3953cb39f98157d17c3fc9d44900f1209054e9/fastrlock-0.8.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:767ec79b7f6ed9b9a00eb9ff62f2a51f56fdb221c5092ab2dadec34a9ccbfc6e", size = 49398 }, + { url = "https://files.pythonhosted.org/packages/15/a1/439962ed439ff6f00b7dce14927e7830e02618f26f4653424220a646cd1c/fastrlock-0.8.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d6a77b3f396f7d41094ef09606f65ae57feeb713f4285e8e417f4021617ca62", size = 53334 }, + { url = "https://files.pythonhosted.org/packages/b5/9e/1ae90829dd40559ab104e97ebe74217d9da794c4bb43016da8367ca7a596/fastrlock-0.8.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:92577ff82ef4a94c5667d6d2841f017820932bc59f31ffd83e4a2c56c1738f90", size = 52495 }, + { url = "https://files.pythonhosted.org/packages/e5/8c/5e746ee6f3d7afbfbb0d794c16c71bfd5259a4e3fb1dda48baf31e46956c/fastrlock-0.8.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3df8514086e16bb7c66169156a8066dc152f3be892c7817e85bf09a27fa2ada2", size = 51972 }, + { url = "https://files.pythonhosted.org/packages/76/a7/8b91068f00400931da950f143fa0f9018bd447f8ed4e34bed3fe65ed55d2/fastrlock-0.8.3-cp310-cp310-win_amd64.whl", hash = "sha256:001fd86bcac78c79658bac496e8a17472d64d558cd2227fdc768aa77f877fe40", size = 30946 }, + { url = "https://files.pythonhosted.org/packages/90/9e/647951c579ef74b6541493d5ca786d21a0b2d330c9514ba2c39f0b0b0046/fastrlock-0.8.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:f68c551cf8a34b6460a3a0eba44bd7897ebfc820854e19970c52a76bf064a59f", size = 55233 }, + { url = "https://files.pythonhosted.org/packages/be/91/5f3afba7d14b8b7d60ac651375f50fff9220d6ccc3bef233d2bd74b73ec7/fastrlock-0.8.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:55d42f6286b9d867370af4c27bc70d04ce2d342fe450c4a4fcce14440514e695", size = 48911 }, + { url = "https://files.pythonhosted.org/packages/d5/7a/e37bd72d7d70a8a551b3b4610d028bd73ff5d6253201d5d3cf6296468bee/fastrlock-0.8.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_24_i686.whl", hash = "sha256:bbc3bf96dcbd68392366c477f78c9d5c47e5d9290cb115feea19f20a43ef6d05", size = 50357 }, + { url = "https://files.pythonhosted.org/packages/0d/ef/a13b8bab8266840bf38831d7bf5970518c02603d00a548a678763322d5bf/fastrlock-0.8.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:77ab8a98417a1f467dafcd2226718f7ca0cf18d4b64732f838b8c2b3e4b55cb5", size = 50222 }, + { url = "https://files.pythonhosted.org/packages/01/e2/5e5515562b2e9a56d84659377176aef7345da2c3c22909a1897fe27e14dd/fastrlock-0.8.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:04bb5eef8f460d13b8c0084ea5a9d3aab2c0573991c880c0a34a56bb14951d30", size = 54553 }, + { url = "https://files.pythonhosted.org/packages/c0/8f/65907405a8cdb2fc8beaf7d09a9a07bb58deff478ff391ca95be4f130b70/fastrlock-0.8.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c9d459ce344c21ff03268212a1845aa37feab634d242131bc16c2a2355d5f65", size = 53362 }, + { url = "https://files.pythonhosted.org/packages/ec/b9/ae6511e52738ba4e3a6adb7c6a20158573fbc98aab448992ece25abb0b07/fastrlock-0.8.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:33e6fa4af4f3af3e9c747ec72d1eadc0b7ba2035456c2afb51c24d9e8a56f8fd", size = 52836 }, + { url = "https://files.pythonhosted.org/packages/88/3e/c26f8192c93e8e43b426787cec04bb46ac36e72b1033b7fe5a9267155fdf/fastrlock-0.8.3-cp311-cp311-win_amd64.whl", hash = "sha256:5e5f1665d8e70f4c5b4a67f2db202f354abc80a321ce5a26ac1493f055e3ae2c", size = 31046 }, +] + +[[package]] +name = "filelock" +version = "3.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/9c/0b15fb47b464e1b663b1acd1253a062aa5feecb07d4e597daea542ebd2b5/filelock-3.17.0.tar.gz", hash = "sha256:ee4e77401ef576ebb38cd7f13b9b28893194acc20a8e68e18730ba9c0e54660e", size = 18027 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/ec/00d68c4ddfedfe64159999e5f8a98fb8442729a63e2077eb9dcd89623d27/filelock-3.17.0-py3-none-any.whl", hash = "sha256:533dc2f7ba78dc2f0f531fc6c4940addf7b70a481e269a5a3b93be94ffbe8338", size = 16164 }, +] + +[[package]] +name = "fonttools" +version = "4.55.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/24/de7e40adc99be2aa5adc6321bbdf3cf58dbe751b87343da658dd3fc7d946/fonttools-4.55.8.tar.gz", hash = "sha256:54d481d456dcd59af25d4a9c56b2c4c3f20e9620b261b84144e5950f33e8df17", size = 3458915 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/b8/82b3444cb081798eabb8397452ddf73680e623d7fdf9c575594a2240b8a2/fonttools-4.55.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d11600f5343092697d7434f3bf77a393c7ae74be206fe30e577b9a195fd53165", size = 2752288 }, + { url = "https://files.pythonhosted.org/packages/86/8f/9c5f2172e9f6dcf52bb6477bcd5a023d056114787c8184b683c34996f5a1/fonttools-4.55.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c96f2506ce1a0beeaa9595f9a8b7446477eb133f40c0e41fc078744c28149f80", size = 2280718 }, + { url = "https://files.pythonhosted.org/packages/c6/a6/b7cd7b54412bb7a27e282ee54459cae24524ad0eab6f81ead2a91d435287/fonttools-4.55.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b5f05ef72e846e9f49ccdd74b9da4309901a4248434c63c1ee9321adcb51d65", size = 4562177 }, + { url = "https://files.pythonhosted.org/packages/0e/16/eff3be24cecb9336639148c40507f949c193642d8369352af480597633fb/fonttools-4.55.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba45b637da80a262b55b7657aec68da2ac54b8ae7891cd977a5dbe5fd26db429", size = 4604843 }, + { url = "https://files.pythonhosted.org/packages/b5/95/737574364439cbcc5e6d4f3e000f15432141680ca8cb5c216b619a3d1cab/fonttools-4.55.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:edcffaeadba9a334c1c3866e275d7dd495465e7dbd296f688901bdbd71758113", size = 4559127 }, + { url = "https://files.pythonhosted.org/packages/5f/07/ea90834742f9b3e51a05f0f15f7c817eb7aab3d6ebf4f06c4626825ccb89/fonttools-4.55.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b9f9fce3c9b2196e162182ec5db8af8eb3acd0d76c2eafe9fdba5f370044e556", size = 4728575 }, + { url = "https://files.pythonhosted.org/packages/93/74/0c816d83cd2945a25aed592b0cb3c9ba32e8b259781bf41dc112204129d9/fonttools-4.55.8-cp310-cp310-win32.whl", hash = "sha256:f089e8da0990cfe2d67e81d9cf581ff372b48dc5acf2782701844211cd1f0eb3", size = 2155662 }, + { url = "https://files.pythonhosted.org/packages/78/bc/f5a24229edd8cdd7494f2099e1c62fca288dad4c8637ee62df04459db27e/fonttools-4.55.8-cp310-cp310-win_amd64.whl", hash = "sha256:01ea3901b0802fc5f9e854f5aeb5bc27770dd9dd24c28df8f74ba90f8b3f5915", size = 2200126 }, + { url = "https://files.pythonhosted.org/packages/0a/e3/834e0919b34b40a6a2895f533323231bba3b8f5ae22c19ab725b84cf84c0/fonttools-4.55.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:95f5a1d4432b3cea6571f5ce4f4e9b25bf36efbd61c32f4f90130a690925d6ee", size = 2753424 }, + { url = "https://files.pythonhosted.org/packages/b6/f9/9cf7fc04da85d37cfa1c287f0a25c274d6940dad259dbaa9fd796b87bd3c/fonttools-4.55.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d20f152de7625a0008ba1513f126daaaa0de3b4b9030aa72dd5c27294992260", size = 2281635 }, + { url = "https://files.pythonhosted.org/packages/35/1f/25330293a5bb6bd50825725270c587c2b25c2694020a82d2c424d2fd5469/fonttools-4.55.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5a3ff5bb95fd5a3962b2754f8435e6d930c84fc9e9921c51e802dddf40acd56", size = 4869363 }, + { url = "https://files.pythonhosted.org/packages/f2/e0/e58b10ef50830145ba94dbeb64b70773af61cfccea663d485c7fae2aab65/fonttools-4.55.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b99d4fd2b6d0a00c7336c8363fccc7a11eccef4b17393af75ca6e77cf93ff413", size = 4898604 }, + { url = "https://files.pythonhosted.org/packages/e0/66/b59025011dbae1ea10dcb60f713a10e54d17cde5c8dc48db75af79dc2088/fonttools-4.55.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d637e4d33e46619c79d1a6c725f74d71b574cd15fb5bbb9b6f3eba8f28363573", size = 4877804 }, + { url = "https://files.pythonhosted.org/packages/67/76/abbbae972af55d54f83fcaeb90e26aaac937c8711b5a32d7c63768c37891/fonttools-4.55.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0f38bfb6b7a39c4162c3eb0820a0bdf8e3bdd125cd54e10ba242397d15e32439", size = 5045913 }, + { url = "https://files.pythonhosted.org/packages/8b/f2/5eb68b5202731b008ccfd4ad6d82af9a8abdec411609e76fdd6c43881f2c/fonttools-4.55.8-cp311-cp311-win32.whl", hash = "sha256:acfec948de41cd5e640d5c15d0200e8b8e7c5c6bb82afe1ca095cbc4af1188ee", size = 2154525 }, + { url = "https://files.pythonhosted.org/packages/42/d6/96dc2462006ffa16c8d475244e372abdc47d03a7bd38be0f29e7ae552af4/fonttools-4.55.8-cp311-cp311-win_amd64.whl", hash = "sha256:604c805b41241b4880e2dc86cf2d4754c06777371c8299799ac88d836cb18c3b", size = 2201043 }, + { url = "https://files.pythonhosted.org/packages/cc/e6/efdcd5d6858b951c29d56de31a19355579d826712bf390d964a21b076ddb/fonttools-4.55.8-py3-none-any.whl", hash = "sha256:07636dae94f7fe88561f9da7a46b13d8e3f529f87fdb221b11d85f91eabceeb7", size = 1089900 }, +] + +[[package]] +name = "fparser" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools-scm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/af/570c36d7bc374646ab82f579e2bf9d24a619cc53d83f95b38b0992de3492/fparser-0.2.0.tar.gz", hash = "sha256:3901d31c104062c4e532248286929e7405e43b79a6a85815146a176673e69c82", size = 433559 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/91/03999b30650f5621dd5ec9e8245024dea1b71c4e28e52e0c7300aa0c769d/fparser-0.2.0-py3-none-any.whl", hash = "sha256:49fab105e3a977b9b9d5d4489649287c5060e94c688f9936f3d5af3a45d6f4eb", size = 639408 }, +] + +[[package]] +name = "frozendict" +version = "2.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/59/19eb300ba28e7547538bdf603f1c6c34793240a90e1a7b61b65d8517e35e/frozendict-2.4.6.tar.gz", hash = "sha256:df7cd16470fbd26fc4969a208efadc46319334eb97def1ddf48919b351192b8e", size = 316416 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/7f/e80cdbe0db930b2ba9d46ca35a41b0150156da16dfb79edcc05642690c3b/frozendict-2.4.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c3a05c0a50cab96b4bb0ea25aa752efbfceed5ccb24c007612bc63e51299336f", size = 37927 }, + { url = "https://files.pythonhosted.org/packages/29/98/27e145ff7e8e63caa95fb8ee4fc56c68acb208bef01a89c3678a66f9a34d/frozendict-2.4.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f5b94d5b07c00986f9e37a38dd83c13f5fe3bf3f1ccc8e88edea8fe15d6cd88c", size = 37945 }, + { url = "https://files.pythonhosted.org/packages/ac/f1/a10be024a9d53441c997b3661ea80ecba6e3130adc53812a4b95b607cdd1/frozendict-2.4.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4c789fd70879ccb6289a603cdebdc4953e7e5dea047d30c1b180529b28257b5", size = 117656 }, + { url = "https://files.pythonhosted.org/packages/46/a6/34c760975e6f1cb4db59a990d58dcf22287e10241c851804670c74c6a27a/frozendict-2.4.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da6a10164c8a50b34b9ab508a9420df38f4edf286b9ca7b7df8a91767baecb34", size = 117444 }, + { url = "https://files.pythonhosted.org/packages/62/dd/64bddd1ffa9617f50e7e63656b2a7ad7f0a46c86b5f4a3d2c714d0006277/frozendict-2.4.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:9a8a43036754a941601635ea9c788ebd7a7efbed2becba01b54a887b41b175b9", size = 116801 }, + { url = "https://files.pythonhosted.org/packages/45/ae/af06a8bde1947277aad895c2f26c3b8b8b6ee9c0c2ad988fb58a9d1dde3f/frozendict-2.4.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c9905dcf7aa659e6a11b8051114c9fa76dfde3a6e50e6dc129d5aece75b449a2", size = 117329 }, + { url = "https://files.pythonhosted.org/packages/d2/df/be3fa0457ff661301228f4c59c630699568c8ed9b5480f113b3eea7d0cb3/frozendict-2.4.6-cp310-cp310-win_amd64.whl", hash = "sha256:323f1b674a2cc18f86ab81698e22aba8145d7a755e0ac2cccf142ee2db58620d", size = 37522 }, + { url = "https://files.pythonhosted.org/packages/4a/6f/c22e0266b4c85f58b4613fec024e040e93753880527bf92b0c1bc228c27c/frozendict-2.4.6-cp310-cp310-win_arm64.whl", hash = "sha256:eabd21d8e5db0c58b60d26b4bb9839cac13132e88277e1376970172a85ee04b3", size = 34056 }, + { url = "https://files.pythonhosted.org/packages/04/13/d9839089b900fa7b479cce495d62110cddc4bd5630a04d8469916c0e79c5/frozendict-2.4.6-py311-none-any.whl", hash = "sha256:d065db6a44db2e2375c23eac816f1a022feb2fa98cbb50df44a9e83700accbea", size = 16148 }, +] + +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794 }, +] + +[[package]] +name = "gitpython" +version = "3.1.44" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/89/37df0b71473153574a5cdef8f242de422a0f5d26d7a9e231e6f169b4ad14/gitpython-3.1.44.tar.gz", hash = "sha256:c87e30b26253bf5418b01b0660f818967f3c503193838337fe5e573331249269", size = 214196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/9a/4114a9057db2f1462d5c8f8390ab7383925fe1ac012eaa42402ad65c2963/GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110", size = 207599 }, +] + +[[package]] +name = "gridtools-cpp" +version = "2.3.8" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/b8/352120417da7a3e16cc822e95668e1843d0cd9ee7f0269b9a098893471cc/gridtools_cpp-2.3.8-py3-none-any.whl", hash = "sha256:d9cb8aadc5dca7e864677072de15596feb883844eee2158ab108d04f2f17f355", size = 420716 }, +] + +[[package]] +name = "gt4py" +source = { editable = "." } +dependencies = [ + { name = "attrs" }, + { name = "black" }, + { name = "boltons" }, + { name = "cached-property" }, + { name = "click" }, + { name = "cmake" }, + { name = "cytoolz" }, + { name = "deepdiff" }, + { name = "devtools" }, + { name = "diskcache" }, + { name = "factory-boy" }, + { name = "filelock" }, + { name = "frozendict" }, + { name = "gridtools-cpp" }, + { name = "jinja2" }, + { name = "lark" }, + { name = "mako" }, + { name = "nanobind" }, + { name = "ninja" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "packaging" }, + { name = "pybind11" }, + { name = "setuptools" }, + { name = "tabulate" }, + { name = "toolz" }, + { name = "typing-extensions" }, + { name = "xxhash" }, +] + +[package.optional-dependencies] +all = [ + { name = "clang-format" }, + { name = "dace", version = "1.0.1", source = { registry = "https://pypi.org/simple" } }, + { name = "hypothesis" }, + { name = "jax" }, + { name = "pytest" }, + { name = "scipy" }, +] +cuda11 = [ + { name = "cupy-cuda11x" }, +] +cuda12 = [ + { name = "cupy-cuda12x" }, +] +dace = [ + { name = "dace", version = "1.0.1", source = { registry = "https://pypi.org/simple" } }, +] +dace-next = [ + { name = "dace", version = "1.0.0", source = { git = "https://github.com/spcl/dace?branch=main#5097d6f1a4b6e1dc8e06be6eb4aa585a6c6e04f3" } }, +] +formatting = [ + { name = "clang-format" }, +] +jax = [ + { name = "jax" }, +] +jax-cuda12 = [ + { name = "cupy-cuda12x" }, + { name = "jax", extra = ["cuda12-local"] }, +] +performance = [ + { name = "scipy" }, +] +rocm4-3 = [ + { name = "cupy-rocm-4-3" }, +] +rocm5-0 = [ + { name = "cupy-rocm-5-0" }, +] +testing = [ + { name = "hypothesis" }, + { name = "pytest" }, +] + +[package.dev-dependencies] +build = [ + { name = "bump-my-version" }, + { name = "cython" }, + { name = "pip" }, + { name = "setuptools" }, + { name = "wheel" }, +] +dev = [ + { name = "atlas4py" }, + { name = "bump-my-version" }, + { name = "coverage", extra = ["toml"] }, + { name = "cython" }, + { name = "esbonio" }, + { name = "hypothesis" }, + { name = "jupytext" }, + { name = "matplotlib" }, + { name = "mypy", extra = ["faster-cache"] }, + { name = "myst-parser" }, + { name = "nbmake" }, + { name = "nox" }, + { name = "pip" }, + { name = "pre-commit" }, + { name = "pygments" }, + { name = "pytest" }, + { name = "pytest-benchmark" }, + { name = "pytest-cache" }, + { name = "pytest-cov" }, + { name = "pytest-factoryboy" }, + { name = "pytest-instafail" }, + { name = "pytest-xdist", extra = ["psutil"] }, + { name = "ruff" }, + { name = "setuptools" }, + { name = "sphinx" }, + { name = "sphinx-rtd-theme" }, + { name = "sphinx-toolbox" }, + { name = "tach" }, + { name = "types-decorator" }, + { name = "types-docutils" }, + { name = "types-pytz" }, + { name = "types-pyyaml" }, + { name = "types-tabulate" }, + { name = "wheel" }, +] +docs = [ + { name = "esbonio" }, + { name = "jupytext" }, + { name = "matplotlib" }, + { name = "myst-parser" }, + { name = "pygments" }, + { name = "sphinx" }, + { name = "sphinx-rtd-theme" }, + { name = "sphinx-toolbox" }, +] +frameworks = [ + { name = "atlas4py" }, +] +lint = [ + { name = "pre-commit" }, + { name = "ruff" }, + { name = "tach" }, +] +test = [ + { name = "coverage", extra = ["toml"] }, + { name = "hypothesis" }, + { name = "nbmake" }, + { name = "nox" }, + { name = "pytest" }, + { name = "pytest-benchmark" }, + { name = "pytest-cache" }, + { name = "pytest-cov" }, + { name = "pytest-factoryboy" }, + { name = "pytest-instafail" }, + { name = "pytest-xdist", extra = ["psutil"] }, +] +typing = [ + { name = "mypy", extra = ["faster-cache"] }, + { name = "types-decorator" }, + { name = "types-docutils" }, + { name = "types-pytz" }, + { name = "types-pyyaml" }, + { name = "types-tabulate" }, +] + +[package.metadata] +requires-dist = [ + { name = "attrs", specifier = ">=21.3" }, + { name = "black", specifier = ">=22.3" }, + { name = "boltons", specifier = ">=20.1" }, + { name = "cached-property", specifier = ">=1.5.1" }, + { name = "clang-format", marker = "extra == 'formatting'", specifier = ">=9.0" }, + { name = "click", specifier = ">=8.0.0" }, + { name = "cmake", specifier = ">=3.22" }, + { name = "cupy-cuda11x", marker = "extra == 'cuda11'", specifier = ">=12.0" }, + { name = "cupy-cuda12x", marker = "extra == 'cuda12'", specifier = ">=12.0" }, + { name = "cupy-rocm-4-3", marker = "extra == 'rocm4-3'", specifier = ">=13.3.0" }, + { name = "cupy-rocm-5-0", marker = "extra == 'rocm5-0'", specifier = ">=13.3.0" }, + { name = "cytoolz", specifier = ">=0.12.1" }, + { name = "dace", marker = "extra == 'dace'", specifier = ">=1.0.1,<1.1.0" }, + { name = "dace", marker = "extra == 'dace-next'", git = "https://github.com/spcl/dace?branch=main" }, + { name = "deepdiff", specifier = ">=5.6.0" }, + { name = "devtools", specifier = ">=0.6" }, + { name = "diskcache", specifier = ">=5.6.3" }, + { name = "factory-boy", specifier = ">=3.3.0" }, + { name = "filelock", specifier = ">=3.16.1" }, + { name = "frozendict", specifier = ">=2.3" }, + { name = "gridtools-cpp", specifier = "==2.*,>=2.3.8" }, + { name = "gt4py", extras = ["cuda12"], marker = "extra == 'jax-cuda12'" }, + { name = "gt4py", extras = ["dace", "formatting", "jax", "performance", "testing"], marker = "extra == 'all'" }, + { name = "hypothesis", marker = "extra == 'testing'", specifier = ">=6.0.0" }, + { name = "jax", marker = "extra == 'jax'", specifier = ">=0.4.26" }, + { name = "jax", extras = ["cuda12-local"], marker = "extra == 'jax-cuda12'", specifier = ">=0.4.26" }, + { name = "jinja2", specifier = ">=3.0.0" }, + { name = "lark", specifier = ">=1.1.2" }, + { name = "mako", specifier = ">=1.1" }, + { name = "nanobind", specifier = ">=1.4.0" }, + { name = "ninja", specifier = ">=1.10" }, + { name = "numpy", specifier = ">=1.23.3" }, + { name = "packaging", specifier = ">=20.0" }, + { name = "pybind11", specifier = ">=2.10.1" }, + { name = "pytest", marker = "extra == 'testing'", specifier = ">=7.0" }, + { name = "scipy", marker = "extra == 'performance'", specifier = ">=1.9.2" }, + { name = "setuptools", specifier = ">=70.0.0" }, + { name = "tabulate", specifier = ">=0.8.10" }, + { name = "toolz", specifier = ">=0.12.1" }, + { name = "typing-extensions", specifier = ">=4.11.0" }, + { name = "xxhash", specifier = ">=1.4.4,<3.1.0" }, +] + +[package.metadata.requires-dev] +build = [ + { name = "bump-my-version", specifier = ">=0.16.0" }, + { name = "cython", specifier = ">=3.0.0" }, + { name = "pip", specifier = ">=22.1.1" }, + { name = "setuptools", specifier = ">=70.0.0" }, + { name = "wheel", specifier = ">=0.33.6" }, +] +dev = [ + { name = "atlas4py", specifier = ">=0.35", index = "https://test.pypi.org/simple/" }, + { name = "bump-my-version", specifier = ">=0.16.0" }, + { name = "coverage", extras = ["toml"], specifier = ">=7.5.0" }, + { name = "cython", specifier = ">=3.0.0" }, + { name = "esbonio", specifier = ">=0.16.0" }, + { name = "hypothesis", specifier = ">=6.0.0" }, + { name = "jupytext", specifier = ">=1.14" }, + { name = "matplotlib", specifier = ">=3.8.4" }, + { name = "mypy", extras = ["faster-cache"], specifier = ">=1.13.0" }, + { name = "myst-parser", specifier = ">=4.0.0" }, + { name = "nbmake", specifier = ">=1.4.6" }, + { name = "nox", specifier = ">=2024.10.9" }, + { name = "pip", specifier = ">=22.1.1" }, + { name = "pre-commit", specifier = ">=4.0.1" }, + { name = "pygments", specifier = ">=2.7.3" }, + { name = "pytest", specifier = ">=8.0.1" }, + { name = "pytest-benchmark", specifier = ">=5.0.0" }, + { name = "pytest-cache", specifier = ">=1.0" }, + { name = "pytest-cov", specifier = ">=5.0.0" }, + { name = "pytest-factoryboy", specifier = ">=2.6.1" }, + { name = "pytest-instafail", specifier = ">=0.5.0" }, + { name = "pytest-xdist", extras = ["psutil"], specifier = ">=3.5.0" }, + { name = "ruff", specifier = ">=0.8.0" }, + { name = "setuptools", specifier = ">=70.0.0" }, + { name = "sphinx", specifier = ">=7.3.7" }, + { name = "sphinx-rtd-theme", specifier = ">=3.0.1" }, + { name = "sphinx-toolbox", specifier = ">=3.8.1" }, + { name = "tach", specifier = ">=0.16.0" }, + { name = "types-decorator", specifier = ">=5.1.8" }, + { name = "types-docutils", specifier = ">=0.21.0" }, + { name = "types-pytz", specifier = ">=2024.2.0" }, + { name = "types-pyyaml", specifier = ">=6.0.10" }, + { name = "types-tabulate", specifier = ">=0.8.10" }, + { name = "wheel", specifier = ">=0.33.6" }, +] +docs = [ + { name = "esbonio", specifier = ">=0.16.0" }, + { name = "jupytext", specifier = ">=1.14" }, + { name = "matplotlib", specifier = ">=3.8.4" }, + { name = "myst-parser", specifier = ">=4.0.0" }, + { name = "pygments", specifier = ">=2.7.3" }, + { name = "sphinx", specifier = ">=7.3.7" }, + { name = "sphinx-rtd-theme", specifier = ">=3.0.1" }, + { name = "sphinx-toolbox", specifier = ">=3.8.1" }, +] +frameworks = [{ name = "atlas4py", specifier = ">=0.35", index = "https://test.pypi.org/simple/" }] +lint = [ + { name = "pre-commit", specifier = ">=4.0.1" }, + { name = "ruff", specifier = ">=0.8.0" }, + { name = "tach", specifier = ">=0.16.0" }, +] +test = [ + { name = "coverage", extras = ["toml"], specifier = ">=7.5.0" }, + { name = "hypothesis", specifier = ">=6.0.0" }, + { name = "nbmake", specifier = ">=1.4.6" }, + { name = "nox", specifier = ">=2024.10.9" }, + { name = "pytest", specifier = ">=8.0.1" }, + { name = "pytest-benchmark", specifier = ">=5.0.0" }, + { name = "pytest-cache", specifier = ">=1.0" }, + { name = "pytest-cov", specifier = ">=5.0.0" }, + { name = "pytest-factoryboy", specifier = ">=2.6.1" }, + { name = "pytest-instafail", specifier = ">=0.5.0" }, + { name = "pytest-xdist", extras = ["psutil"], specifier = ">=3.5.0" }, +] +typing = [ + { name = "mypy", extras = ["faster-cache"], specifier = ">=1.13.0" }, + { name = "types-decorator", specifier = ">=5.1.8" }, + { name = "types-docutils", specifier = ">=0.21.0" }, + { name = "types-pytz", specifier = ">=2024.2.0" }, + { name = "types-pyyaml", specifier = ">=6.0.10" }, + { name = "types-tabulate", specifier = ">=0.8.10" }, +] + +[[package]] +name = "h11" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/38/3af3d3633a34a3316095b39c8e8fb4853a28a536e55d347bd8d8e9a14b03/h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d", size = 100418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, +] + +[[package]] +name = "html5lib" +version = "1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/b6/b55c3f49042f1df3dcd422b7f224f939892ee94f22abcf503a9b7339eaf2/html5lib-1.1.tar.gz", hash = "sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f", size = 272215 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6c/dd/a834df6482147d48e225a49515aabc28974ad5a4ca3215c18a882565b028/html5lib-1.1-py2.py3-none-any.whl", hash = "sha256:0d78f8fde1c230e99fe37986a60526d7049ed4bf8a9fadbad5f00e22e58e041d", size = 112173 }, +] + +[[package]] +name = "httpcore" +version = "1.0.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/41/d7d0a89eb493922c37d343b607bc1b5da7f5be7e383740b4753ad8943e90/httpcore-1.0.7.tar.gz", hash = "sha256:8551cb62a169ec7162ac7be8d4817d561f60e08eaa485234898414bb5a8a0b4c", size = 85196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551 }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, +] + +[[package]] +name = "hypothesis" +version = "6.125.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "sortedcontainers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f9/69/3273c85add01293b0ed8fc71554cecb256c9e7826fa102c72cc847bb8bac/hypothesis-6.125.2.tar.gz", hash = "sha256:c70f0a12deb688ce90f2765a507070c4bff57e48ac86849f4350bbddc1df41a3", size = 417961 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/1b/e78605ce304554451a36c6e24e603cfcee808c9ed09be5112bf00a10eb5e/hypothesis-6.125.2-py3-none-any.whl", hash = "sha256:55d4966d521b85d2f77e916dabb00d66d5530ea9fbb89c7489ee810625fac802", size = 480692 }, +] + +[[package]] +name = "identify" +version = "2.6.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/bf/c68c46601bacd4c6fb4dd751a42b6e7087240eaabc6487f2ef7a48e0e8fc/identify-2.6.6.tar.gz", hash = "sha256:7bec12768ed44ea4761efb47806f0a41f86e7c0a5fdf5950d4648c90eca7e251", size = 99217 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/a1/68a395c17eeefb04917034bd0a1bfa765e7654fa150cca473d669aa3afb5/identify-2.6.6-py2.py3-none-any.whl", hash = "sha256:cbd1810bce79f8b671ecb20f53ee0ae8e86ae84b557de31d89709dc2a48ba881", size = 99083 }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, +] + +[[package]] +name = "imagesize" +version = "1.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/84/62473fb57d61e31fef6e36d64a179c8781605429fd927b5dd608c997be31/imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a", size = 1280026 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/62/85c4c919272577931d407be5ba5d71c20f0b616d31a0befe0ae45bb79abd/imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b", size = 8769 }, +] + +[[package]] +name = "inflection" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/7e/691d061b7329bc8d54edbf0ec22fbfb2afe61facb681f9aaa9bff7a27d04/inflection-0.5.1.tar.gz", hash = "sha256:1a29730d366e996aaacffb2f1f1cb9593dc38e2ddd30c91250c6dde09ea9b417", size = 15091 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/91/aa6bde563e0085a02a435aa99b49ef75b0a4b062635e606dab23ce18d720/inflection-0.5.1-py2.py3-none-any.whl", hash = "sha256:f38b2b640938a4f35ade69ac3d053042959b62a0f1076a5bbaa1b9526605a8a2", size = 9454 }, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, +] + +[[package]] +name = "ipykernel" +version = "6.29.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appnope", marker = "sys_platform == 'darwin' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "comm" }, + { name = "debugpy" }, + { name = "ipython" }, + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "matplotlib-inline" }, + { name = "nest-asyncio" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/5c/67594cb0c7055dc50814b21731c22a601101ea3b1b50a9a1b090e11f5d0f/ipykernel-6.29.5.tar.gz", hash = "sha256:f093a22c4a40f8828f8e330a9c297cb93dcab13bd9678ded6de8e5cf81c56215", size = 163367 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/5c/368ae6c01c7628438358e6d337c19b05425727fbb221d2a3c4303c372f42/ipykernel-6.29.5-py3-none-any.whl", hash = "sha256:afdb66ba5aa354b09b91379bac28ae4afebbb30e8b39510c9690afb7a10421b5", size = 117173 }, +] + +[[package]] +name = "ipython" +version = "8.32.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "decorator" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "jedi" }, + { name = "matplotlib-inline" }, + { name = "pexpect", marker = "(sys_platform != 'emscripten' and sys_platform != 'win32') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'emscripten' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform == 'win32' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "prompt-toolkit" }, + { name = "pygments" }, + { name = "stack-data" }, + { name = "traitlets" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/36/80/4d2a072e0db7d250f134bc11676517299264ebe16d62a8619d49a78ced73/ipython-8.32.0.tar.gz", hash = "sha256:be2c91895b0b9ea7ba49d33b23e2040c352b33eb6a519cca7ce6e0c743444251", size = 5507441 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/e1/f4474a7ecdb7745a820f6f6039dc43c66add40f1bcc66485607d93571af6/ipython-8.32.0-py3-none-any.whl", hash = "sha256:cae85b0c61eff1fc48b0a8002de5958b6528fa9c8defb1894da63f42613708aa", size = 825524 }, +] + +[[package]] +name = "jax" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jaxlib" }, + { name = "ml-dtypes" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "opt-einsum" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4a/cb/22d62b26284f08e62d6eb64603d3b010004cfdb7a97ce6cca5c6cf86edab/jax-0.5.0.tar.gz", hash = "sha256:49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8", size = 1959707 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/58/cc0721a1030fcbab0984beea0bf3c4610ec103f738423cdfa9c4ceb40598/jax-0.5.0-py3-none-any.whl", hash = "sha256:b3907aa87ae2c340b39cdbf80c07a74550369cafcaf7398fb60ba58d167345ab", size = 2270365 }, +] + +[package.optional-dependencies] +cuda12-local = [ + { name = "jax-cuda12-plugin" }, + { name = "jaxlib" }, +] + +[[package]] +name = "jax-cuda12-pjrt" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/a6/4b161016aaafe04d92e8d9a50b47e6767ea5cf874a8a9d2d1bcd049409d3/jax_cuda12_pjrt-0.5.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6025cd4b32d8ec04a11705a749764cd96a6cbc8b6273beac947cc481f2584b8c", size = 89441461 }, + { url = "https://files.pythonhosted.org/packages/8e/ac/824ff70eb5b5dd2a4b597a2017ae62f24b9aaa5fd846f04c94dc447aa1ec/jax_cuda12_pjrt-0.5.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:d23833c1b885d96c2764000e95052f2b5827c77d492ea68f67e903a132656dbb", size = 103122594 }, +] + +[[package]] +name = "jax-cuda12-plugin" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jax-cuda12-pjrt" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/58/3dab6bb4cdbc43663093c2af4671e87312236a23c84a3fc152d3c3979019/jax_cuda12_plugin-0.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d497dcc9205a11d283c308d8f400fb71507cf808753168d47effd1d4c47f9c3d", size = 16777702 }, + { url = "https://files.pythonhosted.org/packages/c2/46/a54402df9e2d057bb16d7e2ab045bd536fc8b83662cfc8d503fc56f5fc41/jax_cuda12_plugin-0.5.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:0f443a6b37298edfb0796fcdbd1f86ce85a4b084b6bd3f1f50a4fbfd67ded86b", size = 16733143 }, + { url = "https://files.pythonhosted.org/packages/d9/d5/64ad0b832122d938cbad07652625679a35c03e16e2ce4b8eda4ead8feed5/jax_cuda12_plugin-0.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:25407ccb030e4eed7d7e2ccccac8ab65f932aa05936ca5cf0e8ded4adfdcad1a", size = 16777553 }, + { url = "https://files.pythonhosted.org/packages/a2/7b/cc9fa545db9397de9054357de8440c8b10d28a6ab5d1cef1eba184c3d426/jax_cuda12_plugin-0.5.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:a98135a0223064b8f5c6853e22ddc1a4e3862152d37fb685f0dbdeffe0c80122", size = 16734352 }, +] + +[[package]] +name = "jaxlib" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ml-dtypes" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "scipy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/41/3e4ac64df72c4da126df3fd66a2214025a46b6263f7be266728e7b8e473e/jaxlib-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1b8a6c4345f137f387650de2dbc488c20251b7412b55dd648e1a4f13bcf507fb", size = 79248968 }, + { url = "https://files.pythonhosted.org/packages/1e/5f/2a16e61f1d54ae5f55fbf3cb3e22ef5bb01bf9d7d6474e0d34fedba19c4d/jaxlib-0.5.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:5b2efe3dfebf18a84c451d3803ac884ee242021c1113b279c13f4bbc378c3dc0", size = 93181077 }, + { url = "https://files.pythonhosted.org/packages/08/c3/573e2f01b99f1247e8fbe1aa46b95a0faa68ef208f9a8e8ef775d607b3e6/jaxlib-0.5.0-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:74440b632107336400d4f97a16481d767f13ea914c53ba14e544c6fda54819b3", size = 101969119 }, + { url = "https://files.pythonhosted.org/packages/6e/38/512f61ea13da41ca47f2411d7c05af0cf74a37f225e16725ed0e6fb58893/jaxlib-0.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:53478a28eee6c2ef01759b05a9491702daef9268c3ed013d6f8e2e5f5cae0887", size = 63883394 }, + { url = "https://files.pythonhosted.org/packages/92/4b/8875870ff52ad3fbea876c905228f691f05c8dc8556b226cbfaf0fba7f62/jaxlib-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6cd762ed1623132499fa701c4203446102e0a9c82ca23194b87288f746d12a29", size = 79242870 }, + { url = "https://files.pythonhosted.org/packages/a0/0f/00cdfa411d7218e4696c10c5867f7d3c396219adbcaeb02e95108ca802de/jaxlib-0.5.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:63088dbfaa85bb56cd521a925a3472fd7328b18ec93c2d8ffa85af331095c995", size = 93181807 }, + { url = "https://files.pythonhosted.org/packages/58/8e/a5c29db03d5a93b0326e297b556d0e0a9805e9c9c1ae5f82f69557273faa/jaxlib-0.5.0-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:09113ef1582ba34d7cbc440fedb318f4855b59b776711a8aba2473c9727d3025", size = 101969212 }, + { url = "https://files.pythonhosted.org/packages/70/86/ceae20e4f37fa07f1cc95551cc0f49170d0db46d2e82fdf511d26bffd801/jaxlib-0.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:78289fc3ddc1e4e9510de2536a6375df9fe1c50de0ac60826c286b7a5c5090fe", size = 63881994 }, +] + +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 }, +] + +[[package]] +name = "jinja2" +version = "3.1.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/af/92/b3130cbbf5591acf9ade8708c365f3238046ac7cb8ccba6e81abccb0ccff/jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb", size = 244674 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/0f/2ba5fbcd631e3e88689309dbe978c5769e883e4b84ebfe7da30b43275c5a/jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb", size = 134596 }, +] + +[[package]] +name = "jsonschema" +version = "4.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "jsonschema-specifications" }, + { name = "referencing" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/2e/03362ee4034a4c917f697890ccd4aec0800ccf9ded7f511971c75451deec/jsonschema-4.23.0.tar.gz", hash = "sha256:d71497fef26351a33265337fa77ffeb82423f3ea21283cd9467bb03999266bc4", size = 325778 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566", size = 88462 }, +] + +[[package]] +name = "jsonschema-specifications" +version = "2024.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "referencing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/db/58f950c996c793472e336ff3655b13fbcf1e3b359dcf52dcf3ed3b52c352/jsonschema_specifications-2024.10.1.tar.gz", hash = "sha256:0f38b83639958ce1152d02a7f062902c41c8fd20d558b0c34344292d417ae272", size = 15561 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/0f/8910b19ac0670a0f80ce1008e5e751c4a57e14d2c4c13a482aa6079fa9d6/jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf", size = 18459 }, +] + +[[package]] +name = "jupyter-client" +version = "8.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-core" }, + { name = "python-dateutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/85/b0394e0b6fcccd2c1eeefc230978a6f8cb0c5df1e4cd3e7625735a0d7d1e/jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f", size = 106105 }, +] + +[[package]] +name = "jupyter-core" +version = "5.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "platformdirs" }, + { name = "pywin32", marker = "(platform_python_implementation != 'PyPy' and sys_platform == 'win32') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (platform_python_implementation == 'PyPy' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (sys_platform != 'win32' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/11/b56381fa6c3f4cc5d2cf54a7dbf98ad9aa0b339ef7a601d6053538b079a7/jupyter_core-5.7.2.tar.gz", hash = "sha256:aa5f8d32bbf6b431ac830496da7392035d6f61b4f54872f15c4bd2a9c3f536d9", size = 87629 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/fb/108ecd1fe961941959ad0ee4e12ee7b8b1477247f30b1fdfd83ceaf017f0/jupyter_core-5.7.2-py3-none-any.whl", hash = "sha256:4f7315d2f6b4bcf2e3e7cb6e46772eba760ae459cd1f59d29eb57b0a01bd7409", size = 28965 }, +] + +[[package]] +name = "jupytext" +version = "1.16.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "mdit-py-plugins" }, + { name = "nbformat" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/10/e7/58d6fd374e1065d2bccefd07953d2f1f911d8de03fd7dc33dd5a25ac659c/jupytext-1.16.6.tar.gz", hash = "sha256:dbd03f9263c34b737003f388fc069e9030834fb7136879c4c32c32473557baa0", size = 3726029 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/02/27191f18564d4f2c0e543643aa94b54567de58f359cd6a3bed33adb723ac/jupytext-1.16.6-py3-none-any.whl", hash = "sha256:900132031f73fee15a1c9ebd862e05eb5f51e1ad6ab3a2c6fdd97ce2f9c913b4", size = 154200 }, +] + +[[package]] +name = "kiwisolver" +version = "1.4.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/59/7c91426a8ac292e1cdd53a63b6d9439abd573c875c3f92c146767dd33faf/kiwisolver-1.4.8.tar.gz", hash = "sha256:23d5f023bdc8c7e54eb65f03ca5d5bb25b601eac4d7f1a042888a1f45237987e", size = 97538 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/5f/4d8e9e852d98ecd26cdf8eaf7ed8bc33174033bba5e07001b289f07308fd/kiwisolver-1.4.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:88c6f252f6816a73b1f8c904f7bbe02fd67c09a69f7cb8a0eecdbf5ce78e63db", size = 124623 }, + { url = "https://files.pythonhosted.org/packages/1d/70/7f5af2a18a76fe92ea14675f8bd88ce53ee79e37900fa5f1a1d8e0b42998/kiwisolver-1.4.8-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c72941acb7b67138f35b879bbe85be0f6c6a70cab78fe3ef6db9c024d9223e5b", size = 66720 }, + { url = "https://files.pythonhosted.org/packages/c6/13/e15f804a142353aefd089fadc8f1d985561a15358c97aca27b0979cb0785/kiwisolver-1.4.8-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ce2cf1e5688edcb727fdf7cd1bbd0b6416758996826a8be1d958f91880d0809d", size = 65413 }, + { url = "https://files.pythonhosted.org/packages/ce/6d/67d36c4d2054e83fb875c6b59d0809d5c530de8148846b1370475eeeece9/kiwisolver-1.4.8-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:c8bf637892dc6e6aad2bc6d4d69d08764166e5e3f69d469e55427b6ac001b19d", size = 1650826 }, + { url = "https://files.pythonhosted.org/packages/de/c6/7b9bb8044e150d4d1558423a1568e4f227193662a02231064e3824f37e0a/kiwisolver-1.4.8-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:034d2c891f76bd3edbdb3ea11140d8510dca675443da7304205a2eaa45d8334c", size = 1628231 }, + { url = "https://files.pythonhosted.org/packages/b6/38/ad10d437563063eaaedbe2c3540a71101fc7fb07a7e71f855e93ea4de605/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d47b28d1dfe0793d5e96bce90835e17edf9a499b53969b03c6c47ea5985844c3", size = 1408938 }, + { url = "https://files.pythonhosted.org/packages/52/ce/c0106b3bd7f9e665c5f5bc1e07cc95b5dabd4e08e3dad42dbe2faad467e7/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb158fe28ca0c29f2260cca8c43005329ad58452c36f0edf298204de32a9a3ed", size = 1422799 }, + { url = "https://files.pythonhosted.org/packages/d0/87/efb704b1d75dc9758087ba374c0f23d3254505edaedd09cf9d247f7878b9/kiwisolver-1.4.8-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5536185fce131780ebd809f8e623bf4030ce1b161353166c49a3c74c287897f", size = 1354362 }, + { url = "https://files.pythonhosted.org/packages/eb/b3/fd760dc214ec9a8f208b99e42e8f0130ff4b384eca8b29dd0efc62052176/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:369b75d40abedc1da2c1f4de13f3482cb99e3237b38726710f4a793432b1c5ff", size = 2222695 }, + { url = "https://files.pythonhosted.org/packages/a2/09/a27fb36cca3fc01700687cc45dae7a6a5f8eeb5f657b9f710f788748e10d/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:641f2ddf9358c80faa22e22eb4c9f54bd3f0e442e038728f500e3b978d00aa7d", size = 2370802 }, + { url = "https://files.pythonhosted.org/packages/3d/c3/ba0a0346db35fe4dc1f2f2cf8b99362fbb922d7562e5f911f7ce7a7b60fa/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:d561d2d8883e0819445cfe58d7ddd673e4015c3c57261d7bdcd3710d0d14005c", size = 2334646 }, + { url = "https://files.pythonhosted.org/packages/41/52/942cf69e562f5ed253ac67d5c92a693745f0bed3c81f49fc0cbebe4d6b00/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:1732e065704b47c9afca7ffa272f845300a4eb959276bf6970dc07265e73b605", size = 2467260 }, + { url = "https://files.pythonhosted.org/packages/32/26/2d9668f30d8a494b0411d4d7d4ea1345ba12deb6a75274d58dd6ea01e951/kiwisolver-1.4.8-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bcb1ebc3547619c3b58a39e2448af089ea2ef44b37988caf432447374941574e", size = 2288633 }, + { url = "https://files.pythonhosted.org/packages/98/99/0dd05071654aa44fe5d5e350729961e7bb535372935a45ac89a8924316e6/kiwisolver-1.4.8-cp310-cp310-win_amd64.whl", hash = "sha256:89c107041f7b27844179ea9c85d6da275aa55ecf28413e87624d033cf1f6b751", size = 71885 }, + { url = "https://files.pythonhosted.org/packages/6c/fc/822e532262a97442989335394d441cd1d0448c2e46d26d3e04efca84df22/kiwisolver-1.4.8-cp310-cp310-win_arm64.whl", hash = "sha256:b5773efa2be9eb9fcf5415ea3ab70fc785d598729fd6057bea38d539ead28271", size = 65175 }, + { url = "https://files.pythonhosted.org/packages/da/ed/c913ee28936c371418cb167b128066ffb20bbf37771eecc2c97edf8a6e4c/kiwisolver-1.4.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a4d3601908c560bdf880f07d94f31d734afd1bb71e96585cace0e38ef44c6d84", size = 124635 }, + { url = "https://files.pythonhosted.org/packages/4c/45/4a7f896f7467aaf5f56ef093d1f329346f3b594e77c6a3c327b2d415f521/kiwisolver-1.4.8-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:856b269c4d28a5c0d5e6c1955ec36ebfd1651ac00e1ce0afa3e28da95293b561", size = 66717 }, + { url = "https://files.pythonhosted.org/packages/5f/b4/c12b3ac0852a3a68f94598d4c8d569f55361beef6159dce4e7b624160da2/kiwisolver-1.4.8-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c2b9a96e0f326205af81a15718a9073328df1173a2619a68553decb7097fd5d7", size = 65413 }, + { url = "https://files.pythonhosted.org/packages/a9/98/1df4089b1ed23d83d410adfdc5947245c753bddfbe06541c4aae330e9e70/kiwisolver-1.4.8-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c5020c83e8553f770cb3b5fc13faac40f17e0b205bd237aebd21d53d733adb03", size = 1343994 }, + { url = "https://files.pythonhosted.org/packages/8d/bf/b4b169b050c8421a7c53ea1ea74e4ef9c335ee9013216c558a047f162d20/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dace81d28c787956bfbfbbfd72fdcef014f37d9b48830829e488fdb32b49d954", size = 1434804 }, + { url = "https://files.pythonhosted.org/packages/66/5a/e13bd341fbcf73325ea60fdc8af752addf75c5079867af2e04cc41f34434/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:11e1022b524bd48ae56c9b4f9296bce77e15a2e42a502cceba602f804b32bb79", size = 1450690 }, + { url = "https://files.pythonhosted.org/packages/9b/4f/5955dcb376ba4a830384cc6fab7d7547bd6759fe75a09564910e9e3bb8ea/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b9b4d2892fefc886f30301cdd80debd8bb01ecdf165a449eb6e78f79f0fabd6", size = 1376839 }, + { url = "https://files.pythonhosted.org/packages/3a/97/5edbed69a9d0caa2e4aa616ae7df8127e10f6586940aa683a496c2c280b9/kiwisolver-1.4.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3a96c0e790ee875d65e340ab383700e2b4891677b7fcd30a699146f9384a2bb0", size = 1435109 }, + { url = "https://files.pythonhosted.org/packages/13/fc/e756382cb64e556af6c1809a1bbb22c141bbc2445049f2da06b420fe52bf/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:23454ff084b07ac54ca8be535f4174170c1094a4cff78fbae4f73a4bcc0d4dab", size = 2245269 }, + { url = "https://files.pythonhosted.org/packages/76/15/e59e45829d7f41c776d138245cabae6515cb4eb44b418f6d4109c478b481/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:87b287251ad6488e95b4f0b4a79a6d04d3ea35fde6340eb38fbd1ca9cd35bbbc", size = 2393468 }, + { url = "https://files.pythonhosted.org/packages/e9/39/483558c2a913ab8384d6e4b66a932406f87c95a6080112433da5ed668559/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:b21dbe165081142b1232a240fc6383fd32cdd877ca6cc89eab93e5f5883e1c25", size = 2355394 }, + { url = "https://files.pythonhosted.org/packages/01/aa/efad1fbca6570a161d29224f14b082960c7e08268a133fe5dc0f6906820e/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:768cade2c2df13db52475bd28d3a3fac8c9eff04b0e9e2fda0f3760f20b3f7fc", size = 2490901 }, + { url = "https://files.pythonhosted.org/packages/c9/4f/15988966ba46bcd5ab9d0c8296914436720dd67fca689ae1a75b4ec1c72f/kiwisolver-1.4.8-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d47cfb2650f0e103d4bf68b0b5804c68da97272c84bb12850d877a95c056bd67", size = 2312306 }, + { url = "https://files.pythonhosted.org/packages/2d/27/bdf1c769c83f74d98cbc34483a972f221440703054894a37d174fba8aa68/kiwisolver-1.4.8-cp311-cp311-win_amd64.whl", hash = "sha256:ed33ca2002a779a2e20eeb06aea7721b6e47f2d4b8a8ece979d8ba9e2a167e34", size = 71966 }, + { url = "https://files.pythonhosted.org/packages/4a/c9/9642ea855604aeb2968a8e145fc662edf61db7632ad2e4fb92424be6b6c0/kiwisolver-1.4.8-cp311-cp311-win_arm64.whl", hash = "sha256:16523b40aab60426ffdebe33ac374457cf62863e330a90a0383639ce14bf44b2", size = 65311 }, + { url = "https://files.pythonhosted.org/packages/1f/f9/ae81c47a43e33b93b0a9819cac6723257f5da2a5a60daf46aa5c7226ea85/kiwisolver-1.4.8-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:e7a019419b7b510f0f7c9dceff8c5eae2392037eae483a7f9162625233802b0a", size = 60403 }, + { url = "https://files.pythonhosted.org/packages/58/ca/f92b5cb6f4ce0c1ebfcfe3e2e42b96917e16f7090e45b21102941924f18f/kiwisolver-1.4.8-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:286b18e86682fd2217a48fc6be6b0f20c1d0ed10958d8dc53453ad58d7be0bf8", size = 58657 }, + { url = "https://files.pythonhosted.org/packages/80/28/ae0240f732f0484d3a4dc885d055653c47144bdf59b670aae0ec3c65a7c8/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4191ee8dfd0be1c3666ccbac178c5a05d5f8d689bbe3fc92f3c4abec817f8fe0", size = 84948 }, + { url = "https://files.pythonhosted.org/packages/5d/eb/78d50346c51db22c7203c1611f9b513075f35c4e0e4877c5dde378d66043/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cd2785b9391f2873ad46088ed7599a6a71e762e1ea33e87514b1a441ed1da1c", size = 81186 }, + { url = "https://files.pythonhosted.org/packages/43/f8/7259f18c77adca88d5f64f9a522792e178b2691f3748817a8750c2d216ef/kiwisolver-1.4.8-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c07b29089b7ba090b6f1a669f1411f27221c3662b3a1b7010e67b59bb5a6f10b", size = 80279 }, + { url = "https://files.pythonhosted.org/packages/3a/1d/50ad811d1c5dae091e4cf046beba925bcae0a610e79ae4c538f996f63ed5/kiwisolver-1.4.8-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:65ea09a5a3faadd59c2ce96dc7bf0f364986a315949dc6374f04396b0d60e09b", size = 71762 }, +] + +[[package]] +name = "lark" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/af/60/bc7622aefb2aee1c0b4ba23c1446d3e30225c8770b38d7aedbfb65ca9d5a/lark-1.2.2.tar.gz", hash = "sha256:ca807d0162cd16cef15a8feecb862d7319e7a09bdb13aef927968e45040fed80", size = 252132 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/00/d90b10b962b4277f5e64a78b6609968859ff86889f5b898c1a778c06ec00/lark-1.2.2-py3-none-any.whl", hash = "sha256:c2276486b02f0f1b90be155f2c8ba4a8e194d42775786db622faccd652d8e80c", size = 111036 }, +] + +[[package]] +name = "lsprotocol" +version = "2023.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "cattrs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/f6/6e80484ec078d0b50699ceb1833597b792a6c695f90c645fbaf54b947e6f/lsprotocol-2023.0.1.tar.gz", hash = "sha256:cc5c15130d2403c18b734304339e51242d3018a05c4f7d0f198ad6e0cd21861d", size = 69434 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/37/2351e48cb3309673492d3a8c59d407b75fb6630e560eb27ecd4da03adc9a/lsprotocol-2023.0.1-py3-none-any.whl", hash = "sha256:c75223c9e4af2f24272b14c6375787438279369236cd568f596d4951052a60f2", size = 70826 }, +] + +[[package]] +name = "mako" +version = "1.3.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/62/4f/ddb1965901bc388958db9f0c991255b2c469349a741ae8c9cd8a562d70a6/mako-1.3.9.tar.gz", hash = "sha256:b5d65ff3462870feec922dbccf38f6efb44e5714d7b593a656be86663d8600ac", size = 392195 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/83/de0a49e7de540513f53ab5d2e105321dedeb08a8f5850f0208decf4390ec/Mako-1.3.9-py3-none-any.whl", hash = "sha256:95920acccb578427a9aa38e37a186b1e43156c87260d7ba18ca63aa4c7cbd3a1", size = 78456 }, +] + +[[package]] +name = "markdown-it-py" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/71/3b932df36c1a044d397a1f92d1cf91ee0a503d91e470cbd670aa66b07ed0/markdown-it-py-3.0.0.tar.gz", hash = "sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb", size = 74596 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528 }, +] + +[[package]] +name = "markupsafe" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/90/d08277ce111dd22f77149fd1a5d4653eeb3b3eaacbdfcbae5afb2600eebd/MarkupSafe-3.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8", size = 14357 }, + { url = "https://files.pythonhosted.org/packages/04/e1/6e2194baeae0bca1fae6629dc0cbbb968d4d941469cbab11a3872edff374/MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158", size = 12393 }, + { url = "https://files.pythonhosted.org/packages/1d/69/35fa85a8ece0a437493dc61ce0bb6d459dcba482c34197e3efc829aa357f/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579", size = 21732 }, + { url = "https://files.pythonhosted.org/packages/22/35/137da042dfb4720b638d2937c38a9c2df83fe32d20e8c8f3185dbfef05f7/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d", size = 20866 }, + { url = "https://files.pythonhosted.org/packages/29/28/6d029a903727a1b62edb51863232152fd335d602def598dade38996887f0/MarkupSafe-3.0.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb", size = 20964 }, + { url = "https://files.pythonhosted.org/packages/cc/cd/07438f95f83e8bc028279909d9c9bd39e24149b0d60053a97b2bc4f8aa51/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b", size = 21977 }, + { url = "https://files.pythonhosted.org/packages/29/01/84b57395b4cc062f9c4c55ce0df7d3108ca32397299d9df00fedd9117d3d/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c", size = 21366 }, + { url = "https://files.pythonhosted.org/packages/bd/6e/61ebf08d8940553afff20d1fb1ba7294b6f8d279df9fd0c0db911b4bbcfd/MarkupSafe-3.0.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171", size = 21091 }, + { url = "https://files.pythonhosted.org/packages/11/23/ffbf53694e8c94ebd1e7e491de185124277964344733c45481f32ede2499/MarkupSafe-3.0.2-cp310-cp310-win32.whl", hash = "sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50", size = 15065 }, + { url = "https://files.pythonhosted.org/packages/44/06/e7175d06dd6e9172d4a69a72592cb3f7a996a9c396eee29082826449bbc3/MarkupSafe-3.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a", size = 15514 }, + { url = "https://files.pythonhosted.org/packages/6b/28/bbf83e3f76936960b850435576dd5e67034e200469571be53f69174a2dfd/MarkupSafe-3.0.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d", size = 14353 }, + { url = "https://files.pythonhosted.org/packages/6c/30/316d194b093cde57d448a4c3209f22e3046c5bb2fb0820b118292b334be7/MarkupSafe-3.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93", size = 12392 }, + { url = "https://files.pythonhosted.org/packages/f2/96/9cdafba8445d3a53cae530aaf83c38ec64c4d5427d975c974084af5bc5d2/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832", size = 23984 }, + { url = "https://files.pythonhosted.org/packages/f1/a4/aefb044a2cd8d7334c8a47d3fb2c9f328ac48cb349468cc31c20b539305f/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84", size = 23120 }, + { url = "https://files.pythonhosted.org/packages/8d/21/5e4851379f88f3fad1de30361db501300d4f07bcad047d3cb0449fc51f8c/MarkupSafe-3.0.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca", size = 23032 }, + { url = "https://files.pythonhosted.org/packages/00/7b/e92c64e079b2d0d7ddf69899c98842f3f9a60a1ae72657c89ce2655c999d/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798", size = 24057 }, + { url = "https://files.pythonhosted.org/packages/f9/ac/46f960ca323037caa0a10662ef97d0a4728e890334fc156b9f9e52bcc4ca/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e", size = 23359 }, + { url = "https://files.pythonhosted.org/packages/69/84/83439e16197337b8b14b6a5b9c2105fff81d42c2a7c5b58ac7b62ee2c3b1/MarkupSafe-3.0.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4", size = 23306 }, + { url = "https://files.pythonhosted.org/packages/9a/34/a15aa69f01e2181ed8d2b685c0d2f6655d5cca2c4db0ddea775e631918cd/MarkupSafe-3.0.2-cp311-cp311-win32.whl", hash = "sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d", size = 15094 }, + { url = "https://files.pythonhosted.org/packages/da/b8/3a3bd761922d416f3dc5d00bfbed11f66b1ab89a0c2b6e887240a30b0f6b/MarkupSafe-3.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b", size = 15521 }, +] + +[[package]] +name = "matplotlib" +version = "3.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "contourpy" }, + { name = "cycler" }, + { name = "fonttools" }, + { name = "kiwisolver" }, + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "pyparsing" }, + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/68/dd/fa2e1a45fce2d09f4aea3cee169760e672c8262325aa5796c49d543dc7e6/matplotlib-3.10.0.tar.gz", hash = "sha256:b886d02a581b96704c9d1ffe55709e49b4d2d52709ccebc4be42db856e511278", size = 36686418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/09/ec/3cdff7b5239adaaacefcc4f77c316dfbbdf853c4ed2beec467e0fec31b9f/matplotlib-3.10.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2c5829a5a1dd5a71f0e31e6e8bb449bc0ee9dbfb05ad28fc0c6b55101b3a4be6", size = 8160551 }, + { url = "https://files.pythonhosted.org/packages/41/f2/b518f2c7f29895c9b167bf79f8529c63383ae94eaf49a247a4528e9a148d/matplotlib-3.10.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a2a43cbefe22d653ab34bb55d42384ed30f611bcbdea1f8d7f431011a2e1c62e", size = 8034853 }, + { url = "https://files.pythonhosted.org/packages/ed/8d/45754b4affdb8f0d1a44e4e2bcd932cdf35b256b60d5eda9f455bb293ed0/matplotlib-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:607b16c8a73943df110f99ee2e940b8a1cbf9714b65307c040d422558397dac5", size = 8446724 }, + { url = "https://files.pythonhosted.org/packages/09/5a/a113495110ae3e3395c72d82d7bc4802902e46dc797f6b041e572f195c56/matplotlib-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:01d2b19f13aeec2e759414d3bfe19ddfb16b13a1250add08d46d5ff6f9be83c6", size = 8583905 }, + { url = "https://files.pythonhosted.org/packages/12/b1/8b1655b4c9ed4600c817c419f7eaaf70082630efd7556a5b2e77a8a3cdaf/matplotlib-3.10.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5e6c6461e1fc63df30bf6f80f0b93f5b6784299f721bc28530477acd51bfc3d1", size = 9395223 }, + { url = "https://files.pythonhosted.org/packages/5a/85/b9a54d64585a6b8737a78a61897450403c30f39e0bd3214270bb0b96f002/matplotlib-3.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:994c07b9d9fe8d25951e3202a68c17900679274dadfc1248738dcfa1bd40d7f3", size = 8025355 }, + { url = "https://files.pythonhosted.org/packages/0c/f1/e37f6c84d252867d7ddc418fff70fc661cfd363179263b08e52e8b748e30/matplotlib-3.10.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:fd44fc75522f58612ec4a33958a7e5552562b7705b42ef1b4f8c0818e304a363", size = 8171677 }, + { url = "https://files.pythonhosted.org/packages/c7/8b/92e9da1f28310a1f6572b5c55097b0c0ceb5e27486d85fb73b54f5a9b939/matplotlib-3.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c58a9622d5dbeb668f407f35f4e6bfac34bb9ecdcc81680c04d0258169747997", size = 8044945 }, + { url = "https://files.pythonhosted.org/packages/c5/cb/49e83f0fd066937a5bd3bc5c5d63093703f3637b2824df8d856e0558beef/matplotlib-3.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:845d96568ec873be63f25fa80e9e7fae4be854a66a7e2f0c8ccc99e94a8bd4ef", size = 8458269 }, + { url = "https://files.pythonhosted.org/packages/b2/7d/2d873209536b9ee17340754118a2a17988bc18981b5b56e6715ee07373ac/matplotlib-3.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5439f4c5a3e2e8eab18e2f8c3ef929772fd5641876db71f08127eed95ab64683", size = 8599369 }, + { url = "https://files.pythonhosted.org/packages/b8/03/57d6cbbe85c61fe4cbb7c94b54dce443d68c21961830833a1f34d056e5ea/matplotlib-3.10.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4673ff67a36152c48ddeaf1135e74ce0d4bce1bbf836ae40ed39c29edf7e2765", size = 9405992 }, + { url = "https://files.pythonhosted.org/packages/14/cf/e382598f98be11bf51dd0bc60eca44a517f6793e3dc8b9d53634a144620c/matplotlib-3.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:7e8632baebb058555ac0cde75db885c61f1212e47723d63921879806b40bec6a", size = 8034580 }, + { url = "https://files.pythonhosted.org/packages/32/5f/29def7ce4e815ab939b56280976ee35afffb3bbdb43f332caee74cb8c951/matplotlib-3.10.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:81713dd0d103b379de4516b861d964b1d789a144103277769238c732229d7f03", size = 8155500 }, + { url = "https://files.pythonhosted.org/packages/de/6d/d570383c9f7ca799d0a54161446f9ce7b17d6c50f2994b653514bcaa108f/matplotlib-3.10.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:359f87baedb1f836ce307f0e850d12bb5f1936f70d035561f90d41d305fdacea", size = 8032398 }, + { url = "https://files.pythonhosted.org/packages/c9/b4/680aa700d99b48e8c4393fa08e9ab8c49c0555ee6f4c9c0a5e8ea8dfde5d/matplotlib-3.10.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae80dc3a4add4665cf2faa90138384a7ffe2a4e37c58d83e115b54287c4f06ef", size = 8587361 }, +] + +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899 }, +] + +[[package]] +name = "mdit-py-plugins" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/03/a2ecab526543b152300717cf232bb4bb8605b6edb946c845016fa9c9c9fd/mdit_py_plugins-0.4.2.tar.gz", hash = "sha256:5f2cd1fdb606ddf152d37ec30e46101a60512bc0e5fa1a7002c36647b09e26b5", size = 43542 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/f7/7782a043553ee469c1ff49cfa1cdace2d6bf99a1f333cf38676b3ddf30da/mdit_py_plugins-0.4.2-py3-none-any.whl", hash = "sha256:0c673c3f889399a33b95e88d2f0d111b4447bdfea7f237dab2d488f459835636", size = 55316 }, +] + +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, +] + +[[package]] +name = "ml-dtypes" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/32/49/6e67c334872d2c114df3020e579f3718c333198f8312290e09ec0216703a/ml_dtypes-0.5.1.tar.gz", hash = "sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9", size = 698772 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/88/11ebdbc75445eeb5b6869b708a0d787d1ed812ff86c2170bbfb95febdce1/ml_dtypes-0.5.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190", size = 671450 }, + { url = "https://files.pythonhosted.org/packages/a4/a4/9321cae435d6140f9b0e7af8334456a854b60e3a9c6101280a16e3594965/ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed", size = 4621075 }, + { url = "https://files.pythonhosted.org/packages/16/d8/4502e12c6a10d42e13a552e8d97f20198e3cf82a0d1411ad50be56a5077c/ml_dtypes-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe", size = 4738414 }, + { url = "https://files.pythonhosted.org/packages/6b/7e/bc54ae885e4d702e60a4bf50aa9066ff35e9c66b5213d11091f6bffb3036/ml_dtypes-0.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4", size = 209718 }, + { url = "https://files.pythonhosted.org/packages/c9/fd/691335926126bb9beeb030b61a28f462773dcf16b8e8a2253b599013a303/ml_dtypes-0.5.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327", size = 671448 }, + { url = "https://files.pythonhosted.org/packages/ff/a6/63832d91f2feb250d865d069ba1a5d0c686b1f308d1c74ce9764472c5e22/ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f", size = 4625792 }, + { url = "https://files.pythonhosted.org/packages/cc/2a/5421fd3dbe6eef9b844cc9d05f568b9fb568503a2e51cb1eb4443d9fc56b/ml_dtypes-0.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab", size = 4743893 }, + { url = "https://files.pythonhosted.org/packages/60/30/d3f0fc9499a22801219679a7f3f8d59f1429943c6261f445fb4bfce20718/ml_dtypes-0.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478", size = 209712 }, +] + +[[package]] +name = "more-itertools" +version = "10.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/88/3b/7fa1fe835e2e93fd6d7b52b2f95ae810cf5ba133e1845f726f5a992d62c2/more-itertools-10.6.0.tar.gz", hash = "sha256:2cd7fad1009c31cc9fb6a035108509e6547547a7a738374f10bd49a09eb3ee3b", size = 125009 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/62/0fe302c6d1be1c777cab0616e6302478251dfbf9055ad426f5d0def75c89/more_itertools-10.6.0-py3-none-any.whl", hash = "sha256:6eb054cb4b6db1473f6e15fcc676a08e4732548acd47c708f0e179c2c7c01e89", size = 63038 }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198 }, +] + +[[package]] +name = "msgpack" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/d0/7555686ae7ff5731205df1012ede15dd9d927f6227ea151e901c7406af4f/msgpack-1.1.0.tar.gz", hash = "sha256:dd432ccc2c72b914e4cb77afce64aab761c1137cc698be3984eee260bcb2896e", size = 167260 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/f9/a892a6038c861fa849b11a2bb0502c07bc698ab6ea53359e5771397d883b/msgpack-1.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:7ad442d527a7e358a469faf43fda45aaf4ac3249c8310a82f0ccff9164e5dccd", size = 150428 }, + { url = "https://files.pythonhosted.org/packages/df/7a/d174cc6a3b6bb85556e6a046d3193294a92f9a8e583cdbd46dc8a1d7e7f4/msgpack-1.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:74bed8f63f8f14d75eec75cf3d04ad581da6b914001b474a5d3cd3372c8cc27d", size = 84131 }, + { url = "https://files.pythonhosted.org/packages/08/52/bf4fbf72f897a23a56b822997a72c16de07d8d56d7bf273242f884055682/msgpack-1.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:914571a2a5b4e7606997e169f64ce53a8b1e06f2cf2c3a7273aa106236d43dd5", size = 81215 }, + { url = "https://files.pythonhosted.org/packages/02/95/dc0044b439b518236aaf012da4677c1b8183ce388411ad1b1e63c32d8979/msgpack-1.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c921af52214dcbb75e6bdf6a661b23c3e6417f00c603dd2070bccb5c3ef499f5", size = 371229 }, + { url = "https://files.pythonhosted.org/packages/ff/75/09081792db60470bef19d9c2be89f024d366b1e1973c197bb59e6aabc647/msgpack-1.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8ce0b22b890be5d252de90d0e0d119f363012027cf256185fc3d474c44b1b9e", size = 378034 }, + { url = "https://files.pythonhosted.org/packages/32/d3/c152e0c55fead87dd948d4b29879b0f14feeeec92ef1fd2ec21b107c3f49/msgpack-1.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:73322a6cc57fcee3c0c57c4463d828e9428275fb85a27aa2aa1a92fdc42afd7b", size = 363070 }, + { url = "https://files.pythonhosted.org/packages/d9/2c/82e73506dd55f9e43ac8aa007c9dd088c6f0de2aa19e8f7330e6a65879fc/msgpack-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e1f3c3d21f7cf67bcf2da8e494d30a75e4cf60041d98b3f79875afb5b96f3a3f", size = 359863 }, + { url = "https://files.pythonhosted.org/packages/cb/a0/3d093b248837094220e1edc9ec4337de3443b1cfeeb6e0896af8ccc4cc7a/msgpack-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:64fc9068d701233effd61b19efb1485587560b66fe57b3e50d29c5d78e7fef68", size = 368166 }, + { url = "https://files.pythonhosted.org/packages/e4/13/7646f14f06838b406cf5a6ddbb7e8dc78b4996d891ab3b93c33d1ccc8678/msgpack-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:42f754515e0f683f9c79210a5d1cad631ec3d06cea5172214d2176a42e67e19b", size = 370105 }, + { url = "https://files.pythonhosted.org/packages/67/fa/dbbd2443e4578e165192dabbc6a22c0812cda2649261b1264ff515f19f15/msgpack-1.1.0-cp310-cp310-win32.whl", hash = "sha256:3df7e6b05571b3814361e8464f9304c42d2196808e0119f55d0d3e62cd5ea044", size = 68513 }, + { url = "https://files.pythonhosted.org/packages/24/ce/c2c8fbf0ded750cb63cbcbb61bc1f2dfd69e16dca30a8af8ba80ec182dcd/msgpack-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:685ec345eefc757a7c8af44a3032734a739f8c45d1b0ac45efc5d8977aa4720f", size = 74687 }, + { url = "https://files.pythonhosted.org/packages/b7/5e/a4c7154ba65d93be91f2f1e55f90e76c5f91ccadc7efc4341e6f04c8647f/msgpack-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3d364a55082fb2a7416f6c63ae383fbd903adb5a6cf78c5b96cc6316dc1cedc7", size = 150803 }, + { url = "https://files.pythonhosted.org/packages/60/c2/687684164698f1d51c41778c838d854965dd284a4b9d3a44beba9265c931/msgpack-1.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:79ec007767b9b56860e0372085f8504db5d06bd6a327a335449508bbee9648fa", size = 84343 }, + { url = "https://files.pythonhosted.org/packages/42/ae/d3adea9bb4a1342763556078b5765e666f8fdf242e00f3f6657380920972/msgpack-1.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6ad622bf7756d5a497d5b6836e7fc3752e2dd6f4c648e24b1803f6048596f701", size = 81408 }, + { url = "https://files.pythonhosted.org/packages/dc/17/6313325a6ff40ce9c3207293aee3ba50104aed6c2c1559d20d09e5c1ff54/msgpack-1.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e59bca908d9ca0de3dc8684f21ebf9a690fe47b6be93236eb40b99af28b6ea6", size = 396096 }, + { url = "https://files.pythonhosted.org/packages/a8/a1/ad7b84b91ab5a324e707f4c9761633e357820b011a01e34ce658c1dda7cc/msgpack-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1da8f11a3dd397f0a32c76165cf0c4eb95b31013a94f6ecc0b280c05c91b59", size = 403671 }, + { url = "https://files.pythonhosted.org/packages/bb/0b/fd5b7c0b308bbf1831df0ca04ec76fe2f5bf6319833646b0a4bd5e9dc76d/msgpack-1.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:452aff037287acb1d70a804ffd022b21fa2bb7c46bee884dbc864cc9024128a0", size = 387414 }, + { url = "https://files.pythonhosted.org/packages/f0/03/ff8233b7c6e9929a1f5da3c7860eccd847e2523ca2de0d8ef4878d354cfa/msgpack-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8da4bf6d54ceed70e8861f833f83ce0814a2b72102e890cbdfe4b34764cdd66e", size = 383759 }, + { url = "https://files.pythonhosted.org/packages/1f/1b/eb82e1fed5a16dddd9bc75f0854b6e2fe86c0259c4353666d7fab37d39f4/msgpack-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:41c991beebf175faf352fb940bf2af9ad1fb77fd25f38d9142053914947cdbf6", size = 394405 }, + { url = "https://files.pythonhosted.org/packages/90/2e/962c6004e373d54ecf33d695fb1402f99b51832631e37c49273cc564ffc5/msgpack-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a52a1f3a5af7ba1c9ace055b659189f6c669cf3657095b50f9602af3a3ba0fe5", size = 396041 }, + { url = "https://files.pythonhosted.org/packages/f8/20/6e03342f629474414860c48aeffcc2f7f50ddaf351d95f20c3f1c67399a8/msgpack-1.1.0-cp311-cp311-win32.whl", hash = "sha256:58638690ebd0a06427c5fe1a227bb6b8b9fdc2bd07701bec13c2335c82131a88", size = 68538 }, + { url = "https://files.pythonhosted.org/packages/aa/c4/5a582fc9a87991a3e6f6800e9bb2f3c82972912235eb9539954f3e9997c7/msgpack-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fd2906780f25c8ed5d7b323379f6138524ba793428db5d0e9d226d3fa6aa1788", size = 74871 }, +] + +[[package]] +name = "mypy" +version = "1.14.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/eb/2c92d8ea1e684440f54fa49ac5d9a5f19967b7b472a281f419e69a8d228e/mypy-1.14.1.tar.gz", hash = "sha256:7ec88144fe9b510e8475ec2f5f251992690fcf89ccb4500b214b4226abcd32d6", size = 3216051 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/7a/87ae2adb31d68402da6da1e5f30c07ea6063e9f09b5e7cfc9dfa44075e74/mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb", size = 11211002 }, + { url = "https://files.pythonhosted.org/packages/e1/23/eada4c38608b444618a132be0d199b280049ded278b24cbb9d3fc59658e4/mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0", size = 10358400 }, + { url = "https://files.pythonhosted.org/packages/43/c9/d6785c6f66241c62fd2992b05057f404237deaad1566545e9f144ced07f5/mypy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:90716d8b2d1f4cd503309788e51366f07c56635a3309b0f6a32547eaaa36a64d", size = 12095172 }, + { url = "https://files.pythonhosted.org/packages/c3/62/daa7e787770c83c52ce2aaf1a111eae5893de9e004743f51bfcad9e487ec/mypy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2ae753f5c9fef278bcf12e1a564351764f2a6da579d4a81347e1d5a15819997b", size = 12828732 }, + { url = "https://files.pythonhosted.org/packages/1b/a2/5fb18318a3637f29f16f4e41340b795da14f4751ef4f51c99ff39ab62e52/mypy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e0fe0f5feaafcb04505bcf439e991c6d8f1bf8b15f12b05feeed96e9e7bf1427", size = 13012197 }, + { url = "https://files.pythonhosted.org/packages/28/99/e153ce39105d164b5f02c06c35c7ba958aaff50a2babba7d080988b03fe7/mypy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:7d54bd85b925e501c555a3227f3ec0cfc54ee8b6930bd6141ec872d1c572f81f", size = 9780836 }, + { url = "https://files.pythonhosted.org/packages/da/11/a9422850fd506edbcdc7f6090682ecceaf1f87b9dd847f9df79942da8506/mypy-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f995e511de847791c3b11ed90084a7a0aafdc074ab88c5a9711622fe4751138c", size = 11120432 }, + { url = "https://files.pythonhosted.org/packages/b6/9e/47e450fd39078d9c02d620545b2cb37993a8a8bdf7db3652ace2f80521ca/mypy-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d64169ec3b8461311f8ce2fd2eb5d33e2d0f2c7b49116259c51d0d96edee48d1", size = 10279515 }, + { url = "https://files.pythonhosted.org/packages/01/b5/6c8d33bd0f851a7692a8bfe4ee75eb82b6983a3cf39e5e32a5d2a723f0c1/mypy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ba24549de7b89b6381b91fbc068d798192b1b5201987070319889e93038967a8", size = 12025791 }, + { url = "https://files.pythonhosted.org/packages/f0/4c/e10e2c46ea37cab5c471d0ddaaa9a434dc1d28650078ac1b56c2d7b9b2e4/mypy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:183cf0a45457d28ff9d758730cd0210419ac27d4d3f285beda038c9083363b1f", size = 12749203 }, + { url = "https://files.pythonhosted.org/packages/88/55/beacb0c69beab2153a0f57671ec07861d27d735a0faff135a494cd4f5020/mypy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f2a0ecc86378f45347f586e4163d1769dd81c5a223d577fe351f26b179e148b1", size = 12885900 }, + { url = "https://files.pythonhosted.org/packages/a2/75/8c93ff7f315c4d086a2dfcde02f713004357d70a163eddb6c56a6a5eff40/mypy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:ad3301ebebec9e8ee7135d8e3109ca76c23752bac1e717bc84cd3836b4bf3eae", size = 9777869 }, + { url = "https://files.pythonhosted.org/packages/a0/b5/32dd67b69a16d088e533962e5044e51004176a9952419de0370cdaead0f8/mypy-1.14.1-py3-none-any.whl", hash = "sha256:b66a60cc4073aeb8ae00057f9c1f64d49e90f918fbcef9a977eb121da8b8f1d1", size = 2752905 }, +] + +[package.optional-dependencies] +faster-cache = [ + { name = "orjson" }, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, +] + +[[package]] +name = "myst-parser" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "jinja2" }, + { name = "markdown-it-py" }, + { name = "mdit-py-plugins" }, + { name = "pyyaml" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/55/6d1741a1780e5e65038b74bce6689da15f620261c490c3511eb4c12bac4b/myst_parser-4.0.0.tar.gz", hash = "sha256:851c9dfb44e36e56d15d05e72f02b80da21a9e0d07cba96baf5e2d476bb91531", size = 93858 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/b4/b036f8fdb667587bb37df29dc6644681dd78b7a2a6321a34684b79412b28/myst_parser-4.0.0-py3-none-any.whl", hash = "sha256:b9317997552424448c6096c2558872fdb6f81d3ecb3a40ce84a7518798f3f28d", size = 84563 }, +] + +[[package]] +name = "nanobind" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/fa/8e5930837f9b08202c4e566cf529480b0c3266e88f39723388baf8c69700/nanobind-2.5.0.tar.gz", hash = "sha256:cc8412e94acffa20a369191382bcdbb6fbfb302e475e87cacff9516d51023a15", size = 962802 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/9e/dadc3831f40e22c1b3925f07894646ada7906ef5b48db5c5eb2b03ca9faa/nanobind-2.5.0-py3-none-any.whl", hash = "sha256:e1e5c816e5d10f0b252d82ba7f769f0f6679f5e043cf406aec3d9e184bf2a60d", size = 236912 }, +] + +[[package]] +name = "natsort" +version = "8.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e2/a9/a0c57aee75f77794adaf35322f8b6404cbd0f89ad45c87197a937764b7d0/natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581", size = 76575 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/82/7a9d0550484a62c6da82858ee9419f3dd1ccc9aa1c26a1e43da3ecd20b0d/natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c", size = 38268 }, +] + +[[package]] +name = "nbclient" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "nbformat" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/66/7ffd18d58eae90d5721f9f39212327695b749e23ad44b3881744eaf4d9e8/nbclient-0.10.2.tar.gz", hash = "sha256:90b7fc6b810630db87a6d0c2250b1f0ab4cf4d3c27a299b0cde78a4ed3fd9193", size = 62424 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/6d/e7fa07f03a4a7b221d94b4d586edb754a9b0dc3c9e2c93353e9fa4e0d117/nbclient-0.10.2-py3-none-any.whl", hash = "sha256:4ffee11e788b4a27fabeb7955547e4318a5298f34342a4bfd01f2e1faaeadc3d", size = 25434 }, +] + +[[package]] +name = "nbformat" +version = "5.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastjsonschema" }, + { name = "jsonschema" }, + { name = "jupyter-core" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/fd/91545e604bc3dad7dca9ed03284086039b294c6b3d75c0d2fa45f9e9caf3/nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a", size = 142749 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b", size = 78454 }, +] + +[[package]] +name = "nbmake" +version = "1.5.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ipykernel" }, + { name = "nbclient" }, + { name = "nbformat" }, + { name = "pygments" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/9a/aae201cee5639e1d562b3843af8fd9f8d018bb323e776a2b973bdd5fc64b/nbmake-1.5.5.tar.gz", hash = "sha256:239dc868ea13a7c049746e2aba2c229bd0f6cdbc6bfa1d22f4c88638aa4c5f5c", size = 85929 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/be/b257e12f9710819fde40adc972578bee6b72c5992da1bc8369bef2597756/nbmake-1.5.5-py3-none-any.whl", hash = "sha256:c6fbe6e48b60cacac14af40b38bf338a3b88f47f085c54ac5b8639ff0babaf4b", size = 12818 }, +] + +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195 }, +] + +[[package]] +name = "networkx" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fd/1d/06475e1cd5264c0b870ea2cc6fdb3e37177c1e565c43f56ff17a10e3937f/networkx-3.4.2.tar.gz", hash = "sha256:307c3669428c5362aab27c8a1260aa8f47c4e91d3891f48be0141738d8d053e1", size = 2151368 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 }, +] + +[[package]] +name = "ninja" +version = "1.11.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/8f/21a2701f95b7d0d5137736561b3427ece0c4a1e085d4a223b92d16ab7d8b/ninja-1.11.1.3.tar.gz", hash = "sha256:edfa0d2e9d7ead1635b03e40a32ad56cc8f56798b6e2e9848d8300b174897076", size = 129532 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/ba/0069cd4a83d68f7b0308be70e219b15d675e50c8ea28763a3f0373c45bfc/ninja-1.11.1.3-py3-none-macosx_10_9_universal2.whl", hash = "sha256:2b4879ea3f1169f3d855182c57dcc84d1b5048628c8b7be0d702b81882a37237", size = 279132 }, + { url = "https://files.pythonhosted.org/packages/72/6b/3805be87df8417a0c7b21078c8045f2a1e59b34f371bfe4cb4fb0d6df7f2/ninja-1.11.1.3-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:bc3ebc8b2e47716149f3541742b5cd8e0b08f51013b825c05baca3e34854370d", size = 472101 }, + { url = "https://files.pythonhosted.org/packages/6b/35/a8e38d54768e67324e365e2a41162be298f51ec93e6bd4b18d237d7250d8/ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:a27e78ca71316c8654965ee94b286a98c83877bfebe2607db96897bbfe458af0", size = 422884 }, + { url = "https://files.pythonhosted.org/packages/2f/99/7996457319e139c02697fb2aa28e42fe32bb0752cef492edc69d56a3552e/ninja-1.11.1.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2883ea46b3c5079074f56820f9989c6261fcc6fd873d914ee49010ecf283c3b2", size = 157046 }, + { url = "https://files.pythonhosted.org/packages/6d/8b/93f38e5cddf76ccfdab70946515b554f25d2b4c95ef9b2f9cfbc43fa7cc1/ninja-1.11.1.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c4bdb9fd2d0c06501ae15abfd23407660e95659e384acd36e013b6dd7d8a8e4", size = 180014 }, + { url = "https://files.pythonhosted.org/packages/7d/1d/713884d0fa3c972164f69d552e0701d30e2bf25eba9ef160bfb3dc69926a/ninja-1.11.1.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:114ed5c61c8474df6a69ab89097a20749b769e2c219a452cb2fadc49b0d581b0", size = 157098 }, + { url = "https://files.pythonhosted.org/packages/c7/22/ecb0f70e77c9e22ee250aa717a608a142756833a34d43943d7d658ee0e56/ninja-1.11.1.3-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:7fa2247fce98f683bc712562d82b22b8a0a5c000738a13147ca2d1b68c122298", size = 130089 }, + { url = "https://files.pythonhosted.org/packages/ec/a6/3ee846c20ab6ad95b90c5c8703c76cb1f39cc8ce2d1ae468956e3b1b2581/ninja-1.11.1.3-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:a38c6c6c8032bed68b70c3b065d944c35e9f903342875d3a3218c1607987077c", size = 372508 }, + { url = "https://files.pythonhosted.org/packages/95/0d/aa44abe4141f29148ce671ac8c92045878906b18691c6f87a29711c2ff1c/ninja-1.11.1.3-py3-none-musllinux_1_1_i686.whl", hash = "sha256:56ada5d33b8741d298836644042faddebc83ee669782d661e21563034beb5aba", size = 419369 }, + { url = "https://files.pythonhosted.org/packages/f7/ec/48bf5105568ac9bd2016b701777bdd5000cc09a14ac837fef9f15e8d634e/ninja-1.11.1.3-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:53409151da081f3c198bb0bfc220a7f4e821e022c5b7d29719adda892ddb31bb", size = 420304 }, + { url = "https://files.pythonhosted.org/packages/18/e5/69df63976cf971a03379899f8520a036c9dbab26330b37197512aed5b3df/ninja-1.11.1.3-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:1ad2112c2b0159ed7c4ae3731595191b1546ba62316fc40808edecd0306fefa3", size = 416056 }, + { url = "https://files.pythonhosted.org/packages/6f/4f/bdb401af7ed0e24a3fef058e13a149f2de1ce4b176699076993615d55610/ninja-1.11.1.3-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:28aea3c1c280cba95b8608d50797169f3a34280e3e9a6379b6e340f0c9eaeeb0", size = 379725 }, + { url = "https://files.pythonhosted.org/packages/bd/68/05e7863bf13128c61652eeb3ec7096c3d3a602f32f31752dbfb034e3fa07/ninja-1.11.1.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b6966f83064a88a51693073eea3decd47e08c3965241e09578ef7aa3a7738329", size = 434881 }, + { url = "https://files.pythonhosted.org/packages/bd/ad/edc0d1efe77f29f45bbca2e1dab07ef597f61a88de6e4bccffc0aec2256c/ninja-1.11.1.3-py3-none-win32.whl", hash = "sha256:a4a3b71490557e18c010cbb26bd1ea9a0c32ee67e8f105e9731515b6e0af792e", size = 255988 }, + { url = "https://files.pythonhosted.org/packages/03/93/09a9f7672b4f97438aca6217ac54212a63273f1cd3b46b731d0bb22c53e7/ninja-1.11.1.3-py3-none-win_amd64.whl", hash = "sha256:04d48d14ea7ba11951c156599ab526bdda575450797ff57c6fdf99b2554d09c7", size = 296502 }, + { url = "https://files.pythonhosted.org/packages/d9/9d/0cc1e82849070ff3cbee69f326cb48a839407bcd15d8844443c30a5e7509/ninja-1.11.1.3-py3-none-win_arm64.whl", hash = "sha256:17978ad611d8ead578d83637f5ae80c2261b033db0b493a7ce94f88623f29e1b", size = 270571 }, +] + +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 }, +] + +[[package]] +name = "nox" +version = "2024.10.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argcomplete" }, + { name = "colorlog" }, + { name = "packaging" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/93/4df547afcd56e0b2bbaa99bc2637deb218a01802ed62d80f763189be802c/nox-2024.10.9.tar.gz", hash = "sha256:7aa9dc8d1c27e9f45ab046ffd1c3b2c4f7c91755304769df231308849ebded95", size = 4003197 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/00/981f0dcaddf111b6caf6e03d7f7f01b07fd4af117316a7eb1c22039d9e37/nox-2024.10.9-py3-none-any.whl", hash = "sha256:1d36f309a0a2a853e9bccb76bbef6bb118ba92fa92674d15604ca99adeb29eab", size = 61210 }, +] + +[[package]] +name = "numpy" +version = "1.26.4" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.11'", + "python_full_version < '3.11'", +] +sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/94/ace0fdea5241a27d13543ee117cbc65868e82213fb31a8eb7fe9ff23f313/numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0", size = 20631468 }, + { url = "https://files.pythonhosted.org/packages/20/f7/b24208eba89f9d1b58c1668bc6c8c4fd472b20c45573cb767f59d49fb0f6/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a", size = 13966411 }, + { url = "https://files.pythonhosted.org/packages/fc/a5/4beee6488160798683eed5bdb7eead455892c3b4e1f78d79d8d3f3b084ac/numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4", size = 14219016 }, + { url = "https://files.pythonhosted.org/packages/4b/d7/ecf66c1cd12dc28b4040b15ab4d17b773b87fa9d29ca16125de01adb36cd/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f", size = 18240889 }, + { url = "https://files.pythonhosted.org/packages/24/03/6f229fe3187546435c4f6f89f6d26c129d4f5bed40552899fcf1f0bf9e50/numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a", size = 13876746 }, + { url = "https://files.pythonhosted.org/packages/39/fe/39ada9b094f01f5a35486577c848fe274e374bbf8d8f472e1423a0bbd26d/numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2", size = 18078620 }, + { url = "https://files.pythonhosted.org/packages/d5/ef/6ad11d51197aad206a9ad2286dc1aac6a378059e06e8cf22cd08ed4f20dc/numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07", size = 5972659 }, + { url = "https://files.pythonhosted.org/packages/19/77/538f202862b9183f54108557bfda67e17603fc560c384559e769321c9d92/numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5", size = 15808905 }, + { url = "https://files.pythonhosted.org/packages/11/57/baae43d14fe163fa0e4c47f307b6b2511ab8d7d30177c491960504252053/numpy-1.26.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71", size = 20630554 }, + { url = "https://files.pythonhosted.org/packages/1a/2e/151484f49fd03944c4a3ad9c418ed193cfd02724e138ac8a9505d056c582/numpy-1.26.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef", size = 13997127 }, + { url = "https://files.pythonhosted.org/packages/79/ae/7e5b85136806f9dadf4878bf73cf223fe5c2636818ba3ab1c585d0403164/numpy-1.26.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e", size = 14222994 }, + { url = "https://files.pythonhosted.org/packages/3a/d0/edc009c27b406c4f9cbc79274d6e46d634d139075492ad055e3d68445925/numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5", size = 18252005 }, + { url = "https://files.pythonhosted.org/packages/09/bf/2b1aaf8f525f2923ff6cfcf134ae5e750e279ac65ebf386c75a0cf6da06a/numpy-1.26.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a", size = 13885297 }, + { url = "https://files.pythonhosted.org/packages/df/a0/4e0f14d847cfc2a633a1c8621d00724f3206cfeddeb66d35698c4e2cf3d2/numpy-1.26.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a", size = 18093567 }, + { url = "https://files.pythonhosted.org/packages/d2/b7/a734c733286e10a7f1a8ad1ae8c90f2d33bf604a96548e0a4a3a6739b468/numpy-1.26.4-cp311-cp311-win32.whl", hash = "sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20", size = 5968812 }, + { url = "https://files.pythonhosted.org/packages/3f/6b/5610004206cf7f8e7ad91c5a85a8c71b2f2f8051a0c0c4d5916b76d6cbb2/numpy-1.26.4-cp311-cp311-win_amd64.whl", hash = "sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2", size = 15811913 }, +] + +[[package]] +name = "numpy" +version = "2.2.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.11'", + "python_full_version < '3.11'", +] +sdist = { url = "https://files.pythonhosted.org/packages/ec/d0/c12ddfd3a02274be06ffc71f3efc6d0e457b0409c4481596881e748cb264/numpy-2.2.2.tar.gz", hash = "sha256:ed6906f61834d687738d25988ae117683705636936cc605be0bb208b23df4d8f", size = 20233295 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/2a/69033dc22d981ad21325314f8357438078f5c28310a6d89fb3833030ec8a/numpy-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7079129b64cb78bdc8d611d1fd7e8002c0a2565da6a47c4df8062349fee90e3e", size = 21215825 }, + { url = "https://files.pythonhosted.org/packages/31/2c/39f91e00bbd3d5639b027ac48c55dc5f2992bd2b305412d26be4c830862a/numpy-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2ec6c689c61df613b783aeb21f945c4cbe6c51c28cb70aae8430577ab39f163e", size = 14354996 }, + { url = "https://files.pythonhosted.org/packages/0a/2c/d468ebd253851af10de5b3e8f3418ebabfaab5f0337a75299fbeb8b8c17a/numpy-2.2.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:40c7ff5da22cd391944a28c6a9c638a5eef77fcf71d6e3a79e1d9d9e82752715", size = 5393621 }, + { url = "https://files.pythonhosted.org/packages/7f/f4/3d8a5a0da297034106c5de92be881aca7079cde6058934215a1de91334f6/numpy-2.2.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:995f9e8181723852ca458e22de5d9b7d3ba4da3f11cc1cb113f093b271d7965a", size = 6928931 }, + { url = "https://files.pythonhosted.org/packages/47/a7/029354ab56edd43dd3f5efbfad292b8844f98b93174f322f82353fa46efa/numpy-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b78ea78450fd96a498f50ee096f69c75379af5138f7881a51355ab0e11286c97", size = 14333157 }, + { url = "https://files.pythonhosted.org/packages/e3/d7/11fc594838d35c43519763310c316d4fd56f8600d3fc80a8e13e325b5c5c/numpy-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3fbe72d347fbc59f94124125e73fc4976a06927ebc503ec5afbfb35f193cd957", size = 16381794 }, + { url = "https://files.pythonhosted.org/packages/af/d4/dd9b19cd4aff9c79d3f54d17f8be815407520d3116004bc574948336981b/numpy-2.2.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8e6da5cffbbe571f93588f562ed130ea63ee206d12851b60819512dd3e1ba50d", size = 15543990 }, + { url = "https://files.pythonhosted.org/packages/30/97/ab96b7650f27f684a9b1e46757a7294ecc50cab27701d05f146e9f779627/numpy-2.2.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:09d6a2032faf25e8d0cadde7fd6145118ac55d2740132c1d845f98721b5ebcfd", size = 18170896 }, + { url = "https://files.pythonhosted.org/packages/81/9b/bae9618cab20db67a2ca9d711795cad29b2ca4b73034dd3b5d05b962070a/numpy-2.2.2-cp310-cp310-win32.whl", hash = "sha256:159ff6ee4c4a36a23fe01b7c3d07bd8c14cc433d9720f977fcd52c13c0098160", size = 6573458 }, + { url = "https://files.pythonhosted.org/packages/92/9b/95678092febd14070cfb7906ea7932e71e9dd5a6ab3ee948f9ed975e905d/numpy-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:64bd6e1762cd7f0986a740fee4dff927b9ec2c5e4d9a28d056eb17d332158014", size = 12915812 }, + { url = "https://files.pythonhosted.org/packages/21/67/32c68756eed84df181c06528ff57e09138f893c4653448c4967311e0f992/numpy-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:642199e98af1bd2b6aeb8ecf726972d238c9877b0f6e8221ee5ab945ec8a2189", size = 21220002 }, + { url = "https://files.pythonhosted.org/packages/3b/89/f43bcad18f2b2e5814457b1c7f7b0e671d0db12c8c0e43397ab8cb1831ed/numpy-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d9fc9d812c81e6168b6d405bf00b8d6739a7f72ef22a9214c4241e0dc70b323", size = 14391215 }, + { url = "https://files.pythonhosted.org/packages/9c/e6/efb8cd6122bf25e86e3dd89d9dbfec9e6861c50e8810eed77d4be59b51c6/numpy-2.2.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:c7d1fd447e33ee20c1f33f2c8e6634211124a9aabde3c617687d8b739aa69eac", size = 5391918 }, + { url = "https://files.pythonhosted.org/packages/47/e2/fccf89d64d9b47ffb242823d4e851fc9d36fa751908c9aac2807924d9b4e/numpy-2.2.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:451e854cfae0febe723077bd0cf0a4302a5d84ff25f0bfece8f29206c7bed02e", size = 6933133 }, + { url = "https://files.pythonhosted.org/packages/34/22/5ece749c0e5420a9380eef6fbf83d16a50010bd18fef77b9193d80a6760e/numpy-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd249bc894af67cbd8bad2c22e7cbcd46cf87ddfca1f1289d1e7e54868cc785c", size = 14338187 }, + { url = "https://files.pythonhosted.org/packages/5b/86/caec78829311f62afa6fa334c8dfcd79cffb4d24bcf96ee02ae4840d462b/numpy-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02935e2c3c0c6cbe9c7955a8efa8908dd4221d7755644c59d1bba28b94fd334f", size = 16393429 }, + { url = "https://files.pythonhosted.org/packages/c8/4e/0c25f74c88239a37924577d6ad780f3212a50f4b4b5f54f5e8c918d726bd/numpy-2.2.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a972cec723e0563aa0823ee2ab1df0cb196ed0778f173b381c871a03719d4826", size = 15559103 }, + { url = "https://files.pythonhosted.org/packages/d4/bd/d557f10fa50dc4d5871fb9606af563249b66af2fc6f99041a10e8757c6f1/numpy-2.2.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d6d6a0910c3b4368d89dde073e630882cdb266755565155bc33520283b2d9df8", size = 18182967 }, + { url = "https://files.pythonhosted.org/packages/30/e9/66cc0f66386d78ed89e45a56e2a1d051e177b6e04477c4a41cd590ef4017/numpy-2.2.2-cp311-cp311-win32.whl", hash = "sha256:860fd59990c37c3ef913c3ae390b3929d005243acca1a86facb0773e2d8d9e50", size = 6571499 }, + { url = "https://files.pythonhosted.org/packages/66/a3/4139296b481ae7304a43581046b8f0a20da6a0dfe0ee47a044cade796603/numpy-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:da1eeb460ecce8d5b8608826595c777728cdf28ce7b5a5a8c8ac8d949beadcf2", size = 12919805 }, + { url = "https://files.pythonhosted.org/packages/96/7e/1dd770ee68916ed358991ab62c2cc353ffd98d0b75b901d52183ca28e8bb/numpy-2.2.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b0531f0b0e07643eb089df4c509d30d72c9ef40defa53e41363eca8a8cc61495", size = 21047291 }, + { url = "https://files.pythonhosted.org/packages/d1/3c/ccd08578dc532a8e6927952339d4a02682b776d5e85be49ed0760308433e/numpy-2.2.2-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:e9e82dcb3f2ebbc8cb5ce1102d5f1c5ed236bf8a11730fb45ba82e2841ec21df", size = 6792494 }, + { url = "https://files.pythonhosted.org/packages/7c/28/8754b9aee4f97199f9a047f73bb644b5a2014994a6d7b061ba67134a42de/numpy-2.2.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0d4142eb40ca6f94539e4db929410f2a46052a0fe7a2c1c59f6179c39938d2a", size = 16197312 }, + { url = "https://files.pythonhosted.org/packages/26/96/deb93f871f401045a684ca08a009382b247d14996d7a94fea6aa43c67b94/numpy-2.2.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:356ca982c188acbfa6af0d694284d8cf20e95b1c3d0aefa8929376fea9146f60", size = 12822674 }, +] + +[[package]] +name = "opt-einsum" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac", size = 63004 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd", size = 71932 }, +] + +[[package]] +name = "orderly-set" +version = "5.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/0e/ef328b512c2595831304e51f25e9287697b7bf13be0527ca9592a2659c16/orderly_set-5.3.0.tar.gz", hash = "sha256:80b3d8fdd3d39004d9aad389eaa0eab02c71f0a0511ba3a6d54a935a6c6a0acc", size = 20026 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/fe/8009ebb64a19cf4bdf51b16d3074375010735d8c30408efada6ce02bf37e/orderly_set-5.3.0-py3-none-any.whl", hash = "sha256:c2c0bfe604f5d3d9b24e8262a06feb612594f37aa3845650548befd7772945d1", size = 12179 }, +] + +[[package]] +name = "orjson" +version = "3.10.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/f9/5dea21763eeff8c1590076918a446ea3d6140743e0e36f58f369928ed0f4/orjson-3.10.15.tar.gz", hash = "sha256:05ca7fe452a2e9d8d9d706a2984c95b9c2ebc5db417ce0b7a49b91d50642a23e", size = 5282482 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/09/e5ff18ad009e6f97eb7edc5f67ef98b3ce0c189da9c3eaca1f9587cd4c61/orjson-3.10.15-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:552c883d03ad185f720d0c09583ebde257e41b9521b74ff40e08b7dec4559c04", size = 249532 }, + { url = "https://files.pythonhosted.org/packages/bd/b8/a75883301fe332bd433d9b0ded7d2bb706ccac679602c3516984f8814fb5/orjson-3.10.15-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:616e3e8d438d02e4854f70bfdc03a6bcdb697358dbaa6bcd19cbe24d24ece1f8", size = 125229 }, + { url = "https://files.pythonhosted.org/packages/83/4b/22f053e7a364cc9c685be203b1e40fc5f2b3f164a9b2284547504eec682e/orjson-3.10.15-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7c2c79fa308e6edb0ffab0a31fd75a7841bf2a79a20ef08a3c6e3b26814c8ca8", size = 150148 }, + { url = "https://files.pythonhosted.org/packages/63/64/1b54fc75ca328b57dd810541a4035fe48c12a161d466e3cf5b11a8c25649/orjson-3.10.15-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cb85490aa6bf98abd20607ab5c8324c0acb48d6da7863a51be48505646c814", size = 139748 }, + { url = "https://files.pythonhosted.org/packages/5e/ff/ff0c5da781807bb0a5acd789d9a7fbcb57f7b0c6e1916595da1f5ce69f3c/orjson-3.10.15-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:763dadac05e4e9d2bc14938a45a2d0560549561287d41c465d3c58aec818b164", size = 154559 }, + { url = "https://files.pythonhosted.org/packages/4e/9a/11e2974383384ace8495810d4a2ebef5f55aacfc97b333b65e789c9d362d/orjson-3.10.15-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a330b9b4734f09a623f74a7490db713695e13b67c959713b78369f26b3dee6bf", size = 130349 }, + { url = "https://files.pythonhosted.org/packages/2d/c4/dd9583aea6aefee1b64d3aed13f51d2aadb014028bc929fe52936ec5091f/orjson-3.10.15-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a61a4622b7ff861f019974f73d8165be1bd9a0855e1cad18ee167acacabeb061", size = 138514 }, + { url = "https://files.pythonhosted.org/packages/53/3e/dcf1729230654f5c5594fc752de1f43dcf67e055ac0d300c8cdb1309269a/orjson-3.10.15-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:acd271247691574416b3228db667b84775c497b245fa275c6ab90dc1ffbbd2b3", size = 130940 }, + { url = "https://files.pythonhosted.org/packages/e8/2b/b9759fe704789937705c8a56a03f6c03e50dff7df87d65cba9a20fec5282/orjson-3.10.15-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:e4759b109c37f635aa5c5cc93a1b26927bfde24b254bcc0e1149a9fada253d2d", size = 414713 }, + { url = "https://files.pythonhosted.org/packages/a7/6b/b9dfdbd4b6e20a59238319eb203ae07c3f6abf07eef909169b7a37ae3bba/orjson-3.10.15-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:9e992fd5cfb8b9f00bfad2fd7a05a4299db2bbe92e6440d9dd2fab27655b3182", size = 141028 }, + { url = "https://files.pythonhosted.org/packages/7c/b5/40f5bbea619c7caf75eb4d652a9821875a8ed04acc45fe3d3ef054ca69fb/orjson-3.10.15-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f95fb363d79366af56c3f26b71df40b9a583b07bbaaf5b317407c4d58497852e", size = 129715 }, + { url = "https://files.pythonhosted.org/packages/38/60/2272514061cbdf4d672edbca6e59c7e01cd1c706e881427d88f3c3e79761/orjson-3.10.15-cp310-cp310-win32.whl", hash = "sha256:f9875f5fea7492da8ec2444839dcc439b0ef298978f311103d0b7dfd775898ab", size = 142473 }, + { url = "https://files.pythonhosted.org/packages/11/5d/be1490ff7eafe7fef890eb4527cf5bcd8cfd6117f3efe42a3249ec847b60/orjson-3.10.15-cp310-cp310-win_amd64.whl", hash = "sha256:17085a6aa91e1cd70ca8533989a18b5433e15d29c574582f76f821737c8d5806", size = 133564 }, + { url = "https://files.pythonhosted.org/packages/7a/a2/21b25ce4a2c71dbb90948ee81bd7a42b4fbfc63162e57faf83157d5540ae/orjson-3.10.15-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:c4cc83960ab79a4031f3119cc4b1a1c627a3dc09df125b27c4201dff2af7eaa6", size = 249533 }, + { url = "https://files.pythonhosted.org/packages/b2/85/2076fc12d8225698a51278009726750c9c65c846eda741e77e1761cfef33/orjson-3.10.15-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ddbeef2481d895ab8be5185f2432c334d6dec1f5d1933a9c83014d188e102cef", size = 125230 }, + { url = "https://files.pythonhosted.org/packages/06/df/a85a7955f11274191eccf559e8481b2be74a7c6d43075d0a9506aa80284d/orjson-3.10.15-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9e590a0477b23ecd5b0ac865b1b907b01b3c5535f5e8a8f6ab0e503efb896334", size = 150148 }, + { url = "https://files.pythonhosted.org/packages/37/b3/94c55625a29b8767c0eed194cb000b3787e3c23b4cdd13be17bae6ccbb4b/orjson-3.10.15-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a6be38bd103d2fd9bdfa31c2720b23b5d47c6796bcb1d1b598e3924441b4298d", size = 139749 }, + { url = "https://files.pythonhosted.org/packages/53/ba/c608b1e719971e8ddac2379f290404c2e914cf8e976369bae3cad88768b1/orjson-3.10.15-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ff4f6edb1578960ed628a3b998fa54d78d9bb3e2eb2cfc5c2a09732431c678d0", size = 154558 }, + { url = "https://files.pythonhosted.org/packages/b2/c4/c1fb835bb23ad788a39aa9ebb8821d51b1c03588d9a9e4ca7de5b354fdd5/orjson-3.10.15-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0482b21d0462eddd67e7fce10b89e0b6ac56570424662b685a0d6fccf581e13", size = 130349 }, + { url = "https://files.pythonhosted.org/packages/78/14/bb2b48b26ab3c570b284eb2157d98c1ef331a8397f6c8bd983b270467f5c/orjson-3.10.15-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bb5cc3527036ae3d98b65e37b7986a918955f85332c1ee07f9d3f82f3a6899b5", size = 138513 }, + { url = "https://files.pythonhosted.org/packages/4a/97/d5b353a5fe532e92c46467aa37e637f81af8468aa894cd77d2ec8a12f99e/orjson-3.10.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d569c1c462912acdd119ccbf719cf7102ea2c67dd03b99edcb1a3048651ac96b", size = 130942 }, + { url = "https://files.pythonhosted.org/packages/b5/5d/a067bec55293cca48fea8b9928cfa84c623be0cce8141d47690e64a6ca12/orjson-3.10.15-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:1e6d33efab6b71d67f22bf2962895d3dc6f82a6273a965fab762e64fa90dc399", size = 414717 }, + { url = "https://files.pythonhosted.org/packages/6f/9a/1485b8b05c6b4c4db172c438cf5db5dcfd10e72a9bc23c151a1137e763e0/orjson-3.10.15-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:c33be3795e299f565681d69852ac8c1bc5c84863c0b0030b2b3468843be90388", size = 141033 }, + { url = "https://files.pythonhosted.org/packages/f8/d2/fc67523656e43a0c7eaeae9007c8b02e86076b15d591e9be11554d3d3138/orjson-3.10.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:eea80037b9fae5339b214f59308ef0589fc06dc870578b7cce6d71eb2096764c", size = 129720 }, + { url = "https://files.pythonhosted.org/packages/79/42/f58c7bd4e5b54da2ce2ef0331a39ccbbaa7699b7f70206fbf06737c9ed7d/orjson-3.10.15-cp311-cp311-win32.whl", hash = "sha256:d5ac11b659fd798228a7adba3e37c010e0152b78b1982897020a8e019a94882e", size = 142473 }, + { url = "https://files.pythonhosted.org/packages/00/f8/bb60a4644287a544ec81df1699d5b965776bc9848d9029d9f9b3402ac8bb/orjson-3.10.15-cp311-cp311-win_amd64.whl", hash = "sha256:cf45e0214c593660339ef63e875f32ddd5aa3b4adc15e662cdb80dc49e194f8e", size = 133570 }, +] + +[[package]] +name = "packaging" +version = "24.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, +] + +[[package]] +name = "parso" +version = "0.8.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 }, +] + +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191 }, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 }, +] + +[[package]] +name = "pillow" +version = "11.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/af/c097e544e7bd278333db77933e535098c259609c4eb3b85381109602fb5b/pillow-11.1.0.tar.gz", hash = "sha256:368da70808b36d73b4b390a8ffac11069f8a5c85f29eff1f1b01bcf3ef5b2a20", size = 46742715 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/1c/2dcea34ac3d7bc96a1fd1bd0a6e06a57c67167fec2cff8d95d88229a8817/pillow-11.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:e1abe69aca89514737465752b4bcaf8016de61b3be1397a8fc260ba33321b3a8", size = 3229983 }, + { url = "https://files.pythonhosted.org/packages/14/ca/6bec3df25e4c88432681de94a3531cc738bd85dea6c7aa6ab6f81ad8bd11/pillow-11.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c640e5a06869c75994624551f45e5506e4256562ead981cce820d5ab39ae2192", size = 3101831 }, + { url = "https://files.pythonhosted.org/packages/d4/2c/668e18e5521e46eb9667b09e501d8e07049eb5bfe39d56be0724a43117e6/pillow-11.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a07dba04c5e22824816b2615ad7a7484432d7f540e6fa86af60d2de57b0fcee2", size = 4314074 }, + { url = "https://files.pythonhosted.org/packages/02/80/79f99b714f0fc25f6a8499ecfd1f810df12aec170ea1e32a4f75746051ce/pillow-11.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e267b0ed063341f3e60acd25c05200df4193e15a4a5807075cd71225a2386e26", size = 4394933 }, + { url = "https://files.pythonhosted.org/packages/81/aa/8d4ad25dc11fd10a2001d5b8a80fdc0e564ac33b293bdfe04ed387e0fd95/pillow-11.1.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:bd165131fd51697e22421d0e467997ad31621b74bfc0b75956608cb2906dda07", size = 4353349 }, + { url = "https://files.pythonhosted.org/packages/84/7a/cd0c3eaf4a28cb2a74bdd19129f7726277a7f30c4f8424cd27a62987d864/pillow-11.1.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:abc56501c3fd148d60659aae0af6ddc149660469082859fa7b066a298bde9482", size = 4476532 }, + { url = "https://files.pythonhosted.org/packages/8f/8b/a907fdd3ae8f01c7670dfb1499c53c28e217c338b47a813af8d815e7ce97/pillow-11.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:54ce1c9a16a9561b6d6d8cb30089ab1e5eb66918cb47d457bd996ef34182922e", size = 4279789 }, + { url = "https://files.pythonhosted.org/packages/6f/9a/9f139d9e8cccd661c3efbf6898967a9a337eb2e9be2b454ba0a09533100d/pillow-11.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:73ddde795ee9b06257dac5ad42fcb07f3b9b813f8c1f7f870f402f4dc54b5269", size = 4413131 }, + { url = "https://files.pythonhosted.org/packages/a8/68/0d8d461f42a3f37432203c8e6df94da10ac8081b6d35af1c203bf3111088/pillow-11.1.0-cp310-cp310-win32.whl", hash = "sha256:3a5fe20a7b66e8135d7fd617b13272626a28278d0e578c98720d9ba4b2439d49", size = 2291213 }, + { url = "https://files.pythonhosted.org/packages/14/81/d0dff759a74ba87715509af9f6cb21fa21d93b02b3316ed43bda83664db9/pillow-11.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b6123aa4a59d75f06e9dd3dac5bf8bc9aa383121bb3dd9a7a612e05eabc9961a", size = 2625725 }, + { url = "https://files.pythonhosted.org/packages/ce/1f/8d50c096a1d58ef0584ddc37e6f602828515219e9d2428e14ce50f5ecad1/pillow-11.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:a76da0a31da6fcae4210aa94fd779c65c75786bc9af06289cd1c184451ef7a65", size = 2375213 }, + { url = "https://files.pythonhosted.org/packages/dd/d6/2000bfd8d5414fb70cbbe52c8332f2283ff30ed66a9cde42716c8ecbe22c/pillow-11.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e06695e0326d05b06833b40b7ef477e475d0b1ba3a6d27da1bb48c23209bf457", size = 3229968 }, + { url = "https://files.pythonhosted.org/packages/d9/45/3fe487010dd9ce0a06adf9b8ff4f273cc0a44536e234b0fad3532a42c15b/pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96f82000e12f23e4f29346e42702b6ed9a2f2fea34a740dd5ffffcc8c539eb35", size = 3101806 }, + { url = "https://files.pythonhosted.org/packages/e3/72/776b3629c47d9d5f1c160113158a7a7ad177688d3a1159cd3b62ded5a33a/pillow-11.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3cd561ded2cf2bbae44d4605837221b987c216cff94f49dfeed63488bb228d2", size = 4322283 }, + { url = "https://files.pythonhosted.org/packages/e4/c2/e25199e7e4e71d64eeb869f5b72c7ddec70e0a87926398785ab944d92375/pillow-11.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f189805c8be5ca5add39e6f899e6ce2ed824e65fb45f3c28cb2841911da19070", size = 4402945 }, + { url = "https://files.pythonhosted.org/packages/c1/ed/51d6136c9d5911f78632b1b86c45241c712c5a80ed7fa7f9120a5dff1eba/pillow-11.1.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:dd0052e9db3474df30433f83a71b9b23bd9e4ef1de13d92df21a52c0303b8ab6", size = 4361228 }, + { url = "https://files.pythonhosted.org/packages/48/a4/fbfe9d5581d7b111b28f1d8c2762dee92e9821bb209af9fa83c940e507a0/pillow-11.1.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:837060a8599b8f5d402e97197d4924f05a2e0d68756998345c829c33186217b1", size = 4484021 }, + { url = "https://files.pythonhosted.org/packages/39/db/0b3c1a5018117f3c1d4df671fb8e47d08937f27519e8614bbe86153b65a5/pillow-11.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa8dd43daa836b9a8128dbe7d923423e5ad86f50a7a14dc688194b7be5c0dea2", size = 4287449 }, + { url = "https://files.pythonhosted.org/packages/d9/58/bc128da7fea8c89fc85e09f773c4901e95b5936000e6f303222490c052f3/pillow-11.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0a2f91f8a8b367e7a57c6e91cd25af510168091fb89ec5146003e424e1558a96", size = 4419972 }, + { url = "https://files.pythonhosted.org/packages/5f/bb/58f34379bde9fe197f51841c5bbe8830c28bbb6d3801f16a83b8f2ad37df/pillow-11.1.0-cp311-cp311-win32.whl", hash = "sha256:c12fc111ef090845de2bb15009372175d76ac99969bdf31e2ce9b42e4b8cd88f", size = 2291201 }, + { url = "https://files.pythonhosted.org/packages/3a/c6/fce9255272bcf0c39e15abd2f8fd8429a954cf344469eaceb9d0d1366913/pillow-11.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:fbd43429d0d7ed6533b25fc993861b8fd512c42d04514a0dd6337fb3ccf22761", size = 2625686 }, + { url = "https://files.pythonhosted.org/packages/c8/52/8ba066d569d932365509054859f74f2a9abee273edcef5cd75e4bc3e831e/pillow-11.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:f7955ecf5609dee9442cbface754f2c6e541d9e6eda87fad7f7a989b0bdb9d71", size = 2375194 }, + { url = "https://files.pythonhosted.org/packages/fa/c5/389961578fb677b8b3244fcd934f720ed25a148b9a5cc81c91bdf59d8588/pillow-11.1.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8c730dc3a83e5ac137fbc92dfcfe1511ce3b2b5d7578315b63dbbb76f7f51d90", size = 3198345 }, + { url = "https://files.pythonhosted.org/packages/c4/fa/803c0e50ffee74d4b965229e816af55276eac1d5806712de86f9371858fd/pillow-11.1.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:7d33d2fae0e8b170b6a6c57400e077412240f6f5bb2a342cf1ee512a787942bb", size = 3072938 }, + { url = "https://files.pythonhosted.org/packages/dc/67/2a3a5f8012b5d8c63fe53958ba906c1b1d0482ebed5618057ef4d22f8076/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a8d65b38173085f24bc07f8b6c505cbb7418009fa1a1fcb111b1f4961814a442", size = 3400049 }, + { url = "https://files.pythonhosted.org/packages/e5/a0/514f0d317446c98c478d1872497eb92e7cde67003fed74f696441e647446/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:015c6e863faa4779251436db398ae75051469f7c903b043a48f078e437656f83", size = 3422431 }, + { url = "https://files.pythonhosted.org/packages/cd/00/20f40a935514037b7d3f87adfc87d2c538430ea625b63b3af8c3f5578e72/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:d44ff19eea13ae4acdaaab0179fa68c0c6f2f45d66a4d8ec1eda7d6cecbcc15f", size = 3446208 }, + { url = "https://files.pythonhosted.org/packages/28/3c/7de681727963043e093c72e6c3348411b0185eab3263100d4490234ba2f6/pillow-11.1.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:d3d8da4a631471dfaf94c10c85f5277b1f8e42ac42bade1ac67da4b4a7359b73", size = 3509746 }, + { url = "https://files.pythonhosted.org/packages/41/67/936f9814bdd74b2dfd4822f1f7725ab5d8ff4103919a1664eb4874c58b2f/pillow-11.1.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:4637b88343166249fe8aa94e7c4a62a180c4b3898283bb5d3d2fd5fe10d8e4e0", size = 2626353 }, +] + +[[package]] +name = "pip" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/47/3e/68beeeeb306ea20ffd30b3ed993f531d16cd884ec4f60c9b1e238f69f2af/pip-25.0.tar.gz", hash = "sha256:8e0a97f7b4c47ae4a494560da84775e9e2f671d415d8d828e052efefb206b30b", size = 1950328 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/8a/1ddf40be20103bcc605db840e9ade09c8e8c9f920a03e9cfe88eae97a058/pip-25.0-py3-none-any.whl", hash = "sha256:b6eb97a803356a52b2dd4bb73ba9e65b2ba16caa6bcb25a7497350a4e5859b65", size = 1841506 }, +] + +[[package]] +name = "platformdirs" +version = "4.3.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/fc/128cc9cb8f03208bdbf93d3aa862e16d376844a14f9a0ce5cf4507372de4/platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907", size = 21302 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 }, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + +[[package]] +name = "ply" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/69/882ee5c9d017149285cab114ebeab373308ef0f874fcdac9beb90e0ac4da/ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3", size = 159130 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/58/35da89ee790598a0700ea49b2a66594140f44dec458c07e8e3d4979137fc/ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce", size = 49567 }, +] + +[[package]] +name = "pre-commit" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/13/b62d075317d8686071eb843f0bb1f195eb332f48869d3c31a4c6f1e063ac/pre_commit-4.1.0.tar.gz", hash = "sha256:ae3f018575a588e30dfddfab9a05448bfbd6b73d78709617b5a2b853549716d4", size = 193330 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/b3/df14c580d82b9627d173ceea305ba898dca135feb360b6d84019d0803d3b/pre_commit-4.1.0-py2.py3-none-any.whl", hash = "sha256:d29e7cb346295bcc1cc75fc3e92e343495e3ea0196c9ec6ba53f49f10ab6ae7b", size = 220560 }, +] + +[[package]] +name = "prompt-toolkit" +version = "3.0.50" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/e1/bd15cb8ffdcfeeb2bdc215de3c3cffca11408d829e4b8416dcfe71ba8854/prompt_toolkit-3.0.50.tar.gz", hash = "sha256:544748f3860a2623ca5cd6d2795e7a14f3d0e1c3c9728359013f79877fc89bab", size = 429087 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", size = 387816 }, +] + +[[package]] +name = "psutil" +version = "6.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/5a/07871137bb752428aa4b659f910b399ba6f291156bdea939be3e96cae7cb/psutil-6.1.1.tar.gz", hash = "sha256:cf8496728c18f2d0b45198f06895be52f36611711746b7f30c464b422b50e2f5", size = 508502 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/99/ca79d302be46f7bdd8321089762dd4476ee725fce16fc2b2e1dbba8cac17/psutil-6.1.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:fc0ed7fe2231a444fc219b9c42d0376e0a9a1a72f16c5cfa0f68d19f1a0663e8", size = 247511 }, + { url = "https://files.pythonhosted.org/packages/0b/6b/73dbde0dd38f3782905d4587049b9be64d76671042fdcaf60e2430c6796d/psutil-6.1.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:0bdd4eab935276290ad3cb718e9809412895ca6b5b334f5a9111ee6d9aff9377", size = 248985 }, + { url = "https://files.pythonhosted.org/packages/17/38/c319d31a1d3f88c5b79c68b3116c129e5133f1822157dd6da34043e32ed6/psutil-6.1.1-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b6e06c20c05fe95a3d7302d74e7097756d4ba1247975ad6905441ae1b5b66003", size = 284488 }, + { url = "https://files.pythonhosted.org/packages/9c/39/0f88a830a1c8a3aba27fededc642da37613c57cbff143412e3536f89784f/psutil-6.1.1-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97f7cb9921fbec4904f522d972f0c0e1f4fabbdd4e0287813b21215074a0f160", size = 287477 }, + { url = "https://files.pythonhosted.org/packages/47/da/99f4345d4ddf2845cb5b5bd0d93d554e84542d116934fde07a0c50bd4e9f/psutil-6.1.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:33431e84fee02bc84ea36d9e2c4a6d395d479c9dd9bba2376c1f6ee8f3a4e0b3", size = 289017 }, + { url = "https://files.pythonhosted.org/packages/38/53/bd755c2896f4461fd4f36fa6a6dcb66a88a9e4b9fd4e5b66a77cf9d4a584/psutil-6.1.1-cp37-abi3-win32.whl", hash = "sha256:eaa912e0b11848c4d9279a93d7e2783df352b082f40111e078388701fd479e53", size = 250602 }, + { url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 }, +] + +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993 }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842 }, +] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335 }, +] + +[[package]] +name = "pybind11" +version = "2.13.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/c1/72b9622fcb32ff98b054f724e213c7f70d6898baa714f4516288456ceaba/pybind11-2.13.6.tar.gz", hash = "sha256:ba6af10348c12b24e92fa086b39cfba0eff619b61ac77c406167d813b096d39a", size = 218403 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/2f/0f24b288e2ce56f51c920137620b4434a38fd80583dbbe24fc2a1656c388/pybind11-2.13.6-py3-none-any.whl", hash = "sha256:237c41e29157b962835d356b370ededd57594a26d5894a795960f0047cb5caf5", size = 243282 }, +] + +[[package]] +name = "pycparser" +version = "2.22" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/b2/31537cf4b1ca988837256c910a668b553fceb8f069bedc4b1c826024b52c/pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6", size = 172736 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 }, +] + +[[package]] +name = "pydantic" +version = "2.10.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b7/ae/d5220c5c52b158b1de7ca89fc5edb72f304a70a4c540c84c8844bf4008de/pydantic-2.10.6.tar.gz", hash = "sha256:ca5daa827cce33de7a42be142548b0096bf05a7e7b365aebfa5f8eeec7128236", size = 761681 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/3c/8cc1cc84deffa6e25d2d0c688ebb80635dfdbf1dbea3e30c541c8cf4d860/pydantic-2.10.6-py3-none-any.whl", hash = "sha256:427d664bf0b8a2b34ff5dd0f5a18df00591adcee7198fbd71981054cef37b584", size = 431696 }, +] + +[[package]] +name = "pydantic-core" +version = "2.27.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/01/f3e5ac5e7c25833db5eb555f7b7ab24cd6f8c322d3a3ad2d67a952dc0abc/pydantic_core-2.27.2.tar.gz", hash = "sha256:eb026e5a4c1fee05726072337ff51d1efb6f59090b7da90d30ea58625b1ffb39", size = 413443 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/bc/fed5f74b5d802cf9a03e83f60f18864e90e3aed7223adaca5ffb7a8d8d64/pydantic_core-2.27.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2d367ca20b2f14095a8f4fa1210f5a7b78b8a20009ecced6b12818f455b1e9fa", size = 1895938 }, + { url = "https://files.pythonhosted.org/packages/71/2a/185aff24ce844e39abb8dd680f4e959f0006944f4a8a0ea372d9f9ae2e53/pydantic_core-2.27.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:491a2b73db93fab69731eaee494f320faa4e093dbed776be1a829c2eb222c34c", size = 1815684 }, + { url = "https://files.pythonhosted.org/packages/c3/43/fafabd3d94d159d4f1ed62e383e264f146a17dd4d48453319fd782e7979e/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7969e133a6f183be60e9f6f56bfae753585680f3b7307a8e555a948d443cc05a", size = 1829169 }, + { url = "https://files.pythonhosted.org/packages/a2/d1/f2dfe1a2a637ce6800b799aa086d079998959f6f1215eb4497966efd2274/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3de9961f2a346257caf0aa508a4da705467f53778e9ef6fe744c038119737ef5", size = 1867227 }, + { url = "https://files.pythonhosted.org/packages/7d/39/e06fcbcc1c785daa3160ccf6c1c38fea31f5754b756e34b65f74e99780b5/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2bb4d3e5873c37bb3dd58714d4cd0b0e6238cebc4177ac8fe878f8b3aa8e74c", size = 2037695 }, + { url = "https://files.pythonhosted.org/packages/7a/67/61291ee98e07f0650eb756d44998214231f50751ba7e13f4f325d95249ab/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:280d219beebb0752699480fe8f1dc61ab6615c2046d76b7ab7ee38858de0a4e7", size = 2741662 }, + { url = "https://files.pythonhosted.org/packages/32/90/3b15e31b88ca39e9e626630b4c4a1f5a0dfd09076366f4219429e6786076/pydantic_core-2.27.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47956ae78b6422cbd46f772f1746799cbb862de838fd8d1fbd34a82e05b0983a", size = 1993370 }, + { url = "https://files.pythonhosted.org/packages/ff/83/c06d333ee3a67e2e13e07794995c1535565132940715931c1c43bfc85b11/pydantic_core-2.27.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:14d4a5c49d2f009d62a2a7140d3064f686d17a5d1a268bc641954ba181880236", size = 1996813 }, + { url = "https://files.pythonhosted.org/packages/7c/f7/89be1c8deb6e22618a74f0ca0d933fdcb8baa254753b26b25ad3acff8f74/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:337b443af21d488716f8d0b6164de833e788aa6bd7e3a39c005febc1284f4962", size = 2005287 }, + { url = "https://files.pythonhosted.org/packages/b7/7d/8eb3e23206c00ef7feee17b83a4ffa0a623eb1a9d382e56e4aa46fd15ff2/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:03d0f86ea3184a12f41a2d23f7ccb79cdb5a18e06993f8a45baa8dfec746f0e9", size = 2128414 }, + { url = "https://files.pythonhosted.org/packages/4e/99/fe80f3ff8dd71a3ea15763878d464476e6cb0a2db95ff1c5c554133b6b83/pydantic_core-2.27.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7041c36f5680c6e0f08d922aed302e98b3745d97fe1589db0a3eebf6624523af", size = 2155301 }, + { url = "https://files.pythonhosted.org/packages/2b/a3/e50460b9a5789ca1451b70d4f52546fa9e2b420ba3bfa6100105c0559238/pydantic_core-2.27.2-cp310-cp310-win32.whl", hash = "sha256:50a68f3e3819077be2c98110c1f9dcb3817e93f267ba80a2c05bb4f8799e2ff4", size = 1816685 }, + { url = "https://files.pythonhosted.org/packages/57/4c/a8838731cb0f2c2a39d3535376466de6049034d7b239c0202a64aaa05533/pydantic_core-2.27.2-cp310-cp310-win_amd64.whl", hash = "sha256:e0fd26b16394ead34a424eecf8a31a1f5137094cabe84a1bcb10fa6ba39d3d31", size = 1982876 }, + { url = "https://files.pythonhosted.org/packages/c2/89/f3450af9d09d44eea1f2c369f49e8f181d742f28220f88cc4dfaae91ea6e/pydantic_core-2.27.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:8e10c99ef58cfdf2a66fc15d66b16c4a04f62bca39db589ae8cba08bc55331bc", size = 1893421 }, + { url = "https://files.pythonhosted.org/packages/9e/e3/71fe85af2021f3f386da42d291412e5baf6ce7716bd7101ea49c810eda90/pydantic_core-2.27.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:26f32e0adf166a84d0cb63be85c562ca8a6fa8de28e5f0d92250c6b7e9e2aff7", size = 1814998 }, + { url = "https://files.pythonhosted.org/packages/a6/3c/724039e0d848fd69dbf5806894e26479577316c6f0f112bacaf67aa889ac/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c19d1ea0673cd13cc2f872f6c9ab42acc4e4f492a7ca9d3795ce2b112dd7e15", size = 1826167 }, + { url = "https://files.pythonhosted.org/packages/2b/5b/1b29e8c1fb5f3199a9a57c1452004ff39f494bbe9bdbe9a81e18172e40d3/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e68c4446fe0810e959cdff46ab0a41ce2f2c86d227d96dc3847af0ba7def306", size = 1865071 }, + { url = "https://files.pythonhosted.org/packages/89/6c/3985203863d76bb7d7266e36970d7e3b6385148c18a68cc8915fd8c84d57/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d9640b0059ff4f14d1f37321b94061c6db164fbe49b334b31643e0528d100d99", size = 2036244 }, + { url = "https://files.pythonhosted.org/packages/0e/41/f15316858a246b5d723f7d7f599f79e37493b2e84bfc789e58d88c209f8a/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:40d02e7d45c9f8af700f3452f329ead92da4c5f4317ca9b896de7ce7199ea459", size = 2737470 }, + { url = "https://files.pythonhosted.org/packages/a8/7c/b860618c25678bbd6d1d99dbdfdf0510ccb50790099b963ff78a124b754f/pydantic_core-2.27.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c1fd185014191700554795c99b347d64f2bb637966c4cfc16998a0ca700d048", size = 1992291 }, + { url = "https://files.pythonhosted.org/packages/bf/73/42c3742a391eccbeab39f15213ecda3104ae8682ba3c0c28069fbcb8c10d/pydantic_core-2.27.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d81d2068e1c1228a565af076598f9e7451712700b673de8f502f0334f281387d", size = 1994613 }, + { url = "https://files.pythonhosted.org/packages/94/7a/941e89096d1175d56f59340f3a8ebaf20762fef222c298ea96d36a6328c5/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1a4207639fb02ec2dbb76227d7c751a20b1a6b4bc52850568e52260cae64ca3b", size = 2002355 }, + { url = "https://files.pythonhosted.org/packages/6e/95/2359937a73d49e336a5a19848713555605d4d8d6940c3ec6c6c0ca4dcf25/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:3de3ce3c9ddc8bbd88f6e0e304dea0e66d843ec9de1b0042b0911c1663ffd474", size = 2126661 }, + { url = "https://files.pythonhosted.org/packages/2b/4c/ca02b7bdb6012a1adef21a50625b14f43ed4d11f1fc237f9d7490aa5078c/pydantic_core-2.27.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:30c5f68ded0c36466acede341551106821043e9afaad516adfb6e8fa80a4e6a6", size = 2153261 }, + { url = "https://files.pythonhosted.org/packages/72/9d/a241db83f973049a1092a079272ffe2e3e82e98561ef6214ab53fe53b1c7/pydantic_core-2.27.2-cp311-cp311-win32.whl", hash = "sha256:c70c26d2c99f78b125a3459f8afe1aed4d9687c24fd677c6a4436bc042e50d6c", size = 1812361 }, + { url = "https://files.pythonhosted.org/packages/e8/ef/013f07248041b74abd48a385e2110aa3a9bbfef0fbd97d4e6d07d2f5b89a/pydantic_core-2.27.2-cp311-cp311-win_amd64.whl", hash = "sha256:08e125dbdc505fa69ca7d9c499639ab6407cfa909214d500897d02afb816e7cc", size = 1982484 }, + { url = "https://files.pythonhosted.org/packages/10/1c/16b3a3e3398fd29dca77cea0a1d998d6bde3902fa2706985191e2313cc76/pydantic_core-2.27.2-cp311-cp311-win_arm64.whl", hash = "sha256:26f0d68d4b235a2bae0c3fc585c585b4ecc51382db0e3ba402a22cbc440915e4", size = 1867102 }, + { url = "https://files.pythonhosted.org/packages/46/72/af70981a341500419e67d5cb45abe552a7c74b66326ac8877588488da1ac/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:2bf14caea37e91198329b828eae1618c068dfb8ef17bb33287a7ad4b61ac314e", size = 1891159 }, + { url = "https://files.pythonhosted.org/packages/ad/3d/c5913cccdef93e0a6a95c2d057d2c2cba347815c845cda79ddd3c0f5e17d/pydantic_core-2.27.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:b0cb791f5b45307caae8810c2023a184c74605ec3bcbb67d13846c28ff731ff8", size = 1768331 }, + { url = "https://files.pythonhosted.org/packages/f6/f0/a3ae8fbee269e4934f14e2e0e00928f9346c5943174f2811193113e58252/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:688d3fd9fcb71f41c4c015c023d12a79d1c4c0732ec9eb35d96e3388a120dcf3", size = 1822467 }, + { url = "https://files.pythonhosted.org/packages/d7/7a/7bbf241a04e9f9ea24cd5874354a83526d639b02674648af3f350554276c/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d591580c34f4d731592f0e9fe40f9cc1b430d297eecc70b962e93c5c668f15f", size = 1979797 }, + { url = "https://files.pythonhosted.org/packages/4f/5f/4784c6107731f89e0005a92ecb8a2efeafdb55eb992b8e9d0a2be5199335/pydantic_core-2.27.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:82f986faf4e644ffc189a7f1aafc86e46ef70372bb153e7001e8afccc6e54133", size = 1987839 }, + { url = "https://files.pythonhosted.org/packages/6d/a7/61246562b651dff00de86a5f01b6e4befb518df314c54dec187a78d81c84/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bec317a27290e2537f922639cafd54990551725fc844249e64c523301d0822fc", size = 1998861 }, + { url = "https://files.pythonhosted.org/packages/86/aa/837821ecf0c022bbb74ca132e117c358321e72e7f9702d1b6a03758545e2/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:0296abcb83a797db256b773f45773da397da75a08f5fcaef41f2044adec05f50", size = 2116582 }, + { url = "https://files.pythonhosted.org/packages/81/b0/5e74656e95623cbaa0a6278d16cf15e10a51f6002e3ec126541e95c29ea3/pydantic_core-2.27.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:0d75070718e369e452075a6017fbf187f788e17ed67a3abd47fa934d001863d9", size = 2151985 }, + { url = "https://files.pythonhosted.org/packages/63/37/3e32eeb2a451fddaa3898e2163746b0cffbbdbb4740d38372db0490d67f3/pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151", size = 2004715 }, +] + +[[package]] +name = "pydantic-settings" +version = "2.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "python-dotenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/73/7b/c58a586cd7d9ac66d2ee4ba60ca2d241fa837c02bca9bea80a9a8c3d22a9/pydantic_settings-2.7.1.tar.gz", hash = "sha256:10c9caad35e64bfb3c2fbf70a078c0e25cc92499782e5200747f942a065dec93", size = 79920 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b4/46/93416fdae86d40879714f72956ac14df9c7b76f7d41a4d68aa9f71a0028b/pydantic_settings-2.7.1-py3-none-any.whl", hash = "sha256:590be9e6e24d06db33a4262829edef682500ef008565a969c73d39d5f8bfb3fd", size = 29718 }, +] + +[[package]] +name = "pydot" +version = "3.0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyparsing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/dd/e0e6a4fb84c22050f6a9701ad9fd6a67ef82faa7ba97b97eb6fdc6b49b34/pydot-3.0.4.tar.gz", hash = "sha256:3ce88b2558f3808b0376f22bfa6c263909e1c3981e2a7b629b65b451eee4a25d", size = 168167 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/5f/1ebfd430df05c4f9e438dd3313c4456eab937d976f6ab8ce81a98f9fb381/pydot-3.0.4-py3-none-any.whl", hash = "sha256:bfa9c3fc0c44ba1d132adce131802d7df00429d1a79cc0346b0a5cd374dbe9c6", size = 35776 }, +] + +[[package]] +name = "pygls" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cattrs" }, + { name = "lsprotocol" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/b9/41d173dad9eaa9db9c785a85671fc3d68961f08d67706dc2e79011e10b5c/pygls-1.3.1.tar.gz", hash = "sha256:140edceefa0da0e9b3c533547c892a42a7d2fd9217ae848c330c53d266a55018", size = 45527 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/19/b74a10dd24548e96e8c80226cbacb28b021bc3a168a7d2709fb0d0185348/pygls-1.3.1-py3-none-any.whl", hash = "sha256:6e00f11efc56321bdeb6eac04f6d86131f654c7d49124344a9ebb968da3dd91e", size = 56031 }, +] + +[[package]] +name = "pygments" +version = "2.19.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7c/2d/c3338d48ea6cc0feb8446d8e6937e1408088a72a39937982cc6111d17f84/pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f", size = 4968581 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 }, +] + +[[package]] +name = "pyparsing" +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/1a/3544f4f299a47911c2ab3710f534e52fea62a633c96806995da5d25be4b2/pyparsing-3.2.1.tar.gz", hash = "sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a", size = 1067694 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1c/a7/c8a2d361bf89c0d9577c934ebb7421b25dc84bf3a8e3ac0a40aed9acc547/pyparsing-3.2.1-py3-none-any.whl", hash = "sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1", size = 107716 }, +] + +[[package]] +name = "pyreadline" +version = "2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/7c/d724ef1ec3ab2125f38a1d53285745445ec4a8f19b9bb0761b4064316679/pyreadline-2.1.zip", hash = "sha256:4530592fc2e85b25b1a9f79664433da09237c1a270e4d78ea5aa3a2c7229e2d1", size = 109189 } + +[[package]] +name = "pyspellchecker" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/42/5d/86d94aceb9c0813f27004ec71c036d8ec6a6324d989854ff0fe13fe036dc/pyspellchecker-0.8.2.tar.gz", hash = "sha256:2b026be14a162ba810bdda8e5454c56e364f42d3b9e14aeff31706e5ebcdc78f", size = 7149207 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/8e/7c79443d302a80cfd59bc365938d51e36e7e9aa7ce8ab1d8a0ca0c8e6065/pyspellchecker-0.8.2-py3-none-any.whl", hash = "sha256:4fee22e1859c5153c3bc3953ac3041bf07d4541520b7e01901e955062022290a", size = 7147898 }, +] + +[[package]] +name = "pytest" +version = "8.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, +] + +[[package]] +name = "pytest-benchmark" +version = "5.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py-cpuinfo" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/d0/a8bd08d641b393db3be3819b03e2d9bb8760ca8479080a26a5f6e540e99c/pytest-benchmark-5.1.0.tar.gz", hash = "sha256:9ea661cdc292e8231f7cd4c10b0319e56a2118e2c09d9f50e1b3d150d2aca105", size = 337810 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/d6/b41653199ea09d5969d4e385df9bbfd9a100f28ca7e824ce7c0a016e3053/pytest_benchmark-5.1.0-py3-none-any.whl", hash = "sha256:922de2dfa3033c227c96da942d1878191afa135a29485fb942e85dff1c592c89", size = 44259 }, +] + +[[package]] +name = "pytest-cache" +version = "1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/15/082fd0428aab33d2bafa014f3beb241830427ba803a8912a5aaeaf3a5663/pytest-cache-1.0.tar.gz", hash = "sha256:be7468edd4d3d83f1e844959fd6e3fd28e77a481440a7118d430130ea31b07a9", size = 16242 } + +[[package]] +name = "pytest-cov" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage", extra = ["toml"] }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/45/9b538de8cef30e17c7b45ef42f538a94889ed6a16f2387a6c89e73220651/pytest-cov-6.0.0.tar.gz", hash = "sha256:fde0b595ca248bb8e2d76f020b465f3b107c9632e6a1d1705f17834c89dcadc0", size = 66945 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/3b/48e79f2cd6a61dbbd4807b4ed46cb564b4fd50a76166b1c4ea5c1d9e2371/pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35", size = 22949 }, +] + +[[package]] +name = "pytest-factoryboy" +version = "2.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "factory-boy" }, + { name = "inflection" }, + { name = "packaging" }, + { name = "pytest" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a6/bc/179653e8cce651575ac95377e4fdf9afd3c4821ab4bba101aae913ebcc27/pytest_factoryboy-2.7.0.tar.gz", hash = "sha256:67fc54ec8669a3feb8ac60094dd57cd71eb0b20b2c319d2957873674c776a77b", size = 17398 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/56/d3ef25286dc8df9d1da0b325ee4b1b1ffd9736e44f9b30cfbe464e9f4f14/pytest_factoryboy-2.7.0-py3-none-any.whl", hash = "sha256:bf3222db22d954fbf46f4bff902a0a8d82f3fc3594a47c04bbdc0546ff4c59a6", size = 16268 }, +] + +[[package]] +name = "pytest-instafail" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/bd/e0ba6c3cd20b9aa445f0af229f3a9582cce589f083537978a23e6f14e310/pytest-instafail-0.5.0.tar.gz", hash = "sha256:33a606f7e0c8e646dc3bfee0d5e3a4b7b78ef7c36168cfa1f3d93af7ca706c9e", size = 5849 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/c0/c32dc39fc172e684fdb3d30169843efb65c067be1e12689af4345731126e/pytest_instafail-0.5.0-py3-none-any.whl", hash = "sha256:6855414487e9e4bb76a118ce952c3c27d3866af15487506c4ded92eb72387819", size = 4176 }, +] + +[[package]] +name = "pytest-xdist" +version = "3.6.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/c4/3c310a19bc1f1e9ef50075582652673ef2bfc8cd62afef9585683821902f/pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d", size = 84060 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6d/82/1d96bf03ee4c0fdc3c0cbe61470070e659ca78dc0086fb88b66c185e2449/pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7", size = 46108 }, +] + +[package.optional-dependencies] +psutil = [ + { name = "psutil" }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, +] + +[[package]] +name = "python-dotenv" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863 }, +] + +[[package]] +name = "pywin32" +version = "308" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/a6/3e9f2c474895c1bb61b11fa9640be00067b5c5b363c501ee9c3fa53aec01/pywin32-308-cp310-cp310-win32.whl", hash = "sha256:796ff4426437896550d2981b9c2ac0ffd75238ad9ea2d3bfa67a1abd546d262e", size = 5927028 }, + { url = "https://files.pythonhosted.org/packages/d9/b4/84e2463422f869b4b718f79eb7530a4c1693e96b8a4e5e968de38be4d2ba/pywin32-308-cp310-cp310-win_amd64.whl", hash = "sha256:4fc888c59b3c0bef905ce7eb7e2106a07712015ea1c8234b703a088d46110e8e", size = 6558484 }, + { url = "https://files.pythonhosted.org/packages/9f/8f/fb84ab789713f7c6feacaa08dad3ec8105b88ade8d1c4f0f0dfcaaa017d6/pywin32-308-cp310-cp310-win_arm64.whl", hash = "sha256:a5ab5381813b40f264fa3495b98af850098f814a25a63589a8e9eb12560f450c", size = 7971454 }, + { url = "https://files.pythonhosted.org/packages/eb/e2/02652007469263fe1466e98439831d65d4ca80ea1a2df29abecedf7e47b7/pywin32-308-cp311-cp311-win32.whl", hash = "sha256:5d8c8015b24a7d6855b1550d8e660d8daa09983c80e5daf89a273e5c6fb5095a", size = 5928156 }, + { url = "https://files.pythonhosted.org/packages/48/ef/f4fb45e2196bc7ffe09cad0542d9aff66b0e33f6c0954b43e49c33cad7bd/pywin32-308-cp311-cp311-win_amd64.whl", hash = "sha256:575621b90f0dc2695fec346b2d6302faebd4f0f45c05ea29404cefe35d89442b", size = 6559559 }, + { url = "https://files.pythonhosted.org/packages/79/ef/68bb6aa865c5c9b11a35771329e95917b5559845bd75b65549407f9fc6b4/pywin32-308-cp311-cp311-win_arm64.whl", hash = "sha256:100a5442b7332070983c4cd03f2e906a5648a5104b8a7f50175f7906efd16bb6", size = 7972495 }, +] + +[[package]] +name = "pyyaml" +version = "6.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/95/a3fac87cb7158e231b5a6012e438c647e1a87f09f8e0d123acec8ab8bf71/PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086", size = 184199 }, + { url = "https://files.pythonhosted.org/packages/c7/7a/68bd47624dab8fd4afbfd3c48e3b79efe09098ae941de5b58abcbadff5cb/PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf", size = 171758 }, + { url = "https://files.pythonhosted.org/packages/49/ee/14c54df452143b9ee9f0f29074d7ca5516a36edb0b4cc40c3f280131656f/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", size = 718463 }, + { url = "https://files.pythonhosted.org/packages/4d/61/de363a97476e766574650d742205be468921a7b532aa2499fcd886b62530/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", size = 719280 }, + { url = "https://files.pythonhosted.org/packages/6b/4e/1523cb902fd98355e2e9ea5e5eb237cbc5f3ad5f3075fa65087aa0ecb669/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", size = 751239 }, + { url = "https://files.pythonhosted.org/packages/b7/33/5504b3a9a4464893c32f118a9cc045190a91637b119a9c881da1cf6b7a72/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", size = 695802 }, + { url = "https://files.pythonhosted.org/packages/5c/20/8347dcabd41ef3a3cdc4f7b7a2aff3d06598c8779faa189cdbf878b626a4/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", size = 720527 }, + { url = "https://files.pythonhosted.org/packages/be/aa/5afe99233fb360d0ff37377145a949ae258aaab831bde4792b32650a4378/PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99", size = 144052 }, + { url = "https://files.pythonhosted.org/packages/b5/84/0fa4b06f6d6c958d207620fc60005e241ecedceee58931bb20138e1e5776/PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e", size = 161774 }, + { url = "https://files.pythonhosted.org/packages/f8/aa/7af4e81f7acba21a4c6be026da38fd2b872ca46226673c89a758ebdc4fd2/PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774", size = 184612 }, + { url = "https://files.pythonhosted.org/packages/8b/62/b9faa998fd185f65c1371643678e4d58254add437edb764a08c5a98fb986/PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee", size = 172040 }, + { url = "https://files.pythonhosted.org/packages/ad/0c/c804f5f922a9a6563bab712d8dcc70251e8af811fce4524d57c2c0fd49a4/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", size = 736829 }, + { url = "https://files.pythonhosted.org/packages/51/16/6af8d6a6b210c8e54f1406a6b9481febf9c64a3109c541567e35a49aa2e7/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", size = 764167 }, + { url = "https://files.pythonhosted.org/packages/75/e4/2c27590dfc9992f73aabbeb9241ae20220bd9452df27483b6e56d3975cc5/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", size = 762952 }, + { url = "https://files.pythonhosted.org/packages/9b/97/ecc1abf4a823f5ac61941a9c00fe501b02ac3ab0e373c3857f7d4b83e2b6/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4", size = 735301 }, + { url = "https://files.pythonhosted.org/packages/45/73/0f49dacd6e82c9430e46f4a027baa4ca205e8b0a9dce1397f44edc23559d/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", size = 756638 }, + { url = "https://files.pythonhosted.org/packages/22/5f/956f0f9fc65223a58fbc14459bf34b4cc48dec52e00535c79b8db361aabd/PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5", size = 143850 }, + { url = "https://files.pythonhosted.org/packages/ed/23/8da0bbe2ab9dcdd11f4f4557ccaf95c10b9811b13ecced089d43ce59c3c8/PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44", size = 161980 }, +] + +[[package]] +name = "pyzmq" +version = "26.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5a/e3/8d0382cb59feb111c252b54e8728257416a38ffcb2243c4e4775a3c990fe/pyzmq-26.2.1.tar.gz", hash = "sha256:17d72a74e5e9ff3829deb72897a175333d3ef5b5413948cae3cf7ebf0b02ecca", size = 278433 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/3d/c2d9d46c033d1b51692ea49a22439f7f66d91d5c938e8b5c56ed7a2151c2/pyzmq-26.2.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:f39d1227e8256d19899d953e6e19ed2ccb689102e6d85e024da5acf410f301eb", size = 1345451 }, + { url = "https://files.pythonhosted.org/packages/0e/df/4754a8abcdeef280651f9bb51446c47659910940b392a66acff7c37f5cef/pyzmq-26.2.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a23948554c692df95daed595fdd3b76b420a4939d7a8a28d6d7dea9711878641", size = 942766 }, + { url = "https://files.pythonhosted.org/packages/74/da/e6053a3b13c912eded6c2cdeee22ff3a4c33820d17f9eb24c7b6e957ffe7/pyzmq-26.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95f5728b367a042df146cec4340d75359ec6237beebf4a8f5cf74657c65b9257", size = 678488 }, + { url = "https://files.pythonhosted.org/packages/9e/50/614934145244142401ca174ca81071777ab93aa88173973ba0154f491e09/pyzmq-26.2.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95f7b01b3f275504011cf4cf21c6b885c8d627ce0867a7e83af1382ebab7b3ff", size = 917115 }, + { url = "https://files.pythonhosted.org/packages/80/2b/ebeb7bc4fc8e9e61650b2e09581597355a4341d413fa9b2947d7a6558119/pyzmq-26.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80a00370a2ef2159c310e662c7c0f2d030f437f35f478bb8b2f70abd07e26b24", size = 874162 }, + { url = "https://files.pythonhosted.org/packages/79/48/93210621c331ad16313dc2849801411fbae10d91d878853933f2a85df8e7/pyzmq-26.2.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:8531ed35dfd1dd2af95f5d02afd6545e8650eedbf8c3d244a554cf47d8924459", size = 874180 }, + { url = "https://files.pythonhosted.org/packages/f0/8b/40924b4d8e33bfdd54c1970fb50f327e39b90b902f897cf09b30b2e9ac48/pyzmq-26.2.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:cdb69710e462a38e6039cf17259d328f86383a06c20482cc154327968712273c", size = 1208139 }, + { url = "https://files.pythonhosted.org/packages/c8/b2/82d6675fc89bd965eae13c45002c792d33f06824589844b03f8ea8fc6d86/pyzmq-26.2.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e7eeaef81530d0b74ad0d29eec9997f1c9230c2f27242b8d17e0ee67662c8f6e", size = 1520666 }, + { url = "https://files.pythonhosted.org/packages/9d/e2/5ff15f2d3f920dcc559d477bd9bb3faacd6d79fcf7c5448e585c78f84849/pyzmq-26.2.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:361edfa350e3be1f987e592e834594422338d7174364763b7d3de5b0995b16f3", size = 1420056 }, + { url = "https://files.pythonhosted.org/packages/40/a2/f9bbeccf7f75aa0d8963e224e5730abcefbf742e1f2ae9ea60fd9d6ff72b/pyzmq-26.2.1-cp310-cp310-win32.whl", hash = "sha256:637536c07d2fb6a354988b2dd1d00d02eb5dd443f4bbee021ba30881af1c28aa", size = 583874 }, + { url = "https://files.pythonhosted.org/packages/56/b1/44f513135843272f0e12f5aebf4af35839e2a88eb45411f2c8c010d8c856/pyzmq-26.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:45fad32448fd214fbe60030aa92f97e64a7140b624290834cc9b27b3a11f9473", size = 647367 }, + { url = "https://files.pythonhosted.org/packages/27/9c/1bef14a37b02d651a462811bbdb1390b61cd4a5b5e95cbd7cc2d60ef848c/pyzmq-26.2.1-cp310-cp310-win_arm64.whl", hash = "sha256:d9da0289d8201c8a29fd158aaa0dfe2f2e14a181fd45e2dc1fbf969a62c1d594", size = 561784 }, + { url = "https://files.pythonhosted.org/packages/b9/03/5ecc46a6ed5971299f5c03e016ca637802d8660e44392bea774fb7797405/pyzmq-26.2.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:c059883840e634a21c5b31d9b9a0e2b48f991b94d60a811092bc37992715146a", size = 1346032 }, + { url = "https://files.pythonhosted.org/packages/40/51/48fec8f990ee644f461ff14c8fe5caa341b0b9b3a0ad7544f8ef17d6f528/pyzmq-26.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ed038a921df836d2f538e509a59cb638df3e70ca0fcd70d0bf389dfcdf784d2a", size = 943324 }, + { url = "https://files.pythonhosted.org/packages/c1/f4/f322b389727c687845e38470b48d7a43c18a83f26d4d5084603c6c3f79ca/pyzmq-26.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9027a7fcf690f1a3635dc9e55e38a0d6602dbbc0548935d08d46d2e7ec91f454", size = 678418 }, + { url = "https://files.pythonhosted.org/packages/a8/df/2834e3202533bd05032d83e02db7ac09fa1be853bbef59974f2b2e3a8557/pyzmq-26.2.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d75fcb00a1537f8b0c0bb05322bc7e35966148ffc3e0362f0369e44a4a1de99", size = 915466 }, + { url = "https://files.pythonhosted.org/packages/b5/e2/45c0f6e122b562cb8c6c45c0dcac1160a4e2207385ef9b13463e74f93031/pyzmq-26.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0019cc804ac667fb8c8eaecdb66e6d4a68acf2e155d5c7d6381a5645bd93ae4", size = 873347 }, + { url = "https://files.pythonhosted.org/packages/de/b9/3e0fbddf8b87454e914501d368171466a12550c70355b3844115947d68ea/pyzmq-26.2.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:f19dae58b616ac56b96f2e2290f2d18730a898a171f447f491cc059b073ca1fa", size = 874545 }, + { url = "https://files.pythonhosted.org/packages/1f/1c/1ee41d6e10b2127263b1994bc53b9e74ece015b0d2c0a30e0afaf69b78b2/pyzmq-26.2.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f5eeeb82feec1fc5cbafa5ee9022e87ffdb3a8c48afa035b356fcd20fc7f533f", size = 1208630 }, + { url = "https://files.pythonhosted.org/packages/3d/a9/50228465c625851a06aeee97c74f253631f509213f979166e83796299c60/pyzmq-26.2.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:000760e374d6f9d1a3478a42ed0c98604de68c9e94507e5452951e598ebecfba", size = 1519568 }, + { url = "https://files.pythonhosted.org/packages/c6/f2/6360b619e69da78863c2108beb5196ae8b955fe1e161c0b886b95dc6b1ac/pyzmq-26.2.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:817fcd3344d2a0b28622722b98500ae9c8bfee0f825b8450932ff19c0b15bebd", size = 1419677 }, + { url = "https://files.pythonhosted.org/packages/da/d5/f179da989168f5dfd1be8103ef508ade1d38a8078dda4f10ebae3131a490/pyzmq-26.2.1-cp311-cp311-win32.whl", hash = "sha256:88812b3b257f80444a986b3596e5ea5c4d4ed4276d2b85c153a6fbc5ca457ae7", size = 582682 }, + { url = "https://files.pythonhosted.org/packages/60/50/e5b2e9de3ffab73ff92bee736216cf209381081fa6ab6ba96427777d98b1/pyzmq-26.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:ef29630fde6022471d287c15c0a2484aba188adbfb978702624ba7a54ddfa6c1", size = 648128 }, + { url = "https://files.pythonhosted.org/packages/d9/fe/7bb93476dd8405b0fc9cab1fd921a08bd22d5e3016aa6daea1a78d54129b/pyzmq-26.2.1-cp311-cp311-win_arm64.whl", hash = "sha256:f32718ee37c07932cc336096dc7403525301fd626349b6eff8470fe0f996d8d7", size = 562465 }, + { url = "https://files.pythonhosted.org/packages/65/d1/e630a75cfb2534574a1258fda54d02f13cf80b576d4ce6d2aa478dc67829/pyzmq-26.2.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:380816d298aed32b1a97b4973a4865ef3be402a2e760204509b52b6de79d755d", size = 847743 }, + { url = "https://files.pythonhosted.org/packages/27/df/f94a711b4f6c4b41e227f9a938103f52acf4c2e949d91cbc682495a48155/pyzmq-26.2.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:97cbb368fd0debdbeb6ba5966aa28e9a1ae3396c7386d15569a6ca4be4572b99", size = 570991 }, + { url = "https://files.pythonhosted.org/packages/bf/08/0c6f97fb3c9dbfa23382f0efaf8f9aa1396a08a3358974eaae3ee659ed5c/pyzmq-26.2.1-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abf7b5942c6b0dafcc2823ddd9154f419147e24f8df5b41ca8ea40a6db90615c", size = 799664 }, + { url = "https://files.pythonhosted.org/packages/05/14/f4d4fd8bb8988c667845734dd756e9ee65b9a17a010d5f288dfca14a572d/pyzmq-26.2.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3fe6e28a8856aea808715f7a4fc11f682b9d29cac5d6262dd8fe4f98edc12d53", size = 758156 }, + { url = "https://files.pythonhosted.org/packages/e3/fe/72e7e166bda3885810bee7b23049133e142f7c80c295bae02c562caeea16/pyzmq-26.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bd8fdee945b877aa3bffc6a5a8816deb048dab0544f9df3731ecd0e54d8c84c9", size = 556563 }, +] + +[[package]] +name = "questionary" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "prompt-toolkit" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/b8/d16eb579277f3de9e56e5ad25280fab52fc5774117fb70362e8c2e016559/questionary-2.1.0.tar.gz", hash = "sha256:6302cdd645b19667d8f6e6634774e9538bfcd1aad9be287e743d96cacaf95587", size = 26775 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ad/3f/11dd4cd4f39e05128bfd20138faea57bec56f9ffba6185d276e3107ba5b2/questionary-2.1.0-py3-none-any.whl", hash = "sha256:44174d237b68bc828e4878c763a9ad6790ee61990e0ae72927694ead57bab8ec", size = 36747 }, +] + +[[package]] +name = "referencing" +version = "0.36.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "rpds-py" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/db/98b5c277be99dd18bfd91dd04e1b759cad18d1a338188c936e92f921c7e2/referencing-0.36.2.tar.gz", hash = "sha256:df2e89862cd09deabbdba16944cc3f10feb6b3e6f18e902f7cc25609a34775aa", size = 74744 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/b1/3baf80dc6d2b7bc27a95a67752d0208e410351e3feb4eb78de5f77454d8d/referencing-0.36.2-py3-none-any.whl", hash = "sha256:e8699adbbf8b5c7de96d8ffa0eb5c158b3beafce084968e2ea8bb08c6794dcd0", size = 26775 }, +] + +[[package]] +name = "requests" +version = "2.32.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, +] + +[[package]] +name = "rich" +version = "13.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, + { name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424 }, +] + +[[package]] +name = "rich-click" +version = "1.8.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/31/103501e85e885e3e202c087fa612cfe450693210372766552ce1ab5b57b9/rich_click-1.8.5.tar.gz", hash = "sha256:a3eebe81da1c9da3c32f3810017c79bd687ff1b3fa35bfc9d8a3338797f1d1a1", size = 38229 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/0b/e2de98c538c0ee9336211d260f88b7e69affab44969750aaca0b48a697c8/rich_click-1.8.5-py3-none-any.whl", hash = "sha256:0fab7bb5b66c15da17c210b4104277cd45f3653a7322e0098820a169880baee0", size = 35081 }, +] + +[[package]] +name = "rpds-py" +version = "0.22.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/80/cce854d0921ff2f0a9fa831ba3ad3c65cee3a46711addf39a2af52df2cfd/rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d", size = 26771 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/42/2a/ead1d09e57449b99dcc190d8d2323e3a167421d8f8fdf0f217c6f6befe47/rpds_py-0.22.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:6c7b99ca52c2c1752b544e310101b98a659b720b21db00e65edca34483259967", size = 359514 }, + { url = "https://files.pythonhosted.org/packages/8f/7e/1254f406b7793b586c68e217a6a24ec79040f85e030fff7e9049069284f4/rpds_py-0.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:be2eb3f2495ba669d2a985f9b426c1797b7d48d6963899276d22f23e33d47e37", size = 349031 }, + { url = "https://files.pythonhosted.org/packages/aa/da/17c6a2c73730d426df53675ff9cc6653ac7a60b6438d03c18e1c822a576a/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70eb60b3ae9245ddea20f8a4190bd79c705a22f8028aaf8bbdebe4716c3fab24", size = 381485 }, + { url = "https://files.pythonhosted.org/packages/aa/13/2dbacd820466aa2a3c4b747afb18d71209523d353cf865bf8f4796c969ea/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4041711832360a9b75cfb11b25a6a97c8fb49c07b8bd43d0d02b45d0b499a4ff", size = 386794 }, + { url = "https://files.pythonhosted.org/packages/6d/62/96905d0a35ad4e4bc3c098b2f34b2e7266e211d08635baa690643d2227be/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64607d4cbf1b7e3c3c8a14948b99345eda0e161b852e122c6bb71aab6d1d798c", size = 423523 }, + { url = "https://files.pythonhosted.org/packages/eb/1b/d12770f2b6a9fc2c3ec0d810d7d440f6d465ccd8b7f16ae5385952c28b89/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e69b0a0e2537f26d73b4e43ad7bc8c8efb39621639b4434b76a3de50c6966e", size = 446695 }, + { url = "https://files.pythonhosted.org/packages/4d/cf/96f1fd75512a017f8e07408b6d5dbeb492d9ed46bfe0555544294f3681b3/rpds_py-0.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc27863442d388870c1809a87507727b799c8460573cfbb6dc0eeaef5a11b5ec", size = 381959 }, + { url = "https://files.pythonhosted.org/packages/ab/f0/d1c5b501c8aea85aeb938b555bfdf7612110a2f8cdc21ae0482c93dd0c24/rpds_py-0.22.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e79dd39f1e8c3504be0607e5fc6e86bb60fe3584bec8b782578c3b0fde8d932c", size = 410420 }, + { url = "https://files.pythonhosted.org/packages/33/3b/45b6c58fb6aad5a569ae40fb890fc494c6b02203505a5008ee6dc68e65f7/rpds_py-0.22.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e0fa2d4ec53dc51cf7d3bb22e0aa0143966119f42a0c3e4998293a3dd2856b09", size = 557620 }, + { url = "https://files.pythonhosted.org/packages/83/62/3fdd2d3d47bf0bb9b931c4c73036b4ab3ec77b25e016ae26fab0f02be2af/rpds_py-0.22.3-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:fda7cb070f442bf80b642cd56483b5548e43d366fe3f39b98e67cce780cded00", size = 584202 }, + { url = "https://files.pythonhosted.org/packages/04/f2/5dced98b64874b84ca824292f9cee2e3f30f3bcf231d15a903126684f74d/rpds_py-0.22.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:cff63a0272fcd259dcc3be1657b07c929c466b067ceb1c20060e8d10af56f5bf", size = 552787 }, + { url = "https://files.pythonhosted.org/packages/67/13/2273dea1204eda0aea0ef55145da96a9aa28b3f88bb5c70e994f69eda7c3/rpds_py-0.22.3-cp310-cp310-win32.whl", hash = "sha256:9bd7228827ec7bb817089e2eb301d907c0d9827a9e558f22f762bb690b131652", size = 220088 }, + { url = "https://files.pythonhosted.org/packages/4e/80/8c8176b67ad7f4a894967a7a4014ba039626d96f1d4874d53e409b58d69f/rpds_py-0.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:9beeb01d8c190d7581a4d59522cd3d4b6887040dcfc744af99aa59fef3e041a8", size = 231737 }, + { url = "https://files.pythonhosted.org/packages/15/ad/8d1ddf78f2805a71253fcd388017e7b4a0615c22c762b6d35301fef20106/rpds_py-0.22.3-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d20cfb4e099748ea39e6f7b16c91ab057989712d31761d3300d43134e26e165f", size = 359773 }, + { url = "https://files.pythonhosted.org/packages/c8/75/68c15732293a8485d79fe4ebe9045525502a067865fa4278f178851b2d87/rpds_py-0.22.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:68049202f67380ff9aa52f12e92b1c30115f32e6895cd7198fa2a7961621fc5a", size = 349214 }, + { url = "https://files.pythonhosted.org/packages/3c/4c/7ce50f3070083c2e1b2bbd0fb7046f3da55f510d19e283222f8f33d7d5f4/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb4f868f712b2dd4bcc538b0a0c1f63a2b1d584c925e69a224d759e7070a12d5", size = 380477 }, + { url = "https://files.pythonhosted.org/packages/9a/e9/835196a69cb229d5c31c13b8ae603bd2da9a6695f35fe4270d398e1db44c/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bc51abd01f08117283c5ebf64844a35144a0843ff7b2983e0648e4d3d9f10dbb", size = 386171 }, + { url = "https://files.pythonhosted.org/packages/f9/8e/33fc4eba6683db71e91e6d594a2cf3a8fbceb5316629f0477f7ece5e3f75/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0f3cec041684de9a4684b1572fe28c7267410e02450f4561700ca5a3bc6695a2", size = 422676 }, + { url = "https://files.pythonhosted.org/packages/37/47/2e82d58f8046a98bb9497a8319604c92b827b94d558df30877c4b3c6ccb3/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7ef9d9da710be50ff6809fed8f1963fecdfecc8b86656cadfca3bc24289414b0", size = 446152 }, + { url = "https://files.pythonhosted.org/packages/e1/78/79c128c3e71abbc8e9739ac27af11dc0f91840a86fce67ff83c65d1ba195/rpds_py-0.22.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59f4a79c19232a5774aee369a0c296712ad0e77f24e62cad53160312b1c1eaa1", size = 381300 }, + { url = "https://files.pythonhosted.org/packages/c9/5b/2e193be0e8b228c1207f31fa3ea79de64dadb4f6a4833111af8145a6bc33/rpds_py-0.22.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1a60bce91f81ddaac922a40bbb571a12c1070cb20ebd6d49c48e0b101d87300d", size = 409636 }, + { url = "https://files.pythonhosted.org/packages/c2/3f/687c7100b762d62186a1c1100ffdf99825f6fa5ea94556844bbbd2d0f3a9/rpds_py-0.22.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e89391e6d60251560f0a8f4bd32137b077a80d9b7dbe6d5cab1cd80d2746f648", size = 556708 }, + { url = "https://files.pythonhosted.org/packages/8c/a2/c00cbc4b857e8b3d5e7f7fc4c81e23afd8c138b930f4f3ccf9a41a23e9e4/rpds_py-0.22.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e3fb866d9932a3d7d0c82da76d816996d1667c44891bd861a0f97ba27e84fc74", size = 583554 }, + { url = "https://files.pythonhosted.org/packages/d0/08/696c9872cf56effdad9ed617ac072f6774a898d46b8b8964eab39ec562d2/rpds_py-0.22.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1352ae4f7c717ae8cba93421a63373e582d19d55d2ee2cbb184344c82d2ae55a", size = 552105 }, + { url = "https://files.pythonhosted.org/packages/18/1f/4df560be1e994f5adf56cabd6c117e02de7c88ee238bb4ce03ed50da9d56/rpds_py-0.22.3-cp311-cp311-win32.whl", hash = "sha256:b0b4136a252cadfa1adb705bb81524eee47d9f6aab4f2ee4fa1e9d3cd4581f64", size = 220199 }, + { url = "https://files.pythonhosted.org/packages/b8/1b/c29b570bc5db8237553002788dc734d6bd71443a2ceac2a58202ec06ef12/rpds_py-0.22.3-cp311-cp311-win_amd64.whl", hash = "sha256:8bd7c8cfc0b8247c8799080fbff54e0b9619e17cdfeb0478ba7295d43f635d7c", size = 231775 }, + { url = "https://files.pythonhosted.org/packages/8b/63/e29f8ee14fcf383574f73b6bbdcbec0fbc2e5fc36b4de44d1ac389b1de62/rpds_py-0.22.3-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:d48424e39c2611ee1b84ad0f44fb3b2b53d473e65de061e3f460fc0be5f1939d", size = 360786 }, + { url = "https://files.pythonhosted.org/packages/d3/e0/771ee28b02a24e81c8c0e645796a371350a2bb6672753144f36ae2d2afc9/rpds_py-0.22.3-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:24e8abb5878e250f2eb0d7859a8e561846f98910326d06c0d51381fed59357bd", size = 350589 }, + { url = "https://files.pythonhosted.org/packages/cf/49/abad4c4a1e6f3adf04785a99c247bfabe55ed868133e2d1881200aa5d381/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b232061ca880db21fa14defe219840ad9b74b6158adb52ddf0e87bead9e8493", size = 381848 }, + { url = "https://files.pythonhosted.org/packages/3a/7d/f4bc6d6fbe6af7a0d2b5f2ee77079efef7c8528712745659ec0026888998/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac0a03221cdb5058ce0167ecc92a8c89e8d0decdc9e99a2ec23380793c4dcb96", size = 387879 }, + { url = "https://files.pythonhosted.org/packages/13/b0/575c797377fdcd26cedbb00a3324232e4cb2c5d121f6e4b0dbf8468b12ef/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb0c341fa71df5a4595f9501df4ac5abfb5a09580081dffbd1ddd4654e6e9123", size = 423916 }, + { url = "https://files.pythonhosted.org/packages/54/78/87157fa39d58f32a68d3326f8a81ad8fb99f49fe2aa7ad9a1b7d544f9478/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bf9db5488121b596dbfc6718c76092fda77b703c1f7533a226a5a9f65248f8ad", size = 448410 }, + { url = "https://files.pythonhosted.org/packages/59/69/860f89996065a88be1b6ff2d60e96a02b920a262d8aadab99e7903986597/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0b8db6b5b2d4491ad5b6bdc2bc7c017eec108acbf4e6785f42a9eb0ba234f4c9", size = 382841 }, + { url = "https://files.pythonhosted.org/packages/bd/d7/bc144e10d27e3cb350f98df2492a319edd3caaf52ddfe1293f37a9afbfd7/rpds_py-0.22.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b3d504047aba448d70cf6fa22e06cb09f7cbd761939fdd47604f5e007675c24e", size = 409662 }, + { url = "https://files.pythonhosted.org/packages/14/2a/6bed0b05233c291a94c7e89bc76ffa1c619d4e1979fbfe5d96024020c1fb/rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:e61b02c3f7a1e0b75e20c3978f7135fd13cb6cf551bf4a6d29b999a88830a338", size = 558221 }, + { url = "https://files.pythonhosted.org/packages/11/23/cd8f566de444a137bc1ee5795e47069a947e60810ba4152886fe5308e1b7/rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_i686.whl", hash = "sha256:e35ba67d65d49080e8e5a1dd40101fccdd9798adb9b050ff670b7d74fa41c566", size = 583780 }, + { url = "https://files.pythonhosted.org/packages/8d/63/79c3602afd14d501f751e615a74a59040328da5ef29ed5754ae80d236b84/rpds_py-0.22.3-pp310-pypy310_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:26fd7cac7dd51011a245f29a2cc6489c4608b5a8ce8d75661bb4a1066c52dfbe", size = 553619 }, + { url = "https://files.pythonhosted.org/packages/9f/2e/c5c1689e80298d4e94c75b70faada4c25445739d91b94c211244a3ed7ed1/rpds_py-0.22.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:177c7c0fce2855833819c98e43c262007f42ce86651ffbb84f37883308cb0e7d", size = 233338 }, +] + +[[package]] +name = "ruamel-yaml" +version = "0.18.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ruamel-yaml-clib", marker = "platform_python_implementation == 'CPython' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ea/46/f44d8be06b85bc7c4d8c95d658be2b68f27711f279bf9dd0612a5e4794f5/ruamel.yaml-0.18.10.tar.gz", hash = "sha256:20c86ab29ac2153f80a428e1254a8adf686d3383df04490514ca3b79a362db58", size = 143447 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/36/dfc1ebc0081e6d39924a2cc53654497f967a084a436bb64402dfce4254d9/ruamel.yaml-0.18.10-py3-none-any.whl", hash = "sha256:30f22513ab2301b3d2b577adc121c6471f28734d3d9728581245f1e76468b4f1", size = 117729 }, +] + +[[package]] +name = "ruamel-yaml-clib" +version = "0.2.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/84/80203abff8ea4993a87d823a5f632e4d92831ef75d404c9fc78d0176d2b5/ruamel.yaml.clib-0.2.12.tar.gz", hash = "sha256:6c8fbb13ec503f99a91901ab46e0b07ae7941cd527393187039aec586fdfd36f", size = 225315 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/70/57/40a958e863e299f0c74ef32a3bde9f2d1ea8d69669368c0c502a0997f57f/ruamel.yaml.clib-0.2.12-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:11f891336688faf5156a36293a9c362bdc7c88f03a8a027c2c1d8e0bcde998e5", size = 131301 }, + { url = "https://files.pythonhosted.org/packages/98/a8/29a3eb437b12b95f50a6bcc3d7d7214301c6c529d8fdc227247fa84162b5/ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a606ef75a60ecf3d924613892cc603b154178ee25abb3055db5062da811fd969", size = 633728 }, + { url = "https://files.pythonhosted.org/packages/35/6d/ae05a87a3ad540259c3ad88d71275cbd1c0f2d30ae04c65dcbfb6dcd4b9f/ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd5415dded15c3822597455bc02bcd66e81ef8b7a48cb71a33628fc9fdde39df", size = 722230 }, + { url = "https://files.pythonhosted.org/packages/7f/b7/20c6f3c0b656fe609675d69bc135c03aac9e3865912444be6339207b6648/ruamel.yaml.clib-0.2.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f66efbc1caa63c088dead1c4170d148eabc9b80d95fb75b6c92ac0aad2437d76", size = 686712 }, + { url = "https://files.pythonhosted.org/packages/cd/11/d12dbf683471f888d354dac59593873c2b45feb193c5e3e0f2ebf85e68b9/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:22353049ba4181685023b25b5b51a574bce33e7f51c759371a7422dcae5402a6", size = 663936 }, + { url = "https://files.pythonhosted.org/packages/72/14/4c268f5077db5c83f743ee1daeb236269fa8577133a5cfa49f8b382baf13/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:932205970b9f9991b34f55136be327501903f7c66830e9760a8ffb15b07f05cd", size = 696580 }, + { url = "https://files.pythonhosted.org/packages/30/fc/8cd12f189c6405a4c1cf37bd633aa740a9538c8e40497c231072d0fef5cf/ruamel.yaml.clib-0.2.12-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a52d48f4e7bf9005e8f0a89209bf9a73f7190ddf0489eee5eb51377385f59f2a", size = 663393 }, + { url = "https://files.pythonhosted.org/packages/80/29/c0a017b704aaf3cbf704989785cd9c5d5b8ccec2dae6ac0c53833c84e677/ruamel.yaml.clib-0.2.12-cp310-cp310-win32.whl", hash = "sha256:3eac5a91891ceb88138c113f9db04f3cebdae277f5d44eaa3651a4f573e6a5da", size = 100326 }, + { url = "https://files.pythonhosted.org/packages/3a/65/fa39d74db4e2d0cd252355732d966a460a41cd01c6353b820a0952432839/ruamel.yaml.clib-0.2.12-cp310-cp310-win_amd64.whl", hash = "sha256:ab007f2f5a87bd08ab1499bdf96f3d5c6ad4dcfa364884cb4549aa0154b13a28", size = 118079 }, + { url = "https://files.pythonhosted.org/packages/fb/8f/683c6ad562f558cbc4f7c029abcd9599148c51c54b5ef0f24f2638da9fbb/ruamel.yaml.clib-0.2.12-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:4a6679521a58256a90b0d89e03992c15144c5f3858f40d7c18886023d7943db6", size = 132224 }, + { url = "https://files.pythonhosted.org/packages/3c/d2/b79b7d695e2f21da020bd44c782490578f300dd44f0a4c57a92575758a76/ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:d84318609196d6bd6da0edfa25cedfbabd8dbde5140a0a23af29ad4b8f91fb1e", size = 641480 }, + { url = "https://files.pythonhosted.org/packages/68/6e/264c50ce2a31473a9fdbf4fa66ca9b2b17c7455b31ef585462343818bd6c/ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb43a269eb827806502c7c8efb7ae7e9e9d0573257a46e8e952f4d4caba4f31e", size = 739068 }, + { url = "https://files.pythonhosted.org/packages/86/29/88c2567bc893c84d88b4c48027367c3562ae69121d568e8a3f3a8d363f4d/ruamel.yaml.clib-0.2.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:811ea1594b8a0fb466172c384267a4e5e367298af6b228931f273b111f17ef52", size = 703012 }, + { url = "https://files.pythonhosted.org/packages/11/46/879763c619b5470820f0cd6ca97d134771e502776bc2b844d2adb6e37753/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cf12567a7b565cbf65d438dec6cfbe2917d3c1bdddfce84a9930b7d35ea59642", size = 704352 }, + { url = "https://files.pythonhosted.org/packages/02/80/ece7e6034256a4186bbe50dee28cd032d816974941a6abf6a9d65e4228a7/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7dd5adc8b930b12c8fc5b99e2d535a09889941aa0d0bd06f4749e9a9397c71d2", size = 737344 }, + { url = "https://files.pythonhosted.org/packages/f0/ca/e4106ac7e80efbabdf4bf91d3d32fc424e41418458251712f5672eada9ce/ruamel.yaml.clib-0.2.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1492a6051dab8d912fc2adeef0e8c72216b24d57bd896ea607cb90bb0c4981d3", size = 714498 }, + { url = "https://files.pythonhosted.org/packages/67/58/b1f60a1d591b771298ffa0428237afb092c7f29ae23bad93420b1eb10703/ruamel.yaml.clib-0.2.12-cp311-cp311-win32.whl", hash = "sha256:bd0a08f0bab19093c54e18a14a10b4322e1eacc5217056f3c063bd2f59853ce4", size = 100205 }, + { url = "https://files.pythonhosted.org/packages/b4/4f/b52f634c9548a9291a70dfce26ca7ebce388235c93588a1068028ea23fcc/ruamel.yaml.clib-0.2.12-cp311-cp311-win_amd64.whl", hash = "sha256:a274fb2cb086c7a3dea4322ec27f4cb5cc4b6298adb583ab0e211a4682f241eb", size = 118185 }, +] + +[[package]] +name = "ruff" +version = "0.9.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/74/6c359f6b9ed85b88df6ef31febce18faeb852f6c9855651dfb1184a46845/ruff-0.9.5.tar.gz", hash = "sha256:11aecd7a633932875ab3cb05a484c99970b9d52606ce9ea912b690b02653d56c", size = 3634177 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/4b/82b7c9ac874e72b82b19fd7eab57d122e2df44d2478d90825854f9232d02/ruff-0.9.5-py3-none-linux_armv6l.whl", hash = "sha256:d466d2abc05f39018d53f681fa1c0ffe9570e6d73cde1b65d23bb557c846f442", size = 11681264 }, + { url = "https://files.pythonhosted.org/packages/27/5c/f5ae0a9564e04108c132e1139d60491c0abc621397fe79a50b3dc0bd704b/ruff-0.9.5-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:38840dbcef63948657fa7605ca363194d2fe8c26ce8f9ae12eee7f098c85ac8a", size = 11657554 }, + { url = "https://files.pythonhosted.org/packages/2a/83/c6926fa3ccb97cdb3c438bb56a490b395770c750bf59f9bc1fe57ae88264/ruff-0.9.5-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d56ba06da53536b575fbd2b56517f6f95774ff7be0f62c80b9e67430391eeb36", size = 11088959 }, + { url = "https://files.pythonhosted.org/packages/af/a7/42d1832b752fe969ffdbfcb1b4cb477cb271bed5835110fb0a16ef31ab81/ruff-0.9.5-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7cb2a01da08244c50b20ccfaeb5972e4228c3c3a1989d3ece2bc4b1f996001", size = 11902041 }, + { url = "https://files.pythonhosted.org/packages/53/cf/1fffa09fb518d646f560ccfba59f91b23c731e461d6a4dedd21a393a1ff1/ruff-0.9.5-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:96d5c76358419bc63a671caac70c18732d4fd0341646ecd01641ddda5c39ca0b", size = 11421069 }, + { url = "https://files.pythonhosted.org/packages/09/27/bb8f1b7304e2a9431f631ae7eadc35550fe0cf620a2a6a0fc4aa3d736f94/ruff-0.9.5-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deb8304636ed394211f3a6d46c0e7d9535b016f53adaa8340139859b2359a070", size = 12625095 }, + { url = "https://files.pythonhosted.org/packages/d7/ce/ab00bc9d3df35a5f1b64f5117458160a009f93ae5caf65894ebb63a1842d/ruff-0.9.5-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df455000bf59e62b3e8c7ba5ed88a4a2bc64896f900f311dc23ff2dc38156440", size = 13257797 }, + { url = "https://files.pythonhosted.org/packages/88/81/c639a082ae6d8392bc52256058ec60f493c6a4d06d5505bccface3767e61/ruff-0.9.5-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de92170dfa50c32a2b8206a647949590e752aca8100a0f6b8cefa02ae29dce80", size = 12763793 }, + { url = "https://files.pythonhosted.org/packages/b3/d0/0a3d8f56d1e49af466dc770eeec5c125977ba9479af92e484b5b0251ce9c/ruff-0.9.5-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d28532d73b1f3f627ba88e1456f50748b37f3a345d2be76e4c653bec6c3e393", size = 14386234 }, + { url = "https://files.pythonhosted.org/packages/04/70/e59c192a3ad476355e7f45fb3a87326f5219cc7c472e6b040c6c6595c8f0/ruff-0.9.5-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c746d7d1df64f31d90503ece5cc34d7007c06751a7a3bbeee10e5f2463d52d2", size = 12437505 }, + { url = "https://files.pythonhosted.org/packages/55/4e/3abba60a259d79c391713e7a6ccabf7e2c96e5e0a19100bc4204f1a43a51/ruff-0.9.5-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11417521d6f2d121fda376f0d2169fb529976c544d653d1d6044f4c5562516ee", size = 11884799 }, + { url = "https://files.pythonhosted.org/packages/a3/db/b0183a01a9f25b4efcae919c18fb41d32f985676c917008620ad692b9d5f/ruff-0.9.5-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5b9d71c3879eb32de700f2f6fac3d46566f644a91d3130119a6378f9312a38e1", size = 11527411 }, + { url = "https://files.pythonhosted.org/packages/0a/e4/3ebfcebca3dff1559a74c6becff76e0b64689cea02b7aab15b8b32ea245d/ruff-0.9.5-py3-none-musllinux_1_2_i686.whl", hash = "sha256:2e36c61145e70febcb78483903c43444c6b9d40f6d2f800b5552fec6e4a7bb9a", size = 12078868 }, + { url = "https://files.pythonhosted.org/packages/ec/b2/5ab808833e06c0a1b0d046a51c06ec5687b73c78b116e8d77687dc0cd515/ruff-0.9.5-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2f71d09aeba026c922aa7aa19a08d7bd27c867aedb2f74285a2639644c1c12f5", size = 12524374 }, + { url = "https://files.pythonhosted.org/packages/e0/51/1432afcc3b7aa6586c480142caae5323d59750925c3559688f2a9867343f/ruff-0.9.5-py3-none-win32.whl", hash = "sha256:134f958d52aa6fdec3b294b8ebe2320a950d10c041473c4316d2e7d7c2544723", size = 9853682 }, + { url = "https://files.pythonhosted.org/packages/b7/ad/c7a900591bd152bb47fc4882a27654ea55c7973e6d5d6396298ad3fd6638/ruff-0.9.5-py3-none-win_amd64.whl", hash = "sha256:78cc6067f6d80b6745b67498fb84e87d32c6fc34992b52bffefbdae3442967d6", size = 10865744 }, + { url = "https://files.pythonhosted.org/packages/75/d9/fde7610abd53c0c76b6af72fc679cb377b27c617ba704e25da834e0a0608/ruff-0.9.5-py3-none-win_arm64.whl", hash = "sha256:18a29f1a005bddb229e580795627d297dfa99f16b30c7039e73278cf6b5f9fa9", size = 10064595 }, +] + +[[package]] +name = "scipy" +version = "1.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "1.26.4", source = { registry = "https://pypi.org/simple" }, marker = "extra == 'extra-5-gt4py-all' or extra == 'extra-5-gt4py-dace' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "numpy", version = "2.2.2", source = { registry = "https://pypi.org/simple" }, marker = "(extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra != 'extra-5-gt4py-all' and extra != 'extra-5-gt4py-dace') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/c6/8eb0654ba0c7d0bb1bf67bf8fbace101a8e4f250f7722371105e8b6f68fc/scipy-1.15.1.tar.gz", hash = "sha256:033a75ddad1463970c96a88063a1df87ccfddd526437136b6ee81ff0312ebdf6", size = 59407493 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/53/b204ce5a4433f1864001b9d16f103b9c25f5002a602ae83585d0ea5f9c4a/scipy-1.15.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:c64ded12dcab08afff9e805a67ff4480f5e69993310e093434b10e85dc9d43e1", size = 41414518 }, + { url = "https://files.pythonhosted.org/packages/c7/fc/54ffa7a8847f7f303197a6ba65a66104724beba2e38f328135a78f0dc480/scipy-1.15.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:5b190b935e7db569960b48840e5bef71dc513314cc4e79a1b7d14664f57fd4ff", size = 32519265 }, + { url = "https://files.pythonhosted.org/packages/f1/77/a98b8ba03d6f371dc31a38719affd53426d4665729dcffbed4afe296784a/scipy-1.15.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:4b17d4220df99bacb63065c76b0d1126d82bbf00167d1730019d2a30d6ae01ea", size = 24792859 }, + { url = "https://files.pythonhosted.org/packages/a7/78/70bb9f0df7444b18b108580934bfef774822e28fd34a68e5c263c7d2828a/scipy-1.15.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:63b9b6cd0333d0eb1a49de6f834e8aeaefe438df8f6372352084535ad095219e", size = 27886506 }, + { url = "https://files.pythonhosted.org/packages/14/a7/f40f6033e06de4176ddd6cc8c3ae9f10a226c3bca5d6b4ab883bc9914a14/scipy-1.15.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f151e9fb60fbf8e52426132f473221a49362091ce7a5e72f8aa41f8e0da4f25", size = 38375041 }, + { url = "https://files.pythonhosted.org/packages/17/03/390a1c5c61fd76b0fa4b3c5aa3bdd7e60f6c46f712924f1a9df5705ec046/scipy-1.15.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21e10b1dd56ce92fba3e786007322542361984f8463c6d37f6f25935a5a6ef52", size = 40597556 }, + { url = "https://files.pythonhosted.org/packages/4e/70/fa95b3ae026b97eeca58204a90868802e5155ac71b9d7bdee92b68115dd3/scipy-1.15.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5dff14e75cdbcf07cdaa1c7707db6017d130f0af9ac41f6ce443a93318d6c6e0", size = 42938505 }, + { url = "https://files.pythonhosted.org/packages/d6/07/427859116bdd71847c898180f01802691f203c3e2455a1eb496130ff07c5/scipy-1.15.1-cp310-cp310-win_amd64.whl", hash = "sha256:f82fcf4e5b377f819542fbc8541f7b5fbcf1c0017d0df0bc22c781bf60abc4d8", size = 43909663 }, + { url = "https://files.pythonhosted.org/packages/8e/2e/7b71312da9c2dabff53e7c9a9d08231bc34d9d8fdabe88a6f1155b44591c/scipy-1.15.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:5bd8d27d44e2c13d0c1124e6a556454f52cd3f704742985f6b09e75e163d20d2", size = 41424362 }, + { url = "https://files.pythonhosted.org/packages/81/8c/ab85f1aa1cc200c796532a385b6ebf6a81089747adc1da7482a062acc46c/scipy-1.15.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:be3deeb32844c27599347faa077b359584ba96664c5c79d71a354b80a0ad0ce0", size = 32535910 }, + { url = "https://files.pythonhosted.org/packages/3b/9c/6f4b787058daa8d8da21ddff881b4320e28de4704a65ec147adb50cb2230/scipy-1.15.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:5eb0ca35d4b08e95da99a9f9c400dc9f6c21c424298a0ba876fdc69c7afacedf", size = 24809398 }, + { url = "https://files.pythonhosted.org/packages/16/2b/949460a796df75fc7a1ee1becea202cf072edbe325ebe29f6d2029947aa7/scipy-1.15.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:74bb864ff7640dea310a1377d8567dc2cb7599c26a79ca852fc184cc851954ac", size = 27918045 }, + { url = "https://files.pythonhosted.org/packages/5f/36/67fe249dd7ccfcd2a38b25a640e3af7e59d9169c802478b6035ba91dfd6d/scipy-1.15.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:667f950bf8b7c3a23b4199db24cb9bf7512e27e86d0e3813f015b74ec2c6e3df", size = 38332074 }, + { url = "https://files.pythonhosted.org/packages/fc/da/452e1119e6f720df3feb588cce3c42c5e3d628d4bfd4aec097bd30b7de0c/scipy-1.15.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395be70220d1189756068b3173853029a013d8c8dd5fd3d1361d505b2aa58fa7", size = 40588469 }, + { url = "https://files.pythonhosted.org/packages/7f/71/5f94aceeac99a4941478af94fe9f459c6752d497035b6b0761a700f5f9ff/scipy-1.15.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ce3a000cd28b4430426db2ca44d96636f701ed12e2b3ca1f2b1dd7abdd84b39a", size = 42965214 }, + { url = "https://files.pythonhosted.org/packages/af/25/caa430865749d504271757cafd24066d596217e83326155993980bc22f97/scipy-1.15.1-cp311-cp311-win_amd64.whl", hash = "sha256:3fe1d95944f9cf6ba77aa28b82dd6bb2a5b52f2026beb39ecf05304b8392864b", size = 43896034 }, +] + +[[package]] +name = "setuptools" +version = "75.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/92/ec/089608b791d210aec4e7f97488e67ab0d33add3efccb83a056cbafe3a2a6/setuptools-75.8.0.tar.gz", hash = "sha256:c5afc8f407c626b8313a86e10311dd3f661c6cd9c09d4bf8c15c0e11f9f2b0e6", size = 1343222 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/8a/b9dc7678803429e4a3bc9ba462fa3dd9066824d3c607490235c6a796be5a/setuptools-75.8.0-py3-none-any.whl", hash = "sha256:e3982f444617239225d675215d51f6ba05f845d4eec313da4418fdbb56fb27e3", size = 1228782 }, +] + +[[package]] +name = "setuptools-scm" +version = "8.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, + { name = "setuptools" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4f/a4/00a9ac1b555294710d4a68d2ce8dfdf39d72aa4d769a7395d05218d88a42/setuptools_scm-8.1.0.tar.gz", hash = "sha256:42dea1b65771cba93b7a515d65a65d8246e560768a66b9106a592c8e7f26c8a7", size = 76465 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/b9/1906bfeb30f2fc13bb39bf7ddb8749784c05faadbd18a21cf141ba37bff2/setuptools_scm-8.1.0-py3-none-any.whl", hash = "sha256:897a3226a6fd4a6eb2f068745e49733261a21f70b1bb28fce0339feb978d9af3", size = 43666 }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, +] + +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303 }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235 }, +] + +[[package]] +name = "snowballstemmer" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/7b/af302bebf22c749c56c9c3e8ae13190b5b5db37a33d9068652e8f73b7089/snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1", size = 86699 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/dc/c02e01294f7265e63a7315fe086dd1df7dacb9f840a804da846b96d01b96/snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a", size = 93002 }, +] + +[[package]] +name = "sortedcontainers" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/c4/ba2f8066cceb6f23394729afe52f3bf7adec04bf9ed2c820b39e19299111/sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88", size = 30594 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575 }, +] + +[[package]] +name = "soupsieve" +version = "2.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/ce/fbaeed4f9fb8b2daa961f90591662df6a86c1abf25c548329a86920aedfb/soupsieve-2.6.tar.gz", hash = "sha256:e2e68417777af359ec65daac1057404a3c8a5455bb8abc36f1a9866ab1a51abb", size = 101569 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/c2/fe97d779f3ef3b15f05c94a2f1e3d21732574ed441687474db9d342a7315/soupsieve-2.6-py3-none-any.whl", hash = "sha256:e72c4ff06e4fb6e4b5a9f0f55fe6e81514581fca1515028625d0f299c602ccc9", size = 36186 }, +] + +[[package]] +name = "sphinx" +version = "8.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "alabaster" }, + { name = "babel" }, + { name = "colorama", marker = "sys_platform == 'win32' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, + { name = "docutils" }, + { name = "imagesize" }, + { name = "jinja2" }, + { name = "packaging" }, + { name = "pygments" }, + { name = "requests" }, + { name = "snowballstemmer" }, + { name = "sphinxcontrib-applehelp" }, + { name = "sphinxcontrib-devhelp" }, + { name = "sphinxcontrib-htmlhelp" }, + { name = "sphinxcontrib-jsmath" }, + { name = "sphinxcontrib-qthelp" }, + { name = "sphinxcontrib-serializinghtml" }, + { name = "tomli", marker = "python_full_version < '3.11' or (extra == 'extra-5-gt4py-all' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-jax-cuda12') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-cuda11' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-dace' and extra == 'extra-5-gt4py-dace-next') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm4-3') or (extra == 'extra-5-gt4py-jax-cuda12' and extra == 'extra-5-gt4py-rocm5-0') or (extra == 'extra-5-gt4py-rocm4-3' and extra == 'extra-5-gt4py-rocm5-0')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/be0b61178fe2cdcb67e2a92fc9ebb488e3c51c4f74a36a7824c0adf23425/sphinx-8.1.3.tar.gz", hash = "sha256:43c1911eecb0d3e161ad78611bc905d1ad0e523e4ddc202a58a821773dc4c927", size = 8184611 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/60/1ddff83a56d33aaf6f10ec8ce84b4c007d9368b21008876fceda7e7381ef/sphinx-8.1.3-py3-none-any.whl", hash = "sha256:09719015511837b76bf6e03e42eb7595ac8c2e41eeb9c29c5b755c6b677992a2", size = 3487125 }, +] + +[[package]] +name = "sphinx-autodoc-typehints" +version = "3.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/26/f0/43c6a5ff3e7b08a8c3b32f81b859f1b518ccc31e45f22e2b41ced38be7b9/sphinx_autodoc_typehints-3.0.1.tar.gz", hash = "sha256:b9b40dd15dee54f6f810c924f863f9cf1c54f9f3265c495140ea01be7f44fa55", size = 36282 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/dc/dc46c5c7c566b7ec5e8f860f9c89533bf03c0e6aadc96fb9b337867e4460/sphinx_autodoc_typehints-3.0.1-py3-none-any.whl", hash = "sha256:4b64b676a14b5b79cefb6628a6dc8070e320d4963e8ff640a2f3e9390ae9045a", size = 20245 }, +] + +[[package]] +name = "sphinx-jinja2-compat" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/26/df/27282da6f8c549f765beca9de1a5fc56f9651ed87711a5cac1e914137753/sphinx_jinja2_compat-0.3.0.tar.gz", hash = "sha256:f3c1590b275f42e7a654e081db5e3e5fb97f515608422bde94015ddf795dfe7c", size = 4998 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/42/2fd09d672eaaa937d6893d8b747d07943f97a6e5e30653aee6ebd339b704/sphinx_jinja2_compat-0.3.0-py3-none-any.whl", hash = "sha256:b1e4006d8e1ea31013fa9946d1b075b0c8d2a42c6e3425e63542c1e9f8be9084", size = 7883 }, +] + +[[package]] +name = "sphinx-prompt" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "docutils" }, + { name = "idna" }, + { name = "pygments" }, + { name = "sphinx" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/34/fe/ac4e24f35b5148b31ac717ae7dcc7a2f7ec56eb729e22c7252ed8ad2d9a5/sphinx_prompt-1.9.0.tar.gz", hash = "sha256:471b3c6d466dce780a9b167d9541865fd4e9a80ed46e31b06a52a0529ae995a1", size = 5340 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/98/e90ca466e0ede452d3e5a8d92b8fb68db6de269856e019ed9cab69440522/sphinx_prompt-1.9.0-py3-none-any.whl", hash = "sha256:fd731446c03f043d1ff6df9f22414495b23067c67011cc21658ea8d36b3575fc", size = 7311 }, +] + +[[package]] +name = "sphinx-rtd-theme" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "sphinx" }, + { name = "sphinxcontrib-jquery" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/44/c97faec644d29a5ceddd3020ae2edffa69e7d00054a8c7a6021e82f20335/sphinx_rtd_theme-3.0.2.tar.gz", hash = "sha256:b7457bc25dda723b20b086a670b9953c859eab60a2a03ee8eb2bb23e176e5f85", size = 7620463 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/77/46e3bac77b82b4df5bb5b61f2de98637724f246b4966cfc34bc5895d852a/sphinx_rtd_theme-3.0.2-py2.py3-none-any.whl", hash = "sha256:422ccc750c3a3a311de4ae327e82affdaf59eb695ba4936538552f3b00f4ee13", size = 7655561 }, +] + +[[package]] +name = "sphinx-tabs" +version = "3.4.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "docutils" }, + { name = "pygments" }, + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/32/ab475e252dc2b704e82a91141fa404cdd8901a5cf34958fd22afacebfccd/sphinx-tabs-3.4.5.tar.gz", hash = "sha256:ba9d0c1e3e37aaadd4b5678449eb08176770e0fc227e769b6ce747df3ceea531", size = 16070 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/9f/4ac7dbb9f23a2ff5a10903a4f9e9f43e0ff051f63a313e989c962526e305/sphinx_tabs-3.4.5-py3-none-any.whl", hash = "sha256:92cc9473e2ecf1828ca3f6617d0efc0aa8acb06b08c56ba29d1413f2f0f6cf09", size = 9904 }, +] + +[[package]] +name = "sphinx-toolbox" +version = "3.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apeye" }, + { name = "autodocsumm" }, + { name = "beautifulsoup4" }, + { name = "cachecontrol", extra = ["filecache"] }, + { name = "dict2css" }, + { name = "docutils" }, + { name = "domdf-python-tools" }, + { name = "filelock" }, + { name = "html5lib" }, + { name = "ruamel-yaml" }, + { name = "sphinx" }, + { name = "sphinx-autodoc-typehints" }, + { name = "sphinx-jinja2-compat" }, + { name = "sphinx-prompt" }, + { name = "sphinx-tabs" }, + { name = "tabulate" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/80/f837e85c8c216cdeef9b60393e4b00c9092a1e3d734106e0021abbf5930c/sphinx_toolbox-3.8.1.tar.gz", hash = "sha256:a4b39a6ea24fc8f10e24f052199bda17837a0bf4c54163a56f521552395f5e1a", size = 111977 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/d6/2a28ee4cbc158ae65afb2cfcb6895ef54d972ce1e167f8a63c135b14b080/sphinx_toolbox-3.8.1-py3-none-any.whl", hash = "sha256:53d8e77dd79e807d9ef18590c4b2960a5aa3c147415054b04c31a91afed8b88b", size = 194621 }, +] + +[[package]] +name = "sphinxcontrib-applehelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/6e/b837e84a1a704953c62ef8776d45c3e8d759876b4a84fe14eba2859106fe/sphinxcontrib_applehelp-2.0.0.tar.gz", hash = "sha256:2f29ef331735ce958efa4734873f084941970894c6090408b079c61b2e1c06d1", size = 20053 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/85/9ebeae2f76e9e77b952f4b274c27238156eae7979c5421fba91a28f4970d/sphinxcontrib_applehelp-2.0.0-py3-none-any.whl", hash = "sha256:4cd3f0ec4ac5dd9c17ec65e9ab272c9b867ea77425228e68ecf08d6b28ddbdb5", size = 119300 }, +] + +[[package]] +name = "sphinxcontrib-devhelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/d2/5beee64d3e4e747f316bae86b55943f51e82bb86ecd325883ef65741e7da/sphinxcontrib_devhelp-2.0.0.tar.gz", hash = "sha256:411f5d96d445d1d73bb5d52133377b4248ec79db5c793ce7dbe59e074b4dd1ad", size = 12967 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/7a/987e583882f985fe4d7323774889ec58049171828b58c2217e7f79cdf44e/sphinxcontrib_devhelp-2.0.0-py3-none-any.whl", hash = "sha256:aefb8b83854e4b0998877524d1029fd3e6879210422ee3780459e28a1f03a8a2", size = 82530 }, +] + +[[package]] +name = "sphinxcontrib-htmlhelp" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/93/983afd9aa001e5201eab16b5a444ed5b9b0a7a010541e0ddfbbfd0b2470c/sphinxcontrib_htmlhelp-2.1.0.tar.gz", hash = "sha256:c9e2916ace8aad64cc13a0d233ee22317f2b9025b9cf3295249fa985cc7082e9", size = 22617 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/7b/18a8c0bcec9182c05a0b3ec2a776bba4ead82750a55ff798e8d406dae604/sphinxcontrib_htmlhelp-2.1.0-py3-none-any.whl", hash = "sha256:166759820b47002d22914d64a075ce08f4c46818e17cfc9470a9786b759b19f8", size = 98705 }, +] + +[[package]] +name = "sphinxcontrib-jquery" +version = "4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sphinx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/de/f3/aa67467e051df70a6330fe7770894b3e4f09436dea6881ae0b4f3d87cad8/sphinxcontrib-jquery-4.1.tar.gz", hash = "sha256:1620739f04e36a2c779f1a131a2dfd49b2fd07351bf1968ced074365933abc7a", size = 122331 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/85/749bd22d1a68db7291c89e2ebca53f4306c3f205853cf31e9de279034c3c/sphinxcontrib_jquery-4.1-py2.py3-none-any.whl", hash = "sha256:f936030d7d0147dd026a4f2b5a57343d233f1fc7b363f68b3d4f1cb0993878ae", size = 121104 }, +] + +[[package]] +name = "sphinxcontrib-jsmath" +version = "1.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/e8/9ed3830aeed71f17c026a07a5097edcf44b692850ef215b161b8ad875729/sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8", size = 5787 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/42/4c8646762ee83602e3fb3fbe774c2fac12f317deb0b5dbeeedd2d3ba4b77/sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178", size = 5071 }, +] + +[[package]] +name = "sphinxcontrib-qthelp" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/68/bc/9104308fc285eb3e0b31b67688235db556cd5b0ef31d96f30e45f2e51cae/sphinxcontrib_qthelp-2.0.0.tar.gz", hash = "sha256:4fe7d0ac8fc171045be623aba3e2a8f613f8682731f9153bb2e40ece16b9bbab", size = 17165 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/27/83/859ecdd180cacc13b1f7e857abf8582a64552ea7a061057a6c716e790fce/sphinxcontrib_qthelp-2.0.0-py3-none-any.whl", hash = "sha256:b18a828cdba941ccd6ee8445dbe72ffa3ef8cbe7505d8cd1fa0d42d3f2d5f3eb", size = 88743 }, +] + +[[package]] +name = "sphinxcontrib-serializinghtml" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3b/44/6716b257b0aa6bfd51a1b31665d1c205fb12cb5ad56de752dfa15657de2f/sphinxcontrib_serializinghtml-2.0.0.tar.gz", hash = "sha256:e9d912827f872c029017a53f0ef2180b327c3f7fd23c87229f7a8e8b70031d4d", size = 16080 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/a7/d2782e4e3f77c8450f727ba74a8f12756d5ba823d81b941f1b04da9d033a/sphinxcontrib_serializinghtml-2.0.0-py3-none-any.whl", hash = "sha256:6e2cb0eef194e10c27ec0023bfeb25badbbb5868244cf5bc5bdc04e4464bf331", size = 92072 }, +] + +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, +] + +[[package]] +name = "sympy" +version = "1.13.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/8a/5a7fd6284fa8caac23a26c9ddf9c30485a48169344b4bd3b0f02fef1890f/sympy-1.13.3.tar.gz", hash = "sha256:b27fd2c6530e0ab39e275fc9b683895367e51d5da91baa8d3d64db2565fec4d9", size = 7533196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/ff/c87e0622b1dadea79d2fb0b25ade9ed98954c9033722eb707053d310d4f3/sympy-1.13.3-py3-none-any.whl", hash = "sha256:54612cf55a62755ee71824ce692986f23c88ffa77207b30c1368eda4a7060f73", size = 6189483 }, +] + +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 }, +] + +[[package]] +name = "tach" +version = "0.24.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitpython" }, + { name = "networkx" }, + { name = "prompt-toolkit" }, + { name = "pydot" }, + { name = "pyyaml" }, + { name = "rich" }, + { name = "tomli" }, + { name = "tomli-w" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/2c/1afb1a3c16125b9cfc5a1da79ba2329dec11e16b9c9eea7ac411074a49cb/tach-0.24.1.tar.gz", hash = "sha256:63f7f3b3e3458a97ded020b524f32fc72bc731ff880d0709301b2802ff759721", size = 490250 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/b3/2af242caa456cd48c83ed8a3872c8eabe9d616d556ea52c1b39835f661c3/tach-0.24.1-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:b965048c4918bd8d24d54a8a7a232bf6b210c1dd0c97caed83ac2f8db271db45", size = 3403749 }, + { url = "https://files.pythonhosted.org/packages/d1/2d/a64f5a9b0674527cc6c95fba681d7d53652f0cc092ce3d768e11409c3378/tach-0.24.1-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:bd203c8a581c6cf1f3813d5eeacd612bdb0c2681939677b33cc7d555d9216ff0", size = 3252234 }, + { url = "https://files.pythonhosted.org/packages/41/36/627ef905e792a0a281ce416581eae33e963b7dda5023460fd81ea0ab944e/tach-0.24.1-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73230ce1af9be01b08e42bd6002344562a5e51942b806869e0c3d784a38ae117", size = 3537522 }, + { url = "https://files.pythonhosted.org/packages/cc/90/d79c0cbfcae6f91b9c3cf5f2c077786057fcd59a4ca06608a3df1c072b3b/tach-0.24.1-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb982d6ead606ead1ca2d5decf1aa10414d6eecdded92de9755940acb18fd1df", size = 3497754 }, + { url = "https://files.pythonhosted.org/packages/77/5b/07fb1554509539cd4a2582a24b49ff3961cdb39cfe064429c8fd7b4fc9cb/tach-0.24.1-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50a56b14fcb8d311d07ac49fdec1a6619b4644b991112c17e894838827f198bb", size = 3814772 }, + { url = "https://files.pythonhosted.org/packages/36/33/1c9b051aada11d4171ba4a64cb537f1f95bc6d093cfae4d235bb0124813a/tach-0.24.1-cp37-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3756ea8fdd7ffeaaa4c2bb272ff3c407f51e7c83d8108ecc28f4acdcb11f5bd4", size = 3789273 }, + { url = "https://files.pythonhosted.org/packages/96/d8/6b3f624d5fa7db9a43e29887b643ae4c560127764e94aea93a4ec51a87e4/tach-0.24.1-cp37-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8f278a930651e7cafb5b2b8fd398cfc0ac205f9c81e618aad1d5bedcce86217d", size = 4057183 }, + { url = "https://files.pythonhosted.org/packages/b3/63/bd8028d67f36f4a35acbed746eb822be8825c1cc02eb990c780ad24877ee/tach-0.24.1-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6eb884a8936d9910d2d8675ad04726ecfba7ac830e09c2463acd561250f507e", size = 3655117 }, + { url = "https://files.pythonhosted.org/packages/6a/be/4a8ff273365dbafe2414665d81bb7416e0ed76b836ebfa6e5aa92ab579f9/tach-0.24.1-cp37-abi3-win32.whl", hash = "sha256:7d5db6480ea33ee95f023d9882b1d67863fb06eb802e97948d5b6c7b0a56bb39", size = 2857513 }, + { url = "https://files.pythonhosted.org/packages/8e/1a/92e7b283147e27750d1485fbe6bd595c64d9d8d017104971175bd82d4072/tach-0.24.1-cp37-abi3-win_amd64.whl", hash = "sha256:4e321f45a1457da49e9aab2f11630907776b0031e78242a80650b27413cb925c", size = 3071088 }, +] + +[[package]] +name = "tomli" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 }, + { url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 }, + { url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 }, + { url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 }, + { url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 }, + { url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 }, + { url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 }, + { url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 }, + { url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 }, + { url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 }, + { url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 }, +] + +[[package]] +name = "tomli-w" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/19/75/241269d1da26b624c0d5e110e8149093c759b7a286138f4efd61a60e75fe/tomli_w-1.2.0.tar.gz", hash = "sha256:2dd14fac5a47c27be9cd4c976af5a12d87fb1f0b4512f81d69cce3b35ae25021", size = 7184 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/18/c86eb8e0202e32dd3df50d43d7ff9854f8e0603945ff398974c1d91ac1ef/tomli_w-1.2.0-py3-none-any.whl", hash = "sha256:188306098d013b691fcadc011abd66727d3c414c571bb01b1a174ba8c983cf90", size = 6675 }, +] + +[[package]] +name = "tomlkit" +version = "0.13.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/09/a439bec5888f00a54b8b9f05fa94d7f901d6735ef4e55dcec9bc37b5d8fa/tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79", size = 192885 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/b6/a447b5e4ec71e13871be01ba81f5dfc9d0af7e473da256ff46bc0e24026f/tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde", size = 37955 }, +] + +[[package]] +name = "toolz" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/0b/d80dfa675bf592f636d1ea0b835eab4ec8df6e9415d8cfd766df54456123/toolz-1.0.0.tar.gz", hash = "sha256:2c86e3d9a04798ac556793bced838816296a2f085017664e4995cb40a1047a02", size = 66790 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/98/eb27cc78ad3af8e302c9d8ff4977f5026676e130d28dd7578132a457170c/toolz-1.0.0-py3-none-any.whl", hash = "sha256:292c8f1c4e7516bf9086f8850935c799a874039c8bcf959d47b600e4c44a6236", size = 56383 }, +] + +[[package]] +name = "tornado" +version = "6.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/59/45/a0daf161f7d6f36c3ea5fc0c2de619746cc3dd4c76402e9db545bd920f63/tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b", size = 501135 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/7e/71f604d8cea1b58f82ba3590290b66da1e72d840aeb37e0d5f7291bd30db/tornado-6.4.2-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e828cce1123e9e44ae2a50a9de3055497ab1d0aeb440c5ac23064d9e44880da1", size = 436299 }, + { url = "https://files.pythonhosted.org/packages/96/44/87543a3b99016d0bf54fdaab30d24bf0af2e848f1d13d34a3a5380aabe16/tornado-6.4.2-cp38-abi3-macosx_10_9_x86_64.whl", hash = "sha256:072ce12ada169c5b00b7d92a99ba089447ccc993ea2143c9ede887e0937aa803", size = 434253 }, + { url = "https://files.pythonhosted.org/packages/cb/fb/fdf679b4ce51bcb7210801ef4f11fdac96e9885daa402861751353beea6e/tornado-6.4.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a017d239bd1bb0919f72af256a970624241f070496635784d9bf0db640d3fec", size = 437602 }, + { url = "https://files.pythonhosted.org/packages/4f/3b/e31aeffffc22b475a64dbeb273026a21b5b566f74dee48742817626c47dc/tornado-6.4.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c36e62ce8f63409301537222faffcef7dfc5284f27eec227389f2ad11b09d946", size = 436972 }, + { url = "https://files.pythonhosted.org/packages/22/55/b78a464de78051a30599ceb6983b01d8f732e6f69bf37b4ed07f642ac0fc/tornado-6.4.2-cp38-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca9eb02196e789c9cb5c3c7c0f04fb447dc2adffd95265b2c7223a8a615ccbf", size = 437173 }, + { url = "https://files.pythonhosted.org/packages/79/5e/be4fb0d1684eb822c9a62fb18a3e44a06188f78aa466b2ad991d2ee31104/tornado-6.4.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:304463bd0772442ff4d0f5149c6f1c2135a1fae045adf070821c6cdc76980634", size = 437892 }, + { url = "https://files.pythonhosted.org/packages/f5/33/4f91fdd94ea36e1d796147003b490fe60a0215ac5737b6f9c65e160d4fe0/tornado-6.4.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:c82c46813ba483a385ab2a99caeaedf92585a1f90defb5693351fa7e4ea0bf73", size = 437334 }, + { url = "https://files.pythonhosted.org/packages/2b/ae/c1b22d4524b0e10da2f29a176fb2890386f7bd1f63aacf186444873a88a0/tornado-6.4.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:932d195ca9015956fa502c6b56af9eb06106140d844a335590c1ec7f5277d10c", size = 437261 }, + { url = "https://files.pythonhosted.org/packages/b5/25/36dbd49ab6d179bcfc4c6c093a51795a4f3bed380543a8242ac3517a1751/tornado-6.4.2-cp38-abi3-win32.whl", hash = "sha256:2876cef82e6c5978fde1e0d5b1f919d756968d5b4282418f3146b79b58556482", size = 438463 }, + { url = "https://files.pythonhosted.org/packages/61/cc/58b1adeb1bb46228442081e746fcdbc4540905c87e8add7c277540934edb/tornado-6.4.2-cp38-abi3-win_amd64.whl", hash = "sha256:908b71bf3ff37d81073356a5fadcc660eb10c1476ee6e2725588626ce7e5ca38", size = 438907 }, +] + +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, +] + +[[package]] +name = "types-decorator" +version = "5.1.8.20250121" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/e6/88de14bb1d1073495b9d9459f90fbb78fe93d89beefcf0af94b871993a56/types_decorator-5.1.8.20250121.tar.gz", hash = "sha256:1b89bb1c481a1d3399e28f1aa3459366b76dde951490992ae8475ba91287cd04", size = 8496 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/0e/59b9637fa66fbe419886b17d59b90e5e4256325c01f94f81dcc44fbeda53/types_decorator-5.1.8.20250121-py3-none-any.whl", hash = "sha256:6bfd5f4464f444a1ee0aea92705ed8466d74c0ddd7ade4bbd003c235db51d21a", size = 8078 }, +] + +[[package]] +name = "types-docutils" +version = "0.21.0.20241128" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dd/df/64e7ab01a4fc5ce46895dc94e31cffc8b8087c8d91ee54c45ac2d8d82445/types_docutils-0.21.0.20241128.tar.gz", hash = "sha256:4dd059805b83ac6ec5a223699195c4e9eeb0446a4f7f2aeff1759a4a7cc17473", size = 26739 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/b6/10ba95739f2cbb9c5bd2f6568148d62b468afe01a94c633e8892a2936d8a/types_docutils-0.21.0.20241128-py3-none-any.whl", hash = "sha256:e0409204009639e9b0bf4521eeabe58b5e574ce9c0db08421c2ac26c32be0039", size = 34677 }, +] + +[[package]] +name = "types-pytz" +version = "2025.1.0.20250204" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/d2/2190c54d53c04491ad72a1df019c5dfa692e6ab6c2dba1be7b6c9d530e30/types_pytz-2025.1.0.20250204.tar.gz", hash = "sha256:00f750132769f1c65a4f7240bc84f13985b4da774bd17dfbe5d9cd442746bd49", size = 10352 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/50/65ffad73746f1d8b15992c030e0fd22965fd5ae2c0206dc28873343b3230/types_pytz-2025.1.0.20250204-py3-none-any.whl", hash = "sha256:32ca4a35430e8b94f6603b35beb7f56c32260ddddd4f4bb305fdf8f92358b87e", size = 10059 }, +] + +[[package]] +name = "types-pyyaml" +version = "6.0.12.20241230" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/f9/4d566925bcf9396136c0a2e5dc7e230ff08d86fa011a69888dd184469d80/types_pyyaml-6.0.12.20241230.tar.gz", hash = "sha256:7f07622dbd34bb9c8b264fe860a17e0efcad00d50b5f27e93984909d9363498c", size = 17078 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e8/c1/48474fbead512b70ccdb4f81ba5eb4a58f69d100ba19f17c92c0c4f50ae6/types_PyYAML-6.0.12.20241230-py3-none-any.whl", hash = "sha256:fa4d32565219b68e6dee5f67534c722e53c00d1cfc09c435ef04d7353e1e96e6", size = 20029 }, +] + +[[package]] +name = "types-tabulate" +version = "0.9.0.20241207" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/43/16030404a327e4ff8c692f2273854019ed36718667b2993609dc37d14dd4/types_tabulate-0.9.0.20241207.tar.gz", hash = "sha256:ac1ac174750c0a385dfd248edc6279fa328aaf4ea317915ab879a2ec47833230", size = 8195 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/86/a9ebfd509cbe74471106dffed320e208c72537f9aeb0a55eaa6b1b5e4d17/types_tabulate-0.9.0.20241207-py3-none-any.whl", hash = "sha256:b8dad1343c2a8ba5861c5441370c3e35908edd234ff036d4298708a1d4cf8a85", size = 8307 }, +] + +[[package]] +name = "typing-extensions" +version = "4.12.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, +] + +[[package]] +name = "urllib3" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/63/e53da845320b757bf29ef6a9062f5c669fe997973f966045cb019c3f4b66/urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d", size = 307268 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 }, +] + +[[package]] +name = "virtualenv" +version = "20.29.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/ca/f23dcb02e161a9bba141b1c08aa50e8da6ea25e6d780528f1d385a3efe25/virtualenv-20.29.1.tar.gz", hash = "sha256:b8b8970138d32fb606192cb97f6cd4bb644fa486be9308fb9b63f81091b5dc35", size = 7658028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/89/9b/599bcfc7064fbe5740919e78c5df18e5dceb0887e676256a1061bb5ae232/virtualenv-20.29.1-py3-none-any.whl", hash = "sha256:4e4cb403c0b0da39e13b46b1b2476e505cb0046b25f242bee80f62bf990b2779", size = 4282379 }, +] + +[[package]] +name = "wcmatch" +version = "10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bracex" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/ab/b3a52228538ccb983653c446c1656eddf1d5303b9cb8b9aef6a91299f862/wcmatch-10.0.tar.gz", hash = "sha256:e72f0de09bba6a04e0de70937b0cf06e55f36f37b3deb422dfaf854b867b840a", size = 115578 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/df/4ee467ab39cc1de4b852c212c1ed3becfec2e486a51ac1ce0091f85f38d7/wcmatch-10.0-py3-none-any.whl", hash = "sha256:0dd927072d03c0a6527a20d2e6ad5ba8d0380e60870c383bc533b71744df7b7a", size = 39347 }, +] + +[[package]] +name = "wcwidth" +version = "0.2.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, +] + +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774 }, +] + +[[package]] +name = "wheel" +version = "0.45.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/98/2d9906746cdc6a6ef809ae6338005b3f21bb568bea3165cfc6a243fdc25c/wheel-0.45.1.tar.gz", hash = "sha256:661e1abd9198507b1409a20c02106d9670b2576e916d58f520316666abca6729", size = 107545 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0b/2c/87f3254fd8ffd29e4c02732eee68a83a1d3c346ae39bc6822dcbcb697f2b/wheel-0.45.1-py3-none-any.whl", hash = "sha256:708e7481cc80179af0e556bbf0cc00b8444c7321e2700b8d8580231d13017248", size = 72494 }, +] + +[[package]] +name = "xxhash" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/3e/ca49932bade8b3308e74df951c36cbc84c8230c9b8715bae1e0014831aa7/xxhash-3.0.0.tar.gz", hash = "sha256:30b2d97aaf11fb122023f6b44ebb97c6955e9e00d7461a96415ca030b5ceb9c7", size = 74279 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/fe/41444c518df82da46bc7125c9daa4159e6cfc2b682ccc73493b0485b8a70/xxhash-3.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:219cba13991fd73cf21a5efdafa5056f0ae0b8f79e5e0112967e3058daf73eea", size = 34110 }, + { url = "https://files.pythonhosted.org/packages/6f/83/0afffed636656f65f78e35da174c9bdd86367f9d4da23a87fc9d1b933bbe/xxhash-3.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3fcbb846af15eff100c412ae54f4974ff277c92eacd41f1ec7803a64fd07fa0c", size = 30664 }, + { url = "https://files.pythonhosted.org/packages/8c/b1/cde24bf3c9d4d6bbe02e9e82604dbd40ab21c9799b0fdb66a4fe2046e96d/xxhash-3.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f475fa817ff7955fc118fc1ca29a6e691d329b7ff43f486af36c22dbdcff1db", size = 241825 }, + { url = "https://files.pythonhosted.org/packages/70/fd/7ebfe1549551c87875b64cf9c925e3cf8be53e475d29aed933643f6dd8aa/xxhash-3.0.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9200a90f02ff6fd5fb63dea107842da71d8626d99b768fd31be44f3002c60bbe", size = 206492 }, + { url = "https://files.pythonhosted.org/packages/d2/6f/eafbb4ec3baf499423f2de3a5f3b6c5898f3bf4a8714e100d5dfb911fbad/xxhash-3.0.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a1403e4f551c9ef7bcef09af55f1adb169f13e4de253db0887928e5129f87af1", size = 286394 }, + { url = "https://files.pythonhosted.org/packages/64/05/504e1a7accc8f115ebfba96104c2f4a4aea3fb415bd664a6a1cc8915671e/xxhash-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa7f6ca53170189a2268c83af0980e6c10aae69e6a5efa7ca989f89fff9f8c02", size = 211550 }, + { url = "https://files.pythonhosted.org/packages/f8/b9/b6558ba62479dbdd18f894842f6ec01bbbf94aa8a26340f889c1af550fa8/xxhash-3.0.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5b63fbeb6d9c93d50ae0dc2b8a8b7f52f2de19e40fe9edc86637bfa5743b8ba2", size = 219718 }, + { url = "https://files.pythonhosted.org/packages/19/7a/270f9c47d9748b7d43ec2ce0ee1d50c189ccf21e7ba6adc39e4045fcd450/xxhash-3.0.0-cp310-cp310-win32.whl", hash = "sha256:31f25efd10b6f1f6d5c34cd231986d8aae9a42e042daa90b783917f170807869", size = 30157 }, + { url = "https://files.pythonhosted.org/packages/67/54/f98d6eccb96da4fc51f4397123828c593c6f2731ede141f2318d1aab8a6b/xxhash-3.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:807e88ed56e0fb347cb57d5bf44851f9878360fed700f2f63e622ef4eede87a5", size = 29918 }, +]