From 4e2ed067d53c8f03242807bf77f9dd43b3ddd0a7 Mon Sep 17 00:00:00 2001 From: etorreborre Date: Wed, 12 Jun 2024 14:10:23 +0200 Subject: [PATCH] feat(rust): use the any driver for sqlx to add support for postgres --- .github/workflows/rust.yml | 429 +++++++++------- Cargo.lock | 23 +- NOTICE.md | 16 +- implementations/rust/Makefile | 7 + .../rust/ockam/ockam_abac/Cargo.toml | 2 +- .../storage/resource_policy_repository_sql.rs | 56 +- .../policy/storage/resource_repository_sql.rs | 48 +- .../resource_type_policy_repository_sql.rs | 91 +++- .../rust/ockam/ockam_abac/src/types.rs | 4 + .../rust/ockam/ockam_api/Cargo.toml | 2 +- .../direct/direct_authenticator.rs | 7 +- .../src/authenticator/one_time_code.rs | 7 - ...thority_enrollment_token_repository_sql.rs | 408 ++++++++------- .../authenticator/storage/authority_member.rs | 7 +- .../authority_members_repository_sql.rs | 294 ++++++----- .../authenticator/storage/enrollment_token.rs | 5 +- .../ockam_api/src/authority_node/authority.rs | 18 +- .../src/authority_node/configuration.rs | 5 +- .../ockam_api/src/cli_state/cli_state.rs | 126 +++-- .../ockam_api/src/cli_state/identities.rs | 4 +- .../ockam/ockam_api/src/cli_state/nodes.rs | 7 +- .../storage/enrollments_repository_sql.rs | 42 +- .../storage/identities_repository_sql.rs | 292 ++++++----- .../storage/journeys_repository_sql.rs | 383 +++++++------- .../cli_state/storage/nodes_repository_sql.rs | 334 ++++++------ .../storage/projects_repository_sql.rs | 395 ++++++++------ .../storage/spaces_repository_sql.rs | 193 +++---- .../storage/tcp_portals_repository_sql.rs | 140 ++--- .../cli_state/storage/users_repository_sql.rs | 164 +++--- .../cli_state/storage/vaults_repository.rs | 14 +- .../storage/vaults_repository_sql.rs | 190 ++++--- .../ockam_api/src/cli_state/test_support.rs | 20 +- .../ockam/ockam_api/src/cli_state/vaults.rs | 486 +++++++++++------- .../ockam_api/src/kafka/integration_test.rs | 3 +- .../ockam_api/src/kafka/portal_worker.rs | 8 +- .../src/kafka/protocol_aware/tests.rs | 2 + .../src/nodes/service/in_memory_node.rs | 9 +- .../ockam/ockam_api/src/test_utils/mod.rs | 19 +- .../ockam/ockam_api/src/ui/output/utils.rs | 15 + .../ockam/ockam_api/tests/common/common.rs | 6 +- .../rust/ockam/ockam_api/tests/latency.rs | 2 + .../rust/ockam/ockam_api/tests/portals.rs | 5 + .../rust/ockam/ockam_app_lib/Cargo.toml | 2 +- .../src/state/model_state_repository_sql.rs | 185 +++---- .../ockam_command/src/authority/create.rs | 2 +- .../src/environment/static/env_info.txt | 14 + .../ockam/ockam_command/src/vault/create.rs | 15 +- .../ockam/ockam_command/src/vault/util.rs | 109 ++-- .../ockam_command/tests/bats/local/jq.bats | 2 +- .../ockam_command/tests/bats/local/vault.bats | 8 +- .../rust/ockam/ockam_identity/Cargo.toml | 2 +- .../storage/change_history_repository_sql.rs | 178 ++++--- .../storage/credential_repository_sql.rs | 204 +++++--- .../identity_attributes_repository_sql.rs | 247 +++++---- .../ockam_identity/src/models/identifiers.rs | 1 - .../storage/purpose_keys_repository_sql.rs | 71 ++- .../storage/secure_channel_repository_sql.rs | 27 +- .../ockam/ockam_identity/tests/persistence.rs | 9 +- .../rust/ockam/ockam_node/Cargo.toml | 8 +- .../rust/ockam/ockam_node/src/lib.rs | 1 + .../database/database_configuration.rs | 198 +++++++ .../application_migration_set.rs | 29 +- .../20240613110000_project_journey.sql | 12 + .../20241701150000_project_journey.sql | 0 .../20242102180000_time_limited_journey.sql | 0 .../migrations/migration_support/migrator.rs | 70 +-- .../migration_support/rust_migration.rs | 4 +- .../migrations/node_migrations/mod.rs | 5 +- .../node_migrations/node_migration_set.rs | 59 ++- .../migrations/node_migrations/rust.rs | 16 +- ...231100000_node_name_identity_attributes.rs | 83 +-- ...ion_20240111100001_add_authority_tables.rs | 94 ++-- ...ion_20240111100002_delete_trust_context.rs | 54 +- ...migration_20240212100000_split_policies.rs | 38 +- ..._20240313100000_remove_orphan_resources.rs | 52 +- ...0240503100000_update_policy_expressions.rs | 47 +- .../node_migrations/rust/sqlite/mod.rs | 14 + .../20240613100000_create_database.sql | 359 +++++++++++++ .../20231006100000_create_database.sql | 0 .../20231230100000_add_rust_migrations.sql | 0 ...31100000_node_name_identity_attributes.sql | 0 ...08100000_rename_confluent_config_table.sql | 0 .../20240111100001_add_authority_tables.sql | 0 .../20240111100002_delete_trust_context.sql | 0 .../20240111100003_add_credential.sql | 0 .../20240212100000_split_policies.sql | 0 .../20240212100001_outlet_remove_alias.sql | 0 .../20240213100000_add_controller_history.sql | 0 ...3100001_add_enrollment_token_reference.sql | 0 .../20240214100000_extend_project.sql | 0 .../20240307100000_credential_add_scope.sql | 0 .../20240314150000_tcp_portals.sql | 0 .../20240321100000_add_enrollment_email.sql | 0 .../20240507100000_node_add_http_column.sql | 0 .../20240527100000_add_sc_persistence.sql | 0 .../sqlite/20240619100000_database_vault.sql | 29 ++ .../ockam_node/src/storage/database/mod.rs | 6 +- .../src/storage/database/sqlx_database.rs | 337 ++++++++++-- .../storage/database/sqlx_from_row_types.rs | 104 ++++ .../src/storage/database/sqlx_types.rs | 206 -------- .../rust/ockam/ockam_vault/Cargo.toml | 2 +- .../src/software/vault_for_signing/types.rs | 10 + .../src/storage/secrets_repository_sql.rs | 152 ++++-- 103 files changed, 4455 insertions(+), 2898 deletions(-) create mode 100644 implementations/rust/ockam/ockam_node/src/storage/database/database_configuration.rs create mode 100644 implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/postgres/20240613110000_project_journey.sql rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/{ => sqlite}/20241701150000_project_journey.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/{ => sqlite}/20242102180000_time_limited_journey.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/{ => sqlite}/migration_20231231100000_node_name_identity_attributes.rs (72%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/{ => sqlite}/migration_20240111100001_add_authority_tables.rs (73%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/{ => sqlite}/migration_20240111100002_delete_trust_context.rs (89%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/{ => sqlite}/migration_20240212100000_split_policies.rs (83%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/{ => sqlite}/migration_20240313100000_remove_orphan_resources.rs (78%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/{ => sqlite}/migration_20240503100000_update_policy_expressions.rs (75%) create mode 100644 implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/mod.rs create mode 100644 implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/postgres/20240613100000_create_database.sql rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20231006100000_create_database.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20231230100000_add_rust_migrations.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20231231100000_node_name_identity_attributes.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240108100000_rename_confluent_config_table.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240111100001_add_authority_tables.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240111100002_delete_trust_context.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240111100003_add_credential.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240212100000_split_policies.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240212100001_outlet_remove_alias.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240213100000_add_controller_history.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240213100001_add_enrollment_token_reference.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240214100000_extend_project.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240307100000_credential_add_scope.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240314150000_tcp_portals.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240321100000_add_enrollment_email.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240507100000_node_add_http_column.sql (100%) rename implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/{ => sqlite}/20240527100000_add_sc_persistence.sql (100%) create mode 100644 implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240619100000_database_vault.sql create mode 100644 implementations/rust/ockam/ockam_node/src/storage/database/sqlx_from_row_types.rs delete mode 100644 implementations/rust/ockam/ockam_node/src/storage/database/sqlx_types.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index bd8f714b4e0..d424d0af51d 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -7,33 +7,33 @@ on: merge_group: pull_request: paths: - - ".github/workflows/rust.yml" - - ".github/actions/**" - - "**.rs" - - "**.toml" - - "**/Cargo.lock" - - "implementations/rust/ockam/ockam_command/tests/**" - - "**/Makefile" - - "tools/nix/**" + - ".github/workflows/rust.yml" + - ".github/actions/**" + - "**.rs" + - "**.toml" + - "**/Cargo.lock" + - "implementations/rust/ockam/ockam_command/tests/**" + - "**/Makefile" + - "tools/nix/**" push: paths: - - ".github/workflows/rust.yml" - - ".github/actions/**" - - "**.rs" - - "**.toml" - - "**/Cargo.lock" - - "implementations/rust/ockam/ockam_command/tests/**" - - "**/Makefile" - - "tools/nix/**" + - ".github/workflows/rust.yml" + - ".github/actions/**" + - "**.rs" + - "**.toml" + - "**/Cargo.lock" + - "implementations/rust/ockam/ockam_command/tests/**" + - "**/Makefile" + - "tools/nix/**" branches: - - develop + - develop schedule: - # We only save cache when a cron job is started, this is to ensure - # that we don't save cache on every push causing excessive caching - # and github deleting useful caches we use in our workflows, we now - # run a cron job every 2 hours so as to update the cache store with the - # latest data so that we don't have stale cache. - - cron: "0 */2 * * *" + # We only save cache when a cron job is started, this is to ensure + # that we don't save cache on every push causing excessive caching + # and github deleting useful caches we use in our workflows, we now + # run a cron job every 2 hours so as to update the cache store with the + # latest data so that we don't have stale cache. + - cron: "0 */2 * * *" workflow_dispatch: inputs: commit_sha: @@ -56,37 +56,37 @@ jobs: fail-fast: false matrix: lint_projects: - - cargo_readme - - cargo_fmt_check - - cargo_clippy - - cargo_deny - - cargo_toml_files - - cargo_machete + - cargo_readme + - cargo_fmt_check + - cargo_clippy + - cargo_deny + - cargo_toml_files + - cargo_machete defaults: run: shell: nix develop ./tools/nix#rust --keep CI --ignore-environment --command bash {0} steps: - - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 - with: - ref: ${{ github.event.inputs.commit_sha }} + - name: Checkout repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + with: + ref: ${{ github.event.inputs.commit_sha }} - - name: Install Nix - uses: ./.github/actions/cache_nix - with: - cache-unique-id: ${{ matrix.lint_projects }} - id: nix-installer + - name: Install Nix + uses: ./.github/actions/cache_nix + with: + cache-unique-id: ${{ matrix.lint_projects }} + id: nix-installer - - uses: ./.github/actions/cache_rust - with: - job_name: "${{ github.job }}-${{ matrix.lint_projects }}" + - uses: ./.github/actions/cache_rust + with: + job_name: "${{ github.job }}-${{ matrix.lint_projects }}" - - name: Run lint ${{ matrix.lint_projects }} - run: make -f implementations/rust/Makefile lint_${{ matrix.lint_projects }} + - name: Run lint ${{ matrix.lint_projects }} + run: make -f implementations/rust/Makefile lint_${{ matrix.lint_projects }} - - name: Nix Upload Store - uses: ./.github/actions/nix_upload_store - if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} + - name: Nix Upload Store + uses: ./.github/actions/nix_upload_store + if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} build: name: Rust - build${{ matrix.build_projects != 'packages' && format('_{0}', matrix.build_projects) || '' }} @@ -95,42 +95,42 @@ jobs: fail-fast: false matrix: include: - - build_projects: packages - make_name: 'build' - - build_projects: docs - make_name: 'build_docs' - - build_projects: examples - make_name: 'build_examples' - - build_projects: nightly - make_name: 'build' - - build_projects: release - make_name: 'build_release' + - build_projects: packages + make_name: 'build' + - build_projects: docs + make_name: 'build_docs' + - build_projects: examples + make_name: 'build_examples' + - build_projects: nightly + make_name: 'build' + - build_projects: release + make_name: 'build_release' defaults: run: shell: nix develop ./tools/nix#rust${{matrix.build_projects == 'nightly' && '_nightly' || '' }} --keep CI --ignore-environment --command bash {0} steps: - - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 - with: - ref: ${{ github.event.inputs.commit_sha }} + - name: Checkout repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + with: + ref: ${{ github.event.inputs.commit_sha }} - - name: Install Nix - uses: ./.github/actions/cache_nix - with: - cache-unique-id: ${{ matrix.build_projects }} - id: nix-installer + - name: Install Nix + uses: ./.github/actions/cache_nix + with: + cache-unique-id: ${{ matrix.build_projects }} + id: nix-installer - - uses: ./.github/actions/cache_rust - with: - job_name: "${{ github.job }}-${{ matrix.build_projects }}" + - uses: ./.github/actions/cache_rust + with: + job_name: "${{ github.job }}-${{ matrix.build_projects }}" - - name: Run build ${{ matrix.build_projects }} - run: make -f implementations/rust/Makefile ${{ matrix.make_name }} + - name: Run build ${{ matrix.build_projects }} + run: make -f implementations/rust/Makefile ${{ matrix.make_name }} - - name: Nix Upload Store - uses: ./.github/actions/nix_upload_store - if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} + - name: Nix Upload Store + uses: ./.github/actions/nix_upload_store + if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} @@ -149,31 +149,80 @@ jobs: fail-fast: false matrix: test_projects: - - stable - - nightly + - stable + - nightly steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 - with: - ref: ${{ github.event.inputs.commit_sha }} + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + with: + ref: ${{ github.event.inputs.commit_sha }} - - name: Install Nix - uses: ./.github/actions/cache_nix - with: - cache-unique-id: ${{ matrix.test_projects }} - id: nix-installer + - name: Install Nix + uses: ./.github/actions/cache_nix + with: + cache-unique-id: ${{ matrix.test_projects }} + id: nix-installer - - uses: ./.github/actions/cache_rust - with: - job_name: "${{ github.job }}-${{ matrix.test_projects }}" + - uses: ./.github/actions/cache_rust + with: + job_name: "${{ github.job }}-${{ matrix.test_projects }}" - - name: Run test on ${{ matrix.test_projects }} - run: make -f implementations/rust/Makefile test + - name: Run test on ${{ matrix.test_projects }} + run: make -f implementations/rust/Makefile test - - name: Nix Upload Store - uses: ./.github/actions/nix_upload_store - if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} + - name: Nix Upload Store + uses: ./.github/actions/nix_upload_store + if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} + test_postgres: + name: Rust - test_postgres${{ matrix.test_projects != 'stable' && format('_{0}', matrix.test_projects) || '' }} + runs-on: ubuntu-22.04 + services: + postgres: + image: postgres:latest + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_DB: test + ports: + - 5432:5432 + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + defaults: + run: + shell: nix develop ./tools/nix#rust${{ matrix.test_projects == 'nightly' && '_nightly' || '' }} --keep CI --ignore-environment --command bash {0} + strategy: + fail-fast: false + matrix: + test_projects: + - stable + steps: + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + with: + ref: ${{ github.event.inputs.commit_sha }} + + - name: Install Nix + uses: ./.github/actions/cache_nix + with: + cache-unique-id: ${{ matrix.test_projects }} + id: nix-installer + + - uses: ./.github/actions/cache_rust + with: + job_name: "${{ github.job }}-${{ matrix.test_projects }}" + + - name: Run postgres test on ${{ matrix.test_projects }} + run: | + pg_ctl -D /var/lib/postgresql/data -l logfile start + export OCKAM_POSTGRES_HOST=localhost + export OCKAM_POSTGRES_PORT=5432 + export OCKAM_POSTGRES_DATABASE_NAME=test + export OCKAM_POSTGRES_USER=postgres + export OCKAM_POSTGRES_PASSWORD=password + make -f implementations/rust/Makefile test_postgres + + - name: Nix Upload Store + uses: ./.github/actions/nix_upload_store + if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} check: name: Rust - check_${{ matrix.check_projects }} @@ -182,37 +231,37 @@ jobs: fail-fast: false matrix: include: - - check_projects: cargo_update - nix_toolchain: 'rust' - - check_projects: no_std - nix_toolchain: 'rust_nightly' - - check_projects: nightly - nix_toolchain: 'rust_nightly' + - check_projects: cargo_update + nix_toolchain: 'rust' + - check_projects: no_std + nix_toolchain: 'rust_nightly' + - check_projects: nightly + nix_toolchain: 'rust_nightly' defaults: run: shell: nix develop ./tools/nix#${{matrix.nix_toolchain }} --keep CI --ignore-environment --command bash {0} steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 - with: - ref: ${{ github.event.inputs.commit_sha }} + - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + with: + ref: ${{ github.event.inputs.commit_sha }} - - name: Install Nix - uses: ./.github/actions/cache_nix - with: - cache-unique-id: ${{ matrix.check_projects }} - id: nix-installer + - name: Install Nix + uses: ./.github/actions/cache_nix + with: + cache-unique-id: ${{ matrix.check_projects }} + id: nix-installer - - uses: ./.github/actions/cache_rust - with: - job_name: "${{ github.job }}-${{ matrix.check_projects }}" + - uses: ./.github/actions/cache_rust + with: + job_name: "${{ github.job }}-${{ matrix.check_projects }}" - - name: Run check on ${{ matrix.check_projects }} - run: make -f implementations/rust/Makefile check${{ matrix.check_projects != 'nightly' && format('_{0}', matrix.check_projects) || '' }} + - name: Run check on ${{ matrix.check_projects }} + run: make -f implementations/rust/Makefile check${{ matrix.check_projects != 'nightly' && format('_{0}', matrix.check_projects) || '' }} - - name: Nix Upload Store - uses: ./.github/actions/nix_upload_store - if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} + - name: Nix Upload Store + uses: ./.github/actions/nix_upload_store + if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} @@ -221,64 +270,64 @@ jobs: strategy: fail-fast: false matrix: - build: [ linux_86 ] + build: [linux_86] include: - - build: linux_86 - os: ubuntu-22.04 - rust: stable - target: x86_64-unknown-linux-gnu + - build: linux_86 + os: ubuntu-22.04 + rust: stable + target: x86_64-unknown-linux-gnu runs-on: ${{ matrix.os }} steps: - - name: Checkout ockam cli repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 - with: - ref: ${{ inputs.ockam_command_cli_version != '' && inputs.ockam_command_cli_version || inputs.commit_sha }} - path: ockam_cli - - - name: Checkout ockam bats repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 - with: - ref: ${{ inputs.commit_sha }} - path: ockam_bats - - - uses: ./ockam_bats/.github/actions/cache_rust - with: - directory_to_cache: "ockam_cli" - job_name: ${{ github.job }} - - - name: Install Nix - uses: ./ockam_bats/.github/actions/cache_nix - with: - cache-unique-id: test_ockam_command - id: nix-installer - - - name: Build Binary - working-directory: ockam_cli - shell: nix develop ./tools/nix#rust --keep CI --ignore-environment --command bash {0} - run: | - rustc --version - set -x - cargo build --bin ockam - - - name: Set Path - run: | - echo "PATH=$(pwd)/ockam_cli/target/debug:$PATH" >> $GITHUB_ENV; - - - name: Run Script On Ubuntu - working-directory: ockam_bats - shell: nix develop ./tools/nix#tooling --command bash {0} - run: | - ockam --version - echo $(which ockam) - echo $BATS_TEST_RETRIES - bash implementations/rust/ockam/ockam_command/tests/bats/run.sh - env: - OCKAM_DISABLE_UPGRADE_CHECK: 1 - BATS_TEST_RETRIES: 2 - - - name: Nix Upload Store - uses: ./ockam_bats/.github/actions/nix_upload_store - if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} + - name: Checkout ockam cli repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + with: + ref: ${{ inputs.ockam_command_cli_version != '' && inputs.ockam_command_cli_version || inputs.commit_sha }} + path: ockam_cli + + - name: Checkout ockam bats repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + with: + ref: ${{ inputs.commit_sha }} + path: ockam_bats + + - uses: ./ockam_bats/.github/actions/cache_rust + with: + directory_to_cache: "ockam_cli" + job_name: ${{ github.job }} + + - name: Install Nix + uses: ./ockam_bats/.github/actions/cache_nix + with: + cache-unique-id: test_ockam_command + id: nix-installer + + - name: Build Binary + working-directory: ockam_cli + shell: nix develop ./tools/nix#rust --keep CI --ignore-environment --command bash {0} + run: | + rustc --version + set -x + cargo build --bin ockam + + - name: Set Path + run: | + echo "PATH=$(pwd)/ockam_cli/target/debug:$PATH" >> $GITHUB_ENV; + + - name: Run Script On Ubuntu + working-directory: ockam_bats + shell: nix develop ./tools/nix#tooling --command bash {0} + run: | + ockam --version + echo $(which ockam) + echo $BATS_TEST_RETRIES + bash implementations/rust/ockam/ockam_command/tests/bats/run.sh + env: + OCKAM_DISABLE_UPGRADE_CHECK: 1 + BATS_TEST_RETRIES: 2 + + - name: Nix Upload Store + uses: ./ockam_bats/.github/actions/nix_upload_store + if: ${{ steps.nix-installer.outputs.cache-hit != 'true' }} ockam_command_cross_build: @@ -286,31 +335,31 @@ jobs: strategy: fail-fast: false matrix: - build: [ linux_armv7, macos_silicon ] + build: [linux_armv7, macos_silicon] include: - - build: linux_armv7 - os: ubuntu-22.04 - toolchain: stable - target: armv7-unknown-linux-musleabihf - use-cross-build: true - - build: macos_silicon - os: macos-14 - toolchain: stable - target: aarch64-apple-darwin - use-cross-build: false + - build: linux_armv7 + os: ubuntu-22.04 + toolchain: stable + target: armv7-unknown-linux-musleabihf + use-cross-build: true + - build: macos_silicon + os: macos-14 + toolchain: stable + target: aarch64-apple-darwin + use-cross-build: false runs-on: ${{ matrix.os }} steps: - - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 - with: - ref: ${{ inputs.commit_sha }} - - - uses: ./.github/actions/build_binaries - with: - use_cross_build: ${{ matrix.use-cross-build }} - toolchain: ${{ matrix.toolchain }} - target: ${{ matrix.target }} - platform_operating_system: ${{ matrix.os }} + - name: Checkout repository + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + with: + ref: ${{ inputs.commit_sha }} + + - uses: ./.github/actions/build_binaries + with: + use_cross_build: ${{ matrix.use-cross-build }} + toolchain: ${{ matrix.toolchain }} + target: ${{ matrix.target }} + platform_operating_system: ${{ matrix.os }} # test_orchestrator_ockam_command: # name: Rust - test_orchestrator_ockam_command diff --git a/Cargo.lock b/Cargo.lock index d7f58b9cb1b..f07e8641d27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4799,7 +4799,6 @@ name = "ockam_node" version = "0.119.0" dependencies = [ "cfg-if", - "chrono", "fs2", "futures 0.3.30", "heapless 0.8.0", @@ -6848,8 +6847,7 @@ dependencies = [ [[package]] name = "sqlx" version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa" +source = "git+https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b#5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" dependencies = [ "sqlx-core", "sqlx-macros", @@ -6861,8 +6859,7 @@ dependencies = [ [[package]] name = "sqlx-core" version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24ba59a9342a3d9bab6c56c118be528b27c9b60e490080e9711a04dccac83ef6" +source = "git+https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b#5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" dependencies = [ "ahash", "atoi", @@ -6900,8 +6897,7 @@ dependencies = [ [[package]] name = "sqlx-macros" version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127" +source = "git+https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b#5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" dependencies = [ "proc-macro2", "quote", @@ -6913,8 +6909,7 @@ dependencies = [ [[package]] name = "sqlx-macros-core" version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5833ef53aaa16d860e92123292f1f6a3d53c34ba8b1969f152ef1a7bb803f3c8" +source = "git+https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b#5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" dependencies = [ "dotenvy", "either", @@ -6928,6 +6923,7 @@ dependencies = [ "sha2", "sqlx-core", "sqlx-mysql", + "sqlx-postgres", "sqlx-sqlite", "syn 1.0.109", "tempfile", @@ -6938,8 +6934,7 @@ dependencies = [ [[package]] name = "sqlx-mysql" version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ed31390216d20e538e447a7a9b959e06ed9fc51c37b514b46eb758016ecd418" +source = "git+https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b#5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" dependencies = [ "atoi", "base64 0.21.7", @@ -6980,8 +6975,7 @@ dependencies = [ [[package]] name = "sqlx-postgres" version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c824eb80b894f926f89a0b9da0c7f435d27cdd35b8c655b114e58223918577e" +source = "git+https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b#5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" dependencies = [ "atoi", "base64 0.21.7", @@ -7018,8 +7012,7 @@ dependencies = [ [[package]] name = "sqlx-sqlite" version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" +source = "git+https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b#5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" dependencies = [ "atoi", "flume", diff --git a/NOTICE.md b/NOTICE.md index 352eed3c3f6..6ac3a394505 100644 --- a/NOTICE.md +++ b/NOTICE.md @@ -192,6 +192,7 @@ This file contains attributions for any 3rd-party open source code used in this | equivalent | Apache-2.0, MIT | https://crates.io/crates/equivalent | | errno | MIT, Apache-2.0 | https://crates.io/crates/errno | | error-code | BSL-1.0 | https://crates.io/crates/error-code | +| etcetera | MIT, Apache-2.0 | https://crates.io/crates/etcetera | | event-listener | Apache-2.0, MIT | https://crates.io/crates/event-listener | | event-listener-strategy | Apache-2.0, MIT | https://crates.io/crates/event-listener-strategy | | fastrand | Apache-2.0, MIT | https://crates.io/crates/fastrand | @@ -199,6 +200,7 @@ This file contains attributions for any 3rd-party open source code used in this | fdeflate | MIT, Apache-2.0 | https://crates.io/crates/fdeflate | | ff | MIT, Apache-2.0 | https://crates.io/crates/ff | | fiat-crypto | MIT, Apache-2.0, BSD-1-Clause | https://crates.io/crates/fiat-crypto | +| finl_unicode | MIT, Apache-2.0 | https://crates.io/crates/finl_unicode | | flate2 | MIT, Apache-2.0 | https://crates.io/crates/flate2 | | flexi_logger | MIT, Apache-2.0 | https://crates.io/crates/flexi_logger | | flume | Apache-2.0, MIT | https://crates.io/crates/flume | @@ -487,11 +489,12 @@ This file contains attributions for any 3rd-party open source code used in this | spin | MIT | https://crates.io/crates/spin | | spki | Apache-2.0, MIT | https://crates.io/crates/spki | | sqlformat | MIT, Apache-2.0 | https://crates.io/crates/sqlformat | -| sqlx | MIT, Apache-2.0 | https://crates.io/crates/sqlx | -| sqlx-core | MIT, Apache-2.0 | https://crates.io/crates/sqlx-core | -| sqlx-macros | MIT, Apache-2.0 | https://crates.io/crates/sqlx-macros | -| sqlx-macros-core | MIT, Apache-2.0 | https://crates.io/crates/sqlx-macros-core | -| sqlx-sqlite | MIT, Apache-2.0 | https://crates.io/crates/sqlx-sqlite | +| sqlx | MIT, Apache-2.0 | https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b | +| sqlx-core | MIT, Apache-2.0 | https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b | +| sqlx-macros | MIT, Apache-2.0 | https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b | +| sqlx-macros-core | MIT, Apache-2.0 | https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b | +| sqlx-postgres | MIT, Apache-2.0 | https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b | +| sqlx-sqlite | MIT, Apache-2.0 | https://github.com/etorreborre/sqlx?rev=5fec648d2de0cbeed738dcf1c6f5bc9194fc439b | | stable_deref_trait | MIT, Apache-2.0 | https://crates.io/crates/stable_deref_trait | | static_assertions | MIT, Apache-2.0 | https://crates.io/crates/static_assertions | | stm32-device-signature | MIT, Apache-2.0 | https://crates.io/crates/stm32-device-signature | @@ -500,6 +503,7 @@ This file contains attributions for any 3rd-party open source code used in this | stm32h7 | MIT, Apache-2.0 | https://crates.io/crates/stm32h7 | | stm32h7xx-hal | 0BSD | https://crates.io/crates/stm32h7xx-hal | | str-buf | BSL-1.0 | https://crates.io/crates/str-buf | +| stringprep | MIT, Apache-2.0 | https://crates.io/crates/stringprep | | strip-ansi-escapes | Apache-2.0, MIT | https://crates.io/crates/strip-ansi-escapes | | strsim | MIT | https://crates.io/crates/strsim | | strum | MIT | https://crates.io/crates/strum | @@ -586,6 +590,7 @@ This file contains attributions for any 3rd-party open source code used in this | waker-fn | Apache-2.0, MIT | https://crates.io/crates/waker-fn | | walkdir | Unlicense, MIT | https://crates.io/crates/walkdir | | want | MIT | https://crates.io/crates/want | +| wasite | Apache-2.0, BSL-1.0, MIT | https://crates.io/crates/wasite | | wasm-bindgen | MIT, Apache-2.0 | https://crates.io/crates/wasm-bindgen | | wasm-bindgen-backend | MIT, Apache-2.0 | https://crates.io/crates/wasm-bindgen-backend | | wasm-bindgen-futures | MIT, Apache-2.0 | https://crates.io/crates/wasm-bindgen-futures | @@ -604,6 +609,7 @@ This file contains attributions for any 3rd-party open source code used in this | web-time | MIT, Apache-2.0 | https://crates.io/crates/web-time | | weezl | MIT, Apache-2.0 | https://crates.io/crates/weezl | | which | MIT | https://crates.io/crates/which | +| whoami | Apache-2.0, BSL-1.0, MIT | https://crates.io/crates/whoami | | winapi | MIT, Apache-2.0 | https://crates.io/crates/winapi | | winapi-i686-pc-windows-gnu | MIT, Apache-2.0 | https://crates.io/crates/winapi-i686-pc-windows-gnu | | winapi-util | Unlicense, MIT | https://crates.io/crates/winapi-util | diff --git a/implementations/rust/Makefile b/implementations/rust/Makefile index d487fb52ba2..d7ebf49dbc3 100644 --- a/implementations/rust/Makefile +++ b/implementations/rust/Makefile @@ -48,6 +48,13 @@ nextest: nextest_%: cargo --locked nextest --config-file $(ROOT_DIR)/tools/nextest/.config/nextest.toml run -E 'package($*)' --no-fail-fast cargo --locked test --doc +test_postgres: + export OCKAM_POSTGRES_HOST=localhost + export OCKAM_POSTGRES_PORT=5433 + export OCKAM_POSTGRES_DATABASE_NAME=test + export OCKAM_POSTGRES_USER=postgres + export OCKAM_POSTGRES_PASSWORD=password + cargo --locked nextest --config-file $(ROOT_DIR)/tools/nextest/.config/nextest.toml run -E 'test(sql) or test(cli_state)' --no-fail-fast --test-threads 1 lint: lint_cargo_fmt_check lint_cargo_deny lint_cargo_clippy lint_cargo_fmt_check: diff --git a/implementations/rust/ockam/ockam_abac/Cargo.toml b/implementations/rust/ockam/ockam_abac/Cargo.toml index 2483f78284d..93a9172e3fd 100644 --- a/implementations/rust/ockam/ockam_abac/Cargo.toml +++ b/implementations/rust/ockam/ockam_abac/Cargo.toml @@ -50,7 +50,7 @@ ockam_executor = { version = "0.80.0", path = "../ockam_executor", default-featu regex = { version = "1.10.5", default-features = false, optional = true } rustyline = { version = "14.0.0", optional = true } rustyline-derive = { version = "0.10.0", optional = true } -sqlx = { version = "0.7.4", optional = true } +sqlx = { git = "https://github.com/etorreborre/sqlx", rev = "5fec648d2de0cbeed738dcf1c6f5bc9194fc439b", optional = true } str-buf = "3.0.3" tokio = { version = "1.38", default-features = false, optional = true, features = ["sync", "time", "rt", "rt-multi-thread", "macros"] } tracing = { version = "0.1", default-features = false, features = ["attributes"] } diff --git a/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_policy_repository_sql.rs b/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_policy_repository_sql.rs index d07b453c907..7d7c7ed62a5 100644 --- a/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_policy_repository_sql.rs +++ b/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_policy_repository_sql.rs @@ -5,7 +5,7 @@ use tracing::debug; use ockam_core::async_trait; use ockam_core::compat::vec::Vec; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, SqlxType, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::{Action, Expr, ResourceName, ResourcePoliciesRepository, ResourcePolicy}; @@ -43,13 +43,15 @@ impl ResourcePoliciesRepository for ResourcePolicySqlxDatabase { expression: &Expr, ) -> Result<()> { let query = query( - r#"INSERT OR REPLACE INTO resource_policy - VALUES (?, ?, ?, ?)"#, + r#"INSERT INTO resource_policy (resource_name, action, expression, node_name) + VALUES ($1, $2, $3, $4) + ON CONFLICT (resource_name, action, node_name) + DO UPDATE SET expression = $3"#, ) - .bind(resource_name.to_sql()) - .bind(action.to_sql()) - .bind(expression.to_string().to_sql()) - .bind(self.node_name.to_sql()); + .bind(resource_name) + .bind(action) + .bind(expression) + .bind(&self.node_name); query.execute(&*self.database.pool).await.void() } @@ -61,11 +63,11 @@ impl ResourcePoliciesRepository for ResourcePolicySqlxDatabase { let query = query_as( r#"SELECT resource_name, action, expression FROM resource_policy - WHERE node_name=$1 and resource_name=$2 and action=$3"#, + WHERE node_name = $1 and resource_name = $2 and action = $3"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_name.to_sql()) - .bind(action.to_sql()); + .bind(&self.node_name) + .bind(resource_name) + .bind(action); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -77,9 +79,9 @@ impl ResourcePoliciesRepository for ResourcePolicySqlxDatabase { let query = query_as( r#"SELECT resource_name, action, expression FROM resource_policy - WHERE node_name=$1"#, + WHERE node_name = $1"#, ) - .bind(self.node_name.to_sql()); + .bind(&self.node_name); let row: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; row.into_iter() .map(|r| r.try_into()) @@ -93,10 +95,10 @@ impl ResourcePoliciesRepository for ResourcePolicySqlxDatabase { let query = query_as( r#"SELECT resource_name, action, expression FROM resource_policy - WHERE node_name=$1 and resource_name=$2"#, + WHERE node_name = $1 and resource_name = $2"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_name.to_sql()); + .bind(&self.node_name) + .bind(resource_name); let row: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; row.into_iter() .map(|r| r.try_into()) @@ -106,29 +108,15 @@ impl ResourcePoliciesRepository for ResourcePolicySqlxDatabase { async fn delete_policy(&self, resource_name: &ResourceName, action: &Action) -> Result<()> { let query = query( r#"DELETE FROM resource_policy - WHERE node_name=? and resource_name=? and action=?"#, + WHERE node_name = $1 and resource_name = $2 and action = $3"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_name.to_sql()) - .bind(action.to_sql()); + .bind(&self.node_name) + .bind(resource_name) + .bind(action); query.execute(&*self.database.pool).await.void() } } -// Database serialization / deserialization - -impl ToSqlxType for ResourceName { - fn to_sql(&self) -> SqlxType { - SqlxType::Text(self.as_str().to_string()) - } -} - -impl ToSqlxType for Action { - fn to_sql(&self) -> SqlxType { - SqlxType::Text(self.to_string()) - } -} - /// Low-level representation of a row in the resource_policy table #[derive(FromRow)] struct PolicyRow { diff --git a/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_repository_sql.rs b/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_repository_sql.rs index 2046978cc7c..6f9a5cfcf50 100644 --- a/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_repository_sql.rs +++ b/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_repository_sql.rs @@ -1,10 +1,12 @@ use core::str::FromStr; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use tracing::debug; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::{Resource, ResourceName, ResourceType, ResourcesRepository}; @@ -37,12 +39,14 @@ impl ResourcesSqlxDatabase { impl ResourcesRepository for ResourcesSqlxDatabase { async fn store_resource(&self, resource: &Resource) -> Result<()> { let query = query( - r#"INSERT OR REPLACE INTO resource - VALUES (?, ?, ?)"#, + r#" + INSERT INTO resource (resource_name, resource_type, node_name) + VALUES ($1, $2, $3) + ON CONFLICT DO NOTHING"#, ) - .bind(resource.resource_name.to_sql()) - .bind(resource.resource_type.to_sql()) - .bind(self.node_name.to_sql()); + .bind(&resource.resource_name) + .bind(&resource.resource_type) + .bind(&self.node_name); query.execute(&*self.database.pool).await.void() } @@ -50,10 +54,10 @@ impl ResourcesRepository for ResourcesSqlxDatabase { let query = query_as( r#"SELECT resource_name, resource_type FROM resource - WHERE node_name=$1 and resource_name=$2"#, + WHERE node_name = $1 and resource_name = $2"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_name.to_sql()); + .bind(&self.node_name) + .bind(resource_name); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -66,24 +70,38 @@ impl ResourcesRepository for ResourcesSqlxDatabase { let query = query( r#"DELETE FROM resource - WHERE node_name=? and resource_name=?"#, + WHERE node_name = $1 and resource_name = $2"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_name.to_sql()); + .bind(&self.node_name) + .bind(resource_name); query.execute(&mut *transaction).await.void()?; let query = sqlx::query( r#"DELETE FROM resource_policy - WHERE node_name=? and resource_name=?"#, + WHERE node_name = $1 and resource_name = $2"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_name.to_sql()); + .bind(&self.node_name) + .bind(resource_name); query.execute(&mut *transaction).await.void()?; transaction.commit().await.void() } } +// Database serialization / deserialization + +impl Type for ResourceName { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl sqlx::Encode<'_, Any> for ResourceName { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + /// Low-level representation of a row in the resource_type_policy table #[derive(FromRow)] #[allow(dead_code)] diff --git a/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_type_policy_repository_sql.rs b/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_type_policy_repository_sql.rs index 94cdab3317f..4a87d0f5d78 100644 --- a/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_type_policy_repository_sql.rs +++ b/implementations/rust/ockam/ockam_abac/src/policy/storage/resource_type_policy_repository_sql.rs @@ -1,11 +1,13 @@ use core::str::FromStr; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use tracing::debug; use ockam_core::async_trait; use ockam_core::compat::vec::Vec; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, SqlxType, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::policy::ResourceTypePolicy; use crate::{Action, Expr, ResourceType, ResourceTypePoliciesRepository}; @@ -44,13 +46,16 @@ impl ResourceTypePoliciesRepository for ResourceTypePolicySqlxDatabase { expression: &Expr, ) -> Result<()> { let query = query( - r#"INSERT OR REPLACE INTO - resource_type_policy VALUES (?, ?, ?, ?)"#, + r#"INSERT INTO + resource_type_policy (resource_type, action, expression, node_name) + VALUES ($1, $2, $3, $4) + ON CONFLICT (node_name, resource_type, action) + DO UPDATE SET expression = $3"#, ) - .bind(resource_type.to_sql()) - .bind(action.to_sql()) - .bind(expression.to_string().to_sql()) - .bind(self.node_name.to_sql()); + .bind(resource_type) + .bind(action) + .bind(expression) + .bind(&self.node_name); query.execute(&*self.database.pool).await.void() } @@ -62,11 +67,11 @@ impl ResourceTypePoliciesRepository for ResourceTypePolicySqlxDatabase { let query = query_as( r#"SELECT resource_type, action, expression FROM resource_type_policy - WHERE node_name=$1 and resource_type=$2 and action=$3"#, + WHERE node_name = $1 and resource_type = $2 and action = $3"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_type.to_sql()) - .bind(action.to_sql()); + .bind(&self.node_name) + .bind(resource_type) + .bind(action); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -77,9 +82,9 @@ impl ResourceTypePoliciesRepository for ResourceTypePolicySqlxDatabase { async fn get_policies(&self) -> Result> { let query = query_as( r#"SELECT resource_type, action, expression - FROM resource_type_policy where node_name=$1"#, + FROM resource_type_policy where node_name = $1"#, ) - .bind(self.node_name.to_sql()); + .bind(&self.node_name); let row: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; row.into_iter() .map(|r| r.try_into()) @@ -92,10 +97,10 @@ impl ResourceTypePoliciesRepository for ResourceTypePolicySqlxDatabase { ) -> Result> { let query = query_as( r#"SELECT resource_type, action, expression - FROM resource_type_policy where node_name=$1 and resource_type=$2"#, + FROM resource_type_policy where node_name = $1 and resource_type = $2"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_type.to_sql()); + .bind(&self.node_name) + .bind(resource_type); let row: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; row.into_iter() .map(|r| r.try_into()) @@ -105,15 +110,53 @@ impl ResourceTypePoliciesRepository for ResourceTypePolicySqlxDatabase { async fn delete_policy(&self, resource_type: &ResourceType, action: &Action) -> Result<()> { let query = query( r#"DELETE FROM resource_type_policy - WHERE node_name=? and resource_type=? and action=?"#, + WHERE node_name = $1 and resource_type = $2 and action = $3"#, ) - .bind(self.node_name.to_sql()) - .bind(resource_type.to_sql()) - .bind(action.to_sql()); + .bind(&self.node_name) + .bind(resource_type) + .bind(action); query.execute(&*self.database.pool).await.void() } } +// Database serialization / deserialization + +impl Type for ResourceType { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for ResourceType { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + +impl Type for Action { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl sqlx::Encode<'_, Any> for Action { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + +impl Type for Expr { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for Expr { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + /// Low-level representation of a row in the resource_type_policy table #[derive(FromRow)] struct PolicyRow { @@ -148,14 +191,6 @@ impl TryFrom for ResourceTypePolicy { } } -// Database serialization / deserialization - -impl ToSqlxType for ResourceType { - fn to_sql(&self) -> SqlxType { - SqlxType::Text(self.to_string()) - } -} - #[cfg(test)] mod test { use super::*; diff --git a/implementations/rust/ockam/ockam_abac/src/types.rs b/implementations/rust/ockam/ockam_abac/src/types.rs index 61795beae8e..d7b8ab621e2 100644 --- a/implementations/rust/ockam/ockam_abac/src/types.rs +++ b/implementations/rust/ockam/ockam_abac/src/types.rs @@ -37,6 +37,10 @@ macro_rules! define { pub fn as_str(&self) -> &str { &self.0 } + + pub fn to_string(&self) -> String { + self.as_str().to_string() + } } impl From<&str> for $t { diff --git a/implementations/rust/ockam/ockam_api/Cargo.toml b/implementations/rust/ockam/ockam_api/Cargo.toml index 49edd8d23a1..145a27fe10b 100644 --- a/implementations/rust/ockam/ockam_api/Cargo.toml +++ b/implementations/rust/ockam/ockam_api/Cargo.toml @@ -75,7 +75,7 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "rus serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.118" sha2 = "0.10.8" -sqlx = { version = "0.7.4", features = ["runtime-tokio", "sqlite"] } +sqlx = { git = "https://github.com/etorreborre/sqlx", rev = "5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" } strip-ansi-escapes = "0.2" sysinfo = "0.30" thiserror = "1.0" diff --git a/implementations/rust/ockam/ockam_api/src/authenticator/direct/direct_authenticator.rs b/implementations/rust/ockam/ockam_api/src/authenticator/direct/direct_authenticator.rs index 9027c8e125f..fae1aa71ff2 100644 --- a/implementations/rust/ockam/ockam_api/src/authenticator/direct/direct_authenticator.rs +++ b/implementations/rust/ockam/ockam_api/src/authenticator/direct/direct_authenticator.rs @@ -90,9 +90,10 @@ impl DirectAuthenticator { "{} is trying to add member {}, but {} is not an enroller", enroller, identifier, enroller ); - return Ok(Either::Right(DirectAuthenticatorError( - "Non-enroller is trying to add a member".to_string(), - ))); + return Ok(Either::Right(DirectAuthenticatorError(format!( + "Non-enroller {} is trying to add a member {}", + enroller, identifier + )))); } let attrs = attributes diff --git a/implementations/rust/ockam/ockam_api/src/authenticator/one_time_code.rs b/implementations/rust/ockam/ockam_api/src/authenticator/one_time_code.rs index 09eebf16321..b6f5060a66d 100644 --- a/implementations/rust/ockam/ockam_api/src/authenticator/one_time_code.rs +++ b/implementations/rust/ockam/ockam_api/src/authenticator/one_time_code.rs @@ -8,7 +8,6 @@ use ockam_core::compat::string::{String, ToString}; use ockam_core::errcode::{Kind, Origin}; use ockam_core::Error; use ockam_core::Result; -use ockam_node::database::{SqlxType, ToSqlxType}; use serde::{Deserialize, Serialize}; use std::fmt::{Display, Formatter}; @@ -91,12 +90,6 @@ impl<'de> Deserialize<'de> for OneTimeCode { } } -impl ToSqlxType for OneTimeCode { - fn to_sql(&self) -> SqlxType { - self.to_string().to_sql() - } -} - /// Create an Identity Error fn error(message: String) -> Error { Error::new(Origin::Identity, Kind::Invalid, message.as_str()) diff --git a/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_enrollment_token_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_enrollment_token_repository_sql.rs index fca4de9e620..8e57534e120 100644 --- a/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_enrollment_token_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_enrollment_token_repository_sql.rs @@ -1,10 +1,12 @@ use ockam::identity::TimestampInSeconds; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use tracing::debug; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::authenticator::one_time_code::OneTimeCode; use crate::authenticator::{ @@ -12,7 +14,7 @@ use crate::authenticator::{ }; /// Implementation of [`AuthorityEnrollmentTokenRepository`] trait based on an underlying database -/// using sqlx as its API, and Sqlite as its driver +/// using sqlx as its API #[derive(Clone)] pub struct AuthorityEnrollmentTokenSqlxDatabase { database: SqlxDatabase, @@ -42,24 +44,25 @@ impl AuthorityEnrollmentTokenRepository for AuthorityEnrollmentTokenSqlxDatabase ) -> Result> { // We need to delete expired tokens regularly // Also makes sure we don't get expired tokens later inside this function - let query1 = - query("DELETE FROM authority_enrollment_token WHERE expires_at<=?").bind(now.to_sql()); + let query1 = query("DELETE FROM authority_enrollment_token WHERE expires_at <= $1") + .bind(now.0 as i64); let res = query1.execute(&*self.database.pool).await.into_core()?; debug!("Deleted {} expired enrollment tokens", res.rows_affected()); let mut transaction = self.database.pool.begin().await.into_core()?; - let query2 = query_as("SELECT one_time_code, reference, issued_by, created_at, expires_at, ttl_count, attributes FROM authority_enrollment_token WHERE one_time_code=?") - .bind(one_time_code.to_sql()); + let query2 = query_as("SELECT one_time_code, reference, issued_by, created_at, expires_at, ttl_count, attributes FROM authority_enrollment_token WHERE one_time_code = $1") + .bind(&one_time_code); let row: Option = query2.fetch_optional(&mut *transaction).await.into_core()?; let token: Option = row.map(|r| r.try_into()).transpose()?; if let Some(token) = &token { if token.ttl_count <= 1 { - let query3 = query("DElETE FROM authority_enrollment_token WHERE one_time_code=?") - .bind(one_time_code.to_sql()); + let query3 = + query("DElETE FROM authority_enrollment_token WHERE one_time_code = $1") + .bind(&one_time_code); query3.execute(&mut *transaction).await.void()?; debug!( "Deleted enrollment token because it has been used. Reference: {}", @@ -68,10 +71,10 @@ impl AuthorityEnrollmentTokenRepository for AuthorityEnrollmentTokenSqlxDatabase } else { let new_ttl_count = token.ttl_count - 1; let query3 = query( - "UPDATE authority_enrollment_token SET ttl_count=? WHERE one_time_code=?", + "UPDATE authority_enrollment_token SET ttl_count = $1 WHERE one_time_code = $2", ) .bind(new_ttl_count as i64) - .bind(one_time_code.to_sql()); + .bind(&one_time_code); query3.execute(&mut *transaction).await.void()?; debug!( "Decreasing enrollment token usage count to {}. Reference: {}", @@ -88,213 +91,242 @@ impl AuthorityEnrollmentTokenRepository for AuthorityEnrollmentTokenSqlxDatabase async fn store_new_token(&self, token: EnrollmentToken) -> Result<()> { let query = query( - "INSERT OR REPLACE INTO authority_enrollment_token (one_time_code, reference, issued_by, created_at, expires_at, ttl_count, attributes) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)", + r#" + INSERT INTO authority_enrollment_token (one_time_code, reference, issued_by, created_at, expires_at, ttl_count, attributes) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (one_time_code) + DO UPDATE SET reference = $2, issued_by = $3, created_at = $4, expires_at = $5, ttl_count = $6, attributes = $7"#, ) - .bind(token.one_time_code.to_sql()) - .bind(token.reference.map(|r| r.to_sql())) - .bind(token.issued_by.to_sql()) - .bind(token.created_at.to_sql()) - .bind(token.expires_at.to_sql()) - .bind(token.ttl_count.to_sql()) - .bind(minicbor::to_vec(token.attrs)?.to_sql()); + .bind(token.one_time_code) + .bind(token.reference) + .bind(token.issued_by) + .bind(token.created_at) + .bind(token.expires_at) + .bind(token.ttl_count as i64) + .bind(minicbor::to_vec(token.attrs)?); query.execute(&*self.database.pool).await.void() } } +// Database serialization / deserialization + +impl Type for OneTimeCode { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl sqlx::Encode<'_, Any> for OneTimeCode { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + #[cfg(test)] mod tests { use super::*; use ockam::identity::utils::now; use ockam::identity::Identifier; use ockam_core::compat::sync::Arc; + use ockam_node::database::with_dbs; use std::collections::BTreeMap; use std::str::FromStr; use std::time::Duration; #[tokio::test] async fn test_authority_enrollment_token_repository_one_time_token() -> Result<()> { - let repository = create_repository().await?; - - let one_time_code = OneTimeCode::new(); - - let issued_by = Identifier::from_str( - "I0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - ) - .unwrap(); - - let created_at = now()?; - let expires_at = created_at + 10; - - let mut attrs = BTreeMap::::default(); - attrs.insert("role".to_string(), "user".to_string()); - - let token = EnrollmentToken { - one_time_code: one_time_code.clone(), - reference: None, - issued_by: issued_by.clone(), - created_at, - expires_at, - ttl_count: 1, - attrs: attrs.clone(), - }; - - repository.store_new_token(token).await?; - - let token1 = repository.use_token(one_time_code.clone(), now()?).await?; - assert!(token1.is_some()); - let token1 = token1.unwrap(); - assert_eq!(token1.one_time_code, one_time_code); - assert_eq!(token1.reference, None); - assert_eq!(token1.issued_by, issued_by); - assert_eq!(token1.created_at, created_at); - assert_eq!(token1.expires_at, expires_at); - assert_eq!(token1.ttl_count, 1); - assert_eq!(token1.attrs, attrs); - - let token2 = repository.use_token(one_time_code, now()?).await?; - assert!(token2.is_none()); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(AuthorityEnrollmentTokenSqlxDatabase::new(db)); + + let one_time_code = OneTimeCode::new(); + + let issued_by = Identifier::from_str( + "I0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + ) + .unwrap(); + + let created_at = now()?; + let expires_at = created_at + 10; + + let mut attrs = BTreeMap::::default(); + attrs.insert("role".to_string(), "user".to_string()); + + let token = EnrollmentToken { + one_time_code: one_time_code.clone(), + reference: None, + issued_by: issued_by.clone(), + created_at, + expires_at, + ttl_count: 1, + attrs: attrs.clone(), + }; + + repository.store_new_token(token).await?; + + let token1 = repository.use_token(one_time_code.clone(), now()?).await?; + assert!(token1.is_some()); + let token1 = token1.unwrap(); + assert_eq!(token1.one_time_code, one_time_code); + assert_eq!(token1.reference, None); + assert_eq!(token1.issued_by, issued_by); + assert_eq!(token1.created_at, created_at); + assert_eq!(token1.expires_at, expires_at); + assert_eq!(token1.ttl_count, 1); + assert_eq!(token1.attrs, attrs); + + let token2 = repository.use_token(one_time_code, now()?).await?; + assert!(token2.is_none()); + + Ok(()) + }) + .await } #[tokio::test] async fn test_authority_enrollment_token_repository_with_reference() -> Result<()> { - let repository = create_repository().await?; - - let one_time_code = OneTimeCode::new(); - let reference = Some(OneTimeCode::new().to_string()); - - let issued_by = Identifier::from_str( - "I0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - ) - .unwrap(); - - let created_at = now()?; - let expires_at = created_at + 10; - - let mut attrs = BTreeMap::::default(); - attrs.insert("role".to_string(), "user".to_string()); - - let token = EnrollmentToken { - one_time_code: one_time_code.clone(), - reference: reference.clone(), - issued_by: issued_by.clone(), - created_at, - expires_at, - ttl_count: 1, - attrs: attrs.clone(), - }; - - repository.store_new_token(token).await?; - - let token1 = repository.use_token(one_time_code.clone(), now()?).await?; - assert!(token1.is_some()); - let token1 = token1.unwrap(); - assert_eq!(token1.one_time_code, one_time_code); - assert_eq!(token1.reference, reference); - assert_eq!(token1.issued_by, issued_by); - assert_eq!(token1.created_at, created_at); - assert_eq!(token1.expires_at, expires_at); - assert_eq!(token1.ttl_count, 1); - assert_eq!(token1.attrs, attrs); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(AuthorityEnrollmentTokenSqlxDatabase::new(db)); + + let one_time_code = OneTimeCode::new(); + let reference = Some(OneTimeCode::new().to_string()); + + let issued_by = Identifier::from_str( + "I0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + ) + .unwrap(); + + let created_at = now()?; + let expires_at = created_at + 10; + + let mut attrs = BTreeMap::::default(); + attrs.insert("role".to_string(), "user".to_string()); + + let token = EnrollmentToken { + one_time_code: one_time_code.clone(), + reference: reference.clone(), + issued_by: issued_by.clone(), + created_at, + expires_at, + ttl_count: 1, + attrs: attrs.clone(), + }; + + repository.store_new_token(token).await?; + + let token1 = repository.use_token(one_time_code.clone(), now()?).await?; + assert!(token1.is_some()); + let token1 = token1.unwrap(); + assert_eq!(token1.one_time_code, one_time_code); + assert_eq!(token1.reference, reference); + assert_eq!(token1.issued_by, issued_by); + assert_eq!(token1.created_at, created_at); + assert_eq!(token1.expires_at, expires_at); + assert_eq!(token1.ttl_count, 1); + assert_eq!(token1.attrs, attrs); + + Ok(()) + }) + .await } #[tokio::test] async fn test_authority_enrollment_token_repository_two_time_token() -> Result<()> { - let repository = create_repository().await?; - - let one_time_code = OneTimeCode::new(); - - let issued_by = Identifier::from_str( - "I0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - ) - .unwrap(); - - let created_at = now()?; - let expires_at = created_at + 10; - - let mut attrs = BTreeMap::::default(); - attrs.insert("role".to_string(), "user".to_string()); - - let token = EnrollmentToken { - one_time_code: one_time_code.clone(), - reference: None, - issued_by: issued_by.clone(), - created_at, - expires_at, - ttl_count: 2, - attrs: attrs.clone(), - }; - - repository.store_new_token(token).await?; - - let token1 = repository.use_token(one_time_code.clone(), now()?).await?; - let token2 = repository.use_token(one_time_code.clone(), now()?).await?; - let token3 = repository.use_token(one_time_code.clone(), now()?).await?; - assert!(token1.is_some()); - assert!(token2.is_some()); - assert!(token3.is_none()); - - let token1 = token1.unwrap(); - let token2 = token2.unwrap(); - - assert_eq!(token1.reference, token2.reference); - assert_eq!(token1.one_time_code, token2.one_time_code); - assert_eq!(token1.issued_by, token2.issued_by); - assert_eq!(token1.created_at, token2.created_at); - assert_eq!(token1.expires_at, token2.expires_at); - assert_eq!(token1.attrs, token2.attrs); - - assert_eq!(token1.ttl_count, 2); - assert_eq!(token2.ttl_count, 1); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(AuthorityEnrollmentTokenSqlxDatabase::new(db)); + + let one_time_code = OneTimeCode::new(); + + let issued_by = Identifier::from_str( + "I0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + ) + .unwrap(); + + let created_at = now()?; + let expires_at = created_at + 10; + + let mut attrs = BTreeMap::::default(); + attrs.insert("role".to_string(), "user".to_string()); + + let token = EnrollmentToken { + one_time_code: one_time_code.clone(), + reference: None, + issued_by: issued_by.clone(), + created_at, + expires_at, + ttl_count: 2, + attrs: attrs.clone(), + }; + + repository.store_new_token(token).await?; + + let token1 = repository.use_token(one_time_code.clone(), now()?).await?; + let token2 = repository.use_token(one_time_code.clone(), now()?).await?; + let token3 = repository.use_token(one_time_code.clone(), now()?).await?; + assert!(token1.is_some()); + assert!(token2.is_some()); + assert!(token3.is_none()); + + let token1 = token1.unwrap(); + let token2 = token2.unwrap(); + + assert_eq!(token1.reference, token2.reference); + assert_eq!(token1.one_time_code, token2.one_time_code); + assert_eq!(token1.issued_by, token2.issued_by); + assert_eq!(token1.created_at, token2.created_at); + assert_eq!(token1.expires_at, token2.expires_at); + assert_eq!(token1.attrs, token2.attrs); + + assert_eq!(token1.ttl_count, 2); + assert_eq!(token2.ttl_count, 1); + + Ok(()) + }) + .await } #[tokio::test] async fn test_authority_enrollment_token_repository_expired_token() -> Result<()> { - let repository = create_repository().await?; - - let one_time_code = OneTimeCode::new(); - - let issued_by = Identifier::from_str( - "I0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - ) - .unwrap(); - - let created_at = now()?; - let expires_at = created_at + 1; - - let mut attrs = BTreeMap::::default(); - attrs.insert("role".to_string(), "user".to_string()); - - let token = EnrollmentToken { - one_time_code: one_time_code.clone(), - reference: None, - issued_by: issued_by.clone(), - created_at, - expires_at, - ttl_count: 1, - attrs: attrs.clone(), - }; - - repository.store_new_token(token).await?; - - tokio::time::sleep(Duration::from_secs(2)).await; - - let token1 = repository.use_token(one_time_code.clone(), now()?).await?; - assert!(token1.is_none()); - - Ok(()) - } - - /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new( - AuthorityEnrollmentTokenSqlxDatabase::create().await?, - )) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(AuthorityEnrollmentTokenSqlxDatabase::new(db)); + + let one_time_code = OneTimeCode::new(); + + let issued_by = Identifier::from_str( + "I0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + ) + .unwrap(); + + let created_at = now()?; + let expires_at = created_at + 1; + + let mut attrs = BTreeMap::::default(); + attrs.insert("role".to_string(), "user".to_string()); + + let token = EnrollmentToken { + one_time_code: one_time_code.clone(), + reference: None, + issued_by: issued_by.clone(), + created_at, + expires_at, + ttl_count: 1, + attrs: attrs.clone(), + }; + + repository.store_new_token(token.clone()).await?; + // Try to store the same token twice + repository.store_new_token(token).await?; + + tokio::time::sleep(Duration::from_secs(2)).await; + + let token1 = repository.use_token(one_time_code.clone(), now()?).await?; + assert!(token1.is_none()); + Ok(()) + }) + .await } } diff --git a/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_member.rs b/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_member.rs index 06844a22b1c..47b7de3b5a2 100644 --- a/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_member.rs +++ b/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_member.rs @@ -2,6 +2,7 @@ use ockam::identity::{Identifier, TimestampInSeconds}; use ockam_core::compat::collections::BTreeMap; use ockam_core::compat::str::FromStr; use ockam_core::{Error, Result}; +use ockam_node::database::Boolean; /// Project member stored on the Authority node #[derive(Debug, Clone, PartialEq, Eq)] @@ -52,10 +53,10 @@ impl AuthorityMember { #[derive(sqlx::FromRow)] pub(crate) struct AuthorityMemberRow { identifier: String, - attributes: Vec, added_by: String, added_at: i64, - is_pre_trusted: bool, + is_pre_trusted: Boolean, + attributes: Vec, } impl TryFrom for AuthorityMember { @@ -67,7 +68,7 @@ impl TryFrom for AuthorityMember { minicbor::decode(&value.attributes)?, Identifier::from_str(&value.added_by)?, TimestampInSeconds(value.added_at as u64), - value.is_pre_trusted, + value.is_pre_trusted.to_bool(), ); Ok(member) diff --git a/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_members_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_members_repository_sql.rs index ae72be86df0..6eeda59b53b 100644 --- a/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_members_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/authenticator/storage/authority_members_repository_sql.rs @@ -5,7 +5,7 @@ use tracing::debug; use ockam::identity::Identifier; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::authenticator::{ AuthorityMember, AuthorityMemberRow, AuthorityMembersRepository, PreTrustedIdentities, @@ -34,8 +34,8 @@ impl AuthorityMembersSqlxDatabase { #[async_trait] impl AuthorityMembersRepository for AuthorityMembersSqlxDatabase { async fn get_member(&self, identifier: &Identifier) -> Result> { - let query = query_as("SELECT identifier, attributes, added_by, added_at, is_pre_trusted FROM authority_member WHERE identifier=?") - .bind(identifier.to_sql()); + let query = query_as("SELECT identifier, added_by, added_at, is_pre_trusted, attributes FROM authority_member WHERE identifier = $1") + .bind(identifier); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -44,26 +44,31 @@ impl AuthorityMembersRepository for AuthorityMembersSqlxDatabase { } async fn get_members(&self) -> Result> { - let query = query_as("SELECT identifier, attributes, added_by, added_at, is_pre_trusted FROM authority_member"); + let query = query_as("SELECT identifier, added_by, added_at, is_pre_trusted, attributes FROM authority_member"); let row: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; row.into_iter().map(|r| r.try_into()).collect() } async fn delete_member(&self, identifier: &Identifier) -> Result<()> { - let query = query("DELETE FROM authority_member WHERE identifier=? AND is_pre_trusted=?") - .bind(identifier.to_sql()) - .bind(false.to_sql()); + let query = + query("DELETE FROM authority_member WHERE identifier = $1 AND is_pre_trusted = $2") + .bind(identifier) + .bind(false); query.execute(&*self.database.pool).await.void() } async fn add_member(&self, member: AuthorityMember) -> Result<()> { - let query = query("INSERT OR REPLACE INTO authority_member VALUES (?1, ?2, ?3, ?4, ?5)") - .bind(member.identifier().to_sql()) - .bind(member.added_by().to_sql()) - .bind(member.added_at().to_sql()) - .bind(member.is_pre_trusted().to_sql()) - .bind(minicbor::to_vec(member.attributes())?.to_sql()); + let query = query(r#" + INSERT INTO authority_member (identifier, added_by, added_at, is_pre_trusted, attributes) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (identifier) + DO UPDATE SET added_by = $2, added_at = $3, is_pre_trusted = $4, attributes = $5"#) + .bind(member.identifier()) + .bind(member.added_by()) + .bind(member.added_at()) + .bind(member.is_pre_trusted()) + .bind(minicbor::to_vec(member.attributes())?); query.execute(&*self.database.pool).await.void() } @@ -73,18 +78,21 @@ impl AuthorityMembersRepository for AuthorityMembersSqlxDatabase { pre_trusted_identities: &PreTrustedIdentities, ) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; - let query1 = - query("DELETE FROM authority_member WHERE is_pre_trusted=?").bind(true.to_sql()); + let query1 = query("DELETE FROM authority_member WHERE is_pre_trusted = $1").bind(true); query1.execute(&mut *transaction).await.void()?; for (identifier, pre_trusted_identity) in pre_trusted_identities.deref() { let query2 = - query("INSERT OR REPLACE INTO authority_member VALUES (?1, ?2, ?3, ?4, ?5)") - .bind(identifier.to_sql()) - .bind(pre_trusted_identity.attested_by().to_sql()) - .bind(pre_trusted_identity.added_at().to_sql()) - .bind(true.to_sql()) - .bind(minicbor::to_vec(pre_trusted_identity.attrs())?.to_sql()); + query(r#" + INSERT INTO authority_member (identifier, added_by, added_at, is_pre_trusted, attributes) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (identifier) + DO UPDATE SET added_by = $2, added_at = $3, is_pre_trusted = $4, attributes = $5"#) + .bind(identifier) + .bind(pre_trusted_identity.attested_by()) + .bind(pre_trusted_identity.added_at()) + .bind(true) + .bind(minicbor::to_vec(pre_trusted_identity.attrs())?); query2.execute(&mut *transaction).await.void()?; } @@ -104,6 +112,7 @@ mod tests { use ockam_core::compat::collections::BTreeMap; use ockam_core::compat::rand::RngCore; use ockam_core::compat::sync::Arc; + use ockam_node::database::with_dbs; use rand::thread_rng; fn random_identifier() -> Identifier { @@ -117,130 +126,133 @@ mod tests { #[tokio::test] async fn test_authority_members_repository_crud() -> Result<()> { - let repository = create_repository().await?; - - let admin = random_identifier(); - let timestamp1 = now()?; - - let identifier1 = random_identifier(); - let mut attributes1 = BTreeMap::, Vec>::default(); - attributes1.insert( - "role".as_bytes().to_vec(), - OCKAM_ROLE_ATTRIBUTE_ENROLLER_VALUE.as_bytes().to_vec(), - ); - let member1 = AuthorityMember::new( - identifier1.clone(), - attributes1, - admin.clone(), - timestamp1, - false, - ); - repository.add_member(member1.clone()).await?; - - let members = repository.get_members().await?; - assert_eq!(members.len(), 1); - assert!(members.contains(&member1)); - - let identifier2 = random_identifier(); - let mut attributes2 = BTreeMap::, Vec>::default(); - attributes2.insert("role".as_bytes().to_vec(), "user".as_bytes().to_vec()); - let timestamp2 = timestamp1 + 10; - let member2 = AuthorityMember::new( - identifier2.clone(), - attributes2, - admin.clone(), - timestamp2, - false, - ); - repository.add_member(member2.clone()).await?; - - let members = repository.get_members().await?; - assert_eq!(members.len(), 2); - assert!(members.contains(&member1)); - assert!(members.contains(&member2)); - - repository.delete_member(&identifier1).await?; - - let members = repository.get_members().await?; - assert_eq!(members.len(), 1); - assert!(members.contains(&member2)); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(AuthorityMembersSqlxDatabase::new(db)); + + let admin = random_identifier(); + let timestamp1 = now()?; + + let identifier1 = random_identifier(); + let mut attributes1 = BTreeMap::, Vec>::default(); + attributes1.insert( + "role".as_bytes().to_vec(), + OCKAM_ROLE_ATTRIBUTE_ENROLLER_VALUE.as_bytes().to_vec(), + ); + let member1 = AuthorityMember::new( + identifier1.clone(), + attributes1, + admin.clone(), + timestamp1, + false, + ); + repository.add_member(member1.clone()).await?; + + let members = repository.get_members().await?; + assert_eq!(members.len(), 1); + assert!(members.contains(&member1)); + + let identifier2 = random_identifier(); + let mut attributes2 = BTreeMap::, Vec>::default(); + attributes2.insert("role".as_bytes().to_vec(), "user".as_bytes().to_vec()); + let timestamp2 = timestamp1 + 10; + let member2 = AuthorityMember::new( + identifier2.clone(), + attributes2, + admin.clone(), + timestamp2, + false, + ); + repository.add_member(member2.clone()).await?; + + let members = repository.get_members().await?; + assert_eq!(members.len(), 2); + assert!(members.contains(&member1)); + assert!(members.contains(&member2)); + + repository.delete_member(&identifier1).await?; + + let members = repository.get_members().await?; + assert_eq!(members.len(), 1); + assert!(members.contains(&member2)); + + Ok(()) + }) + .await } #[tokio::test] async fn test_authority_members_repository_bootstrap() -> Result<()> { - let repository = create_repository().await?; - - let mut pre_trusted_identities = BTreeMap::::default(); - - let timestamp1 = now()?; - - let authority = random_identifier(); - let identifier1 = random_identifier(); - let mut attributes1 = BTreeMap::, Vec>::default(); - attributes1.insert( - "role".as_bytes().to_vec(), - OCKAM_ROLE_ATTRIBUTE_ENROLLER_VALUE.as_bytes().to_vec(), - ); - - pre_trusted_identities.insert( - identifier1.clone(), - PreTrustedIdentity::new(attributes1.clone(), timestamp1, None, authority.clone()), - ); - - let identifier2 = random_identifier(); - let mut attributes2 = BTreeMap::, Vec>::default(); - attributes2.insert("role".as_bytes().to_vec(), "user".as_bytes().to_vec()); - let timestamp2 = timestamp1 + 10; - let timestamp3 = timestamp2 + 10; - - pre_trusted_identities.insert( - identifier2.clone(), - PreTrustedIdentity::new( - attributes2.clone(), - timestamp2, - Some(timestamp3), - identifier1.clone(), - ), - ); - - repository - .bootstrap_pre_trusted_members(&pre_trusted_identities.into()) - .await?; - - let members = repository.get_members().await?; - assert_eq!(members.len(), 2); - let member1 = members - .iter() - .find(|x| x.identifier() == &identifier1) - .unwrap(); - assert_eq!(member1.added_at(), timestamp1); - assert_eq!(member1.added_by(), &authority); - assert_eq!(member1.attributes(), &attributes1); - assert!(member1.is_pre_trusted()); - - let member2 = members - .iter() - .find(|x| x.identifier() == &identifier2) - .unwrap(); - assert_eq!(member2.added_at(), timestamp2); - assert_eq!(member2.added_by(), &identifier1); - assert_eq!(member2.attributes(), &attributes2); - assert!(member2.is_pre_trusted()); - - repository.delete_member(&identifier1).await?; - - let members = repository.get_members().await?; - assert_eq!(members.len(), 2); - assert!(members.contains(member2)); - assert!(members.contains(member1)); - - Ok(()) - } + with_dbs(|db| async move { + let repository: Arc = + Arc::new(AuthorityMembersSqlxDatabase::new(db)); + + let mut pre_trusted_identities = BTreeMap::::default(); - /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(AuthorityMembersSqlxDatabase::create().await?)) + let timestamp1 = now()?; + + let authority = random_identifier(); + let identifier1 = random_identifier(); + let mut attributes1 = BTreeMap::, Vec>::default(); + attributes1.insert( + "role".as_bytes().to_vec(), + OCKAM_ROLE_ATTRIBUTE_ENROLLER_VALUE.as_bytes().to_vec(), + ); + + pre_trusted_identities.insert( + identifier1.clone(), + PreTrustedIdentity::new(attributes1.clone(), timestamp1, None, authority.clone()), + ); + + let identifier2 = random_identifier(); + let mut attributes2 = BTreeMap::, Vec>::default(); + attributes2.insert("role".as_bytes().to_vec(), "user".as_bytes().to_vec()); + let timestamp2 = timestamp1 + 10; + let timestamp3 = timestamp2 + 10; + + pre_trusted_identities.insert( + identifier2.clone(), + PreTrustedIdentity::new( + attributes2.clone(), + timestamp2, + Some(timestamp3), + identifier1.clone(), + ), + ); + + repository + .bootstrap_pre_trusted_members(&pre_trusted_identities.into()) + .await?; + + let members = repository.get_members().await?; + assert_eq!(members.len(), 2); + let member1 = members + .iter() + .find(|x| x.identifier() == &identifier1) + .unwrap(); + assert_eq!(member1.added_at(), timestamp1); + assert_eq!(member1.added_by(), &authority); + assert_eq!(member1.attributes(), &attributes1); + assert!(member1.is_pre_trusted()); + + let member2 = members + .iter() + .find(|x| x.identifier() == &identifier2) + .unwrap(); + assert_eq!(member2.added_at(), timestamp2); + assert_eq!(member2.added_by(), &identifier1); + assert_eq!(member2.attributes(), &attributes2); + assert!(member2.is_pre_trusted()); + + repository.delete_member(&identifier1).await?; + + let members = repository.get_members().await?; + assert_eq!(members.len(), 2); + assert!(members.contains(member2)); + assert!(members.contains(member1)); + + Ok(()) + }) + .await } } diff --git a/implementations/rust/ockam/ockam_api/src/authenticator/storage/enrollment_token.rs b/implementations/rust/ockam/ockam_api/src/authenticator/storage/enrollment_token.rs index e7a677f1a10..2e1aee569ed 100644 --- a/implementations/rust/ockam/ockam_api/src/authenticator/storage/enrollment_token.rs +++ b/implementations/rust/ockam/ockam_api/src/authenticator/storage/enrollment_token.rs @@ -2,6 +2,7 @@ use crate::authenticator::one_time_code::OneTimeCode; use ockam::identity::{Identifier, TimestampInSeconds}; use ockam_core::compat::str::FromStr; use ockam_core::{Error, Result}; +use ockam_node::database::Nullable; use std::collections::BTreeMap; #[derive(Clone, Eq, PartialEq)] @@ -35,7 +36,7 @@ impl EnrollmentToken { #[derive(sqlx::FromRow)] pub(crate) struct EnrollmentTokenRow { one_time_code: String, - reference: Option, + reference: Nullable, issued_by: String, created_at: i64, expires_at: i64, @@ -49,7 +50,7 @@ impl TryFrom for EnrollmentToken { fn try_from(value: EnrollmentTokenRow) -> Result { let member = EnrollmentToken { one_time_code: OneTimeCode::from_str(&value.one_time_code)?, - reference: value.reference, + reference: value.reference.to_option(), issued_by: Identifier::from_str(&value.issued_by)?, created_at: TimestampInSeconds(value.created_at as u64), expires_at: TimestampInSeconds(value.expires_at as u64), diff --git a/implementations/rust/ockam/ockam_api/src/authority_node/authority.rs b/implementations/rust/ockam/ockam_api/src/authority_node/authority.rs index 9b4999a1a4a..e8be31a00cf 100644 --- a/implementations/rust/ockam/ockam_api/src/authority_node/authority.rs +++ b/implementations/rust/ockam/ockam_api/src/authority_node/authority.rs @@ -1,6 +1,4 @@ use std::collections::BTreeMap; -use std::path::Path; - use tracing::info; use crate::authenticator::credential_issuer::CredentialIssuerWorker; @@ -20,9 +18,8 @@ use ockam::identity::{ use ockam::tcp::{TcpListenerOptions, TcpTransport}; use ockam_core::compat::sync::Arc; use ockam_core::env::get_env; -use ockam_core::errcode::{Kind, Origin}; use ockam_core::flow_control::FlowControlId; -use ockam_core::{Error, Result}; +use ockam_core::Result; use ockam_node::database::SqlxDatabase; use ockam_node::Context; @@ -67,9 +64,7 @@ impl Authority { // create the database let node_name = "authority"; - let database_path = &configuration.database_path; - Self::create_ockam_directory_if_necessary(database_path)?; - let database = SqlxDatabase::create(database_path).await?; + let database = SqlxDatabase::create(&configuration.database_configuration).await?; let members = Arc::new(AuthorityMembersSqlxDatabase::new(database.clone())); let tokens = Arc::new(AuthorityEnrollmentTokenSqlxDatabase::new(database.clone())); let secure_channel_repository = Arc::new(SecureChannelSqlxDatabase::new(database.clone())); @@ -315,15 +310,6 @@ impl Authority { /// Private Authority functions impl Authority { - /// Create a directory to save storage files if they haven't been created before - fn create_ockam_directory_if_necessary(path: &Path) -> Result<()> { - let parent = path.parent().unwrap(); - if !parent.exists() { - std::fs::create_dir_all(parent).map_err(|e| Error::new(Origin::Node, Kind::Io, e))?; - } - Ok(()) - } - /// Make an identities repository pre-populated with the attributes of some trusted /// identities. The values either come from the command line or are read directly from a file /// every time we try to retrieve some attributes diff --git a/implementations/rust/ockam/ockam_api/src/authority_node/configuration.rs b/implementations/rust/ockam/ockam_api/src/authority_node/configuration.rs index bbb7eaadd91..a411a0231bf 100644 --- a/implementations/rust/ockam/ockam_api/src/authority_node/configuration.rs +++ b/implementations/rust/ockam/ockam_api/src/authority_node/configuration.rs @@ -1,5 +1,3 @@ -use std::path::PathBuf; - use ockam::identity::models::ChangeHistory; use serde::{Deserialize, Serialize}; @@ -7,6 +5,7 @@ use ockam::identity::Identifier; use ockam_core::compat::collections::HashMap; use ockam_core::compat::fmt; use ockam_core::compat::fmt::{Display, Formatter}; +use ockam_node::database::DatabaseConfiguration; use crate::authenticator::PreTrustedIdentities; use crate::config::lookup::InternetAddress; @@ -19,7 +18,7 @@ pub struct Configuration { pub identifier: Identifier, /// path where the database should be stored - pub database_path: PathBuf, + pub database_configuration: DatabaseConfiguration, /// Project id on the Orchestrator node pub project_identifier: String, diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/cli_state.rs b/implementations/rust/ockam/ockam_api/src/cli_state/cli_state.rs index 163d2a36825..d9e792e0796 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/cli_state.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/cli_state.rs @@ -1,15 +1,14 @@ -use std::path::{Path, PathBuf}; - use rand::random; +use std::path::{Path, PathBuf}; +use tokio::sync::broadcast::{channel, Receiver, Sender}; -use cli_state::error::Result; use ockam::SqlxDatabase; use ockam_core::env::get_env_with_default; -use ockam_node::database::application_migration_set::ApplicationMigrationSet; +use ockam_node::database::DatabaseConfiguration; use ockam_node::Executor; -use tokio::sync::broadcast::{channel, Receiver, Sender}; -use crate::cli_state::{self, CliStateError}; +use crate::cli_state::error::Result; +use crate::cli_state::CliStateError; use crate::logs::ExportingEnabled; use crate::terminal::notification::Notification; @@ -68,16 +67,24 @@ impl CliState { &self.database } - pub fn database_path(&self) -> PathBuf { - Self::make_database_path(&self.dir) + pub fn database_configuration(&self) -> Result { + Self::make_database_configuration(&self.dir) + } + + pub fn is_database_path(&self, path: &Path) -> bool { + let database_configuration = self.database_configuration().ok(); + match database_configuration { + Some(c) => c.path() == Some(path.to_path_buf()), + None => false, + } } pub fn application_database(&self) -> SqlxDatabase { self.application_database.clone() } - pub fn application_database_path(&self) -> PathBuf { - Self::make_application_database_path(&self.dir) + pub fn application_database_configuration(&self) -> Result { + Self::make_application_database_configuration(&self.dir) } pub fn subscribe_to_notifications(&self) -> Receiver { @@ -118,17 +125,24 @@ impl CliState { self.delete_all_named_identities().await?; self.delete_all_nodes(true).await?; self.delete_all_named_vaults().await?; - self.delete() + self.delete().await } /// Removes all the directories storing state without loading the current state + /// The database data is only removed if the database is a SQLite one pub fn hard_reset() -> Result<()> { let dir = Self::default_dir()?; Self::delete_at(&dir) } /// Delete the local database and log files - pub fn delete(&self) -> Result<()> { + pub async fn delete(&self) -> Result<()> { + self.database.drop_postgres_node_tables().await?; + self.delete_local_data() + } + + /// Delete the local data on disk: sqlite database file and log files + pub fn delete_local_data(&self) -> Result<()> { Self::delete_at(&self.dir) } @@ -139,7 +153,8 @@ impl CliState { } /// Backup and reset is used to save aside - /// some corrupted local state for later inspection and then reset the state + /// some corrupted local state for later inspection and then reset the state. + /// The database is backed-up only if it is a SQLite database. pub fn backup_and_reset() -> Result<()> { let dir = Self::default_dir()?; @@ -189,12 +204,10 @@ impl CliState { /// Create a new CliState where the data is stored at a given path pub async fn create(dir: PathBuf) -> Result { std::fs::create_dir_all(&dir)?; - let database = SqlxDatabase::create(Self::make_database_path(&dir)).await?; - let application_database = SqlxDatabase::create_with_migration( - Self::make_application_database_path(&dir), - ApplicationMigrationSet, - ) - .await?; + let database = SqlxDatabase::create(&Self::make_database_configuration(&dir)?).await?; + let configuration = Self::make_application_database_configuration(&dir)?; + let application_database = + SqlxDatabase::create_application_database(&configuration).await?; debug!("Opened the main database with options {:?}", database); debug!( "Opened the application database with options {:?}", @@ -230,12 +243,26 @@ impl CliState { } } - pub(super) fn make_database_path(root_path: &Path) -> PathBuf { - root_path.join("database.sqlite3") + /// If the postgres database is configured, return the postgres configuration + pub(super) fn make_database_configuration(root_path: &Path) -> Result { + match DatabaseConfiguration::postgres()? { + Some(configuration) => Ok(configuration), + None => Ok(DatabaseConfiguration::sqlite( + root_path.join("database.sqlite3").as_path(), + )), + } } - pub(super) fn make_application_database_path(root_path: &Path) -> PathBuf { - root_path.join("application_database.sqlite3") + /// If the postgres database is configured, return the postgres configuration + pub(super) fn make_application_database_configuration( + root_path: &Path, + ) -> Result { + match DatabaseConfiguration::postgres()? { + Some(configuration) => Ok(configuration), + None => Ok(DatabaseConfiguration::sqlite( + root_path.join("application_database.sqlite3").as_path(), + )), + } } pub(super) fn make_node_dir_path(root_path: &Path, node_name: &str) -> PathBuf { @@ -251,7 +278,9 @@ impl CliState { // Delete nodes logs let _ = std::fs::remove_dir_all(Self::make_nodes_dir_path(root_path)); // Delete the nodes database, keep the application database - let _ = std::fs::remove_file(Self::make_database_path(root_path)); + if let Some(path) = Self::make_database_configuration(root_path)?.path() { + std::fs::remove_file(path)? + }; Ok(()) } @@ -279,6 +308,9 @@ pub fn random_name() -> String { mod tests { use super::*; use itertools::Itertools; + use ockam_node::database::DatabaseType; + use sqlx::any::AnyRow; + use sqlx::Row; use std::fs; use tempfile::NamedTempFile; @@ -286,6 +318,11 @@ mod tests { async fn test_reset() -> Result<()> { let db_file = NamedTempFile::new().unwrap(); let cli_state_directory = db_file.path().parent().unwrap().join(random_name()); + let db = SqlxDatabase::create(&CliState::make_database_configuration( + &cli_state_directory, + )?) + .await?; + db.drop_all_postgres_tables().await?; let cli = CliState::create(cli_state_directory.clone()).await?; // create 2 vaults @@ -310,16 +347,18 @@ mod tests { .await?; let file_names = list_file_names(&cli_state_directory); - assert_eq!( - file_names.iter().sorted().as_slice(), - [ + let expected = match cli.database_configuration()?.database_type() { + DatabaseType::Sqlite => vec![ "vault-vault2".to_string(), "application_database.sqlite3".to_string(), - "database.sqlite3".to_string() - ] - .iter() - .sorted() - .as_slice() + "database.sqlite3".to_string(), + ], + DatabaseType::Postgres => vec!["vault-vault2".to_string()], + }; + + assert_eq!( + file_names.iter().sorted().as_slice(), + expected.iter().sorted().as_slice() ); // reset the local state @@ -327,10 +366,25 @@ mod tests { let result = fs::read_dir(&cli_state_directory); assert!(result.is_ok(), "the cli state directory is not deleted"); - // only the application database must remain - let file_names = list_file_names(&cli_state_directory); - assert_eq!(file_names, vec!["application_database.sqlite3".to_string()]); - + match cli.database_configuration()?.database_type() { + DatabaseType::Sqlite => { + // When the database is SQLite, only the application database must remain + let file_names = list_file_names(&cli_state_directory); + let expected = vec!["application_database.sqlite3".to_string()]; + assert_eq!(file_names, expected); + } + DatabaseType::Postgres => { + // When the database is Postgres, only the journey tables must remain + let tables: Vec = sqlx::query( + "SELECT tablename::text FROM pg_tables WHERE schemaname = 'public'", + ) + .fetch_all(&*db.pool) + .await + .unwrap(); + let actual: Vec = tables.iter().map(|r| r.get(0)).sorted().collect(); + assert_eq!(actual, vec!["host_journey", "project_journey"]); + } + }; Ok(()) } diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/identities.rs b/implementations/rust/ockam/ockam_api/src/cli_state/identities.rs index 77597bd1ef9..cb3f9817ca3 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/identities.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/identities.rs @@ -66,8 +66,8 @@ impl CliState { ) -> Result { let vault = self.get_named_vault(vault_name).await?; - // Check that the vault is an KMS vault - if !vault.is_kms() { + // Check that the vault is an AWS KMS vault + if !vault.use_aws_kms() { return Err(Error::new( Origin::Api, Kind::Misuse, diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/nodes.rs b/implementations/rust/ockam/ockam_api/src/cli_state/nodes.rs index d6d01288192..72619525ab0 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/nodes.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/nodes.rs @@ -155,18 +155,13 @@ impl CliState { /// - remove the node log files #[instrument(skip_all, fields(node_name = node_name))] pub async fn remove_node(&self, node_name: &str) -> Result<()> { - // don't try to remove a node on a non-existent database - if !self.database_path().exists() { - return Ok(()); - }; - // remove the node from the database let repository = self.nodes_repository(); let node_exists = repository.get_node(node_name).await.is_ok(); - repository.delete_node(node_name).await?; // set another node as the default node if node_exists { + repository.delete_node(node_name).await?; let other_nodes = repository.get_nodes().await?; if let Some(other_node) = other_nodes.first() { repository.set_default_node(&other_node.name()).await?; diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/enrollments_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/enrollments_repository_sql.rs index 48303edd1fa..c5042e68644 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/enrollments_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/enrollments_repository_sql.rs @@ -1,14 +1,15 @@ use std::str::FromStr; -use sqlx::sqlite::SqliteRow; +use sqlx::any::AnyRow; use sqlx::FromRow; use sqlx::*; use time::OffsetDateTime; use ockam::identity::Identifier; -use ockam::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam::{FromSqlxError, SqlxDatabase, ToVoid}; use ockam_core::async_trait; use ockam_core::Result; +use ockam_node::database::{Boolean, Nullable}; use crate::cli_state::enrollments::IdentityEnrollment; use crate::cli_state::EnrollmentsRepository; @@ -35,10 +36,16 @@ impl EnrollmentsSqlxDatabase { #[async_trait] impl EnrollmentsRepository for EnrollmentsSqlxDatabase { async fn set_as_enrolled(&self, identifier: &Identifier, email: &EmailAddress) -> Result<()> { - let query = query("INSERT OR REPLACE INTO identity_enrollment(identifier, enrolled_at, email) VALUES (?, ?, ?)") - .bind(identifier.to_sql()) - .bind(OffsetDateTime::now_utc().to_sql()) - .bind(email.to_sql()); + let query = query( + r#" + INSERT INTO identity_enrollment (identifier, enrolled_at, email) + VALUES ($1, $2, $3) + ON CONFLICT (identifier) + DO UPDATE SET enrolled_at = $2, email = $3"#, + ) + .bind(identifier) + .bind(OffsetDateTime::now_utc().unix_timestamp()) + .bind(email); Ok(query.execute(&*self.database.pool).await.void()?) } @@ -94,11 +101,11 @@ impl EnrollmentsRepository for EnrollmentsSqlxDatabase { INNER JOIN named_identity ON identity.identifier = named_identity.identifier WHERE - named_identity.is_default = ? + named_identity.is_default = $1 "#, ) - .bind(true.to_sql()); - let result: Option = query + .bind(true); + let result: Option = query .fetch_optional(&*self.database.pool) .await .into_core()?; @@ -116,11 +123,11 @@ impl EnrollmentsRepository for EnrollmentsSqlxDatabase { INNER JOIN named_identity ON identity.identifier = named_identity.identifier WHERE - named_identity.name = ? + named_identity.name = $1 "#, ) - .bind(name.to_sql()); - let result: Option = query + .bind(name); + let result: Option = query .fetch_optional(&*self.database.pool) .await .into_core()?; @@ -132,9 +139,9 @@ impl EnrollmentsRepository for EnrollmentsSqlxDatabase { pub struct EnrollmentRow { identifier: String, name: String, - email: Option, - is_default: bool, - enrolled_at: Option, + email: Nullable, + is_default: Boolean, + enrolled_at: Nullable, } impl EnrollmentRow { @@ -142,14 +149,14 @@ impl EnrollmentRow { let identifier = Identifier::from_str(self.identifier.as_str())?; let email = self .email - .as_ref() + .to_option() .map(|e| EmailAddress::parse(e.as_str())) .transpose()?; Ok(IdentityEnrollment::new( identifier, self.name.clone(), - self.is_default, + self.is_default.to_bool(), self.enrolled_at(), email, )) @@ -157,6 +164,7 @@ impl EnrollmentRow { fn enrolled_at(&self) -> Option { self.enrolled_at + .to_option() .map(|at| OffsetDateTime::from_unix_timestamp(at).unwrap_or(OffsetDateTime::now_utc())) } } diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/identities_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/identities_repository_sql.rs index 41e57055471..d35ac1c2ac6 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/identities_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/identities_repository_sql.rs @@ -5,7 +5,7 @@ use sqlx::*; use ockam::identity::Identifier; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{Boolean, FromSqlxError, SqlxDatabase, ToVoid}; use crate::cli_state::{IdentitiesRepository, NamedIdentity}; @@ -40,17 +40,24 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { let mut transaction = self.database.begin().await.into_core()?; let query1 = query_scalar( - "SELECT EXISTS(SELECT 1 FROM named_identity WHERE is_default=$1 AND name=$2)", + "SELECT EXISTS(SELECT 1 FROM named_identity WHERE is_default = $1 AND name = $2)", ) - .bind(true.to_sql()) - .bind(name.to_sql()); - let is_already_default: bool = query1.fetch_one(&mut *transaction).await.into_core()?; - - let query2 = query("INSERT OR REPLACE INTO named_identity VALUES (?, ?, ?, ?)") - .bind(identifier.to_sql()) - .bind(name.to_sql()) - .bind(vault_name.to_sql()) - .bind(is_already_default.to_sql()); + .bind(true) + .bind(name); + let is_already_default: Boolean = query1.fetch_one(&mut *transaction).await.into_core()?; + let is_already_default = is_already_default.to_bool(); + + let query2 = query( + r#" + INSERT INTO named_identity (identifier, name, vault_name, is_default) + VALUES ($1, $2, $3, $4) + ON CONFLICT (identifier) + DO UPDATE SET name = $2, vault_name = $3, is_default = $4"#, + ) + .bind(identifier) + .bind(name) + .bind(vault_name) + .bind(is_already_default); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void()?; @@ -68,9 +75,9 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { // get the named identity let query1 = query_as( - "SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name=$1", + "SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name = $1", ) - .bind(name.to_sql()); + .bind(name); let row: Option = query1.fetch_optional(&mut *transaction).await.into_core()?; let named_identity = row.map(|r| r.named_identity()).transpose()?; @@ -81,7 +88,7 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { // otherwise delete it and set another identity as the default Some(named_identity) => { - let query2 = query("DELETE FROM named_identity WHERE name=?").bind(name.to_sql()); + let query2 = query("DELETE FROM named_identity WHERE name = $1").bind(name); query2.execute(&mut *transaction).await.void()?; // if the deleted identity was the default one, select another identity to be the default one @@ -93,9 +100,9 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { .into_core()? { let query3 = - query("UPDATE named_identity SET is_default = ? WHERE name = ?") - .bind(true.to_sql()) - .bind(other_name.to_sql()); + query("UPDATE named_identity SET is_default = $1 WHERE name = $2") + .bind(true) + .bind(other_name); query3.execute(&mut *transaction).await.void()? } } @@ -120,9 +127,9 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { async fn get_identifier(&self, name: &str) -> Result> { let query = query_as( - "SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name=$1", + "SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name = $1", ) - .bind(name.to_sql()); + .bind(name); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -135,7 +142,7 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { identifier: &Identifier, ) -> Result> { let query = - query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE identifier=$1").bind(identifier.to_sql()); + query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE identifier = $1").bind(identifier); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -145,9 +152,9 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { async fn get_named_identity(&self, name: &str) -> Result> { let query = query_as( - "SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name=$1", + "SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name = $1", ) - .bind(name.to_sql()); + .bind(name); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -160,7 +167,7 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { identifier: &Identifier, ) -> Result> { let query = - query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE identifier=$1").bind(identifier.to_sql()); + query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE identifier = $1").bind(identifier); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -178,7 +185,7 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { &self, vault_name: &str, ) -> Result> { - let query = query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE vault_name=?").bind(vault_name.to_sql()); + let query = query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE vault_name = $1").bind(vault_name); let row: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; row.iter().map(|r| r.named_identity()).collect() } @@ -186,15 +193,15 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { async fn set_as_default(&self, name: &str) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; // set the identifier as the default one - let query1 = query("UPDATE named_identity SET is_default = ? WHERE name = ?") - .bind(true.to_sql()) - .bind(name.to_sql()); + let query1 = query("UPDATE named_identity SET is_default = $1 WHERE name = $2") + .bind(true) + .bind(name); query1.execute(&mut *transaction).await.void()?; // set all the others as non-default - let query2 = query("UPDATE named_identity SET is_default = ? WHERE name <> ?") - .bind(false.to_sql()) - .bind(name.to_sql()); + let query2 = query("UPDATE named_identity SET is_default = $1 WHERE name <> $2") + .bind(false) + .bind(name); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void() } @@ -202,22 +209,22 @@ impl IdentitiesRepository for IdentitiesSqlxDatabase { async fn set_as_default_by_identifier(&self, identifier: &Identifier) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; // set the identifier as the default one - let query1 = query("UPDATE named_identity SET is_default = ? WHERE identifier = ?") - .bind(true.to_sql()) - .bind(identifier.to_sql()); + let query1 = query("UPDATE named_identity SET is_default = $1 WHERE identifier = $2") + .bind(true) + .bind(identifier); query1.execute(&mut *transaction).await.void()?; // set all the others as non-default - let query2 = query("UPDATE named_identity SET is_default = ? WHERE identifier <> ?") - .bind(false.to_sql()) - .bind(identifier.to_sql()); + let query2 = query("UPDATE named_identity SET is_default = $1 WHERE identifier <> $2") + .bind(false) + .bind(identifier); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void() } async fn get_default_named_identity(&self) -> Result> { let query = - query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE is_default=$1").bind(true.to_sql()); + query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE is_default = $1").bind(true); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -231,7 +238,7 @@ pub(crate) struct NamedIdentityRow { identifier: String, name: String, vault_name: String, - is_default: bool, + is_default: Boolean, } impl NamedIdentityRow { @@ -253,7 +260,7 @@ impl NamedIdentityRow { self.identifier()?, self.name.clone(), self.vault_name.clone(), - self.is_default, + self.is_default.to_bool(), )) } } @@ -262,121 +269,130 @@ impl NamedIdentityRow { mod tests { use ockam::identity::identities; use ockam_core::compat::sync::Arc; + use ockam_node::database::with_dbs; use super::*; #[tokio::test] async fn test_identities_repository_named_identities() -> Result<()> { - let repository = create_repository().await?; - - // A name can be associated to an identity - let identifier1 = create_identity().await?; - repository - .store_named_identity(&identifier1, "name1", "vault") - .await?; - - let identifier2 = create_identity().await?; - repository - .store_named_identity(&identifier2, "name2", "vault") - .await?; - - let result = repository.get_identifier("name1").await?; - assert_eq!(result, Some(identifier1.clone())); - - let result = repository - .get_identity_name_by_identifier(&identifier1) - .await?; - assert_eq!(result, Some("name1".into())); - - let result = repository.get_named_identity("name2").await?; - assert_eq!(result.map(|n| n.identifier()), Some(identifier2.clone())); - - let result = repository.get_named_identities().await?; - assert_eq!( - result.iter().map(|n| n.identifier()).collect::>(), - vec![identifier1.clone(), identifier2.clone()] - ); - - repository.delete_identity("name1").await?; - let result = repository.get_named_identities().await?; - assert_eq!( - result.iter().map(|n| n.identifier()).collect::>(), - vec![identifier2.clone()] - ); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(IdentitiesSqlxDatabase::new(db)); + + // A name can be associated to an identity + let identifier1 = create_identity().await?; + repository + .store_named_identity(&identifier1, "name1", "vault") + .await?; + + let identifier2 = create_identity().await?; + repository + .store_named_identity(&identifier2, "name2", "vault") + .await?; + + let result = repository.get_identifier("name1").await?; + assert_eq!(result, Some(identifier1.clone())); + + let result = repository + .get_identity_name_by_identifier(&identifier1) + .await?; + assert_eq!(result, Some("name1".into())); + + let result = repository.get_named_identity("name2").await?; + assert_eq!(result.map(|n| n.identifier()), Some(identifier2.clone())); + + let result = repository.get_named_identities().await?; + assert_eq!( + result.iter().map(|n| n.identifier()).collect::>(), + vec![identifier1.clone(), identifier2.clone()] + ); + + repository.delete_identity("name1").await?; + let result = repository.get_named_identities().await?; + assert_eq!( + result.iter().map(|n| n.identifier()).collect::>(), + vec![identifier2.clone()] + ); + + Ok(()) + }) + .await } #[tokio::test] async fn test_identities_repository_default_identities() -> Result<()> { - let repository = create_repository().await?; - - // A name can be associated to an identity - let identifier1 = create_identity().await?; - let named_identity1 = repository - .store_named_identity(&identifier1, "name1", "vault") - .await?; - - let identifier2 = create_identity().await?; - let named_identity2 = repository - .store_named_identity(&identifier2, "name2", "vault") - .await?; - - // An identity can be marked as being the default one - repository - .set_as_default_by_identifier(&identifier1) - .await?; - let result = repository.get_default_named_identity().await?; - assert_eq!(result, Some(named_identity1.set_as_default())); - - // An identity can be marked as being the default one by passing its name - repository.set_as_default("name2").await?; - let result = repository.get_default_named_identity().await?; - assert_eq!(result, Some(named_identity2.set_as_default())); - - let result = repository.get_named_identity("name1").await?; - assert!(!result.unwrap().is_default()); - - let result = repository.get_default_named_identity().await?; - assert_eq!(result.map(|i| i.name()), Some("name2".to_string())); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(IdentitiesSqlxDatabase::new(db)); + + // A name can be associated to an identity + let identifier1 = create_identity().await?; + let named_identity1 = repository + .store_named_identity(&identifier1, "name1", "vault") + .await?; + + let identifier2 = create_identity().await?; + let named_identity2 = repository + .store_named_identity(&identifier2, "name2", "vault") + .await?; + + // An identity can be marked as being the default one + repository + .set_as_default_by_identifier(&identifier1) + .await?; + let result = repository.get_default_named_identity().await?; + assert_eq!(result, Some(named_identity1.set_as_default())); + + // An identity can be marked as being the default one by passing its name + repository.set_as_default("name2").await?; + let result = repository.get_default_named_identity().await?; + assert_eq!(result, Some(named_identity2.set_as_default())); + + let result = repository.get_named_identity("name1").await?; + assert!(!result.unwrap().is_default()); + + let result = repository.get_default_named_identity().await?; + assert_eq!(result.map(|i| i.name()), Some("name2".to_string())); + + Ok(()) + }) + .await } #[tokio::test] async fn test_get_identities_by_vault_name() -> Result<()> { - let repository = create_repository().await?; - - // A name can be associated to an identity - let identifier1 = create_identity().await?; - repository - .store_named_identity(&identifier1, "name1", "vault1") - .await?; - - let identifier2 = create_identity().await?; - repository - .store_named_identity(&identifier2, "name2", "vault2") - .await?; - - let identifier3 = create_identity().await?; - repository - .store_named_identity(&identifier3, "name3", "vault1") - .await?; - - let result = repository - .get_named_identities_by_vault_name("vault1") - .await?; - let names: Vec = result.iter().map(|i| i.name()).collect(); - assert_eq!(names, vec!["name1", "name3"]); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(IdentitiesSqlxDatabase::new(db)); + + // A name can be associated to an identity + let identifier1 = create_identity().await?; + repository + .store_named_identity(&identifier1, "name1", "vault1") + .await?; + + let identifier2 = create_identity().await?; + repository + .store_named_identity(&identifier2, "name2", "vault2") + .await?; + + let identifier3 = create_identity().await?; + repository + .store_named_identity(&identifier3, "name3", "vault1") + .await?; + + let result = repository + .get_named_identities_by_vault_name("vault1") + .await?; + let names: Vec = result.iter().map(|i| i.name()).collect(); + assert_eq!(names, vec!["name1", "name3"]); + + Ok(()) + }) + .await } /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(IdentitiesSqlxDatabase::create().await?)) - } - async fn create_identity() -> Result { let identities = identities().await?; identities.identities_creation().create_identity().await diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/journeys_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/journeys_repository_sql.rs index 7ebdf8020c1..4d160009b82 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/journeys_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/journeys_repository_sql.rs @@ -6,7 +6,7 @@ use crate::cli_state::JourneysRepository; use ockam_core::errcode::{Kind, Origin}; use ockam_core::Result; use ockam_core::{async_trait, OpenTelemetryContext}; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, Nullable, SqlxDatabase, ToVoid}; #[derive(Clone)] pub struct JourneysSqlxDatabase { @@ -16,7 +16,7 @@ pub struct JourneysSqlxDatabase { impl JourneysSqlxDatabase { /// Create a new database pub fn new(database: SqlxDatabase) -> Self { - debug!("create a repository for spaces"); + debug!("create a repository for user journeys"); Self { database } } @@ -31,15 +31,20 @@ impl JourneysSqlxDatabase { #[async_trait] impl JourneysRepository for JourneysSqlxDatabase { async fn store_project_journey(&self, project_journey: ProjectJourney) -> Result<()> { - let query = query("INSERT OR REPLACE INTO project_journey VALUES (?, ?, ?, ?)") - .bind(project_journey.project_id().to_sql()) - .bind(project_journey.opentelemetry_context().to_string().to_sql()) - .bind(project_journey.start().to_sql()) - .bind( - project_journey - .previous_opentelemetry_context() - .map(|c| c.to_string().to_sql()), - ); + let previous: Option = project_journey + .previous_opentelemetry_context() + .map(|c| c.to_string()); + let query = query( + r#" + INSERT INTO project_journey (project_id, opentelemetry_context, start_datetime, previous_opentelemetry_context) + VALUES ($1, $2, $3, $4) + ON CONFLICT (opentelemetry_context) + DO UPDATE SET project_id = $1, start_datetime = $3, previous_opentelemetry_context = $4"#, + ) + .bind(project_journey.project_id()) + .bind(project_journey.opentelemetry_context().to_string()) + .bind(project_journey.start().to_rfc3339()) + .bind(previous); query.execute(&*self.database.pool).await.void() } @@ -52,12 +57,12 @@ impl JourneysRepository for JourneysSqlxDatabase { "\ SELECT project_id, opentelemetry_context, start_datetime, previous_opentelemetry_context \ FROM project_journey \ - WHERE project_id = ? AND start_datetime <= ? \ + WHERE project_id = $1 AND start_datetime <= $2 \ ORDER BY start_datetime DESC \ LIMIT 1 OFFSET 0", ) - .bind(project_id.to_sql()) - .bind(now.to_sql()); + .bind(project_id) + .bind(now.to_rfc3339()); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -66,33 +71,39 @@ impl JourneysRepository for JourneysSqlxDatabase { } async fn delete_project_journeys(&self, project_id: &str) -> Result<()> { - let query = - query("DELETE FROM project_journey where project_id = ?").bind(project_id.to_sql()); + let query = query("DELETE FROM project_journey where project_id = $1").bind(project_id); query.execute(&*self.database.pool).await.void() } async fn store_host_journey(&self, host_journey: Journey) -> Result<()> { - let query = query("INSERT OR REPLACE INTO host_journey VALUES (?, ?, ?)") - .bind(host_journey.opentelemetry_context().to_string().to_sql()) - .bind(host_journey.start().to_sql()) - .bind( - host_journey - .previous_opentelemetry_context() - .map(|c| c.to_string().to_sql()), - ); + let query = query( + r#" + INSERT INTO host_journey (opentelemetry_context, start_datetime, previous_opentelemetry_context) + VALUES ($1, $2, $3) + ON CONFLICT (opentelemetry_context) + DO UPDATE SET start_datetime = $2, previous_opentelemetry_context = $3"#, + ) + .bind(host_journey.opentelemetry_context().to_string()) + .bind(host_journey.start().to_rfc3339()) + .bind( + host_journey + .previous_opentelemetry_context() + .map(|c| c.to_string()), + ); query.execute(&*self.database.pool).await.void() } async fn get_host_journey(&self, now: DateTime) -> Result> { let query = query_as( - "\ - SELECT opentelemetry_context, start_datetime, previous_opentelemetry_context \ - FROM host_journey \ - WHERE start_datetime <= ? \ - ORDER BY start_datetime DESC \ - LIMIT 1 OFFSET 0", + r#" + SELECT opentelemetry_context, start_datetime, previous_opentelemetry_context + FROM host_journey + WHERE start_datetime <= $1 + ORDER BY start_datetime DESC + LIMIT 1 OFFSET 0 + "#, ) - .bind(now.to_sql()); + .bind(now.to_rfc3339()); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -101,15 +112,13 @@ impl JourneysRepository for JourneysSqlxDatabase { } } -// Database serialization / deserialization - /// Low-level representation of a row in the project journey table #[derive(sqlx::FromRow)] struct ProjectJourneyRow { project_id: String, opentelemetry_context: String, start_datetime: String, - previous_opentelemetry_context: Option, + previous_opentelemetry_context: Nullable, } impl ProjectJourneyRow { @@ -128,7 +137,7 @@ impl ProjectJourneyRow { fn previous_opentelemetry_context(&self) -> Result> { self.previous_opentelemetry_context - .clone() + .to_option() .map(|c| c.try_into()) .transpose() } @@ -147,7 +156,7 @@ impl ProjectJourneyRow { struct HostJourneyRow { opentelemetry_context: String, start_datetime: String, - previous_opentelemetry_context: Option, + previous_opentelemetry_context: Nullable, } impl HostJourneyRow { @@ -165,7 +174,7 @@ impl HostJourneyRow { fn previous_opentelemetry_context(&self) -> Result> { self.previous_opentelemetry_context - .clone() + .to_option() .map(|c| c.try_into()) .transpose() } @@ -184,6 +193,7 @@ mod test { use super::*; use crate::cli_state::journeys::{Journey, ProjectJourney}; use crate::cli_state::JourneysRepository; + use ockam_node::database::with_application_dbs; use std::ops::{Add, Sub}; use std::str::FromStr; use std::sync::Arc; @@ -191,167 +201,170 @@ mod test { #[tokio::test] async fn test_repository() -> Result<()> { - let repository = create_repository().await?; - - // the repository is initially empty - let actual = repository.get_host_journey(Utc::now()).await?; - assert_eq!(actual, None); - - // create and store a host journey - let opentelemetry_context = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-01\",\"tracestate\":\"\"}").unwrap(); - let host_journey = Journey::new(opentelemetry_context.clone(), None, Utc::now()); - repository.store_host_journey(host_journey.clone()).await?; - let actual = repository.get_host_journey(Utc::now()).await?; - assert_eq!(actual, Some(host_journey)); - - // create and store a project journey - let project_journey = - ProjectJourney::new("project_id", opentelemetry_context, None, Utc::now()); - repository - .store_project_journey(project_journey.clone()) - .await?; - let actual = repository - .get_project_journey("project_id", Utc::now()) - .await?; - assert_eq!(actual, Some(project_journey)); - - // delete a project journey - repository.delete_project_journeys("project_id").await?; - let actual = repository - .get_project_journey("project_id", Utc::now()) - .await?; - assert_eq!(actual, None); - Ok(()) + with_application_dbs(|db| async move { + let repository: Arc = + Arc::new(JourneysSqlxDatabase::new(db)); + + // the repository is initially empty + let actual = repository.get_host_journey(Utc::now()).await?; + assert_eq!(actual, None); + + // create and store a host journey + let opentelemetry_context = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-01\",\"tracestate\":\"\"}").unwrap(); + let host_journey = Journey::new(opentelemetry_context.clone(), None, Utc::now()); + repository.store_host_journey(host_journey.clone()).await?; + let actual = repository.get_host_journey(Utc::now()).await?; + assert_eq!(actual, Some(host_journey)); + + // create and store a project journey + let project_journey = + ProjectJourney::new("project_id", opentelemetry_context, None, Utc::now()); + repository + .store_project_journey(project_journey.clone()) + .await?; + let actual = repository + .get_project_journey("project_id", Utc::now()) + .await?; + assert_eq!(actual, Some(project_journey)); + + // delete a project journey + repository.delete_project_journeys("project_id").await?; + let actual = repository + .get_project_journey("project_id", Utc::now()) + .await?; + assert_eq!(actual, None); + Ok(()) + }).await } /// This test checks that we can store host journeys with a previous / next relationship #[tokio::test] async fn test_several_host_journeys() -> Result<()> { - let repository = create_repository().await?; - - // create and store a the first host journey - let opentelemetry_context1 = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-01\",\"tracestate\":\"\"}").unwrap(); - let start1 = Utc::now(); - let host_journey1 = Journey::new(opentelemetry_context1.clone(), None, start1); - repository.store_host_journey(host_journey1.clone()).await?; - - // retrieve the journey based on the time - // before the journey 1 start -> None - // equal or after the journey 1 start -> Some(journey1) - let actual = repository - .get_host_journey(start1.sub(Duration::from_secs(3))) - .await?; - assert_eq!(actual, None); - - let actual = repository.get_host_journey(start1).await?; - assert_eq!(actual, Some(host_journey1.clone())); - - let actual = repository - .get_host_journey(start1.add(Duration::from_secs(3))) - .await?; - assert_eq!(actual, Some(host_journey1.clone())); - - // Create the next journey - let opentelemetry_context2 = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-02\",\"tracestate\":\"\"}").unwrap(); - let start2 = start1.add(Duration::from_secs(1000)); - let host_journey2 = Journey::new( - opentelemetry_context2.clone(), - Some(opentelemetry_context1), - start2, - ); - repository.store_host_journey(host_journey2.clone()).await?; - - // retrieve the journey based on the time - // right before the journey 2 start -> Some(journey1) - // equal or after the journey 2 start -> Some(journey2) - let actual = repository - .get_host_journey(start2.sub(Duration::from_secs(3))) - .await?; - assert_eq!(actual, Some(host_journey1.clone())); - - let actual = repository.get_host_journey(start2).await?; - assert_eq!(actual, Some(host_journey2.clone())); - assert_eq!( - host_journey2.previous_opentelemetry_context(), - Some(host_journey1.opentelemetry_context()) - ); + with_application_dbs(|db| async move { + let repository: Arc = + Arc::new(JourneysSqlxDatabase::new(db)); + + // create and store a the first host journey + let opentelemetry_context1 = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-01\",\"tracestate\":\"\"}").unwrap(); + let start1 = Utc::now(); + let host_journey1 = Journey::new(opentelemetry_context1.clone(), None, start1); + repository.store_host_journey(host_journey1.clone()).await?; + + // retrieve the journey based on the time + // before the journey 1 start -> None + // equal or after the journey 1 start -> Some(journey1) + let actual = repository + .get_host_journey(start1.sub(Duration::from_secs(3))) + .await?; + assert_eq!(actual, None); + + let actual = repository.get_host_journey(start1).await?; + assert_eq!(actual, Some(host_journey1.clone())); + + let actual = repository + .get_host_journey(start1.add(Duration::from_secs(3))) + .await?; + assert_eq!(actual, Some(host_journey1.clone())); + + // Create the next journey + let opentelemetry_context2 = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-02\",\"tracestate\":\"\"}").unwrap(); + let start2 = start1.add(Duration::from_secs(1000)); + let host_journey2 = Journey::new( + opentelemetry_context2.clone(), + Some(opentelemetry_context1), + start2, + ); + repository.store_host_journey(host_journey2.clone()).await?; + // retrieve the journey based on the time + // right before the journey 2 start -> Some(journey1) + // equal or after the journey 2 start -> Some(journey2) + let actual = repository + .get_host_journey(start2.sub(Duration::from_secs(3))) + .await?; + assert_eq!(actual, Some(host_journey1.clone())); + + let actual = repository.get_host_journey(start2).await?; + assert_eq!(actual, Some(host_journey2.clone())); + assert_eq!( + host_journey2.previous_opentelemetry_context(), + Some(host_journey1.opentelemetry_context()) + ); - let actual = repository - .get_host_journey(start2.add(Duration::from_secs(3))) - .await?; - assert_eq!(actual, Some(host_journey2)); + let actual = repository + .get_host_journey(start2.add(Duration::from_secs(3))) + .await?; + assert_eq!(actual, Some(host_journey2)); - Ok(()) + Ok(()) + }).await } /// This test checks that we can store project journeys with a previous / next relationship #[tokio::test] async fn test_several_project_journeys() -> Result<()> { - let repository = create_repository().await?; - - // create and store a the first host journey - let opentelemetry_context1 = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-01\",\"tracestate\":\"\"}").unwrap(); - let start1 = Utc::now(); - let project_journey1 = - ProjectJourney::new("project_id", opentelemetry_context1.clone(), None, start1); - repository - .store_project_journey(project_journey1.clone()) - .await?; - - // retrieve the journey based on the time - // before the journey 1 start -> None - // equal or after the journey 1 start -> Some(journey1) - let actual = repository - .get_project_journey("project_id", start1.sub(Duration::from_secs(3))) - .await?; - assert_eq!(actual, None); - - let actual = repository.get_project_journey("project_id", start1).await?; - assert_eq!(actual, Some(project_journey1.clone())); - - let actual = repository - .get_project_journey("project_id", start1.add(Duration::from_secs(3))) - .await?; - assert_eq!(actual, Some(project_journey1.clone())); - - // Create the next journey - let opentelemetry_context2 = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-02\",\"tracestate\":\"\"}").unwrap(); - let start2 = start1.add(Duration::from_secs(1000)); - let project_journey2 = ProjectJourney::new( - "project_id", - opentelemetry_context2.clone(), - Some(opentelemetry_context1), - start2, - ); - repository - .store_project_journey(project_journey2.clone()) - .await?; - - // retrieve the journey based on the time - // right before the journey 2 start -> Some(journey1) - // equal or after the journey 2 start -> Some(journey2) - let actual = repository - .get_project_journey("project_id", start2.sub(Duration::from_secs(3))) - .await?; - assert_eq!(actual, Some(project_journey1.clone())); - - let actual = repository.get_project_journey("project_id", start2).await?; - assert_eq!(actual, Some(project_journey2.clone())); - assert_eq!( - project_journey2.previous_opentelemetry_context(), - Some(project_journey1.opentelemetry_context()) - ); - - let actual = repository - .get_project_journey("project_id", start2.add(Duration::from_secs(3))) - .await?; - assert_eq!(actual, Some(project_journey2)); + with_application_dbs(|db| async move { + let repository: Arc = + Arc::new(JourneysSqlxDatabase::new(db)); + + // create and store a the first host journey + let opentelemetry_context1 = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-01\",\"tracestate\":\"\"}").unwrap(); + let start1 = Utc::now(); + let project_journey1 = + ProjectJourney::new("project_id", opentelemetry_context1.clone(), None, start1); + repository + .store_project_journey(project_journey1.clone()) + .await?; + + // retrieve the journey based on the time + // before the journey 1 start -> None + // equal or after the journey 1 start -> Some(journey1) + let actual = repository + .get_project_journey("project_id", start1.sub(Duration::from_secs(3))) + .await?; + assert_eq!(actual, None); + + let actual = repository.get_project_journey("project_id", start1).await?; + assert_eq!(actual, Some(project_journey1.clone())); + + let actual = repository + .get_project_journey("project_id", start1.add(Duration::from_secs(3))) + .await?; + assert_eq!(actual, Some(project_journey1.clone())); + + // Create the next journey + let opentelemetry_context2 = OpenTelemetryContext::from_str("{\"traceparent\":\"00-b9ce70eaad5a86ef6b9fa4db00589e86-8e2d99c5e5ed66e4-02\",\"tracestate\":\"\"}").unwrap(); + let start2 = start1.add(Duration::from_secs(1000)); + let project_journey2 = ProjectJourney::new( + "project_id", + opentelemetry_context2.clone(), + Some(opentelemetry_context1), + start2, + ); + repository + .store_project_journey(project_journey2.clone()) + .await?; + + // retrieve the journey based on the time + // right before the journey 2 start -> Some(journey1) + // equal or after the journey 2 start -> Some(journey2) + let actual = repository + .get_project_journey("project_id", start2.sub(Duration::from_secs(3))) + .await?; + assert_eq!(actual, Some(project_journey1.clone())); + + let actual = repository.get_project_journey("project_id", start2).await?; + assert_eq!(actual, Some(project_journey2.clone())); + assert_eq!( + project_journey2.previous_opentelemetry_context(), + Some(project_journey1.opentelemetry_context()) + ); - Ok(()) - } + let actual = repository + .get_project_journey("project_id", start2.add(Duration::from_secs(3))) + .await?; + assert_eq!(actual, Some(project_journey2)); - /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(JourneysSqlxDatabase::create().await?)) + Ok(()) + }).await } } diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/nodes_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/nodes_repository_sql.rs index 581a87b1ea4..0e813487f54 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/nodes_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/nodes_repository_sql.rs @@ -1,13 +1,16 @@ use std::str::FromStr; -use sqlx::sqlite::SqliteRow; +use sqlx::any::AnyRow; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use ockam::identity::Identifier; -use ockam::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam::{FromSqlxError, SqlxDatabase, ToVoid}; use ockam_core::async_trait; use ockam_core::Result; +use ockam_node::database::{Boolean, Nullable}; use crate::cli_state::{NodeInfo, NodesRepository}; use crate::config::lookup::InternetAddress; @@ -32,24 +35,28 @@ impl NodesSqlxDatabase { #[async_trait] impl NodesRepository for NodesSqlxDatabase { async fn store_node(&self, node_info: &NodeInfo) -> Result<()> { - let query = query("INSERT OR REPLACE INTO node VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)") - .bind(node_info.name().to_sql()) - .bind(node_info.identifier().to_sql()) - .bind(node_info.verbosity().to_sql()) - .bind(node_info.is_default().to_sql()) - .bind(node_info.is_authority_node().to_sql()) + let query = query(r#" + INSERT INTO node (name, identifier, verbosity, is_default, is_authority, tcp_listener_address, pid, http_server_address) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (name) + DO UPDATE SET identifier = $2, verbosity = $3, is_default = $4, is_authority = $5, tcp_listener_address = $6, pid = $7, http_server_address = $8"#) + .bind(node_info.name()) + .bind(node_info.identifier()) + .bind(node_info.verbosity() as i16) + .bind(node_info.is_default()) + .bind(node_info.is_authority_node()) .bind( node_info .tcp_listener_address() .as_ref() - .map(|a| a.to_string().to_sql()), + .map(|a| a.to_string()), ) - .bind(node_info.pid().map(|p| p.to_sql())) + .bind(node_info.pid().map(|p| p as i32)) .bind( node_info .http_server_address() .as_ref() - .map(|a| a.to_string().to_sql()), + .map(|a| a.to_string()), ); Ok(query.execute(&*self.database.pool).await.void()?) } @@ -61,7 +68,7 @@ impl NodesRepository for NodesSqlxDatabase { } async fn get_node(&self, node_name: &str) -> Result> { - let query = query_as("SELECT name, identifier, verbosity, is_default, is_authority, tcp_listener_address, pid, http_server_address FROM node WHERE name = ?").bind(node_name.to_sql()); + let query = query_as("SELECT name, identifier, verbosity, is_default, is_authority, tcp_listener_address, pid, http_server_address FROM node WHERE name = $1").bind(node_name); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -70,13 +77,13 @@ impl NodesRepository for NodesSqlxDatabase { } async fn get_nodes_by_identifier(&self, identifier: &Identifier) -> Result> { - let query = query_as("SELECT name, identifier, verbosity, is_default, is_authority, tcp_listener_address, pid, http_server_address FROM node WHERE identifier = ?").bind(identifier.to_sql()); + let query = query_as("SELECT name, identifier, verbosity, is_default, is_authority, tcp_listener_address, pid, http_server_address FROM node WHERE identifier = $1").bind(identifier.to_string()); let rows: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; rows.iter().map(|r| r.node_info()).collect() } async fn get_default_node(&self) -> Result> { - let query = query_as("SELECT name, identifier, verbosity, is_default, is_authority, tcp_listener_address, pid, http_server_address FROM node WHERE is_default = ?").bind(true.to_sql()); + let query = query_as("SELECT name, identifier, verbosity, is_default, is_authority, tcp_listener_address, pid, http_server_address FROM node WHERE is_default = $1").bind(true); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -85,26 +92,28 @@ impl NodesRepository for NodesSqlxDatabase { } async fn is_default_node(&self, node_name: &str) -> Result { - let query = query("SELECT is_default FROM node WHERE name = ?").bind(node_name.to_sql()); - let row: Option = query + let query = query("SELECT is_default FROM node WHERE name = $1").bind(node_name); + let row: Option = query .fetch_optional(&*self.database.pool) .await .into_core()?; - Ok(row.map(|r| r.get(0)).unwrap_or(false)) + Ok(row + .map(|r| r.get::(0).to_bool()) + .unwrap_or(false)) } async fn set_default_node(&self, node_name: &str) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; // set the node as the default one - let query1 = query("UPDATE node SET is_default = ? WHERE name = ?") - .bind(true.to_sql()) - .bind(node_name.to_sql()); + let query1 = query("UPDATE node SET is_default = $1 WHERE name = $2") + .bind(true) + .bind(node_name); query1.execute(&mut *transaction).await.void()?; // set all the others as non-default - let query2 = query("UPDATE node SET is_default = ? WHERE name <> ?") - .bind(false.to_sql()) - .bind(node_name.to_sql()); + let query2 = query("UPDATE node SET is_default = $1 WHERE name <> $2") + .bind(false) + .bind(node_name); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void() } @@ -112,37 +121,34 @@ impl NodesRepository for NodesSqlxDatabase { async fn delete_node(&self, node_name: &str) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; - let query = query("DELETE FROM node WHERE name=?").bind(node_name.to_sql()); + let query = query("DELETE FROM node WHERE name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; - let query = - sqlx::query("DELETE FROM credential WHERE node_name=?").bind(node_name.to_sql()); + let query = sqlx::query("DELETE FROM credential WHERE node_name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; - let query = sqlx::query("DELETE FROM resource WHERE node_name=?").bind(node_name.to_sql()); + let query = sqlx::query("DELETE FROM resource WHERE node_name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; - let query = - sqlx::query("DELETE FROM resource_policy WHERE node_name=?").bind(node_name.to_sql()); + let query = sqlx::query("DELETE FROM resource_policy WHERE node_name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; - let query = sqlx::query("DELETE FROM resource_type_policy WHERE node_name=?") - .bind(node_name.to_sql()); + let query = + sqlx::query("DELETE FROM resource_type_policy WHERE node_name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; - let query = sqlx::query("DELETE FROM identity_attributes WHERE node_name=?") - .bind(node_name.to_sql()); + let query = + sqlx::query("DELETE FROM identity_attributes WHERE node_name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; - let query = sqlx::query("DELETE FROM tcp_inlet WHERE node_name=?").bind(node_name.to_sql()); + let query = sqlx::query("DELETE FROM tcp_inlet WHERE node_name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; let query = - sqlx::query("DELETE FROM tcp_outlet_status WHERE node_name=?").bind(node_name.to_sql()); + sqlx::query("DELETE FROM tcp_outlet_status WHERE node_name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; - let query = - sqlx::query("DELETE FROM node_project WHERE node_name=?").bind(node_name.to_sql()); + let query = sqlx::query("DELETE FROM node_project WHERE node_name = $1").bind(node_name); query.execute(&mut *transaction).await.void()?; transaction.commit().await.void() @@ -153,9 +159,9 @@ impl NodesRepository for NodesSqlxDatabase { node_name: &str, address: &InternetAddress, ) -> Result<()> { - let query = query("UPDATE node SET tcp_listener_address = ? WHERE name = ?") - .bind(address.to_string().to_sql()) - .bind(node_name.to_sql()); + let query = query("UPDATE node SET tcp_listener_address = $1 WHERE name = $2") + .bind(address) + .bind(node_name); query.execute(&*self.database.pool).await.void() } @@ -164,16 +170,16 @@ impl NodesRepository for NodesSqlxDatabase { node_name: &str, address: &InternetAddress, ) -> Result<()> { - let query = query("UPDATE node SET http_server_address = ? WHERE name = ?") - .bind(address.to_string().to_sql()) - .bind(node_name.to_sql()); + let query = query("UPDATE node SET http_server_address = $1 WHERE name = $2") + .bind(address) + .bind(node_name); query.execute(&*self.database.pool).await.void() } async fn set_as_authority_node(&self, node_name: &str) -> Result<()> { - let query = query("UPDATE node SET is_authority = ? WHERE name = ?") - .bind(true.to_sql()) - .bind(node_name.to_sql()); + let query = query("UPDATE node SET is_authority = $1 WHERE name = $2") + .bind(true) + .bind(node_name); query.execute(&*self.database.pool).await.void() } @@ -192,28 +198,34 @@ impl NodesRepository for NodesSqlxDatabase { } async fn set_node_pid(&self, node_name: &str, pid: u32) -> Result<()> { - let query = query("UPDATE node SET pid = ? WHERE name = ?") - .bind(pid.to_sql()) - .bind(node_name.to_sql()); + let query = query("UPDATE node SET pid = $1 WHERE name = $2") + .bind(pid as i32) + .bind(node_name); query.execute(&*self.database.pool).await.void() } async fn set_no_node_pid(&self, node_name: &str) -> Result<()> { - let query = query("UPDATE node SET pid = NULL WHERE name = ?").bind(node_name.to_sql()); + let query = query("UPDATE node SET pid=NULL WHERE name = $1 ").bind(node_name); query.execute(&*self.database.pool).await.void() } async fn set_node_project_name(&self, node_name: &str, project_name: &str) -> Result<()> { - let query = query("INSERT OR REPLACE INTO node_project VALUES (?1, ?2)") - .bind(node_name.to_sql()) - .bind(project_name.to_sql()); + let query = query( + r#" + INSERT INTO node_project (node_name, project_name) + VALUES ($1, $2) + ON CONFLICT (node_name) + DO UPDATE SET project_name = $2"#, + ) + .bind(node_name) + .bind(project_name); Ok(query.execute(&*self.database.pool).await.void()?) } async fn get_node_project_name(&self, node_name: &str) -> Result> { - let query = query("SELECT project_name FROM node_project WHERE node_name = ?") - .bind(node_name.to_sql()); - let row: Option = query + let query = + query("SELECT project_name FROM node_project WHERE node_name = $1").bind(node_name); + let row: Option = query .fetch_optional(&*self.database.pool) .await .into_core()?; @@ -224,37 +236,49 @@ impl NodesRepository for NodesSqlxDatabase { // Database serialization / deserialization +impl Type for InternetAddress { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl sqlx::Encode<'_, Any> for InternetAddress { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + #[derive(FromRow)] pub(crate) struct NodeRow { name: String, identifier: String, - verbosity: u8, - is_default: bool, - is_authority: bool, - tcp_listener_address: Option, - pid: Option, - http_server_address: Option, + verbosity: i64, + is_default: Boolean, + is_authority: Boolean, + tcp_listener_address: Nullable, + pid: Nullable, + http_server_address: Nullable, } impl NodeRow { pub(crate) fn node_info(&self) -> Result { let tcp_listener_address = self .tcp_listener_address - .as_ref() - .and_then(|a| InternetAddress::new(a)); + .to_option() + .and_then(|a| InternetAddress::new(&a)); let http_server_address = self .http_server_address - .as_ref() - .and_then(|a| InternetAddress::new(a)); + .to_option() + .and_then(|a| InternetAddress::new(&a)); Ok(NodeInfo::new( self.name.clone(), Identifier::from_str(&self.identifier.clone())?, - self.verbosity, - self.is_default, - self.is_authority, + self.verbosity as u8, + self.is_default.to_bool(), + self.is_authority.to_bool(), tcp_listener_address, - self.pid, + self.pid.to_option().map(|p| p as u32), http_server_address, )) } @@ -264,110 +288,118 @@ impl NodeRow { mod test { use crate::cli_state::NodeInfo; use ockam::identity::identities; + use ockam_node::database::with_dbs; use std::sync::Arc; use super::*; #[tokio::test] async fn test_repository() -> Result<()> { - let repository = create_repository().await?; - let identifier = create_identity().await?; - - // The information about a node can be stored - let node_info1 = NodeInfo::new( - "node1".to_string(), - identifier.clone(), - 0, - false, - false, - InternetAddress::new("127.0.0.1:51591"), - Some(1234), - InternetAddress::new("127.0.0.1:51592"), - ); - - repository.store_node(&node_info1).await?; - - // get the node by name - let result = repository.get_node("node1").await?; - assert_eq!(result, Some(node_info1.clone())); + with_dbs(|db| async move { + let repository: Arc = Arc::new(NodesSqlxDatabase::new(db)); + + let identifier = create_identity().await?; + + // The information about a node can be stored + let node_info1 = NodeInfo::new( + "node1".to_string(), + identifier.clone(), + 0, + false, + false, + InternetAddress::new("127.0.0.1:51591"), + Some(1234), + InternetAddress::new("127.0.0.1:51592"), + ); - // get the node by identifier - let result = repository.get_nodes_by_identifier(&identifier).await?; - assert_eq!(result, vec![node_info1.clone()]); + repository.store_node(&node_info1).await?; + + // get the node by name + let result = repository.get_node("node1").await?; + assert_eq!(result, Some(node_info1.clone())); + + // get the node by identifier + let result = repository.get_nodes_by_identifier(&identifier).await?; + assert_eq!(result, vec![node_info1.clone()]); + + // the list of all the nodes can be retrieved + let node_info2 = NodeInfo::new( + "node2".to_string(), + identifier.clone(), + 0, + false, + false, + None, + Some(5678), + None, + ); - // the list of all the nodes can be retrieved - let node_info2 = NodeInfo::new( - "node2".to_string(), - identifier.clone(), - 0, - false, - false, - None, - Some(5678), - None, - ); - - repository.store_node(&node_info2).await?; - let result = repository.get_nodes().await?; - assert_eq!(result, vec![node_info1.clone(), node_info2.clone()]); - - // a node can be set as the default - repository.set_default_node("node2").await?; - let result = repository.get_default_node().await?; - assert_eq!(result, Some(node_info2.set_as_default())); - - // a node can be deleted - repository.delete_node("node2").await?; - let result = repository.get_nodes().await?; - assert_eq!(result, vec![node_info1.clone()]); - - // in that case there is no more default node - let result = repository.get_default_node().await?; - assert!(result.is_none()); - Ok(()) + repository.store_node(&node_info2).await?; + let result = repository.get_nodes().await?; + assert_eq!(result, vec![node_info1.clone(), node_info2.clone()]); + + // a node can be set as the default + repository.set_default_node("node2").await?; + let result = repository.get_default_node().await?; + assert_eq!(result, Some(node_info2.set_as_default())); + + // a node can be deleted + repository.delete_node("node2").await?; + let result = repository.get_nodes().await?; + assert_eq!(result, vec![node_info1.clone()]); + + // in that case there is no more default node + let result = repository.get_default_node().await?; + assert!(result.is_none()); + Ok(()) + }) + .await } #[tokio::test] async fn test_an_identity_used_by_two_nodes() -> Result<()> { - let repository = create_repository().await?; - let identifier1 = create_identity().await?; - let identifier2 = create_identity().await?; + with_dbs(|db| async move { + let repository: Arc = Arc::new(NodesSqlxDatabase::new(db)); + + let identifier1 = create_identity().await?; + let identifier2 = create_identity().await?; - // Create 3 nodes: 2 with the same identifier, 1 with a different identifier - let node_info1 = create_node("node1", &identifier1); - repository.store_node(&node_info1).await?; + // Create 3 nodes: 2 with the same identifier, 1 with a different identifier + let node_info1 = create_node("node1", &identifier1); + repository.store_node(&node_info1).await?; - let node_info2 = create_node("node2", &identifier1); - repository.store_node(&node_info2).await?; + let node_info2 = create_node("node2", &identifier1); + repository.store_node(&node_info2).await?; - let node_info3 = create_node("node3", &identifier2); - repository.store_node(&node_info3).await?; + let node_info3 = create_node("node3", &identifier2); + repository.store_node(&node_info3).await?; - // get the nodes for identifier1 - let result = repository.get_nodes_by_identifier(&identifier1).await?; - assert_eq!(result, vec![node_info1.clone(), node_info2.clone()]); - Ok(()) + // get the nodes for identifier1 + let result = repository.get_nodes_by_identifier(&identifier1).await?; + assert_eq!(result, vec![node_info1.clone(), node_info2.clone()]); + Ok(()) + }) + .await } #[tokio::test] async fn test_node_project() -> Result<()> { - let repository = create_repository().await?; - - // a node can be associated to a project name - repository - .set_node_project_name("node_name", "project1") - .await?; - let result = repository.get_node_project_name("node_name").await?; - assert_eq!(result, Some("project1".into())); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = Arc::new(NodesSqlxDatabase::new(db)); + + // a node can be associated to a project name + repository + .set_node_project_name("node_name", "project1") + .await?; + let result = repository.get_node_project_name("node_name").await?; + assert_eq!(result, Some("project1".into())); + + Ok(()) + }) + .await } /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(NodesSqlxDatabase::create().await?)) - } - async fn create_identity() -> Result { let identities = identities().await?; identities.identities_creation().create_identity().await diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/projects_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/projects_repository_sql.rs index 492a8e3cd01..b6b9853bda3 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/projects_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/projects_repository_sql.rs @@ -1,6 +1,8 @@ use std::str::FromStr; -use sqlx::sqlite::SqliteRow; +use sqlx::any::AnyRow; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use ockam::identity::Identifier; @@ -8,7 +10,7 @@ use ockam_core::async_trait; use ockam_core::env::FromString; use ockam_core::errcode::{Kind, Origin}; use ockam_core::{Error, Result}; -use ockam_node::database::{FromSqlxError, SqlxDatabase, SqlxType, ToSqlxType, ToVoid}; +use ockam_node::database::{Boolean, FromSqlxError, Nullable, SqlxDatabase, ToVoid}; use crate::cloud::addon::KafkaConfig; use crate::cloud::email_address::EmailAddress; @@ -50,81 +52,109 @@ impl ProjectsRepository for ProjectsSqlxDatabase { let mut transaction = self.database.begin().await.into_core()?; let query1 = query_scalar( - "SELECT EXISTS(SELECT 1 FROM project WHERE is_default=$1 AND project_id=$2)", + "SELECT EXISTS(SELECT 1 FROM project WHERE is_default = $1 AND project_id = $2)", ) - .bind(true.to_sql()) - .bind(project.id.to_sql()); - let is_already_default: bool = query1.fetch_one(&mut *transaction).await.into_core()?; + .bind(true) + .bind(project.id.clone()); + let is_already_default: Boolean = query1.fetch_one(&mut *transaction).await.into_core()?; + let is_already_default = is_already_default.to_bool(); let query2 = query( - "INSERT OR REPLACE INTO project (project_id, project_name, is_default, space_id, space_name, project_identifier, project_change_history, access_route, authority_change_history, authority_access_route, version, running, operation_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)", + r#" + INSERT INTO project (project_id, project_name, is_default, space_id, space_name, project_identifier, project_change_history, access_route, authority_change_history, authority_access_route, version, running, operation_id) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ON CONFLICT (project_id) + DO UPDATE SET project_name = $2, is_default = $3, space_id = $4, space_name = $5, project_identifier = $6, project_change_history = $7, access_route = $8, authority_change_history = $9, authority_access_route = $10, version = $11, running = $12, operation_id = $13"#, ) - .bind(project.id.to_sql()) - .bind(project.name.to_sql()) - .bind(is_already_default.to_sql()) - .bind(project.space_id.to_sql()) - .bind(project.space_name.to_sql()) - .bind(project.identity.as_ref().map(|i| i.to_sql())) - .bind(project.project_change_history.as_ref().map(|r| r.to_sql())) - .bind(project.access_route.to_sql()) - .bind(project.authority_identity.as_ref().map(|r| r.to_sql())) - .bind(project.authority_access_route.as_ref().map(|r| r.to_sql())) - .bind(project.version.as_ref().map(|r| r.to_sql())) - .bind(project.running.as_ref().map(|r| r.to_sql())) - .bind(project.operation_id.as_ref().map(|r| r.to_sql())); + .bind(&project.id) + .bind(&project.name) + .bind(is_already_default) + .bind(&project.space_id) + .bind(&project.space_name) + .bind(&project.identity) + .bind(project.project_change_history.as_ref()) + .bind(&project.access_route) + .bind(project.authority_identity.as_ref()) + .bind(project.authority_access_route.as_ref()) + .bind(project.version.as_ref()) + .bind(project.running.as_ref()) + .bind(project.operation_id.as_ref()); query2.execute(&mut *transaction).await.void()?; // remove any existing users related to that project if any - let query3 = - query("DELETE FROM user_project WHERE project_id=$1").bind(project.id.to_sql()); + let query3 = query("DELETE FROM user_project WHERE project_id = $1").bind(&project.id); query3.execute(&mut *transaction).await.void()?; // store the users associated to that project for user_email in &project.users { - let query = query("INSERT OR REPLACE INTO user_project VALUES (?, ?)") - .bind(user_email.to_sql()) - .bind(project.id.to_sql()); + let query = query( + r#" + INSERT INTO user_project (user_email, project_id) + VALUES ($1, $2) + ON CONFLICT DO NOTHING"#, + ) + .bind(user_email) + .bind(&project.id); query.execute(&mut *transaction).await.void()?; } // remove any existing user roles related to that project if any - let query4 = query("DELETE FROM user_role WHERE project_id=$1").bind(project.id.to_sql()); + let query4 = query("DELETE FROM user_role WHERE project_id = $1").bind(&project.id); query4.execute(&mut *transaction).await.void()?; // store the user roles associated to that project for user_role in &project.user_roles { - let query = query("INSERT OR REPLACE INTO user_role VALUES (?, ?, ?, ?, ?)") - .bind(user_role.id.to_sql()) - .bind(project.id.to_sql()) - .bind(user_role.email.to_sql()) - .bind(user_role.role.to_string().to_sql()) - .bind(user_role.scope.to_string().to_sql()); + let query = query( + r#" + INSERT INTO user_role (user_id, project_id, user_email, role, scope) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING"#, + ) + .bind(user_role.id as i64) + .bind(&project.id) + .bind(&user_role.email) + .bind(&user_role.role) + .bind(&user_role.scope); query.execute(&mut *transaction).await.void()?; } // make sure that the project space is also saved - let query5 = query("INSERT OR IGNORE INTO space VALUES ($1, $2, $3)") - .bind(project.space_id.to_sql()) - .bind(project.space_name.to_sql()) - .bind(true.to_sql()); + let query5 = query( + r#" + INSERT INTO space (space_id, space_name, is_default) + VALUES ($1, $2, $3) + ON CONFLICT (space_id) + DO UPDATE SET space_name = $2, is_default = $3"#, + ) + .bind(&project.space_id) + .bind(&project.space_name) + .bind(true); query5.execute(&mut *transaction).await.void()?; // store the okta configuration if any for okta_config in &project.okta_config { - let query = query("INSERT OR REPLACE INTO okta_config VALUES (?, ?, ?, ?, ?)") - .bind(project.id.to_sql()) - .bind(okta_config.tenant_base_url.to_string().to_sql()) - .bind(okta_config.client_id.to_sql()) - .bind(okta_config.certificate.to_string().to_sql()) - .bind(okta_config.attributes.join(",").to_string().to_sql()); + let query = query(r#" + INSERT INTO okta_config (project_id, tenant_base_url, client_id, certificate, attributes) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING"#) + .bind(&project.id) + .bind(&okta_config.tenant_base_url) + .bind(&okta_config.client_id) + .bind(&okta_config.certificate) + .bind(okta_config.attributes.join(",").to_string()); query.execute(&mut *transaction).await.void()?; } // store the kafka configuration if any for kafka_config in &project.kafka_config { - let query = query("INSERT OR REPLACE INTO kafka_config VALUES (?, ?)") - .bind(project.id.to_sql()) - .bind(kafka_config.bootstrap_server.to_sql()); + let query = query( + r#" + INSERT INTO kafka_config (project_id, bootstrap_server) + VALUES ($1, $2) + ON CONFLICT DO NOTHING"#, + ) + .bind(&project.id) + .bind(&kafka_config.bootstrap_server); query.execute(&mut *transaction).await.void()?; } @@ -133,8 +163,8 @@ impl ProjectsRepository for ProjectsSqlxDatabase { async fn get_project(&self, project_id: &str) -> Result> { let query = - query("SELECT project_name FROM project WHERE project_id=$1").bind(project_id.to_sql()); - let row: Option = query + query("SELECT project_name FROM project WHERE project_id = $1").bind(project_id); + let row: Option = query .fetch_optional(&*self.database.pool) .await .into_core()?; @@ -150,14 +180,15 @@ impl ProjectsRepository for ProjectsSqlxDatabase { async fn get_project_by_name(&self, name: &str) -> Result> { let mut transaction = self.database.begin().await.into_core()?; - let query = query_as("SELECT project_id, project_name, is_default, space_id, space_name, project_identifier, project_change_history, access_route, authority_change_history, authority_access_route, version, running, operation_id FROM project WHERE project_name=$1").bind(name.to_sql()); + let query = query_as("SELECT project_id, project_name, is_default, space_id, space_name, project_identifier, project_change_history, access_route, authority_change_history, authority_access_route, version, running, operation_id FROM project WHERE project_name = $1").bind(name); let row: Option = query.fetch_optional(&mut *transaction).await.into_core()?; let project = match row.map(|r| r.project()).transpose()? { Some(mut project) => { // get the project users emails - let query2 = - query_as("SELECT project_id, user_email FROM user_project WHERE project_id=$1") - .bind(project.id.to_sql()); + let query2 = query_as( + "SELECT project_id, user_email FROM user_project WHERE project_id = $1", + ) + .bind(&project.id); let rows: Vec = query2.fetch_all(&mut *transaction).await.into_core()?; let users: Result> = @@ -165,8 +196,8 @@ impl ProjectsRepository for ProjectsSqlxDatabase { project.users = users?; // get the project users roles - let query3 = query_as("SELECT user_id, project_id, user_email, role, scope FROM user_role WHERE project_id=$1") - .bind(project.id.to_sql()); + let query3 = query_as("SELECT user_id, project_id, user_email, role, scope FROM user_role WHERE project_id = $1") + .bind(&project.id); let rows: Vec = query3.fetch_all(&mut *transaction).await.into_core()?; let user_roles: Vec = rows @@ -176,17 +207,17 @@ impl ProjectsRepository for ProjectsSqlxDatabase { project.user_roles = user_roles; // get the project okta configuration - let query4 = query_as("SELECT project_id, tenant_base_url, client_id, certificate, attributes FROM okta_config WHERE project_id=$1") - .bind(project.id.to_sql()); + let query4 = query_as("SELECT project_id, tenant_base_url, client_id, certificate, attributes FROM okta_config WHERE project_id = $1") + .bind(&project.id); let row: Option = query4.fetch_optional(&mut *transaction).await.into_core()?; project.okta_config = row.map(|r| r.okta_config()).transpose()?; // get the project kafka configuration let query5 = query_as( - "SELECT project_id, bootstrap_server FROM kafka_config WHERE project_id=$1", + "SELECT project_id, bootstrap_server FROM kafka_config WHERE project_id = $1", ) - .bind(project.id.to_sql()); + .bind(&project.id); let row: Option = query5.fetch_optional(&mut *transaction).await.into_core()?; project.kafka_config = row.map(|r| r.kafka_config()); @@ -202,7 +233,7 @@ impl ProjectsRepository for ProjectsSqlxDatabase { async fn get_projects(&self) -> Result> { let query = query("SELECT project_name FROM project"); - let rows: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; + let rows: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; let project_names: Vec = rows.iter().map(|r| r.get(0)).collect(); let mut projects = vec![]; for project_name in project_names { @@ -215,9 +246,8 @@ impl ProjectsRepository for ProjectsSqlxDatabase { } async fn get_default_project(&self) -> Result> { - let query = - query("SELECT project_name FROM project WHERE is_default=$1").bind(true.to_sql()); - let row: Option = query + let query = query("SELECT project_name FROM project WHERE is_default = $1").bind(true); + let row: Option = query .fetch_optional(&*self.database.pool) .await .into_core()?; @@ -233,15 +263,15 @@ impl ProjectsRepository for ProjectsSqlxDatabase { async fn set_default_project(&self, project_id: &str) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; // set the project as the default one - let query1 = query("UPDATE project SET is_default = ? WHERE project_id = ?") - .bind(true.to_sql()) - .bind(project_id.to_sql()); + let query1 = query("UPDATE project SET is_default = $1 WHERE project_id = $2") + .bind(true) + .bind(project_id); query1.execute(&mut *transaction).await.void()?; // set all the others as non-default - let query2 = query("UPDATE project SET is_default = ? WHERE project_id <> ?") - .bind(false.to_sql()) - .bind(project_id.to_sql()); + let query2 = query("UPDATE project SET is_default = $1 WHERE project_id <> $2") + .bind(false) + .bind(project_id); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void() } @@ -249,19 +279,19 @@ impl ProjectsRepository for ProjectsSqlxDatabase { async fn delete_project(&self, project_id: &str) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; - let query1 = query("DELETE FROM project WHERE project_id=?").bind(project_id.to_sql()); + let query1 = query("DELETE FROM project WHERE project_id = $1").bind(project_id); query1.execute(&mut *transaction).await.void()?; - let query2 = query("DELETE FROM user_project WHERE project_id=?").bind(project_id.to_sql()); + let query2 = query("DELETE FROM user_project WHERE project_id = $1").bind(project_id); query2.execute(&mut *transaction).await.void()?; - let query3 = query("DELETE FROM user_role WHERE project_id=?").bind(project_id.to_sql()); + let query3 = query("DELETE FROM user_role WHERE project_id = $1").bind(project_id); query3.execute(&mut *transaction).await.void()?; - let query4 = query("DELETE FROM okta_config WHERE project_id=?").bind(project_id.to_sql()); + let query4 = query("DELETE FROM okta_config WHERE project_id = $1").bind(project_id); query4.execute(&mut *transaction).await.void()?; - let query5 = query("DELETE FROM kafka_config WHERE project_id=?").bind(project_id.to_sql()); + let query5 = query("DELETE FROM kafka_config WHERE project_id = $1").bind(project_id); query5.execute(&mut *transaction).await.void()?; transaction.commit().await.void()?; @@ -277,17 +307,17 @@ struct ProjectRow { project_id: String, project_name: String, #[allow(unused)] - is_default: bool, + is_default: Boolean, space_id: String, space_name: String, - project_identifier: Option, - project_change_history: Option, + project_identifier: Nullable, + project_change_history: Nullable, access_route: String, - authority_change_history: Option, - authority_access_route: Option, - version: Option, - running: Option, - operation_id: Option, + authority_change_history: Nullable, + authority_access_route: Nullable, + version: Nullable, + running: Nullable, + operation_id: Nullable, } impl ProjectRow { @@ -304,8 +334,8 @@ impl ProjectRow { ) -> Result { let project_identifier = self .project_identifier - .as_ref() - .map(|i| Identifier::from_string(i)) + .to_option() + .map(|i| Identifier::from_string(&i)) .transpose()?; Ok(ProjectModel { id: self.project_id.clone(), @@ -313,13 +343,13 @@ impl ProjectRow { space_id: self.space_id.clone(), space_name: self.space_name.clone(), identity: project_identifier, - project_change_history: self.project_change_history.clone(), + project_change_history: self.project_change_history.to_option(), access_route: self.access_route.clone(), - authority_access_route: self.authority_access_route.clone(), - authority_identity: self.authority_change_history.clone(), - version: self.version.clone(), - running: self.running, - operation_id: self.operation_id.clone(), + authority_access_route: self.authority_access_route.to_option(), + authority_identity: self.authority_change_history.to_option(), + version: self.version.to_option(), + running: self.running.to_option().map(|r| r.to_bool()), + operation_id: self.operation_id.to_option(), users: user_emails, user_roles, okta_config, @@ -353,9 +383,51 @@ struct UserRoleRow { scope: String, } -impl ToSqlxType for EmailAddress { - fn to_sql(&self) -> SqlxType { - self.to_string().to_sql() +impl Type for EmailAddress { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for EmailAddress { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + +impl Type for RoleInShare { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for RoleInShare { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + +impl Type for ShareScope { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for ShareScope { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + +impl Type for Url { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for Url { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) } } @@ -419,87 +491,92 @@ mod test { use super::*; use crate::cli_state::{SpacesRepository, SpacesSqlxDatabase}; + use ockam_node::database::with_dbs; use std::sync::Arc; #[tokio::test] async fn test_repository() -> Result<()> { - let repository = create_repository().await?; - - // create and store 2 projects - let project1 = create_project( - "1", - "name1", - vec!["me@ockam.io", "you@ockam.io"], - vec![ - create_project_user_role(1, RoleInShare::Admin), - create_project_user_role(2, RoleInShare::Guest), - ], - ); - let mut project2 = create_project( - "2", - "name2", - vec!["me@ockam.io", "him@ockam.io", "her@ockam.io"], - vec![ - create_project_user_role(1, RoleInShare::Admin), - create_project_user_role(2, RoleInShare::Guest), - ], - ); - repository.store_project(&project1).await?; - repository.store_project(&project2).await?; - - // retrieve them as a list or by name - let result = repository.get_projects().await?; - assert_eq!(result, vec![project1.clone(), project2.clone()]); - - let result = repository.get_project_by_name("name1").await?; - assert_eq!(result, Some(project1.clone())); - - // a project can be marked as the default project - repository.set_default_project("1").await?; - let result = repository.get_default_project().await?; - assert_eq!(result, Some(project1.clone())); - - repository.set_default_project("2").await?; - let result = repository.get_default_project().await?; - assert_eq!(result, Some(project2.clone())); - - // updating a project which was already the default should keep it the default - project2.users = vec!["someone@ockam.io".try_into().unwrap()]; - repository.store_project(&project2).await?; - let result = repository.get_default_project().await?; - assert_eq!(result, Some(project2.clone())); - - // a project can be deleted - repository.delete_project("2").await?; - let result = repository.get_default_project().await?; - assert_eq!(result, None); - - let result = repository.get_projects().await?; - assert_eq!(result, vec![project1.clone()]); - Ok(()) + with_dbs(|db| async move { + let repository: Arc = Arc::new(ProjectsSqlxDatabase::new(db)); + + // create and store 2 projects + let project1 = create_project( + "1", + "name1", + vec!["me@ockam.io", "you@ockam.io"], + vec![ + create_project_user_role(1, RoleInShare::Admin), + create_project_user_role(2, RoleInShare::Guest), + ], + ); + let mut project2 = create_project( + "2", + "name2", + vec!["me@ockam.io", "him@ockam.io", "her@ockam.io"], + vec![ + create_project_user_role(1, RoleInShare::Admin), + create_project_user_role(2, RoleInShare::Guest), + ], + ); + repository.store_project(&project1).await?; + repository.store_project(&project2).await?; + + // retrieve them as a list or by name + let result = repository.get_projects().await?; + assert_eq!(result, vec![project1.clone(), project2.clone()]); + + let result = repository.get_project_by_name("name1").await?; + assert_eq!(result, Some(project1.clone())); + + // a project can be marked as the default project + repository.set_default_project("1").await?; + let result = repository.get_default_project().await?; + assert_eq!(result, Some(project1.clone())); + + repository.set_default_project("2").await?; + let result = repository.get_default_project().await?; + assert_eq!(result, Some(project2.clone())); + + // updating a project which was already the default should keep it the default + project2.users = vec!["someone@ockam.io".try_into().unwrap()]; + repository.store_project(&project2).await?; + let result = repository.get_default_project().await?; + assert_eq!(result, Some(project2.clone())); + + // a project can be deleted + repository.delete_project("2").await?; + let result = repository.get_default_project().await?; + assert_eq!(result, None); + + let result = repository.get_projects().await?; + assert_eq!(result, vec![project1.clone()]); + Ok(()) + }) + .await } #[tokio::test] async fn test_store_project_space() -> Result<()> { - let db = SqlxDatabase::in_memory("projects").await?; - let projects_repository = ProjectsSqlxDatabase::new(db.clone()); - let project = create_project("1", "name1", vec![], vec![]); - projects_repository.store_project(&project).await?; + with_dbs(|db| async move { + let projects_repository: Arc = + Arc::new(ProjectsSqlxDatabase::new(db.clone())); - // the space information coming from the project must also be stored in the spaces table - let spaces_repository: Arc = Arc::new(SpacesSqlxDatabase::new(db)); - let space = spaces_repository.get_default_space().await?.unwrap(); - assert_eq!(project.space_id, space.id); - assert_eq!(project.space_name, space.name); + let project = create_project("1", "name1", vec![], vec![]); + projects_repository.store_project(&project).await?; - Ok(()) - } + // the space information coming from the project must also be stored in the spaces table + let spaces_repository: Arc = + Arc::new(SpacesSqlxDatabase::new(db)); + let space = spaces_repository.get_default_space().await?.unwrap(); + assert_eq!(project.space_id, space.id); + assert_eq!(project.space_name, space.name); - /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(ProjectsSqlxDatabase::create().await?)) + Ok(()) + }) + .await } + /// HELPERS fn create_project( id: &str, name: &str, diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/spaces_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/spaces_repository_sql.rs index 69e463e4e38..a1bd917b514 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/spaces_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/spaces_repository_sql.rs @@ -1,9 +1,9 @@ -use sqlx::sqlite::SqliteRow; +use sqlx::any::AnyRow; use sqlx::*; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{Boolean, FromSqlxError, SqlxDatabase, ToVoid}; use crate::cloud::space::Space; @@ -32,27 +32,40 @@ impl SpacesRepository for SpacesSqlxDatabase { async fn store_space(&self, space: &Space) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; - let query1 = - query_scalar("SELECT EXISTS (SELECT 1 FROM space WHERE is_default=$1 AND space_id=$2)") - .bind(true.to_sql()) - .bind(space.id.to_sql()); - let is_already_default: bool = query1.fetch_one(&mut *transaction).await.into_core()?; - - let query2 = query("INSERT OR REPLACE INTO space VALUES (?, ?, ?)") - .bind(space.id.to_sql()) - .bind(space.name.to_sql()) - .bind(is_already_default.to_sql()); + let query1 = query_scalar( + "SELECT EXISTS (SELECT 1 FROM space WHERE is_default = $1 AND space_id = $2)", + ) + .bind(true) + .bind(&space.id); + let is_already_default: Boolean = query1.fetch_one(&mut *transaction).await.into_core()?; + let is_already_default = is_already_default.to_bool(); + + let query2 = query( + r#" + INSERT INTO space (space_id, space_name, is_default) + VALUES ($1, $2, $3) + ON CONFLICT (space_id) + DO UPDATE SET space_name = $2, is_default = $3"#, + ) + .bind(&space.id) + .bind(&space.name) + .bind(is_already_default); query2.execute(&mut *transaction).await.void()?; // remove any existing users related to that space if any - let query3 = query("DELETE FROM user_space WHERE space_id=$1").bind(space.id.to_sql()); + let query3 = query("DELETE FROM user_space WHERE space_id = $1").bind(&space.id); query3.execute(&mut *transaction).await.void()?; // store the users associated to that space for user_email in &space.users { - let query4 = query("INSERT OR REPLACE INTO user_space VALUES (?, ?)") - .bind(user_email.to_sql()) - .bind(space.id.to_sql()); + let query4 = query( + r#" + INSERT INTO user_space (user_email, space_id) + VALUES ($1, $2) + ON CONFLICT DO NOTHING"#, + ) + .bind(user_email) + .bind(&space.id); query4.execute(&mut *transaction).await.void()?; } @@ -60,8 +73,8 @@ impl SpacesRepository for SpacesSqlxDatabase { } async fn get_space(&self, space_id: &str) -> Result> { - let query = query("SELECT space_name FROM space WHERE space_id=$1").bind(space_id.to_sql()); - let row: Option = query + let query = query("SELECT space_name FROM space WHERE space_id = $1").bind(space_id); + let row: Option = query .fetch_optional(&*self.database.pool) .await .into_core()?; @@ -77,14 +90,14 @@ impl SpacesRepository for SpacesSqlxDatabase { async fn get_space_by_name(&self, name: &str) -> Result> { let mut transaction = self.database.begin().await.into_core()?; - let query1 = query_as("SELECT space_id, space_name FROM space WHERE space_name=$1") - .bind(name.to_sql()); + let query1 = + query_as("SELECT space_id, space_name FROM space WHERE space_name = $1").bind(name); let row: Option = query1.fetch_optional(&mut *transaction).await.into_core()?; let space = match row.map(|r| r.space()) { Some(mut space) => { let query2 = - query_as("SELECT space_id, user_email FROM user_space WHERE space_id=$1") - .bind(space.id.to_sql()); + query_as("SELECT space_id, user_email FROM user_space WHERE space_id = $1") + .bind(&space.id); let rows: Vec = query2.fetch_all(&mut *transaction).await.into_core()?; let users = rows.into_iter().map(|r| r.user_email).collect(); @@ -105,8 +118,9 @@ impl SpacesRepository for SpacesSqlxDatabase { let mut spaces = vec![]; for space_row in row { - let query2 = query_as("SELECT space_id, user_email FROM user_space WHERE space_id=$1") - .bind(space_row.space_id.to_sql()); + let query2 = + query_as("SELECT space_id, user_email FROM user_space WHERE space_id = $1") + .bind(&space_row.space_id); let rows: Vec = query2.fetch_all(&mut *transaction).await.into_core()?; let users = rows.into_iter().map(|r| r.user_email).collect(); spaces.push(space_row.space_with_user_emails(users)) @@ -118,8 +132,8 @@ impl SpacesRepository for SpacesSqlxDatabase { } async fn get_default_space(&self) -> Result> { - let query = query("SELECT space_name FROM space WHERE is_default=$1").bind(true.to_sql()); - let row: Option = query + let query = query("SELECT space_name FROM space WHERE is_default = $1").bind(true); + let row: Option = query .fetch_optional(&*self.database.pool) .await .into_core()?; @@ -133,15 +147,15 @@ impl SpacesRepository for SpacesSqlxDatabase { async fn set_default_space(&self, space_id: &str) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; // set the space as the default one - let query1 = query("UPDATE space SET is_default = ? WHERE space_id = ?") - .bind(true.to_sql()) - .bind(space_id.to_sql()); + let query1 = query("UPDATE space SET is_default = $1 WHERE space_id = $2") + .bind(true) + .bind(space_id); query1.execute(&mut *transaction).await.void()?; // set all the others as non-default - let query2 = query("UPDATE space SET is_default = ? WHERE space_id <> ?") - .bind(false.to_sql()) - .bind(space_id.to_sql()); + let query2 = query("UPDATE space SET is_default = $1 WHERE space_id <> $2") + .bind(false) + .bind(space_id); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void() } @@ -149,10 +163,10 @@ impl SpacesRepository for SpacesSqlxDatabase { async fn delete_space(&self, space_id: &str) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; - let query1 = query("DELETE FROM space WHERE space_id=?").bind(space_id.to_sql()); + let query1 = query("DELETE FROM space WHERE space_id = $1").bind(space_id); query1.execute(&mut *transaction).await.void()?; - let query2 = query("DELETE FROM user_space WHERE space_id=?").bind(space_id.to_sql()); + let query2 = query("DELETE FROM user_space WHERE space_id = $1").bind(space_id); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void() @@ -193,65 +207,64 @@ struct UserSpaceRow { #[cfg(test)] mod test { use super::*; + use ockam_node::database::with_dbs; use std::sync::Arc; #[tokio::test] async fn test_repository() -> Result<()> { - let repository = create_repository().await?; - - // create and store 2 spaces - let space1 = Space { - id: "1".to_string(), - name: "name1".to_string(), - users: vec!["me@ockam.io".to_string(), "you@ockam.io".to_string()], - }; - let mut space2 = Space { - id: "2".to_string(), - name: "name2".to_string(), - users: vec![ - "me@ockam.io".to_string(), - "him@ockam.io".to_string(), - "her@ockam.io".to_string(), - ], - }; - - repository.store_space(&space1).await?; - repository.store_space(&space2).await?; - - // retrieve them as a vector or by name - let result = repository.get_spaces().await?; - assert_eq!(result, vec![space1.clone(), space2.clone()]); - - let result = repository.get_space_by_name("name1").await?; - assert_eq!(result, Some(space1.clone())); - - // a space can be marked as the default space - repository.set_default_space("1").await?; - let result = repository.get_default_space().await?; - assert_eq!(result, Some(space1.clone())); - - repository.set_default_space("2").await?; - let result = repository.get_default_space().await?; - assert_eq!(result, Some(space2.clone())); - - // updating a space which was already the default should keep it the default - space2.users = vec!["someone@ockam.io".to_string()]; - repository.store_space(&space2).await?; - let result = repository.get_default_space().await?; - assert_eq!(result, Some(space2.clone())); - - // a space can be deleted - repository.delete_space("2").await?; - let result = repository.get_default_space().await?; - assert_eq!(result, None); - - let result = repository.get_spaces().await?; - assert_eq!(result, vec![space1.clone()]); - Ok(()) - } - - /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(SpacesSqlxDatabase::create().await?)) + with_dbs(|db| async move { + let repository: Arc = Arc::new(SpacesSqlxDatabase::new(db)); + + // create and store 2 spaces + let space1 = Space { + id: "1".to_string(), + name: "name1".to_string(), + users: vec!["me@ockam.io".to_string(), "you@ockam.io".to_string()], + }; + let mut space2 = Space { + id: "2".to_string(), + name: "name2".to_string(), + users: vec![ + "me@ockam.io".to_string(), + "him@ockam.io".to_string(), + "her@ockam.io".to_string(), + ], + }; + + repository.store_space(&space1).await?; + repository.store_space(&space2).await?; + + // retrieve them as a vector or by name + let result = repository.get_spaces().await?; + assert_eq!(result, vec![space1.clone(), space2.clone()]); + + let result = repository.get_space_by_name("name1").await?; + assert_eq!(result, Some(space1.clone())); + + // a space can be marked as the default space + repository.set_default_space("1").await?; + let result = repository.get_default_space().await?; + assert_eq!(result, Some(space1.clone())); + + repository.set_default_space("2").await?; + let result = repository.get_default_space().await?; + assert_eq!(result, Some(space2.clone())); + + // updating a space which was already the default should keep it the default + space2.users = vec!["someone@ockam.io".to_string()]; + repository.store_space(&space2).await?; + let result = repository.get_default_space().await?; + assert_eq!(result, Some(space2.clone())); + + // a space can be deleted + repository.delete_space("2").await?; + let result = repository.get_default_space().await?; + assert_eq!(result, None); + + let result = repository.get_spaces().await?; + assert_eq!(result, vec![space1.clone()]); + Ok(()) + }) + .await } } diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/tcp_portals_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/tcp_portals_repository_sql.rs index bd3f097e899..bf8df959335 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/tcp_portals_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/tcp_portals_repository_sql.rs @@ -6,7 +6,7 @@ use sqlx::*; use tracing::debug; use crate::nodes::models::portal::OutletStatus; -use ockam::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam::{FromSqlxError, SqlxDatabase, ToVoid}; use ockam_core::errcode::{Kind, Origin}; use ockam_core::Error; use ockam_core::Result; @@ -44,11 +44,16 @@ impl TcpPortalsRepository for TcpPortalsSqlxDatabase { node_name: &str, tcp_inlet: &TcpInlet, ) -> ockam_core::Result<()> { - let query = query("INSERT OR REPLACE INTO tcp_inlet VALUES (?, ?, ?, ?)") - .bind(node_name.to_sql()) - .bind(tcp_inlet.bind_addr().to_string().to_sql()) - .bind(tcp_inlet.outlet_addr().to_string().to_sql()) - .bind(tcp_inlet.alias().to_sql()); + let query = query( + r#" + INSERT INTO tcp_inlet (node_name, bind_addr, outlet_addr, alias) + VALUES ($1, $2, $3, $4) + ON CONFLICT DO NOTHING"#, + ) + .bind(node_name) + .bind(tcp_inlet.bind_addr().to_string()) + .bind(tcp_inlet.outlet_addr().to_string()) + .bind(tcp_inlet.alias()); query.execute(&*self.database.pool).await.void()?; Ok(()) } @@ -59,10 +64,10 @@ impl TcpPortalsRepository for TcpPortalsSqlxDatabase { alias: &str, ) -> ockam_core::Result> { let query = query_as( - "SELECT bind_addr, outlet_addr, alias FROM tcp_inlet WHERE node_name = ? AND alias = ?", + "SELECT bind_addr, outlet_addr, alias FROM tcp_inlet WHERE node_name = $1 AND alias = $2", ) - .bind(node_name.to_sql()) - .bind(alias.to_sql()); + .bind(node_name) + .bind(alias); let result: Option = query .fetch_optional(&*self.database.pool) .await @@ -71,9 +76,9 @@ impl TcpPortalsRepository for TcpPortalsSqlxDatabase { } async fn delete_tcp_inlet(&self, node_name: &str, alias: &str) -> ockam_core::Result<()> { - let query = query("DELETE FROM tcp_inlet WHERE node_name = ? AND alias = ?") - .bind(node_name.to_sql()) - .bind(alias.to_sql()); + let query = query("DELETE FROM tcp_inlet WHERE node_name = $1 AND alias = $2") + .bind(node_name) + .bind(alias); query.execute(&*self.database.pool).await.into_core()?; Ok(()) } @@ -83,11 +88,16 @@ impl TcpPortalsRepository for TcpPortalsSqlxDatabase { node_name: &str, tcp_outlet_status: &OutletStatus, ) -> ockam_core::Result<()> { - let query = query("INSERT OR REPLACE INTO tcp_outlet_status VALUES (?, ?, ?, ?)") - .bind(node_name.to_sql()) - .bind(tcp_outlet_status.socket_addr.to_sql()) - .bind(tcp_outlet_status.worker_addr.to_sql()) - .bind(tcp_outlet_status.payload.as_ref().map(|p| p.to_sql())); + let query = query( + r#" + INSERT INTO tcp_outlet_status (node_name, socket_addr, worker_addr, payload) + VALUES ($1, $2, $3, $4) + ON CONFLICT DO NOTHING"#, + ) + .bind(node_name) + .bind(tcp_outlet_status.socket_addr.to_string()) + .bind(tcp_outlet_status.worker_addr.to_string()) + .bind(tcp_outlet_status.payload.as_ref()); query.execute(&*self.database.pool).await.void()?; Ok(()) } @@ -97,9 +107,9 @@ impl TcpPortalsRepository for TcpPortalsSqlxDatabase { node_name: &str, worker_addr: &Address, ) -> ockam_core::Result> { - let query = query_as("SELECT socket_addr, worker_addr, payload FROM tcp_outlet_status WHERE node_name = ? AND worker_addr = ?") - .bind(node_name.to_sql()) - .bind(worker_addr.to_sql()); + let query = query_as("SELECT socket_addr, worker_addr, payload FROM tcp_outlet_status WHERE node_name = $1 AND worker_addr = $2") + .bind(node_name) + .bind(worker_addr.to_string()); let result: Option = query .fetch_optional(&*self.database.pool) .await @@ -112,9 +122,10 @@ impl TcpPortalsRepository for TcpPortalsSqlxDatabase { node_name: &str, worker_addr: &Address, ) -> ockam_core::Result<()> { - let query = query("DELETE FROM tcp_outlet_status WHERE node_name = ? AND worker_addr = ?") - .bind(node_name.to_sql()) - .bind(worker_addr.to_sql()); + let query = + query("DELETE FROM tcp_outlet_status WHERE node_name = $1 AND worker_addr = $2") + .bind(node_name) + .bind(worker_addr.to_string()); query.execute(&*self.database.pool).await.into_core()?; Ok(()) } @@ -174,52 +185,47 @@ impl TcpOutletStatusRow { #[cfg(test)] mod tests { use super::*; + use ockam_node::database::with_dbs; #[tokio::test] async fn test_repository() -> Result<()> { - let db = create_database().await?; - let repository = create_repository(db.clone()); - - let tcp_inlet = TcpInlet::new( - &SocketAddr::from_str("127.0.0.1:80").unwrap(), - &MultiAddr::from_str("/node/outlet").unwrap(), - "alias", - ); - repository.store_tcp_inlet("node_name", &tcp_inlet).await?; - let actual = repository.get_tcp_inlet("node_name", "alias").await?; - assert_eq!(actual, Some(tcp_inlet.clone())); - - repository.delete_tcp_inlet("node_name", "alias").await?; - let actual = repository.get_tcp_inlet("node_name", "alias").await?; - assert_eq!(actual, None); - - let worker_addr = Address::from_str("worker_addr").unwrap(); - let tcp_outlet_status = OutletStatus::new( - SocketAddr::from_str("127.0.0.1:80").unwrap(), - worker_addr.clone(), - Some("payload".to_string()), - ); - repository - .store_tcp_outlet("node_name", &tcp_outlet_status) - .await?; - let actual = repository.get_tcp_outlet("node_name", &worker_addr).await?; - assert_eq!(actual, Some(tcp_outlet_status.clone())); - - repository - .delete_tcp_outlet("node_name", &worker_addr) - .await?; - let actual = repository.get_tcp_outlet("node_name", &worker_addr).await?; - assert_eq!(actual, None); - - Ok(()) - } - - /// HELPERS - fn create_repository(db: SqlxDatabase) -> Arc { - Arc::new(TcpPortalsSqlxDatabase::new(db)) - } - - async fn create_database() -> Result { - SqlxDatabase::in_memory("test").await + with_dbs(|db| async move { + let repository: Arc = + Arc::new(TcpPortalsSqlxDatabase::new(db)); + + let tcp_inlet = TcpInlet::new( + &SocketAddr::from_str("127.0.0.1:80").unwrap(), + &MultiAddr::from_str("/node/outlet").unwrap(), + "alias", + ); + repository.store_tcp_inlet("node_name", &tcp_inlet).await?; + let actual = repository.get_tcp_inlet("node_name", "alias").await?; + assert_eq!(actual, Some(tcp_inlet.clone())); + + repository.delete_tcp_inlet("node_name", "alias").await?; + let actual = repository.get_tcp_inlet("node_name", "alias").await?; + assert_eq!(actual, None); + + let worker_addr = Address::from_str("worker_addr").unwrap(); + let tcp_outlet_status = OutletStatus::new( + SocketAddr::from_str("127.0.0.1:80").unwrap(), + worker_addr.clone(), + Some("payload".to_string()), + ); + repository + .store_tcp_outlet("node_name", &tcp_outlet_status) + .await?; + let actual = repository.get_tcp_outlet("node_name", &worker_addr).await?; + assert_eq!(actual, Some(tcp_outlet_status.clone())); + + repository + .delete_tcp_outlet("node_name", &worker_addr) + .await?; + let actual = repository.get_tcp_outlet("node_name", &worker_addr).await?; + assert_eq!(actual, None); + + Ok(()) + }) + .await } } diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/users_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/users_repository_sql.rs index 195abd9e989..c45b0f44653 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/users_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/users_repository_sql.rs @@ -3,7 +3,7 @@ use sqlx::*; use crate::cloud::email_address::EmailAddress; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{Boolean, FromSqlxError, SqlxDatabase, ToVoid}; use crate::cloud::enroll::auth0::UserInfo; @@ -32,28 +32,34 @@ impl UsersRepository for UsersSqlxDatabase { async fn store_user(&self, user: &UserInfo) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; - let query1 = - query_scalar("SELECT EXISTS(SELECT email FROM user WHERE is_default=$1 AND email=$2)") - .bind(true.to_sql()) - .bind(user.email.to_sql()); - let is_already_default: bool = query1.fetch_one(&mut *transaction).await.into_core()?; - - let query2 = query("INSERT OR REPLACE INTO user VALUES ($1, $2, $3, $4, $5, $6, $7, $8)") - .bind(user.email.to_sql()) - .bind(user.sub.to_sql()) - .bind(user.nickname.to_sql()) - .bind(user.name.to_sql()) - .bind(user.picture.to_sql()) - .bind(user.updated_at.to_sql()) - .bind(user.email_verified.to_sql()) - .bind(is_already_default.to_sql()); + let query1 = query_scalar( + r#"SELECT EXISTS(SELECT email FROM "user" WHERE is_default = $1 AND email = $2)"#, + ) + .bind(true) + .bind(&user.email); + let is_already_default: Boolean = query1.fetch_one(&mut *transaction).await.into_core()?; + let is_already_default = is_already_default.to_bool(); + + let query2 = query(r#" + INSERT INTO "user" (email, sub, nickname, name, picture, updated_at, email_verified, is_default) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (email) + DO UPDATE SET sub = $2, nickname = $3, name = $4, picture = $5, updated_at = $6, email_verified = $7, is_default = $8"#) + .bind(&user.email) + .bind(&user.sub) + .bind(&user.nickname) + .bind(&user.name) + .bind(&user.picture) + .bind(&user.updated_at) + .bind(user.email_verified) + .bind(is_already_default); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void() } async fn get_default_user(&self) -> Result> { - let query = query_as("SELECT email, sub, nickname, name, picture, updated_at, email_verified, is_default FROM user WHERE is_default=$1").bind(true.to_sql()); + let query = query_as(r#"SELECT email, sub, nickname, name, picture, updated_at, email_verified, is_default FROM "user" WHERE is_default = $1"#).bind(true); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -62,14 +68,14 @@ impl UsersRepository for UsersSqlxDatabase { } async fn set_default_user(&self, email: &EmailAddress) -> Result<()> { - let query = query("UPDATE user SET is_default = ? WHERE email = ?") - .bind(true.to_sql()) - .bind(email.to_sql()); + let query = query(r#"UPDATE "user" SET is_default = $1 WHERE email = $2"#) + .bind(true) + .bind(email); query.execute(&*self.database.pool).await.void() } async fn get_user(&self, email: &EmailAddress) -> Result> { - let query = query_as("SELECT email, sub, nickname, name, picture, updated_at, email_verified, is_default FROM user WHERE email=$1").bind(email.to_sql()); + let query = query_as(r#"SELECT email, sub, nickname, name, picture, updated_at, email_verified, is_default FROM "user" WHERE email = $1"#).bind(email); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -78,13 +84,15 @@ impl UsersRepository for UsersSqlxDatabase { } async fn get_users(&self) -> Result> { - let query = query_as("SELECT email, sub, nickname, name, picture, updated_at, email_verified, is_default FROM user"); + let query = query_as( + r#"SELECT email, sub, nickname, name, picture, updated_at, email_verified, is_default FROM "user""#, + ); let rows: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; rows.iter().map(|u| u.user()).collect() } async fn delete_user(&self, email: &EmailAddress) -> Result<()> { - let query1 = query("DELETE FROM user WHERE email=?").bind(email.to_sql()); + let query1 = query(r#"DELETE FROM "user" WHERE email = $1"#).bind(email); query1.execute(&*self.database.pool).await.void() } } @@ -100,9 +108,9 @@ struct UserRow { name: String, picture: String, updated_at: String, - email_verified: bool, + email_verified: Boolean, #[allow(unused)] - is_default: bool, + is_default: Boolean, } impl UserRow { @@ -114,7 +122,7 @@ impl UserRow { name: self.name.clone(), picture: self.picture.clone(), updated_at: self.updated_at.clone(), - email_verified: self.email_verified, + email_verified: self.email_verified.to_bool(), }) } } @@ -123,61 +131,61 @@ impl UserRow { mod test { use super::*; + use ockam_node::database::with_dbs; use std::sync::Arc; #[tokio::test] async fn test_repository() -> Result<()> { - let repository = create_repository().await?; - let my_email_address: EmailAddress = "me@ockam.io".try_into().unwrap(); - let your_email_address: EmailAddress = "you@ockam.io".try_into().unwrap(); - - // create and store 2 users - let user1 = UserInfo { - sub: "sub".into(), - nickname: "me".to_string(), - name: "me".to_string(), - picture: "me".to_string(), - updated_at: "today".to_string(), - email: my_email_address.clone(), - email_verified: false, - }; - let user2 = UserInfo { - sub: "sub".into(), - nickname: "you".to_string(), - name: "you".to_string(), - picture: "you".to_string(), - updated_at: "today".to_string(), - email: your_email_address.clone(), - email_verified: false, - }; - - repository.store_user(&user1).await?; - repository.store_user(&user2).await?; - - // retrieve them as a vector or by name - let result = repository.get_users().await?; - assert_eq!(result, vec![user1.clone(), user2.clone()]); - - let result = repository.get_user(&my_email_address).await?; - assert_eq!(result, Some(user1.clone())); - - // a user can be set created as the default user - repository.set_default_user(&my_email_address).await?; - let result = repository.get_default_user().await?; - assert_eq!(result, Some(user1.clone())); - - // a user can be deleted - repository.delete_user(&your_email_address).await?; - let result = repository.get_user(&your_email_address).await?; - assert_eq!(result, None); - - let result = repository.get_users().await?; - assert_eq!(result, vec![user1.clone()]); - Ok(()) - } - - /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(UsersSqlxDatabase::create().await?)) + with_dbs(|db| async move { + let repository: Arc = Arc::new(UsersSqlxDatabase::new(db)); + + let my_email_address: EmailAddress = "me@ockam.io".try_into().unwrap(); + let your_email_address: EmailAddress = "you@ockam.io".try_into().unwrap(); + + // create and store 2 users + let user1 = UserInfo { + sub: "sub".into(), + nickname: "me".to_string(), + name: "me".to_string(), + picture: "me".to_string(), + updated_at: "today".to_string(), + email: my_email_address.clone(), + email_verified: false, + }; + let user2 = UserInfo { + sub: "sub".into(), + nickname: "you".to_string(), + name: "you".to_string(), + picture: "you".to_string(), + updated_at: "today".to_string(), + email: your_email_address.clone(), + email_verified: false, + }; + + repository.store_user(&user1).await?; + repository.store_user(&user2).await?; + + // retrieve them as a vector or by name + let result = repository.get_users().await?; + assert_eq!(result, vec![user1.clone(), user2.clone()]); + + let result = repository.get_user(&my_email_address).await?; + assert_eq!(result, Some(user1.clone())); + + // a user can be set created as the default user + repository.set_default_user(&my_email_address).await?; + let result = repository.get_default_user().await?; + assert_eq!(result, Some(user1.clone())); + + // a user can be deleted + repository.delete_user(&your_email_address).await?; + let result = repository.get_user(&your_email_address).await?; + assert_eq!(result, None); + + let result = repository.get_users().await?; + assert_eq!(result, vec![user1.clone()]); + Ok(()) + }) + .await } } diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/vaults_repository.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/vaults_repository.rs index 8e5f95067a0..66cb236a225 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/vaults_repository.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/vaults_repository.rs @@ -1,6 +1,4 @@ -use std::path::Path; - -use crate::cli_state::NamedVault; +use crate::cli_state::{NamedVault, VaultType}; use ockam_core::async_trait; use ockam_core::Result; @@ -9,20 +7,20 @@ use ockam_core::Result; #[async_trait] pub trait VaultsRepository: Send + Sync + 'static { /// Store a new vault path with an associated name - async fn store_vault(&self, name: &str, path: &Path, is_kms: bool) -> Result; + async fn store_vault(&self, name: &str, vault_type: VaultType) -> Result; /// Update a vault path - async fn update_vault(&self, name: &str, path: &Path) -> Result<()>; + async fn update_vault(&self, name: &str, vault_type: VaultType) -> Result<()>; /// Delete a vault given its name async fn delete_named_vault(&self, name: &str) -> Result<()>; + /// Return the database vault if it has been created + async fn get_database_vault(&self) -> Result>; + /// Return a vault by name async fn get_named_vault(&self, name: &str) -> Result>; - /// Return a vault by path - async fn get_named_vault_with_path(&self, path: &Path) -> Result>; - /// Return all vaults async fn get_named_vaults(&self) -> Result>; } diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/storage/vaults_repository_sql.rs b/implementations/rust/ockam/ockam_api/src/cli_state/storage/vaults_repository_sql.rs index 4d674a78545..8b085c70f5e 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/storage/vaults_repository_sql.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/storage/vaults_repository_sql.rs @@ -1,12 +1,13 @@ -use std::path::{Path, PathBuf}; -use std::str::FromStr; +use std::path::PathBuf; use sqlx::*; -use crate::cli_state::{NamedVault, VaultsRepository}; -use ockam::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam::{FromSqlxError, SqlxDatabase, ToVoid}; use ockam_core::async_trait; use ockam_core::Result; +use ockam_node::database::{Boolean, Nullable}; + +use crate::cli_state::{NamedVault, UseAwsKms, VaultType, VaultsRepository}; #[derive(Clone)] pub struct VaultsSqlxDatabase { @@ -27,33 +28,47 @@ impl VaultsSqlxDatabase { #[async_trait] impl VaultsRepository for VaultsSqlxDatabase { - async fn store_vault(&self, name: &str, path: &Path, is_kms: bool) -> Result { - let query = query("INSERT INTO vault VALUES (?1, ?2, ?3, ?4)") - .bind(name.to_sql()) - .bind(path.to_sql()) - .bind(true.to_sql()) - .bind(is_kms.to_sql()); - query.execute(&*self.database.pool).await.void()?; - - Ok(NamedVault::new(name, path.into(), is_kms)) + async fn store_vault(&self, name: &str, vault_type: VaultType) -> Result { + let mut transaction = self.database.begin().await.into_core()?; + + let query1 = + query_scalar("SELECT EXISTS(SELECT 1 FROM vault WHERE is_default = $1)").bind(true); + let default_exists: Boolean = query1.fetch_one(&mut *transaction).await.into_core()?; + let default_exists = default_exists.to_bool(); + + let query = query( + r#" + INSERT INTO + vault (name, path, is_default, is_kms) + VALUES ($1, $2, $3, $4) + ON CONFLICT (name) + DO UPDATE SET path = $2, is_default = $3, is_kms = $4"#, + ) + .bind(name) + .bind(vault_type.path().map(|p| p.to_string_lossy().to_string())) + .bind(!default_exists) + .bind(vault_type.use_aws_kms()); + query.execute(&mut *transaction).await.void()?; + + transaction.commit().await.void()?; + Ok(NamedVault::new(name, vault_type, !default_exists)) } - async fn update_vault(&self, name: &str, path: &Path) -> Result<()> { - let query = query("UPDATE vault SET path=$1 WHERE name=$2") - .bind(path.to_sql()) - .bind(name.to_sql()); + async fn update_vault(&self, name: &str, vault_type: VaultType) -> Result<()> { + let query = query("UPDATE vault SET path = $1, is_kms = $2 WHERE name = $3") + .bind(vault_type.path().map(|p| p.to_string_lossy().to_string())) + .bind(vault_type.use_aws_kms()) + .bind(name); query.execute(&*self.database.pool).await.void() } - /// Delete a vault by name async fn delete_named_vault(&self, name: &str) -> Result<()> { - let query = query("DELETE FROM vault WHERE name=?").bind(name.to_sql()); + let query = query("DELETE FROM vault WHERE name = $1").bind(name); query.execute(&*self.database.pool).await.void() } - async fn get_named_vault(&self, name: &str) -> Result> { - let query = - query_as("SELECT name, path, is_kms FROM vault WHERE name = $1").bind(name.to_sql()); + async fn get_database_vault(&self) -> Result> { + let query = query_as("SELECT name, path, is_default, is_kms FROM vault WHERE path is NULL"); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -61,9 +76,9 @@ impl VaultsRepository for VaultsSqlxDatabase { row.map(|r| r.named_vault()).transpose() } - async fn get_named_vault_with_path(&self, path: &Path) -> Result> { + async fn get_named_vault(&self, name: &str) -> Result> { let query = - query_as("SELECT name, path, is_kms FROM vault WHERE path = $1").bind(path.to_sql()); + query_as("SELECT name, path, is_default, is_kms FROM vault WHERE name = $1").bind(name); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -72,7 +87,7 @@ impl VaultsRepository for VaultsSqlxDatabase { } async fn get_named_vaults(&self) -> Result> { - let query = query_as("SELECT name, path, is_kms FROM vault"); + let query = query_as("SELECT name, path, is_default, is_kms FROM vault"); let rows: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; rows.iter().map(|r| r.named_vault()).collect() } @@ -83,84 +98,93 @@ impl VaultsRepository for VaultsSqlxDatabase { #[derive(FromRow)] pub(crate) struct VaultRow { name: String, - path: String, - is_kms: bool, + path: Nullable, + is_default: Boolean, + is_kms: Boolean, } impl VaultRow { pub(crate) fn named_vault(&self) -> Result { Ok(NamedVault::new( &self.name, - PathBuf::from_str(self.path.as_str()).unwrap(), - self.is_kms, + self.vault_type(), + self.is_default(), )) } + + pub(crate) fn vault_type(&self) -> VaultType { + match self.path.to_option() { + None => VaultType::database(UseAwsKms::from(self.is_kms.to_bool())), + Some(p) => VaultType::local_file( + PathBuf::from(p).as_path(), + UseAwsKms::from(self.is_kms.to_bool()), + ), + } + } + + pub(crate) fn is_default(&self) -> bool { + self.is_default.to_bool() + } } #[cfg(test)] mod test { use super::*; + use ockam_node::database::with_dbs; use std::sync::Arc; #[tokio::test] async fn test_repository() -> Result<()> { - let repository = create_repository().await?; - - // A vault can be defined with a path and stored under a specific name - let named_vault1 = repository - .store_vault("vault1", Path::new("path"), false) - .await?; - let expected = NamedVault::new("vault1", Path::new("path").into(), false); - assert_eq!(named_vault1, expected); - - // A vault with the same name can not be created twice - let result = repository - .store_vault("vault1", Path::new("path"), false) - .await; - assert!(result.is_err()); - - // The vault can then be retrieved with its name - let result = repository.get_named_vault("vault1").await?; - assert_eq!(result, Some(named_vault1.clone())); - - // The vault can then be retrieved with its path - let result = repository - .get_named_vault_with_path(Path::new("path")) - .await?; - assert_eq!(result, Some(named_vault1.clone())); - - // The vault can be set at another path - repository - .update_vault("vault1", Path::new("path2")) - .await?; - let result = repository.get_named_vault("vault1").await?; - assert_eq!( - result, - Some(NamedVault::new("vault1", Path::new("path2").into(), false)) - ); - - // The vault can also be deleted - repository.delete_named_vault("vault1").await?; - let result = repository.get_named_vault("vault1").await?; - assert_eq!(result, None); - Ok(()) + with_dbs(|db| async move { + let repository: Arc = Arc::new(VaultsSqlxDatabase::new(db)); + + // A vault can be defined with a path and stored under a specific name + let vault_type = VaultType::local_file("path", UseAwsKms::No); + let named_vault1 = repository.store_vault("vault1", vault_type.clone()).await?; + let expected = NamedVault::new("vault1", vault_type.clone(), true); + assert_eq!(named_vault1, expected); + + // The vault can then be retrieved with its name + let result = repository.get_named_vault("vault1").await?; + assert_eq!(result, Some(named_vault1.clone())); + + // Another vault can be created. + // It is not the default vault + let vault_type = VaultType::local_file("path2", UseAwsKms::No); + let named_vault2 = repository.store_vault("vault2", vault_type.clone()).await?; + let expected = NamedVault::new("vault2", vault_type.clone(), false); + // it is not the default vault + assert_eq!(named_vault2, expected); + + // The first vault can be set at another path + let vault_type = VaultType::local_file("path2", UseAwsKms::No); + repository + .update_vault("vault1", vault_type.clone()) + .await?; + let result = repository.get_named_vault("vault1").await?; + assert_eq!(result, Some(NamedVault::new("vault1", vault_type, true))); + + // The first vault can be deleted + repository.delete_named_vault("vault1").await?; + let result = repository.get_named_vault("vault1").await?; + assert_eq!(result, None); + Ok(()) + }) + .await } #[tokio::test] async fn test_store_kms_vault() -> Result<()> { - let repository = create_repository().await?; - - // A KMS vault can be created by setting the kms flag to true - let kms = repository - .store_vault("kms", Path::new("path"), true) - .await?; - let expected = NamedVault::new("kms", Path::new("path").into(), true); - assert_eq!(kms, expected); - Ok(()) - } - - /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(VaultsSqlxDatabase::create().await?)) + with_dbs(|db| async move { + let repository: Arc = Arc::new(VaultsSqlxDatabase::new(db)); + + // It is possible to create a vault storing its signing keys in an AWS KMS + let vault_type = VaultType::database(UseAwsKms::Yes); + let kms = repository.store_vault("kms", vault_type.clone()).await?; + let expected = NamedVault::new("kms", vault_type, true); + assert_eq!(kms, expected); + Ok(()) + }) + .await } } diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/test_support.rs b/implementations/rust/ockam/ockam_api/src/cli_state/test_support.rs index a12176e58d7..d34baa6726b 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/test_support.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/test_support.rs @@ -1,12 +1,30 @@ use crate::cli_state::Result; use crate::cli_state::{random_name, CliState, CliStateError}; +use ockam_node::database::SqlxDatabase; use std::path::PathBuf; /// Test support impl CliState { /// Return a test CliState with a random root directory + /// Use this CliState for a simple integration test since every call to that function deletes + /// all previous state if the database being used is Postgres. pub async fn test() -> Result { - Self::create(Self::test_dir()?).await + let test_dir = Self::test_dir()?; + + // clean the existing state if any + let db = SqlxDatabase::create(&CliState::make_database_configuration(&test_dir)?).await?; + db.drop_all_postgres_tables().await?; + + Self::create(test_dir).await + } + + /// Return a test CliState with a random root directory + /// Use this CliState for system tests involving several nodes + /// since calls to that function do not delete + /// any previous state if the database being used is Postgres. + pub async fn system() -> Result { + let test_dir = Self::test_dir()?; + Self::create(test_dir).await } /// Return a random root directory diff --git a/implementations/rust/ockam/ockam_api/src/cli_state/vaults.rs b/implementations/rust/ockam/ockam_api/src/cli_state/vaults.rs index 899f6f9a565..af66a49d322 100644 --- a/implementations/rust/ockam/ockam_api/src/cli_state/vaults.rs +++ b/implementations/rust/ockam/ockam_api/src/cli_state/vaults.rs @@ -28,26 +28,51 @@ impl CliState { /// If the path is not specified then: /// - if this is the first vault then secrets are persisted in the main database /// - if this is a new vault then secrets are persisted in $OCKAM_HOME/vault_name - #[instrument(skip_all, fields(vault_name = vault_name.clone(), path = path.clone().map_or("n/a".to_string(), |p| p.to_string_lossy().to_string())))] + #[instrument(skip_all, fields(vault_name = vault_name.clone()))] pub async fn create_named_vault( &self, - vault_name: &Option, - path: &Option, + vault_name: Option, + path: Option, + use_aws_kms: UseAwsKms, ) -> Result { - self.create_a_vault(vault_name, path, false).await - } + let vaults_repository = self.vaults_repository(); - /// Create a KMS vault with a given name - /// If the path is not specified then: - /// - if this is the first vault then secrets are persisted in the main database - /// - if this is a new vault then secrets are persisted in $OCKAM_HOME/vault_name - #[instrument(skip_all, fields(vault_name = vault_name.clone(), path = path.clone().map_or("n/a".to_string(), |p| p.to_string_lossy().to_string())))] - pub async fn create_kms_vault( - &self, - vault_name: &Option, - path: &Option, - ) -> Result { - self.create_a_vault(vault_name, path, true).await + // determine the vault name to use if not given by the user + let vault_name = match vault_name { + Some(vault_name) => vault_name.clone(), + None => self.make_vault_name().await?, + }; + + // verify that a vault with that name does not exist + if vaults_repository + .get_named_vault(&vault_name) + .await? + .is_some() + { + return Err(CliStateError::AlreadyExists { + resource: "vault".to_string(), + name: vault_name.to_string(), + }); + } + + // Determine if the vault needs to be created at a specific path + // or if data can be stored in the main database directly + match path { + None => match self.vaults_repository().get_database_vault().await? { + None => Ok(vaults_repository + .store_vault(&vault_name, VaultType::database(use_aws_kms)) + .await?), + Some(_) => { + let path = self.make_vault_path(&vault_name); + Ok(self + .create_local_vault(vault_name, &path, use_aws_kms) + .await?) + } + }, + Some(path) => Ok(self + .create_local_vault(vault_name, &path, use_aws_kms) + .await?), + } } /// Delete an existing vault @@ -78,15 +103,14 @@ impl CliState { let vault = repository.get_named_vault(vault_name).await?; if let Some(vault) = vault { repository.delete_named_vault(vault_name).await?; - - // if the vault is stored in a separate file - // remove that file - if vault.path != self.database_path() { - let _ = std::fs::remove_file(vault.path); - } else { - // otherwise delete the tables used by the database vault - self.purpose_keys_repository().delete_all().await?; - self.secrets_repository().delete_all().await?; + match vault.vault_type { + VaultType::DatabaseVault { .. } => { + self.purpose_keys_repository().delete_all().await?; + self.secrets_repository().delete_all().await?; + } + VaultType::LocalFileVault { path, .. } => { + let _ = std::fs::remove_file(path); + } } } Ok(()) @@ -144,7 +168,6 @@ impl CliState { #[instrument(skip_all, fields(vault_name = vault_name))] pub async fn get_or_create_named_vault(&self, vault_name: &str) -> Result { let vaults_repository = self.vaults_repository(); - let is_default = vault_name == DEFAULT_VAULT_NAME; if let Ok(Some(existing_vault)) = vaults_repository.get_named_vault(vault_name).await { return Ok(existing_vault); @@ -153,17 +176,39 @@ impl CliState { self.notify_message(fmt_log!( "This Identity needs a Vault to store its secrets." )); - self.notify_message(fmt_log!( - "There is no default Vault on this machine, creating one..." - )); - let named_vault = self - .create_a_vault(&Some(vault_name.to_string()), &None, false) - .await?; - self.notify_message(fmt_ok!( - "Created a new Vault named {} on your disk.", - color_primary(vault_name) - )); - if is_default { + let named_vault = if self + .vaults_repository() + .get_database_vault() + .await? + .is_none() + { + self.notify_message(fmt_log!( + "There is no default Vault on this machine, creating one..." + )); + let vault = self + .create_database_vault(vault_name.to_string(), UseAwsKms::No) + .await?; + self.notify_message(fmt_ok!( + "Created a new Vault named {}.", + color_primary(vault_name) + )); + vault + } else { + let vault = self + .create_local_vault( + vault_name.to_string(), + &self.make_vault_path(vault_name), + UseAwsKms::No, + ) + .await?; + self.notify_message(fmt_ok!( + "Created a new Vault named {} on your disk.", + color_primary(vault_name) + )); + vault + }; + + if named_vault.is_default() { self.notify_message(fmt_ok!( "Marked this new Vault as your default Vault, on this machine.\n" )); @@ -211,38 +256,52 @@ impl CliState { pub async fn move_vault(&self, vault_name: &str, path: &Path) -> Result<()> { let repository = self.vaults_repository(); let vault = self.get_named_vault(vault_name).await?; - if vault.path() == self.database_path() { - return Err(ockam_core::Error::new(Origin::Api, Kind::Invalid, format!("The vault at path {:?} cannot be moved to {path:?} because this is the default vault", vault.path())))?; - }; - - // copy the file to the new location - std::fs::copy(vault.path(), path)?; - // update the path in the database - repository.update_vault(vault_name, path).await?; - // remove the old file - std::fs::remove_file(vault.path())?; + match vault.vault_type { + VaultType::DatabaseVault { .. } => Err(ockam_core::Error::new( + Origin::Api, + Kind::Invalid, + format!( + "The vault {} cannot be moved to {path:?} because this is the default vault", + vault.name() + ), + ))?, + VaultType::LocalFileVault { + path: old_path, + use_aws_kms, + } => { + // copy the file to the new location + std::fs::copy(&old_path, path)?; + // update the path in the database + repository + .update_vault(vault_name, VaultType::local_file(path, use_aws_kms)) + .await?; + // remove the old file + std::fs::remove_file(old_path)?; + } + } Ok(()) } - /// Move a vault file to another location if the vault is not the default vault - /// contained in the main database - #[instrument(skip_all, fields(vault_name = named_vault.name, path = named_vault.path.to_string_lossy().to_string()))] + /// Make a concrete vault based on the NamedVault metadata + #[instrument(skip_all, fields(vault_name = named_vault.name))] pub async fn make_vault(&self, named_vault: NamedVault) -> Result { - let db = if Some(named_vault.path.as_path()) == self.database_ref().path() { - self.database() - } else { + let db = match named_vault.vault_type { + VaultType::DatabaseVault { .. } => self.database(), + VaultType::LocalFileVault { ref path, .. } => // TODO: Avoid creating multiple dbs with the same file - SqlxDatabase::create(named_vault.path.as_path()).await? + { + SqlxDatabase::create_sqlite(path.as_path()).await? + } }; - let mut vault = Vault::create_with_database(db); - if named_vault.is_kms { + if named_vault.vault_type.use_aws_kms() { + let mut vault = Vault::create_with_database(db); let aws_vault = Arc::new(AwsSigningVault::create().await?); vault.identity_vault = aws_vault.clone(); vault.credential_vault = aws_vault; Ok(vault) } else { - Ok(vault) + Ok(Vault::create_with_database(db)) } } } @@ -259,66 +318,54 @@ impl CliState { /// Private functions impl CliState { - /// Create a vault with the given name and indicate if it is going to be used as a KMS vault - /// If the vault with the same name already exists then an error is returned - /// If there is already a file at the provided path, then an error is returned - #[instrument(skip_all, fields(vault_name = vault_name))] - async fn create_a_vault( + /// Create the database vault if it doesn't exist already + async fn create_database_vault( &self, - vault_name: &Option, - path: &Option, - is_kms: bool, + vault_name: String, + use_aws_kms: UseAwsKms, ) -> Result { - let vaults_repository = self.vaults_repository(); - - // determine the vault name to use if not given by the user - let vault_name = match vault_name { - Some(vault_name) => vault_name.clone(), - None => self.make_vault_name().await?, - }; - - // verify that a vault with that name does not exist - if vaults_repository - .get_named_vault(&vault_name) - .await? - .is_some() - { - return Err(CliStateError::AlreadyExists { - resource: "vault".to_string(), - name: vault_name.to_string(), - }); + match self.vaults_repository().get_database_vault().await? { + None => Ok(self + .vaults_repository() + .store_vault(&vault_name, VaultType::database(use_aws_kms)) + .await?), + Some(vault) => Err(CliStateError::AlreadyExists { + resource: "database vault".to_string(), + name: vault.name().to_string(), + }), } + } - // determine the vault path - // if the vault is the first vault we store the data directly in the main database - // otherwise we open a new file with the vault name - let path = match path { - Some(path) => path.clone(), - None => self.make_vault_path(&vault_name).await?, - }; - + /// Create a vault store in a local file if the path has not been taken already + async fn create_local_vault( + &self, + vault_name: String, + path: &PathBuf, + use_aws_kms: UseAwsKms, + ) -> Result { // check if the new file can be created - let path_taken = self.get_named_vault_with_path(&path).await?.is_some(); + let path_taken = self + .get_named_vaults() + .await? + .iter() + .any(|v| v.path() == Some(path.as_path())); if path_taken { - return Err(CliStateError::AlreadyExists { + Err(CliStateError::AlreadyExists { resource: "vault path".to_string(), name: format!("{path:?}"), - }); + })?; } else { // create a new file if we need to store the vault data outside of the main database - if path != self.database_path() { - // similar to File::create_new which is unstable for now - OpenOptions::new() - .read(true) - .write(true) - .create_new(true) - .open(&path)?; - } + // similar to File::create_new which is unstable for now + OpenOptions::new() + .read(true) + .write(true) + .create_new(true) + .open(path)?; }; - - // store the vault metadata - Ok(vaults_repository - .store_vault(&vault_name, &path, is_kms) + Ok(self + .vaults_repository() + .store_vault(&vault_name, VaultType::local_file(path, use_aws_kms)) .await?) } @@ -338,46 +385,100 @@ impl CliState { } /// Decide which path to use for a vault path: - /// - if no vault has been using the main database, use it /// - otherwise return a new path alongside the database $OCKAM_HOME/vault-{vault_name} - /// - async fn make_vault_path(&self, vault_name: &str) -> Result { - let vaults_repository = self.vaults_repository(); - // is there already a vault using the main database? - let is_database_path_available = vaults_repository - .get_named_vaults() - .await? - .iter() - .all(|v| v.path() != self.database_path()); - if is_database_path_available { - Ok(self.database_path()) - } else { - Ok(self.dir().join(format!("vault-{vault_name}"))) - } - } - - async fn get_named_vault_with_path(&self, path: &Path) -> Result> { - Ok(self - .vaults_repository() - .get_named_vault_with_path(path) - .await?) + fn make_vault_path(&self, vault_name: &str) -> PathBuf { + self.dir().join(format!("vault-{vault_name}")) } } #[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] pub struct NamedVault { name: String, - path: PathBuf, - is_kms: bool, + vault_type: VaultType, + is_default: bool, +} + +#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] +pub enum VaultType { + DatabaseVault { + use_aws_kms: UseAwsKms, + }, + LocalFileVault { + path: PathBuf, + use_aws_kms: UseAwsKms, + }, +} + +impl Display for VaultType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "Type: {}", + match &self { + VaultType::DatabaseVault { .. } => "INTERNAL", + VaultType::LocalFileVault { .. } => "EXTERNAL", + } + )?; + if self.use_aws_kms() { + writeln!(f, "Uses AWS KMS: true",)?; + } + Ok(()) + } +} + +#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)] +pub enum UseAwsKms { + Yes, + No, +} + +impl UseAwsKms { + pub fn from(b: bool) -> Self { + if b { + UseAwsKms::Yes + } else { + UseAwsKms::No + } + } +} + +impl VaultType { + pub fn database(use_aws_kms: UseAwsKms) -> Self { + VaultType::DatabaseVault { use_aws_kms } + } + + pub fn local_file(path: impl Into, use_aws_kms: UseAwsKms) -> Self { + VaultType::LocalFileVault { + path: path.into(), + use_aws_kms, + } + } + + pub fn path(&self) -> Option<&Path> { + match self { + VaultType::DatabaseVault { .. } => None, + VaultType::LocalFileVault { path, .. } => Some(path.as_path()), + } + } + + pub fn use_aws_kms(&self) -> bool { + match self { + VaultType::DatabaseVault { use_aws_kms } => use_aws_kms == &UseAwsKms::Yes, + VaultType::LocalFileVault { + path: _, + use_aws_kms, + } => use_aws_kms == &UseAwsKms::Yes, + } + } } impl NamedVault { /// Create a new named vault - pub fn new(name: &str, path: PathBuf, is_kms: bool) -> Self { + pub fn new(name: &str, vault_type: VaultType, is_default: bool) -> Self { Self { name: name.to_string(), - path, - is_kms, + vault_type, + is_default, } } @@ -386,33 +487,38 @@ impl NamedVault { self.name.clone() } - /// Return the vault path - pub fn path(&self) -> PathBuf { - self.path.clone() + /// Return the vault type + pub fn vault_type(&self) -> VaultType { + self.vault_type.clone() } - /// Return the vault path as a String - pub fn path_as_string(&self) -> String { - self.path.clone().to_string_lossy().to_string() + /// Return true if this is the default vault + pub fn is_default(&self) -> bool { + self.is_default } - /// Return true if this vault is a KMS vault - pub fn is_kms(&self) -> bool { - self.is_kms + /// Return true if an AWS KMS is used to store signing keys + pub fn use_aws_kms(&self) -> bool { + self.vault_type.use_aws_kms() + } + + /// Return the vault path if the vault data is stored in a local file + pub fn path(&self) -> Option<&Path> { + self.vault_type.path() + } + + /// Return the vault path as a String + pub fn path_as_string(&self) -> Option { + self.vault_type + .path() + .map(|p| p.to_string_lossy().to_string()) } } impl Display for NamedVault { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { writeln!(f, "Name: {}", self.name)?; - writeln!( - f, - "Type: {}", - match self.is_kms { - true => "AWS KMS", - false => "OCKAM", - } - )?; + writeln!(f, "{}", self.vault_type)?; Ok(()) } } @@ -420,15 +526,18 @@ impl Display for NamedVault { impl Output for NamedVault { fn item(&self) -> crate::Result { let mut output = String::new(); - writeln!(output, "Name: {}", self.name())?; + writeln!(output, "Name: {}", self.name)?; writeln!( output, "Type: {}", - match self.is_kms() { - true => "AWS KMS", - false => "OCKAM", + match &self.vault_type { + VaultType::DatabaseVault { .. } => "INTERNAL", + VaultType::LocalFileVault { .. } => "EXTERNAL", } )?; + if self.vault_type.use_aws_kms() { + writeln!(output, "Uses AWS KMS: true",)?; + } Ok(output) } } @@ -442,13 +551,13 @@ mod tests { ECDSASHA256CurveP256SecretKey, ECDSASHA256CurveP256Signature, HandleToSecret, SigningSecret, SigningSecretKeyHandle, X25519SecretKey, X25519SecretKeyHandle, }; - use tempfile::NamedTempFile; #[tokio::test] async fn test_create_named_vault() -> Result<()> { let cli = CliState::test().await?; // create a vault + // since this is the first one, the data is stored in the database let named_vault1 = cli.get_or_create_named_vault("vault1").await?; let result = cli.get_named_vault("vault1").await?; @@ -456,19 +565,12 @@ mod tests { // another vault cannot be created with the same name let result = cli - .create_named_vault(&Some("vault1".to_string()), &None) + .create_named_vault(Some("vault1".to_string()), None, UseAwsKms::No) .await .ok(); assert_eq!(result, None); - // another vault cannot be created with the same path - let result = cli - .create_named_vault(&None, &Some(named_vault1.path())) - .await - .ok(); - assert_eq!(result, None); - - // the first created vault is the default one if it is the only one + // the first created vault is the default one let result = cli.get_or_create_default_named_vault().await?; assert_eq!(result, named_vault1.clone()); @@ -484,6 +586,19 @@ mod tests { let named_vault2 = cli.get_or_create_named_vault("vault2").await?; + // that vault is using a local file + assert!(named_vault2.path().is_some()); + // another vault cannot be created with the same path + let result = cli + .create_named_vault( + Some("another name".to_string()), + named_vault2.path().map(|p| p.to_path_buf()), + UseAwsKms::No, + ) + .await + .ok(); + assert_eq!(result, None); + let result = cli.get_named_vaults().await?; assert_eq!(result, vec![named_vault1.clone(), named_vault2.clone()]); @@ -537,7 +652,7 @@ mod tests { // if we create a second vault, it can be returned by name let vault2 = cli - .create_named_vault(&Some("vault-2".to_string()), &None) + .create_named_vault(Some("vault-2".to_string()), None, UseAwsKms::No) .await?; let result = cli.get_named_vault_or_default(&Some(vault2.name())).await?; assert_eq!(result, vault2); @@ -551,16 +666,14 @@ mod tests { #[tokio::test] async fn test_move_vault() -> Result<()> { - let db_file = NamedTempFile::new().unwrap(); - let cli_state_directory = db_file.path().parent().unwrap().join(random_name()); - let cli = CliState::create(cli_state_directory.clone()).await?; + let cli = CliState::test().await?; // create a vault let _ = cli.get_or_create_named_vault("vault1").await?; // try to move it. That should fail because the first vault is // stored in the main database - let new_vault_path = cli_state_directory.join("new-vault-name"); + let new_vault_path = cli.dir().join("new-vault-name"); let result = cli.move_vault("vault1", &new_vault_path).await; assert!(result.is_err()); @@ -569,15 +682,15 @@ mod tests { // try to move it. This should succeed let result = cli - .move_vault("vault2", &cli_state_directory.join("new-vault-name")) + .move_vault("vault2", &cli.dir().join("new-vault-name")) .await; if let Err(e) = result { panic!("{}", e.to_string()) }; let vault = cli.get_named_vault("vault2").await?; - assert_eq!(vault.path(), new_vault_path); - assert!(vault.path().exists()); + assert_eq!(vault.path(), Some(new_vault_path.as_path())); + assert!(new_vault_path.exists()); Ok(()) } @@ -587,33 +700,36 @@ mod tests { let cli = CliState::test().await?; // the first vault is stored in the main database with the name 'default' - let result = cli.create_named_vault(&None, &None).await?; + let result = cli.create_named_vault(None, None, UseAwsKms::No).await?; assert_eq!(result.name(), DEFAULT_VAULT_NAME.to_string()); - assert_eq!(result.path(), cli.database_path()); + assert_eq!(result.vault_type(), VaultType::database(UseAwsKms::No)); // the second vault is stored in a separate file, with a random name // that name is used to create the file name - let result = cli.create_named_vault(&None, &None).await?; + let result = cli.create_named_vault(None, None, UseAwsKms::No).await?; + assert!(result.path().is_some()); assert!(result .path_as_string() + .unwrap() .ends_with(&format!("vault-{}", result.name()))); // a third vault with a name is also stored in a separate file let result = cli - .create_named_vault(&Some("secrets".to_string()), &None) + .create_named_vault(Some("secrets".to_string()), None, UseAwsKms::No) .await?; assert_eq!(result.name(), "secrets".to_string()); - assert!(result.path_as_string().contains("vault-secrets")); + assert!(result.path().is_some()); + assert!(result.path_as_string().unwrap().contains("vault-secrets")); // if we reset, we can check that the first vault gets the user defined name // instead of default cli.reset().await?; let cli = CliState::test().await?; let result = cli - .create_named_vault(&Some("secrets".to_string()), &None) + .create_named_vault(Some("secrets".to_string()), None, UseAwsKms::No) .await?; assert_eq!(result.name(), "secrets".to_string()); - assert_eq!(result.path(), cli.database_path()); + assert_eq!(result.vault_type(), VaultType::database(UseAwsKms::No)); Ok(()) } @@ -621,13 +737,17 @@ mod tests { #[tokio::test] async fn test_create_vault_with_a_user_path() -> Result<()> { let cli = CliState::test().await?; - let vault_path = cli.database_path().parent().unwrap().join(random_name()); + let vault_path = cli.dir().join(random_name()); let result = cli - .create_named_vault(&Some("secrets".to_string()), &Some(vault_path.clone())) + .create_named_vault( + Some("secrets".to_string()), + Some(vault_path.clone()), + UseAwsKms::No, + ) .await?; assert_eq!(result.name(), "secrets".to_string()); - assert_eq!(result.path(), vault_path); + assert_eq!(result.path(), Some(vault_path.as_path())); Ok(()) } @@ -637,7 +757,7 @@ mod tests { let cli = CliState::test().await?; // create a vault and populate the tables used by the vault - let vault = cli.create_named_vault(&None, &None).await?; + let vault = cli.create_named_vault(None, None, UseAwsKms::No).await?; let purpose_keys_repository = cli.purpose_keys_repository(); let identity = cli.create_identity_with_name("name").await?; diff --git a/implementations/rust/ockam/ockam_api/src/kafka/integration_test.rs b/implementations/rust/ockam/ockam_api/src/kafka/integration_test.rs index 7950ce28d37..af946fa9b7b 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/integration_test.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/integration_test.rs @@ -45,7 +45,7 @@ mod test { use crate::kafka::{ ConsumerPublishing, ConsumerResolution, KafkaInletController, KafkaPortalListener, }; - use crate::test_utils::NodeManagerHandle; + use crate::test_utils::{NodeManagerHandle, TestNode}; // TODO: upgrade to 13 by adding a metadata request to map uuid<=>topic_name const TEST_KAFKA_API_VERSION: i16 = 12; @@ -131,6 +131,7 @@ mod test { async fn producer__flow_with_mock_kafka__content_encryption_and_decryption( context: &mut Context, ) -> ockam::Result<()> { + TestNode::clean().await?; let handle = crate::test_utils::start_manager_for_tests(context, None, None).await?; let consumer_bootstrap_port = create_kafka_service( diff --git a/implementations/rust/ockam/ockam_api/src/kafka/portal_worker.rs b/implementations/rust/ockam/ockam_api/src/kafka/portal_worker.rs index 049dbe02316..c6c85af0796 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/portal_worker.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/portal_worker.rs @@ -477,7 +477,7 @@ mod test { use crate::kafka::secure_channel_map::controller::KafkaSecureChannelControllerImpl; use crate::kafka::{ConsumerPublishing, ConsumerResolution}; use crate::port_range::PortRange; - use crate::test_utils::NodeManagerHandle; + use crate::test_utils::{NodeManagerHandle, TestNode}; use ockam::MessageReceiveOptions; use ockam_abac::{ Action, Env, Policies, Resource, ResourcePolicySqlxDatabase, ResourceType, @@ -517,6 +517,7 @@ mod test { async fn kafka_portal_worker__ping_pong_pass_through__should_pass( context: &mut Context, ) -> ockam::Result<()> { + TestNode::clean().await?; let handle = crate::test_utils::start_manager_for_tests(context, None, None).await?; let portal_inlet_address = setup_only_worker(context, &handle).await; @@ -554,6 +555,7 @@ mod test { async fn kafka_portal_worker__pieces_of_kafka_message__message_assembled( context: &mut Context, ) -> ockam::Result<()> { + TestNode::clean().await?; let handle = crate::test_utils::start_manager_for_tests(context, None, None).await?; let portal_inlet_address = setup_only_worker(context, &handle).await; @@ -596,6 +598,7 @@ mod test { async fn kafka_portal_worker__double_kafka_message__message_assembled( context: &mut Context, ) -> ockam::Result<()> { + TestNode::clean().await?; let handle = crate::test_utils::start_manager_for_tests(context, None, None).await?; let portal_inlet_address = setup_only_worker(context, &handle).await; @@ -633,6 +636,7 @@ mod test { async fn kafka_portal_worker__bigger_than_limit_kafka_message__error( context: &mut Context, ) -> ockam::Result<()> { + TestNode::clean().await?; let handle = crate::test_utils::start_manager_for_tests(context, None, None).await?; let portal_inlet_address = setup_only_worker(context, &handle).await; @@ -686,6 +690,7 @@ mod test { async fn kafka_portal_worker__almost_over_limit_than_limit_kafka_message__two_kafka_message_pass( context: &mut Context, ) -> ockam::Result<()> { + TestNode::clean().await?; let handle = crate::test_utils::start_manager_for_tests(context, None, None).await?; let portal_inlet_address = setup_only_worker(context, &handle).await; @@ -859,6 +864,7 @@ mod test { async fn kafka_portal_worker__metadata_exchange__response_changed( context: &mut Context, ) -> ockam::Result<()> { + TestNode::clean().await?; let handle = crate::test_utils::start_manager_for_tests(context, None, None).await?; let project_authority = handle .node_manager diff --git a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/tests.rs b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/tests.rs index f9799553518..62d7afe7fab 100644 --- a/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/tests.rs +++ b/implementations/rust/ockam/ockam_api/src/kafka/protocol_aware/tests.rs @@ -7,6 +7,7 @@ mod test { use crate::kafka::secure_channel_map::controller::KafkaSecureChannelControllerImpl; use crate::kafka::{ConsumerPublishing, ConsumerResolution}; use crate::port_range::PortRange; + use crate::test_utils::TestNode; use kafka_protocol::messages::ApiKey; use kafka_protocol::messages::BrokerId; use kafka_protocol::messages::{ApiVersionsRequest, MetadataRequest, MetadataResponse}; @@ -22,6 +23,7 @@ mod test { async fn interceptor__basic_messages_with_several_api_versions__parsed_correctly( context: &mut Context, ) -> ockam::Result<()> { + TestNode::clean().await?; let handle = crate::test_utils::start_manager_for_tests(context, None, None).await?; let inlet_map = KafkaInletController::new( diff --git a/implementations/rust/ockam/ockam_api/src/nodes/service/in_memory_node.rs b/implementations/rust/ockam/ockam_api/src/nodes/service/in_memory_node.rs index 2112e2cc6b9..80bc6f68f37 100644 --- a/implementations/rust/ockam/ockam_api/src/nodes/service/in_memory_node.rs +++ b/implementations/rust/ockam/ockam_api/src/nodes/service/in_memory_node.rs @@ -58,10 +58,13 @@ impl Drop for InMemoryNode { // because in that case they can be restarted if !self.persistent { executor::block_on(async { - self.node_manager - .delete_node() + // We need to recreate the CliState here to make sure that + // we get a fresh connection to the database (otherwise this code blocks) + let cli_state = CliState::create(self.cli_state.dir()).await.unwrap(); + cli_state + .remove_node(&self.node_name) .await - .unwrap_or_else(|e| panic!("cannot delete the node {}: {e:?}", self.node_name)) + .unwrap_or_else(|e| panic!("cannot delete the node {}: {e:?}", self.node_name)); }); } } diff --git a/implementations/rust/ockam/ockam_api/src/test_utils/mod.rs b/implementations/rust/ockam/ockam_api/src/test_utils/mod.rs index 1507d7ab1c6..6b0c32fbaf0 100644 --- a/implementations/rust/ockam/ockam_api/src/test_utils/mod.rs +++ b/implementations/rust/ockam/ockam_api/src/test_utils/mod.rs @@ -21,6 +21,7 @@ use ockam::tcp::{TcpListenerOptions, TcpTransport}; use ockam::transport::HostnamePort; use ockam::Result; use ockam_core::AsyncTryClone; +use ockam_node::database::{DatabaseConfiguration, SqlxDatabase}; use crate::authenticator::credential_issuer::{DEFAULT_CREDENTIAL_VALIDITY, PROJECT_MEMBER_SCHEMA}; use crate::cli_state::{random_name, CliState}; @@ -42,7 +43,9 @@ pub struct NodeManagerHandle { impl Drop for NodeManagerHandle { fn drop(&mut self) { - self.cli_state.delete().expect("cannot delete cli state"); + self.cli_state + .delete_local_data() + .expect("cannot delete cli state"); } } @@ -63,7 +66,7 @@ pub async fn start_manager_for_tests( ) .await?; - let cli_state = CliState::test().await?; + let cli_state = CliState::system().await?; let node_name = random_name(); cli_state @@ -202,6 +205,18 @@ pub struct TestNode { } impl TestNode { + /// If the database being used for the tests is Postgres then it is shared across all the tests and + /// needs be cleaned-up before a test is executed + pub async fn clean() -> Result<()> { + if let Some(configuration) = DatabaseConfiguration::postgres()? { + let db = SqlxDatabase::create_no_migration(&configuration) + .await + .unwrap(); + db.drop_all_postgres_tables().await?; + }; + Ok(()) + } + pub async fn create(runtime: Arc, listen_addr: Option<&str>) -> Self { let (mut context, mut executor) = NodeBuilder::new().with_runtime(runtime.clone()).build(); runtime.spawn(async move { diff --git a/implementations/rust/ockam/ockam_api/src/ui/output/utils.rs b/implementations/rust/ockam/ockam_api/src/ui/output/utils.rs index bd390c53324..348d62f1900 100644 --- a/implementations/rust/ockam/ockam_api/src/ui/output/utils.rs +++ b/implementations/rust/ockam/ockam_api/src/ui/output/utils.rs @@ -47,6 +47,15 @@ pub fn colorize_connection_status(status: ConnectionStatus) -> CString { } } +pub fn indent(indent: impl Into, text: impl Into) -> String { + let indent: String = indent.into(); + text.into() + .split('\n') + .map(|line| format!("{indent}{line}")) + .collect::>() + .join("\n") +} + #[cfg(test)] mod tests { use super::*; @@ -57,4 +66,10 @@ mod tests { let result = comma_separated(&data); assert_eq!(result, "a, b, c"); } + + #[test] + fn test_indent() { + let result = indent("---", "line1\nthen line2\n and finally line3"); + assert_eq!(result, "---line1\n---then line2\n--- and finally line3"); + } } diff --git a/implementations/rust/ockam/ockam_api/tests/common/common.rs b/implementations/rust/ockam/ockam_api/tests/common/common.rs index 745d2a125ff..5388d4d53cf 100644 --- a/implementations/rust/ockam/ockam_api/tests/common/common.rs +++ b/implementations/rust/ockam/ockam_api/tests/common/common.rs @@ -1,5 +1,6 @@ use core::time::Duration; +use log::debug; use ockam::identity::models::CredentialSchemaIdentifier; use ockam::identity::utils::AttributesBuilder; use ockam::identity::{ @@ -13,6 +14,7 @@ use ockam_api::config::lookup::InternetAddress; use ockam_api::nodes::NodeManager; use ockam_core::Result; use ockam_multiaddr::MultiAddr; +use ockam_node::database::DatabaseConfiguration; use ockam_node::Context; use ockam_transport_tcp::TcpTransport; use rand::{thread_rng, Rng}; @@ -29,7 +31,7 @@ pub async fn default_configuration() -> Result { let mut configuration = authority_node::Configuration { identifier: "I4dba4b2e53b2ed95967b3bab350b6c9ad9c624e5a1b2c3d4e5f6a6b5c4d3e2f1" .try_into()?, - database_path, + database_configuration: DatabaseConfiguration::sqlite(database_path.as_path()), project_identifier: "123456".to_string(), tcp_listener_address: InternetAddress::new(&format!("127.0.0.1:{}", port)).unwrap(), secure_channel_listener_name: None, @@ -129,7 +131,7 @@ pub async fn start_authority( configuration.no_direct_authentication = false; configuration.no_token_enrollment = false; - println!( + debug!( "common.rs about to call authority::start_node with {:?}", configuration.account_authority.is_some() ); diff --git a/implementations/rust/ockam/ockam_api/tests/latency.rs b/implementations/rust/ockam/ockam_api/tests/latency.rs index 451b633a7a1..a5874d833d3 100644 --- a/implementations/rust/ockam/ockam_api/tests/latency.rs +++ b/implementations/rust/ockam/ockam_api/tests/latency.rs @@ -27,6 +27,7 @@ pub fn measure_message_latency_two_nodes() { let result: ockam::Result<()> = runtime_cloned.block_on(async move { let test_body = async move { + TestNode::clean().await?; let mut first_node = TestNode::create(runtime.clone(), None).await; let second_node = TestNode::create(runtime.clone(), None).await; @@ -124,6 +125,7 @@ pub fn measure_buffer_latency_two_nodes_portal() { let test_body = async move { let echo_server_handle = start_tcp_echo_server().await; + TestNode::clean().await?; let first_node = TestNode::create(runtime.clone(), None).await; let second_node = TestNode::create(runtime.clone(), None).await; diff --git a/implementations/rust/ockam/ockam_api/tests/portals.rs b/implementations/rust/ockam/ockam_api/tests/portals.rs index 3ee6832b2a3..5e4b6441d66 100644 --- a/implementations/rust/ockam/ockam_api/tests/portals.rs +++ b/implementations/rust/ockam/ockam_api/tests/portals.rs @@ -21,6 +21,7 @@ use tracing::info; #[ockam_macros::test] async fn inlet_outlet_local_successful(context: &mut Context) -> ockam::Result<()> { + TestNode::clean().await?; let echo_server_handle = start_tcp_echo_server().await; let node_manager_handle = start_manager_for_tests(context, None, None).await?; @@ -96,6 +97,7 @@ fn portal_node_goes_down_reconnect() { let test_body = async move { let echo_server_handle = start_tcp_echo_server().await; + TestNode::clean().await?; let first_node = TestNode::create(runtime_cloned.clone(), None).await; let second_node = TestNode::create(runtime_cloned.clone(), None).await; @@ -237,6 +239,7 @@ fn portal_low_bandwidth_connection_keep_working_for_60s() { let test_body = async move { let echo_server_handle = start_tcp_echo_server().await; + TestNode::clean().await?; let first_node = TestNode::create(runtime_cloned.clone(), None).await; let second_node = TestNode::create(runtime_cloned, None).await; @@ -354,6 +357,7 @@ fn portal_heavy_load_exchanged() { let test_body = async move { let echo_server_handle = start_tcp_echo_server().await; + TestNode::clean().await?; let first_node = TestNode::create(runtime_cloned.clone(), None).await; let second_node = TestNode::create(runtime_cloned, None).await; @@ -496,6 +500,7 @@ fn test_portal_payload_transfer(outgoing_disruption: Disruption, incoming_disrup let test_body = async move { let echo_server_handle = start_tcp_echo_server().await; + TestNode::clean().await?; let first_node = TestNode::create(runtime_cloned.clone(), None).await; let second_node = TestNode::create(runtime_cloned, None).await; diff --git a/implementations/rust/ockam/ockam_app_lib/Cargo.toml b/implementations/rust/ockam/ockam_app_lib/Cargo.toml index 728d790fff6..99860d66b6e 100644 --- a/implementations/rust/ockam/ockam_app_lib/Cargo.toml +++ b/implementations/rust/ockam/ockam_app_lib/Cargo.toml @@ -43,7 +43,7 @@ ockam_core = { path = "../ockam_core", version = "^0.111.0" } ockam_multiaddr = { path = "../ockam_multiaddr", version = "0.55.0", features = ["cbor", "serde"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -sqlx = { version = "0.7.4", features = ["runtime-tokio", "sqlite", "migrate"] } +sqlx = { git = "https://github.com/etorreborre/sqlx", rev = "5fec648d2de0cbeed738dcf1c6f5bc9194fc439b" } thiserror = "1.0" tokio = { version = "1.38.0", features = ["full"] } tracing = { version = "0.1", default-features = false } diff --git a/implementations/rust/ockam/ockam_app_lib/src/state/model_state_repository_sql.rs b/implementations/rust/ockam/ockam_app_lib/src/state/model_state_repository_sql.rs index 9e7b8092049..af2fdf6970a 100644 --- a/implementations/rust/ockam/ockam_app_lib/src/state/model_state_repository_sql.rs +++ b/implementations/rust/ockam/ockam_app_lib/src/state/model_state_repository_sql.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use sqlx::*; use tracing::debug; -use ockam::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam::{Boolean, FromSqlxError, Nullable, SqlxDatabase, ToVoid}; use ockam_api::nodes::models::portal::OutletStatus; use ockam_core::errcode::{Kind, Origin}; use ockam_core::Error; @@ -44,19 +44,24 @@ impl ModelStateRepository for ModelStateSqlxDatabase { let mut transaction = self.database.begin().await.into_core()?; // remove previous tcp_outlet_status state - query("DELETE FROM tcp_outlet_status where node_name = ?") - .bind(node_name.to_sql()) + query("DELETE FROM tcp_outlet_status where node_name = $1") + .bind(node_name) .execute(&mut *transaction) .await .void()?; // re-insert the new state for tcp_outlet_status in &model_state.tcp_outlets { - let query = query("INSERT OR REPLACE INTO tcp_outlet_status VALUES (?, ?, ?, ?)") - .bind(node_name.to_sql()) - .bind(tcp_outlet_status.socket_addr.to_sql()) - .bind(tcp_outlet_status.worker_addr.to_sql()) - .bind(tcp_outlet_status.payload.as_ref().map(|p| p.to_sql())); + let query = query( + r#" + INSERT INTO tcp_outlet_status (node_name, socket_addr, worker_addr, payload) + VALUES ($1, $2, $3, $4) + ON CONFLICT DO NOTHING"#, + ) + .bind(node_name) + .bind(tcp_outlet_status.socket_addr.to_string()) + .bind(tcp_outlet_status.worker_addr.to_string()) + .bind(tcp_outlet_status.payload.as_ref()); query.execute(&mut *transaction).await.void()?; } @@ -68,10 +73,16 @@ impl ModelStateRepository for ModelStateSqlxDatabase { // re-insert the new state for incoming_service in &model_state.incoming_services { - let query = query("INSERT OR REPLACE INTO incoming_service VALUES (?, ?, ?)") - .bind(incoming_service.invitation_id.to_sql()) - .bind(incoming_service.enabled.to_sql()) - .bind(incoming_service.name.as_ref().map(|n| n.to_sql())); + let query = query( + r#" + INSERT INTO incoming_service (invitation_id, enabled, name) + VALUES ($1, $2, $3) + ON CONFLICT (invitation_id) + DO UPDATE SET enabled = $2, name = $3"#, + ) + .bind(&incoming_service.invitation_id) + .bind(incoming_service.enabled) + .bind(incoming_service.name.as_ref()); query.execute(&mut *transaction).await.void()?; } transaction.commit().await.void()?; @@ -81,9 +92,9 @@ impl ModelStateRepository for ModelStateSqlxDatabase { async fn load(&self, node_name: &str) -> Result { let query1 = query_as( - "SELECT socket_addr, worker_addr, payload FROM tcp_outlet_status WHERE node_name = ?", + "SELECT socket_addr, worker_addr, payload FROM tcp_outlet_status WHERE node_name = $1", ) - .bind(node_name.to_sql()); + .bind(node_name); let result: Vec = query1.fetch_all(&*self.database.pool).await.into_core()?; let tcp_outlets = result @@ -109,7 +120,7 @@ impl ModelStateRepository for ModelStateSqlxDatabase { struct TcpOutletStatusRow { socket_addr: String, worker_addr: String, - payload: Option, + payload: Nullable, } impl TcpOutletStatusRow { @@ -120,7 +131,7 @@ impl TcpOutletStatusRow { Ok(OutletStatus { socket_addr, worker_addr, - payload: self.payload.clone(), + payload: self.payload.to_option(), }) } } @@ -129,101 +140,97 @@ impl TcpOutletStatusRow { #[derive(sqlx::FromRow)] struct PersistentIncomingServiceRow { invitation_id: String, - enabled: bool, - name: Option, + enabled: Boolean, + name: Nullable, } impl PersistentIncomingServiceRow { fn persistent_incoming_service(&self) -> Result { Ok(PersistentIncomingService { invitation_id: self.invitation_id.clone(), - enabled: self.enabled, - name: self.name.clone(), + enabled: self.enabled.to_bool(), + name: self.name.to_option(), }) } } #[cfg(test)] mod tests { + use super::*; + use ockam::with_dbs; use ockam_api::nodes::models::portal::OutletStatus; use ockam_core::Address; - use super::*; - #[tokio::test] - async fn store_and_load() -> Result<()> { - let db = create_database().await?; - let repository = create_repository(db.clone()); - let node_name = "node"; - - let mut state = ModelState::default(); - repository.store(node_name, &state).await?; - let loaded = repository.load(node_name).await?; - assert!(state.tcp_outlets.is_empty()); - assert_eq!(state, loaded); - - // Add a tcp outlet - state.add_tcp_outlet(OutletStatus::new( - "127.0.0.1:1001".parse()?, - Address::from_string("s1"), - None, - )); - // Add an incoming service - state.add_incoming_service(PersistentIncomingService { - invitation_id: "1235".to_string(), - enabled: true, - name: Some("aws".to_string()), - }); - repository.store(node_name, &state).await?; - let loaded = repository.load(node_name).await?; - assert_eq!(state.tcp_outlets.len(), 1); - assert_eq!(state.incoming_services.len(), 1); - assert_eq!(state, loaded); - - // Add a few more - for i in 2..=5 { + async fn store_and_load() -> ockam_core::Result<()> { + with_dbs(|db| async move { + let repository: Arc = + Arc::new(ModelStateSqlxDatabase::new(db.clone())); + + let node_name = "node"; + + let mut state = ModelState::default(); + repository.store(node_name, &state).await.unwrap(); + let loaded = repository.load(node_name).await.unwrap(); + assert!(state.tcp_outlets.is_empty()); + assert_eq!(state, loaded); + + // Add a tcp outlet state.add_tcp_outlet(OutletStatus::new( - format!("127.0.0.1:100{i}").parse().unwrap(), - Address::from_string(format!("s{i}")), + "127.0.0.1:1001".parse().unwrap(), + Address::from_string("s1"), None, )); + // Add an incoming service + state.add_incoming_service(PersistentIncomingService { + invitation_id: "1235".to_string(), + enabled: true, + name: Some("aws".to_string()), + }); repository.store(node_name, &state).await.unwrap(); - } - let loaded = repository.load(node_name).await?; - assert_eq!(state.tcp_outlets.len(), 5); - assert_eq!(state, loaded); - - // Reload from DB scratch to emulate an app restart - let repository = create_repository(db); - let loaded = repository.load(node_name).await?; - assert_eq!(state.tcp_outlets.len(), 5); - assert_eq!(state.incoming_services.len(), 1); - assert_eq!(state, loaded); - - // Remove some values from the current state - let _ = state.tcp_outlets.split_off(2); - state.add_incoming_service(PersistentIncomingService { - invitation_id: "4567".to_string(), - enabled: true, - name: Some("aws".to_string()), - }); - - repository.store(node_name, &state).await?; - let loaded = repository.load(node_name).await?; - - assert_eq!(state.tcp_outlets.len(), 2); - assert_eq!(state.incoming_services.len(), 2); - assert_eq!(state, loaded); + let loaded = repository.load(node_name).await.unwrap(); + assert_eq!(state.tcp_outlets.len(), 1); + assert_eq!(state.incoming_services.len(), 1); + assert_eq!(state, loaded); + + // Add a few more + for i in 2..=5 { + state.add_tcp_outlet(OutletStatus::new( + format!("127.0.0.1:100{i}").parse().unwrap(), + Address::from_string(format!("s{i}")), + None, + )); + repository.store(node_name, &state).await.unwrap(); + } + let loaded = repository.load(node_name).await.unwrap(); + assert_eq!(state.tcp_outlets.len(), 5); + assert_eq!(state, loaded); + + // Reload from DB scratch to emulate an app restart + let repository: Arc = + Arc::new(ModelStateSqlxDatabase::new(db)); + let loaded = repository.load(node_name).await.unwrap(); + assert_eq!(state.tcp_outlets.len(), 5); + assert_eq!(state.incoming_services.len(), 1); + assert_eq!(state, loaded); + + // Remove some values from the current state + let _ = state.tcp_outlets.split_off(2); + state.add_incoming_service(PersistentIncomingService { + invitation_id: "4567".to_string(), + enabled: true, + name: Some("aws".to_string()), + }); - Ok(()) - } + repository.store(node_name, &state).await.unwrap(); + let loaded = repository.load(node_name).await.unwrap(); - /// HELPERS - fn create_repository(db: SqlxDatabase) -> Arc { - Arc::new(ModelStateSqlxDatabase::new(db)) - } + assert_eq!(state.tcp_outlets.len(), 2); + assert_eq!(state.incoming_services.len(), 2); + assert_eq!(state, loaded); - async fn create_database() -> Result { - Ok(SqlxDatabase::in_memory("enrollments-test").await?) + Ok(()) + }) + .await } } diff --git a/implementations/rust/ockam/ockam_command/src/authority/create.rs b/implementations/rust/ockam/ockam_command/src/authority/create.rs index 7e02766ee62..a5c9d8e2be4 100644 --- a/implementations/rust/ockam/ockam_command/src/authority/create.rs +++ b/implementations/rust/ockam/ockam_command/src/authority/create.rs @@ -392,7 +392,7 @@ impl CreateCommand { let configuration = authority_node::Configuration { identifier: node.identifier(), - database_path: opts.state.database_path(), + database_configuration: opts.state.database_configuration()?, project_identifier: self.project_identifier.clone(), tcp_listener_address: self.tcp_listener_address.clone(), secure_channel_listener_name: None, diff --git a/implementations/rust/ockam/ockam_command/src/environment/static/env_info.txt b/implementations/rust/ockam/ockam_command/src/environment/static/env_info.txt index 5fa2551e97c..2a95a1cf4bd 100644 --- a/implementations/rust/ockam/ockam_command/src/environment/static/env_info.txt +++ b/implementations/rust/ockam/ockam_command/src/environment/static/env_info.txt @@ -21,6 +21,20 @@ Logging - OCKAM_LOG_MAX_FILES: an `integer` that defines the maximum number of log files to keep per node. Default value `60`. - OCKAM_LOG_CRATES_FILTER: a filter for log messages based on crate names: `all`, `default`, comma-separated list of crate names. Default value: `default`, i.e. the list of `ockam` crates. +Database +- OCKAM_POSTGRES_HOST: Postgres database host. Example: 'localhost'. +- OCKAM_POSTGRES_PORT: Postgres database port. Example: 5432. +- OCKAM_POSTGRES_DATABASE_NAME: Postgres database name. Default value: 'postgres'. +- OCKAM_POSTGRES_USER: Postgres database user. If it is not set, no authorization will be used to access the database. +- OCKAM_POSTGRES_PASSWORD: Postgres database password. If it is not set, no authorization will be used to access the database. + +- OCKAM_LOGGING: set this variable to any value in order to enable logging. +- OCKAM_LOG_LEVEL: a `string` that defines the verbosity of the logs when the `--verbose` argument is not passed: `info`, `warn`, `error`, `debug` or `trace`. Default value: `debug`. +- OCKAM_LOG_FORMAT: a `string` that overrides the default format of the logs: `default`, `json`, or `pretty`. Default value: `default`. +- OCKAM_LOG_MAX_SIZE_MB: an `integer` that defines the maximum size of a log file in MB. Default value `100`. +- OCKAM_LOG_MAX_FILES: an `integer` that defines the maximum number of log files to keep per node. Default value `60`. +- OCKAM_LOG_CRATES_FILTER: a filter for log messages based on crate names: `all`, `default`, comma-separated list of crate names. Default value: `default`, i.e. the list of `ockam` crates. + Tracing - OCKAM_OPENTELEMETRY_EXPORT: set this variable to a false value to disable tracing: `0`, `false`, `no`. Default value: `true` - OCKAM_OPENTELEMETRY_ENDPOINT: the URL of an OpenTelemetry collector accepting gRPC. diff --git a/implementations/rust/ockam/ockam_command/src/vault/create.rs b/implementations/rust/ockam/ockam_command/src/vault/create.rs index 927b225fa8e..74c75432a4c 100644 --- a/implementations/rust/ockam/ockam_command/src/vault/create.rs +++ b/implementations/rust/ockam/ockam_command/src/vault/create.rs @@ -3,6 +3,7 @@ use std::path::PathBuf; use async_trait::async_trait; use clap::Args; use colorful::Colorful; +use ockam_api::cli_state::UseAwsKms; use ockam_api::{fmt_info, fmt_ok}; use ockam_node::Context; @@ -39,19 +40,17 @@ impl Command for CreateCommand { "This is the first vault to be created in this environment. It will be set as the default vault" ))?; } - let vault = if self.aws_kms { - opts.state.create_kms_vault(&self.name, &self.path).await? - } else { - opts.state - .create_named_vault(&self.name, &self.path) - .await? - }; + + let vault = opts + .state + .create_named_vault(self.name, self.path, UseAwsKms::from(self.aws_kms)) + .await?; opts.terminal .stdout() .plain(fmt_ok!("Vault created with name '{}'!", vault.name())) .machine(vault.name()) - .json(serde_json::json!({ "name": &self.name })) + .json(serde_json::json!({ "name": &vault.name() })) .write_line()?; Ok(()) } diff --git a/implementations/rust/ockam/ockam_command/src/vault/util.rs b/implementations/rust/ockam/ockam_command/src/vault/util.rs index fb600d5bb18..7a23f0d53a1 100644 --- a/implementations/rust/ockam/ockam_command/src/vault/util.rs +++ b/implementations/rust/ockam/ockam_command/src/vault/util.rs @@ -2,8 +2,9 @@ use colorful::Colorful; use indoc::formatdoc; use ockam_api::cli_state::vaults::NamedVault; +use ockam_api::cli_state::{UseAwsKms, VaultType}; use ockam_api::colors::OckamColor; -use ockam_api::output::Output; +use ockam_api::output::{indent, Output}; #[derive(serde::Serialize)] pub struct VaultOutput { @@ -27,48 +28,82 @@ impl Output for VaultOutput { Ok(formatdoc!( r#" Vault: - Name: {name} - Type: {vault_type} - Path: {vault_path} + {vault} "#, - name = self - .vault - .name() - .to_string() - .color(OckamColor::PrimaryResource.color()), - vault_type = match self.vault.is_kms() { - true => "AWS KMS", - false => "OCKAM", - } - .to_string() - .color(OckamColor::PrimaryResource.color()), - vault_path = self - .vault - .path_as_string() - .color(OckamColor::PrimaryResource.color()), + vault = indent(" ", self.as_list_item()?) )) } fn as_list_item(&self) -> ockam_api::Result { - Ok(formatdoc!( - r#"Name: {name} + let name = self + .vault + .name() + .to_string() + .color(OckamColor::PrimaryResource.color()); + + let vault_type = if self.vault.path().is_some() { + "External" + } else { + "Internal" + } + .to_string() + .color(OckamColor::PrimaryResource.color()); + + let uses_aws_kms = if self.vault.use_aws_kms() { + "true" + } else { + "false" + } + .to_string() + .color(OckamColor::PrimaryResource.color()); + + Ok(match self.vault.vault_type() { + VaultType::DatabaseVault { + use_aws_kms: UseAwsKms::No, + } => formatdoc!( + r#"Name: {name} + Type: {vault_type}"#, + name = name, + vault_type = vault_type + ), + VaultType::DatabaseVault { + use_aws_kms: UseAwsKms::Yes, + } => formatdoc!( + r#"Name: {name} + Type: {vault_type} + Uses AWS KMS: {uses_aws_kms}"#, + name = name, + uses_aws_kms = uses_aws_kms + ), + VaultType::LocalFileVault { + path, + use_aws_kms: UseAwsKms::No, + } => formatdoc!( + r#"Name: {name} Type: {vault_type} Path: {vault_path}"#, - name = self - .vault - .name() - .to_string() - .color(OckamColor::PrimaryResource.color()), - vault_type = match self.vault.is_kms() { - true => "AWS KMS", - false => "OCKAM", - } - .to_string() - .color(OckamColor::PrimaryResource.color()), - vault_path = self - .vault - .path_as_string() - .color(OckamColor::PrimaryResource.color()), - )) + name = name, + vault_type = vault_type, + vault_path = path + .to_string_lossy() + .to_string() + .color(OckamColor::PrimaryResource.color()) + ), + VaultType::LocalFileVault { + path, + use_aws_kms: UseAwsKms::Yes, + } => formatdoc!( + r#"Name: {name} + Type: External + Path: {vault_path} + Uses AWS KMS: {uses_aws_kms}"#, + name = name, + vault_path = path + .to_string_lossy() + .to_string() + .color(OckamColor::PrimaryResource.color()), + uses_aws_kms = uses_aws_kms, + ), + }) } } diff --git a/implementations/rust/ockam/ockam_command/tests/bats/local/jq.bats b/implementations/rust/ockam/ockam_command/tests/bats/local/jq.bats index 5fec1942e31..baa75dce795 100644 --- a/implementations/rust/ockam/ockam_command/tests/bats/local/jq.bats +++ b/implementations/rust/ockam/ockam_command/tests/bats/local/jq.bats @@ -19,7 +19,7 @@ teardown() { run_success "$OCKAM" vault show v1 --output json --jq . assert_output --partial "\"name\":\"v1\"" - assert_output --partial "\"is_kms\":false" + assert_output --partial "\"use_aws_kms\":\"No\"" run_success "$OCKAM" vault show v1 --output json --jq .vault.name assert_output --partial "v1" diff --git a/implementations/rust/ockam/ockam_command/tests/bats/local/vault.bats b/implementations/rust/ockam/ockam_command/tests/bats/local/vault.bats index 8dd0b53a79c..2b0de25b42f 100644 --- a/implementations/rust/ockam/ockam_command/tests/bats/local/vault.bats +++ b/implementations/rust/ockam/ockam_command/tests/bats/local/vault.bats @@ -19,22 +19,22 @@ teardown() { run_success "$OCKAM" vault show v1 --output json assert_output --partial "\"name\":\"v1\"" - assert_output --partial "\"is_kms\":false" + assert_output --partial "\"use_aws_kms\":\"No\"" run_success "$OCKAM" vault list --output json assert_output --partial "\"name\":\"v1\"" - assert_output --partial "\"is_kms\":false" + assert_output --partial "\"use_aws_kms\":\"No\"" run_success "$OCKAM" vault create v2 run_success "$OCKAM" vault show v2 --output json assert_output --partial "\"name\":\"v2\"" - assert_output --partial "\"is_kms\":false" + assert_output --partial "\"use_aws_kms\":\"No\"" run_success "$OCKAM" vault list --output json assert_output --partial "\"name\":\"v1\"" assert_output --partial "\"name\":\"v2\"" - assert_output --partial "\"is_kms\":false" + assert_output --partial "\"use_aws_kms\":\"No\"" } @test "vault - CRUD" { diff --git a/implementations/rust/ockam/ockam_identity/Cargo.toml b/implementations/rust/ockam/ockam_identity/Cargo.toml index 8b47d39f7dc..fc5d34e1d04 100644 --- a/implementations/rust/ockam/ockam_identity/Cargo.toml +++ b/implementations/rust/ockam/ockam_identity/Cargo.toml @@ -88,7 +88,7 @@ serde = { version = "1.0", default-features = false, features = ["derive"] } serde_bare = { version = "0.5.0", default-features = false, features = ["alloc"] } serde_json = { version = "1.0", optional = true } sha2 = { version = "0.10", default-features = false } -sqlx = { version = "0.7.4", optional = true } +sqlx = { git = "https://github.com/etorreborre/sqlx", rev = "5fec648d2de0cbeed738dcf1c6f5bc9194fc439b", optional = true } tokio-retry = { version = "0.3.0", default-features = false, optional = true } tracing = { version = "0.1", default_features = false } tracing-attributes = { version = "0.1", default_features = false } diff --git a/implementations/rust/ockam/ockam_identity/src/identities/storage/change_history_repository_sql.rs b/implementations/rust/ockam/ockam_identity/src/identities/storage/change_history_repository_sql.rs index b2d8ce8143b..685a04c8a28 100644 --- a/implementations/rust/ockam/ockam_identity/src/identities/storage/change_history_repository_sql.rs +++ b/implementations/rust/ockam/ockam_identity/src/identities/storage/change_history_repository_sql.rs @@ -1,13 +1,15 @@ use core::str::FromStr; +use sqlx::any::AnyArguments; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::query::Query; -use sqlx::sqlite::SqliteArguments; use sqlx::*; use tracing::debug; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, SqlxType, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::models::{ChangeHistory, Identifier}; use crate::{ChangeHistoryRepository, Identity, IdentityError, IdentityHistoryComparison, Vault}; @@ -37,8 +39,8 @@ impl ChangeHistoryRepository for ChangeHistorySqlxDatabase { async fn update_identity(&self, identity: &Identity, ignore_older: bool) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; let query1 = - query_as("SELECT identifier, change_history FROM identity WHERE identifier=$1") - .bind(identity.identifier().to_sql()); + query_as("SELECT identifier, change_history FROM identity WHERE identifier = $1") + .bind(identity.identifier()); let row: Option = query1.fetch_optional(&mut *transaction).await.into_core()?; @@ -91,19 +93,20 @@ impl ChangeHistoryRepository for ChangeHistorySqlxDatabase { async fn delete_change_history(&self, identifier: &Identifier) -> Result<()> { let mut transaction = self.database.begin().await.into_core()?; - let query1 = query("DELETE FROM identity where identifier=?").bind(identifier.to_sql()); + let query1 = query("DELETE FROM identity where identifier = $1").bind(identifier); query1.execute(&mut *transaction).await.void()?; let query2 = - query("DELETE FROM identity_attributes where identifier=?").bind(identifier.to_sql()); + query("DELETE FROM identity_attributes where identifier = $1").bind(identifier); query2.execute(&mut *transaction).await.void()?; transaction.commit().await.void()?; Ok(()) } async fn get_change_history(&self, identifier: &Identifier) -> Result> { - let query = query_as("SELECT identifier, change_history FROM identity WHERE identifier=$1") - .bind(identifier.to_sql()); + let query = + query_as("SELECT identifier, change_history FROM identity WHERE identifier = $1") + .bind(identifier); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -120,26 +123,44 @@ impl ChangeHistoryRepository for ChangeHistorySqlxDatabase { impl ChangeHistorySqlxDatabase { fn insert_query<'a>( - identifier: &Identifier, - change_history: &ChangeHistory, - ) -> Query<'a, Sqlite, SqliteArguments<'a>> { - query("INSERT OR REPLACE INTO identity VALUES (?, ?)") - .bind(identifier.to_sql()) - .bind(change_history.to_sql()) + identifier: &'a Identifier, + change_history: &'a ChangeHistory, + ) -> Query<'a, Any, AnyArguments<'a>> { + query( + r#" + INSERT INTO identity (identifier, change_history) + VALUES ($1, $2) + ON CONFLICT (identifier) + DO UPDATE SET change_history = $2"#, + ) + .bind(identifier) + .bind(change_history) } } // Database serialization / deserialization -impl ToSqlxType for Identifier { - fn to_sql(&self) -> SqlxType { - self.to_string().to_sql() +impl Type for Identifier { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for Identifier { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.to_string(), buf) + } +} + +impl Type for ChangeHistory { + fn type_info() -> ::TypeInfo { + >::type_info() } } -impl ToSqlxType for ChangeHistory { - fn to_sql(&self) -> SqlxType { - self.export_as_string().unwrap().to_sql() +impl Encode<'_, Any> for ChangeHistory { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.export_as_string().unwrap(), buf) } } @@ -167,6 +188,7 @@ mod tests { use crate::{identities, Identity}; use ockam_core::compat::sync::Arc; + use ockam_node::database::with_dbs; fn orchestrator_identity() -> (Identifier, ChangeHistory) { let identifier = Identifier::from_str( @@ -180,69 +202,77 @@ mod tests { #[tokio::test] async fn test_identities_repository_has_orchestrator_history() -> Result<()> { - // Clean repository should already have the orchestartor change history - let repository = create_repository().await?; + with_dbs(|db| async move { + // Clean repository should already have the Orchestrator change history + let repository: Arc = + Arc::new(ChangeHistorySqlxDatabase::new(db)); - let (orchestrator_identifier, orchestrator_change_history) = orchestrator_identity(); + let (orchestrator_identifier, orchestrator_change_history) = orchestrator_identity(); - // the change history can be retrieved - let result = repository - .get_change_history(&orchestrator_identifier) - .await?; - assert_eq!(result.as_ref(), Some(&orchestrator_change_history)); + // the change history can be retrieved + let result = repository + .get_change_history(&orchestrator_identifier) + .await?; + assert_eq!(result.as_ref(), Some(&orchestrator_change_history)); - let result = repository.get_change_histories().await?; - assert_eq!(result, vec![orchestrator_change_history]); + let result = repository.get_change_histories().await?; + assert_eq!(result, vec![orchestrator_change_history]); - Ok(()) + Ok(()) + }) + .await } #[tokio::test] async fn test_identities_repository() -> Result<()> { - let identity1 = create_identity().await?; - let identity2 = create_identity().await?; - let repository = create_repository().await?; - - // store and retrieve an identity - repository - .store_change_history(identity1.identifier(), identity1.change_history().clone()) - .await?; + with_dbs(|db| async move { + let repository: Arc = + Arc::new(ChangeHistorySqlxDatabase::new(db)); + let identity1 = create_identity().await?; + let identity2 = create_identity().await?; + + // store and retrieve an identity + repository + .store_change_history(identity1.identifier(), identity1.change_history().clone()) + .await?; - // the change history can be retrieved - let result = repository - .get_change_history(identity1.identifier()) - .await?; - assert_eq!(result, Some(identity1.change_history().clone())); + // the change history can be retrieved + let result = repository + .get_change_history(identity1.identifier()) + .await?; + assert_eq!(result, Some(identity1.change_history().clone())); - // trying to retrieve a missing identity returns None - let result = repository - .get_change_history(identity2.identifier()) - .await?; - assert_eq!(result, None); + // trying to retrieve a missing identity returns None + let result = repository + .get_change_history(identity2.identifier()) + .await?; + assert_eq!(result, None); - // the repository can return the list of all change histories - let (_orchestrator_identifier, orchestrator_change_history) = orchestrator_identity(); - repository - .store_change_history(identity2.identifier(), identity2.change_history().clone()) - .await?; - let result = repository.get_change_histories().await?; - assert_eq!( - result, - vec![ - orchestrator_change_history, - identity1.change_history().clone(), - identity2.change_history().clone(), - ] - ); - // a change history can also be deleted from the repository - repository - .delete_change_history(identity2.identifier()) - .await?; - let result = repository - .get_change_history(identity2.identifier()) - .await?; - assert_eq!(result, None); - Ok(()) + // the repository can return the list of all change histories + let (_orchestrator_identifier, orchestrator_change_history) = orchestrator_identity(); + repository + .store_change_history(identity2.identifier(), identity2.change_history().clone()) + .await?; + let result = repository.get_change_histories().await?; + assert_eq!( + result, + vec![ + orchestrator_change_history, + identity1.change_history().clone(), + identity2.change_history().clone(), + ] + ); + // a change history can also be deleted from the repository + repository + .delete_change_history(identity2.identifier()) + .await?; + let result = repository + .get_change_history(identity2.identifier()) + .await?; + assert_eq!(result, None); + Ok(()) + }) + .await } #[tokio::test] @@ -288,10 +318,6 @@ mod tests { } /// HELPERS - async fn create_repository() -> Result> { - Ok(Arc::new(ChangeHistorySqlxDatabase::create().await?)) - } - async fn create_identity() -> Result { let identities = identities().await?; let identifier = identities.identities_creation().create_identity().await?; diff --git a/implementations/rust/ockam/ockam_identity/src/identities/storage/credential_repository_sql.rs b/implementations/rust/ockam/ockam_identity/src/identities/storage/credential_repository_sql.rs index d49f2a55ec2..739a22aacc4 100644 --- a/implementations/rust/ockam/ockam_identity/src/identities/storage/credential_repository_sql.rs +++ b/implementations/rust/ockam/ockam_identity/src/identities/storage/credential_repository_sql.rs @@ -1,9 +1,11 @@ +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use tracing::debug; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::models::{CredentialAndPurposeKey, Identifier}; use crate::{CredentialRepository, TimestampInSeconds}; @@ -38,8 +40,8 @@ impl CredentialSqlxDatabase { impl CredentialSqlxDatabase { /// Return all cached credentials for the given node pub async fn get_all(&self) -> Result> { - let query = query_as("SELECT credential, scope FROM credential WHERE node_name=?") - .bind(self.node_name.to_sql()); + let query = query_as("SELECT credential, scope FROM credential WHERE node_name = $1") + .bind(self.node_name.clone()); let cached_credential: Vec = query.fetch_all(&*self.database.pool).await.into_core()?; @@ -65,12 +67,12 @@ impl CredentialRepository for CredentialSqlxDatabase { scope: &str, ) -> Result> { let query = query_as( - "SELECT credential FROM credential WHERE subject_identifier=$1 AND issuer_identifier=$2 AND scope=$3 AND node_name=$4" + "SELECT credential FROM credential WHERE subject_identifier = $1 AND issuer_identifier = $2 AND scope = $3 AND node_name = $4" ) - .bind(subject.to_sql()) - .bind(issuer.to_sql()) - .bind(scope.to_sql()) - .bind(self.node_name.to_sql()); + .bind(subject) + .bind(issuer) + .bind(scope) + .bind(self.node_name.clone()); let cached_credential: Option = query .fetch_optional(&*self.database.pool) .await @@ -87,27 +89,55 @@ impl CredentialRepository for CredentialSqlxDatabase { credential: CredentialAndPurposeKey, ) -> Result<()> { let query = query( - "INSERT OR REPLACE INTO credential (subject_identifier, issuer_identifier, scope, credential, expires_at, node_name) VALUES (?, ?, ?, ?, ?, ?)" - ) - .bind(subject.to_sql()) - .bind(issuer.to_sql()) - .bind(scope.to_sql()) - .bind(credential.encode_as_cbor_bytes()?.to_sql()) - .bind(expires_at.to_sql()) - .bind(self.node_name.to_sql()); + r#"INSERT INTO credential (subject_identifier, issuer_identifier, scope, credential, expires_at, node_name) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (subject_identifier, issuer_identifier, scope) + DO UPDATE SET credential = $4, expires_at = $5, node_name = $6"#) + .bind(subject) + .bind(issuer) + .bind(scope) + .bind(credential) + .bind(expires_at) + .bind(self.node_name.clone()); query.execute(&*self.database.pool).await.void() } async fn delete(&self, subject: &Identifier, issuer: &Identifier, scope: &str) -> Result<()> { - let query = query("DELETE FROM credential WHERE subject_identifier=$1 AND issuer_identifier=$2 AND scope=$3 AND node_name=$4") - .bind(subject.to_sql()) - .bind(issuer.to_sql()) - .bind(scope.to_sql()) - .bind(self.node_name.to_sql()); + let query = query("DELETE FROM credential WHERE subject_identifier = $1 AND issuer_identifier = $2 AND scope = $3 AND node_name = $4") + .bind(subject) + .bind(issuer) + .bind(scope) + .bind(self.node_name.clone()); query.execute(&*self.database.pool).await.void() } } +// Database serialization / deserialization + +impl Type for CredentialAndPurposeKey { + fn type_info() -> ::TypeInfo { + as Type>::type_info() + } +} + +impl Encode<'_, Any> for CredentialAndPurposeKey { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + as Encode<'_, Any>>::encode_by_ref(&self.encode_as_cbor_bytes().unwrap(), buf) + } +} + +impl Type for TimestampInSeconds { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for TimestampInSeconds { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&(self.0 as i64), buf) + } +} + // Low-level representation of a table row #[derive(FromRow)] struct CachedCredentialRow { @@ -138,6 +168,7 @@ impl CachedCredentialAndScopeRow { #[cfg(test)] mod tests { use ockam_core::compat::sync::Arc; + use ockam_node::database::with_dbs; use std::time::Duration; use super::*; @@ -147,68 +178,73 @@ mod tests { #[tokio::test] async fn test_cached_credential_repository() -> Result<()> { - let scope = "test".to_string(); - let repository = Arc::new(CredentialSqlxDatabase::create().await?); - - let all = repository.get_all().await?; - assert_eq!(all.len(), 0); - - let identities = identities().await?; - - let issuer = identities.identities_creation().create_identity().await?; - let subject = identities.identities_creation().create_identity().await?; - - let attributes1 = AttributesBuilder::with_schema(CredentialSchemaIdentifier(1)) - .with_attribute("key1", "value1") - .build(); - let credential1 = identities - .credentials() - .credentials_creation() - .issue_credential(&issuer, &subject, attributes1, Duration::from_secs(60 * 60)) - .await?; - - repository - .put( - &subject, - &issuer, - &scope, - credential1.get_credential_data()?.expires_at, - credential1.clone(), - ) - .await?; - - let all = repository.get_all().await?; - assert_eq!(all.len(), 1); - - let credential2 = repository.get(&subject, &issuer, &scope).await?; - assert_eq!(credential2, Some(credential1)); - - let attributes2 = AttributesBuilder::with_schema(CredentialSchemaIdentifier(1)) - .with_attribute("key2", "value2") - .build(); - let credential3 = identities - .credentials() - .credentials_creation() - .issue_credential(&issuer, &subject, attributes2, Duration::from_secs(60 * 60)) - .await?; - repository - .put( - &subject, - &issuer, - &scope, - credential3.get_credential_data()?.expires_at, - credential3.clone(), - ) - .await?; - let all = repository.get_all().await?; - assert_eq!(all.len(), 1); - let credential4 = repository.get(&subject, &issuer, &scope).await?; - assert_eq!(credential4, Some(credential3)); - - repository.delete(&subject, &issuer, &scope).await?; - let result = repository.get(&subject, &issuer, &scope).await?; - assert_eq!(result, None); - - Ok(()) + with_dbs(|db| async move { + let credentials_database = CredentialSqlxDatabase::new(db, "node"); + let repository: Arc = Arc::new(credentials_database.clone()); + + let scope = "test".to_string(); + + let all = credentials_database.get_all().await?; + assert_eq!(all.len(), 0); + + let identities = identities().await?; + + let issuer = identities.identities_creation().create_identity().await?; + let subject = identities.identities_creation().create_identity().await?; + + let attributes1 = AttributesBuilder::with_schema(CredentialSchemaIdentifier(1)) + .with_attribute("key1", "value1") + .build(); + let credential1 = identities + .credentials() + .credentials_creation() + .issue_credential(&issuer, &subject, attributes1, Duration::from_secs(60 * 60)) + .await?; + + repository + .put( + &subject, + &issuer, + &scope, + credential1.get_credential_data()?.expires_at, + credential1.clone(), + ) + .await?; + + let all = credentials_database.get_all().await?; + assert_eq!(all.len(), 1); + + let credential2 = repository.get(&subject, &issuer, &scope).await?; + assert_eq!(credential2, Some(credential1)); + + let attributes2 = AttributesBuilder::with_schema(CredentialSchemaIdentifier(1)) + .with_attribute("key2", "value2") + .build(); + let credential3 = identities + .credentials() + .credentials_creation() + .issue_credential(&issuer, &subject, attributes2, Duration::from_secs(60 * 60)) + .await?; + repository + .put( + &subject, + &issuer, + &scope, + credential3.get_credential_data()?.expires_at, + credential3.clone(), + ) + .await?; + let all = credentials_database.get_all().await?; + assert_eq!(all.len(), 1); + let credential4 = repository.get(&subject, &issuer, &scope).await?; + assert_eq!(credential4, Some(credential3)); + + repository.delete(&subject, &issuer, &scope).await?; + let result = repository.get(&subject, &issuer, &scope).await?; + assert_eq!(result, None); + + Ok(()) + }) + .await } } diff --git a/implementations/rust/ockam/ockam_identity/src/identities/storage/identity_attributes_repository_sql.rs b/implementations/rust/ockam/ockam_identity/src/identities/storage/identity_attributes_repository_sql.rs index 54bbe2b2671..f580a58b24b 100644 --- a/implementations/rust/ockam/ockam_identity/src/identities/storage/identity_attributes_repository_sql.rs +++ b/implementations/rust/ockam/ockam_identity/src/identities/storage/identity_attributes_repository_sql.rs @@ -1,11 +1,13 @@ use core::str::FromStr; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use tracing::debug; use ockam_core::async_trait; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, SqlxType, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, Nullable, SqlxDatabase, ToVoid}; use crate::models::Identifier; use crate::{AttributesEntry, IdentityAttributesRepository, TimestampInSeconds}; @@ -45,11 +47,11 @@ impl IdentityAttributesRepository for IdentityAttributesSqlxDatabase { attested_by: &Identifier, ) -> Result> { let query = query_as( - "SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE identifier=$1 AND attested_by=$2 AND node_name=$3" + "SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE identifier = $1 AND attested_by = $2 AND node_name = $3" ) - .bind(identity.to_sql()) - .bind(attested_by.to_sql()) - .bind(self.node_name.to_sql()); + .bind(identity) + .bind(attested_by) + .bind(&self.node_name); let identity_attributes: Option = query .fetch_optional(&*self.database.pool) .await @@ -59,31 +61,40 @@ impl IdentityAttributesRepository for IdentityAttributesSqlxDatabase { async fn put_attributes(&self, subject: &Identifier, entry: AttributesEntry) -> Result<()> { let query = query( - "INSERT OR REPLACE INTO identity_attributes (identifier, attributes, added, expires, attested_by, node_name) VALUES (?, ?, ?, ?, ?, ?)" - ) - .bind(subject.to_sql()) - .bind(minicbor::to_vec(entry.attrs())?.to_sql()) - .bind(entry.added_at().to_sql()) - .bind(entry.expires_at().map(|e| e.to_sql())) - .bind(entry.attested_by().map(|e| e.to_sql())) - .bind(self.node_name.to_sql()); + r#" + INSERT INTO identity_attributes (identifier, attributes, added, expires, attested_by, node_name) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (identifier, node_name) + DO UPDATE SET attributes = $2, added = $3, expires = $4, attested_by = $5, node_name = $6"#) + .bind(subject) + .bind(&entry) + .bind(entry.added_at()) + .bind(entry.expires_at()) + .bind(entry.attested_by()) + .bind(&self.node_name); query.execute(&*self.database.pool).await.void() } // This query is regularly invoked by IdentitiesAttributes to make sure that we expire attributes regularly async fn delete_expired_attributes(&self, now: TimestampInSeconds) -> Result<()> { - let query = query("DELETE FROM identity_attributes WHERE expires<=? AND node_name=?") - .bind(now.to_sql()) - .bind(self.node_name.to_sql()); + let query = query("DELETE FROM identity_attributes WHERE expires <= $1 AND node_name = $2") + .bind(now) + .bind(&self.node_name); query.execute(&*self.database.pool).await.void() } } // Database serialization / deserialization -impl ToSqlxType for TimestampInSeconds { - fn to_sql(&self) -> SqlxType { - self.0.to_sql() +impl Type for AttributesEntry { + fn type_info() -> ::TypeInfo { + as Type>::type_info() + } +} + +impl Encode<'_, Any> for AttributesEntry { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + as Encode<'_, Any>>::encode_by_ref(&minicbor::to_vec(self.attrs()).unwrap(), buf) } } @@ -93,8 +104,8 @@ struct IdentityAttributesRow { identifier: String, attributes: Vec, added: i64, - expires: Option, - attested_by: Option, + expires: Nullable, + attested_by: Nullable, } impl IdentityAttributesRow { @@ -107,10 +118,13 @@ impl IdentityAttributesRow { let attributes = minicbor::decode(self.attributes.as_slice()).map_err(SqlxDatabase::map_decode_err)?; let added = TimestampInSeconds(self.added as u64); - let expires = self.expires.map(|v| TimestampInSeconds(v as u64)); + let expires = self + .expires + .to_option() + .map(|v| TimestampInSeconds(v as u64)); let attested_by = self .attested_by - .clone() + .to_option() .map(|v| Identifier::from_str(&v)) .transpose()?; @@ -127,6 +141,7 @@ impl IdentityAttributesRow { mod tests { use ockam_core::compat::collections::BTreeMap; use ockam_core::compat::sync::Arc; + use ockam_node::database::with_dbs; use std::ops::Add; use super::*; @@ -135,96 +150,106 @@ mod tests { #[tokio::test] async fn test_identities_attributes_repository() -> Result<()> { - let repository = create_repository().await?; - let now = now()?; - - // store and retrieve attributes by identity - let identifier1 = create_identity().await?; - let attributes1 = create_attributes_entry(&identifier1, now, Some(2.into())).await?; - let identifier2 = create_identity().await?; - let attributes2 = create_attributes_entry(&identifier2, now, Some(2.into())).await?; - - repository - .put_attributes(&identifier1, attributes1.clone()) - .await?; - repository - .put_attributes(&identifier2, attributes2.clone()) - .await?; - - let result = repository - .get_attributes(&identifier1, &identifier1) - .await?; - assert_eq!(result, Some(attributes1.clone())); - - let result = repository - .get_attributes(&identifier2, &identifier2) - .await?; - assert_eq!(result, Some(attributes2.clone())); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(IdentityAttributesSqlxDatabase::new(db, "node")); + + let now = now()?; + + // store and retrieve attributes by identity + let identifier1 = create_identity().await?; + let attributes1 = create_attributes_entry(&identifier1, now, Some(2.into())).await?; + let identifier2 = create_identity().await?; + let attributes2 = create_attributes_entry(&identifier2, now, Some(2.into())).await?; + + repository + .put_attributes(&identifier1, attributes1.clone()) + .await?; + repository + .put_attributes(&identifier2, attributes2.clone()) + .await?; + + let result = repository + .get_attributes(&identifier1, &identifier1) + .await?; + assert_eq!(result, Some(attributes1.clone())); + + let result = repository + .get_attributes(&identifier2, &identifier2) + .await?; + assert_eq!(result, Some(attributes2.clone())); + + Ok(()) + }) + .await } #[tokio::test] async fn test_delete_expired_attributes() -> Result<()> { - let repository = create_repository().await?; - let now = now()?; - - // store some attributes with and without an expiry date - let identifier1 = create_identity().await?; - let identifier2 = create_identity().await?; - let identifier3 = create_identity().await?; - let identifier4 = create_identity().await?; - let attributes1 = create_attributes_entry(&identifier1, now, Some(1.into())).await?; - let attributes2 = create_attributes_entry(&identifier2, now, Some(10.into())).await?; - let attributes3 = create_attributes_entry(&identifier3, now, Some(100.into())).await?; - let attributes4 = create_attributes_entry(&identifier4, now, None).await?; - - repository - .put_attributes(&identifier1, attributes1.clone()) - .await?; - repository - .put_attributes(&identifier2, attributes2.clone()) - .await?; - repository - .put_attributes(&identifier3, attributes3.clone()) - .await?; - repository - .put_attributes(&identifier4, attributes4.clone()) - .await?; - - // delete all the attributes with an expiry date <= now + 10 - // only attributes1 and attributes2 must be deleted - repository.delete_expired_attributes(now.add(10)).await?; - - let result = repository - .get_attributes(&identifier1, &identifier1) - .await?; - assert_eq!(result, None); - - let result = repository - .get_attributes(&identifier2, &identifier2) - .await?; - assert_eq!(result, None); - - let result = repository - .get_attributes(&identifier3, &identifier3) - .await?; - assert_eq!( - result, - Some(attributes3), - "attributes 3 are not expired yet" - ); - - let result = repository - .get_attributes(&identifier4, &identifier4) - .await?; - assert_eq!( - result, - Some(attributes4), - "attributes 4 have no expiry date" - ); - - Ok(()) + with_dbs(|db| async move { + let repository: Arc = + Arc::new(IdentityAttributesSqlxDatabase::new(db, "node")); + + let now = now()?; + + // store some attributes with and without an expiry date + let identifier1 = create_identity().await?; + let identifier2 = create_identity().await?; + let identifier3 = create_identity().await?; + let identifier4 = create_identity().await?; + let attributes1 = create_attributes_entry(&identifier1, now, Some(1.into())).await?; + let attributes2 = create_attributes_entry(&identifier2, now, Some(10.into())).await?; + let attributes3 = create_attributes_entry(&identifier3, now, Some(100.into())).await?; + let attributes4 = create_attributes_entry(&identifier4, now, None).await?; + + repository + .put_attributes(&identifier1, attributes1.clone()) + .await?; + repository + .put_attributes(&identifier2, attributes2.clone()) + .await?; + repository + .put_attributes(&identifier3, attributes3.clone()) + .await?; + repository + .put_attributes(&identifier4, attributes4.clone()) + .await?; + + // delete all the attributes with an expiry date <= now + 10 + // only attributes1 and attributes2 must be deleted + repository.delete_expired_attributes(now.add(10)).await?; + + let result = repository + .get_attributes(&identifier1, &identifier1) + .await?; + assert_eq!(result, None); + + let result = repository + .get_attributes(&identifier2, &identifier2) + .await?; + assert_eq!(result, None); + + let result = repository + .get_attributes(&identifier3, &identifier3) + .await?; + assert_eq!( + result, + Some(attributes3), + "attributes 3 are not expired yet" + ); + + let result = repository + .get_attributes(&identifier4, &identifier4) + .await?; + assert_eq!( + result, + Some(attributes4), + "attributes 4 have no expiry date" + ); + + Ok(()) + }) + .await } /// HELPERS @@ -248,8 +273,4 @@ mod tests { let identities = identities().await?; identities.identities_creation().create_identity().await } - - async fn create_repository() -> Result> { - Ok(Arc::new(IdentityAttributesSqlxDatabase::create().await?)) - } } diff --git a/implementations/rust/ockam/ockam_identity/src/models/identifiers.rs b/implementations/rust/ockam/ockam_identity/src/models/identifiers.rs index a51c7d4e406..3df30677e11 100644 --- a/implementations/rust/ockam/ockam_identity/src/models/identifiers.rs +++ b/implementations/rust/ockam/ockam_identity/src/models/identifiers.rs @@ -1,5 +1,4 @@ use core::fmt::{Debug, Formatter}; - use minicbor::{Decode, Encode}; use crate::alloc::string::ToString; diff --git a/implementations/rust/ockam/ockam_identity/src/purpose_keys/storage/purpose_keys_repository_sql.rs b/implementations/rust/ockam/ockam_identity/src/purpose_keys/storage/purpose_keys_repository_sql.rs index 2be4bb23738..0f8e11d3768 100644 --- a/implementations/rust/ockam/ockam_identity/src/purpose_keys/storage/purpose_keys_repository_sql.rs +++ b/implementations/rust/ockam/ockam_identity/src/purpose_keys/storage/purpose_keys_repository_sql.rs @@ -1,5 +1,7 @@ use core::str::FromStr; +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use tracing::debug; @@ -8,7 +10,7 @@ use ockam_core::compat::string::{String, ToString}; use ockam_core::compat::vec::Vec; use ockam_core::errcode::{Kind, Origin}; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, SqlxType, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::identity::IdentityConstants; use crate::models::{Identifier, PurposeKeyAttestation}; @@ -42,17 +44,23 @@ impl PurposeKeysRepository for PurposeKeysSqlxDatabase { purpose: Purpose, purpose_key_attestation: &PurposeKeyAttestation, ) -> Result<()> { - let query = query("INSERT OR REPLACE INTO purpose_key VALUES (?, ?, ?)") - .bind(subject.to_sql()) - .bind(purpose.to_sql()) - .bind(minicbor::to_vec(purpose_key_attestation)?.to_sql()); + let query = query( + r#" + INSERT INTO purpose_key (identifier, purpose, purpose_key_attestation) + VALUES ($1, $2, $3) + ON CONFLICT (identifier, purpose) + DO UPDATE SET purpose_key_attestation = $3"#, + ) + .bind(subject) + .bind(purpose) + .bind(purpose_key_attestation); query.execute(&*self.database.pool).await.void() } async fn delete_purpose_key(&self, subject: &Identifier, purpose: Purpose) -> Result<()> { - let query = query("DELETE FROM purpose_key WHERE identifier = ? and purpose = ?") - .bind(subject.to_sql()) - .bind(purpose.to_sql()); + let query = query("DELETE FROM purpose_key WHERE identifier = $1 and purpose = $2") + .bind(subject) + .bind(purpose); query.execute(&*self.database.pool).await.void() } @@ -61,9 +69,9 @@ impl PurposeKeysRepository for PurposeKeysSqlxDatabase { identifier: &Identifier, purpose: Purpose, ) -> Result> { - let query = query_as("SELECT identifier, purpose, purpose_key_attestation FROM purpose_key WHERE identifier=$1 and purpose=$2") - .bind(identifier.to_sql()) - .bind(purpose.to_sql()); + let query = query_as("SELECT identifier, purpose, purpose_key_attestation FROM purpose_key WHERE identifier = $1 and purpose = $2") + .bind(identifier) + .bind(purpose); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -81,6 +89,34 @@ impl PurposeKeysRepository for PurposeKeysSqlxDatabase { // Database serialization / deserialization +impl Type for Purpose { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for Purpose { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + let purpose = match self { + Purpose::SecureChannel => IdentityConstants::SECURE_CHANNEL_PURPOSE_KEY, + Purpose::Credentials => IdentityConstants::CREDENTIALS_PURPOSE_KEY, + }; + >::encode_by_ref(&purpose.to_string(), buf) + } +} + +impl Type for PurposeKeyAttestation { + fn type_info() -> ::TypeInfo { + as Type>::type_info() + } +} + +impl Encode<'_, Any> for PurposeKeyAttestation { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + as Encode<'_, Any>>::encode_by_ref(&minicbor::to_vec(self).unwrap(), buf) + } +} + #[derive(FromRow)] pub(crate) struct PurposeKeyRow { // The identifier who is using this key @@ -115,19 +151,6 @@ impl PurposeKeyRow { } } -impl ToSqlxType for Purpose { - fn to_sql(&self) -> SqlxType { - match self { - Purpose::SecureChannel => { - SqlxType::Text(IdentityConstants::SECURE_CHANNEL_PURPOSE_KEY.to_string()) - } - Purpose::Credentials => { - SqlxType::Text(IdentityConstants::CREDENTIALS_PURPOSE_KEY.to_string()) - } - } - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/implementations/rust/ockam/ockam_identity/src/secure_channels/storage/secure_channel_repository_sql.rs b/implementations/rust/ockam/ockam_identity/src/secure_channels/storage/secure_channel_repository_sql.rs index f1728cb0241..bd9651a0c39 100644 --- a/implementations/rust/ockam/ockam_identity/src/secure_channels/storage/secure_channel_repository_sql.rs +++ b/implementations/rust/ockam/ockam_identity/src/secure_channels/storage/secure_channel_repository_sql.rs @@ -5,7 +5,7 @@ use crate::secure_channel::Role; use crate::Identifier; use ockam_core::{async_trait, Address}; use ockam_core::{Error, Result}; -use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use ockam_vault::{AeadSecretKeyHandle, HandleToSecret}; use crate::secure_channels::storage::secure_channel_repository::{ @@ -39,9 +39,9 @@ impl SecureChannelRepository for SecureChannelSqlxDatabase { decryptor_remote_address: &Address, ) -> Result> { let query = query_as( - "SELECT role, my_identifier, their_identifier, decryptor_remote_address, decryptor_api_address, decryption_key_handle FROM secure_channel WHERE decryptor_remote_address=$1" + "SELECT role, my_identifier, their_identifier, decryptor_remote_address, decryptor_api_address, decryption_key_handle FROM secure_channel WHERE decryptor_remote_address = $1" ) - .bind(decryptor_remote_address.to_string().to_sql()); + .bind(decryptor_remote_address.to_string()); let secure_channel: Option = query .fetch_optional(&*self.database.pool) .await @@ -52,20 +52,23 @@ impl SecureChannelRepository for SecureChannelSqlxDatabase { async fn put(&self, secure_channel: PersistedSecureChannel) -> Result<()> { let query = query( - "INSERT OR REPLACE INTO secure_channel (role, my_identifier, their_identifier, decryptor_remote_address, decryptor_api_address, decryption_key_handle) VALUES (?, ?, ?, ?, ?, ?)" + r#"INSERT INTO secure_channel (role, my_identifier, their_identifier, decryptor_remote_address, decryptor_api_address, decryption_key_handle) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (decryptor_remote_address) + DO UPDATE SET role = $1, my_identifier = $2, their_identifier = $3, decryptor_api_address = $5, decryption_key_handle = $6"# ) - .bind(secure_channel.role().str().to_sql()) - .bind(secure_channel.my_identifier().to_string().to_sql()) - .bind(secure_channel.their_identifier().to_sql()) - .bind(secure_channel.decryptor_remote().to_sql()) - .bind(secure_channel.decryptor_api().to_sql()) - .bind(secure_channel.decryption_key_handle().to_sql()); + .bind(secure_channel.role().str()) + .bind(secure_channel.my_identifier()) + .bind(secure_channel.their_identifier()) + .bind(secure_channel.decryptor_remote().to_string()) + .bind(secure_channel.decryptor_api().to_string()) + .bind(secure_channel.decryption_key_handle()); query.execute(&*self.database.pool).await.void() } async fn delete(&self, decryptor_remote_address: &Address) -> Result<()> { - let query = query("DELETE FROM secure_channel WHERE decryptor_remote_address=$1") - .bind(decryptor_remote_address.to_string().to_sql()); + let query = query("DELETE FROM secure_channel WHERE decryptor_remote_address = $1") + .bind(decryptor_remote_address.to_string()); query.execute(&*self.database.pool).await.void() } } diff --git a/implementations/rust/ockam/ockam_identity/tests/persistence.rs b/implementations/rust/ockam/ockam_identity/tests/persistence.rs index d87b095617f..3e859728d64 100644 --- a/implementations/rust/ockam/ockam_identity/tests/persistence.rs +++ b/implementations/rust/ockam/ockam_identity/tests/persistence.rs @@ -162,11 +162,12 @@ fn test_persistence() -> ockam_core::Result<()> { let data = executor1 .execute(async move { let data = std::panic::AssertUnwindSafe(async { - let db_alice = SqlxDatabase::create(db_file_alice_path_clone.as_path()).await?; + let db_alice = + SqlxDatabase::create_sqlite(db_file_alice_path_clone.as_path()).await?; let secure_channel_repository_alice = Arc::new(SecureChannelSqlxDatabase::new(db_alice.clone())); let secrets_repository_alice = Arc::new(SecretsSqlxDatabase::new(db_alice)); - let db_bob = SqlxDatabase::create(db_file_bob_path_clone.as_path()).await?; + let db_bob = SqlxDatabase::create_sqlite(db_file_bob_path_clone.as_path()).await?; let secure_channel_repository_bob = Arc::new(SecureChannelSqlxDatabase::new(db_bob.clone())); let secrets_repository_bob = Arc::new(SecretsSqlxDatabase::new(db_bob)); @@ -288,11 +289,11 @@ fn test_persistence() -> ockam_core::Result<()> { executor2 .execute(async move { let res = std::panic::AssertUnwindSafe(async { - let db_alice = SqlxDatabase::create(db_file_alice_path.as_path()).await?; + let db_alice = SqlxDatabase::create_sqlite(db_file_alice_path.as_path()).await?; let secure_channel_repository_alice = Arc::new(SecureChannelSqlxDatabase::new(db_alice.clone())); let secrets_repository_alice = Arc::new(SecretsSqlxDatabase::new(db_alice)); - let db_bob = SqlxDatabase::create(db_file_bob_path.as_path()).await?; + let db_bob = SqlxDatabase::create_sqlite(db_file_bob_path.as_path()).await?; let secure_channel_repository_bob = Arc::new(SecureChannelSqlxDatabase::new(db_bob.clone())); let secrets_repository_bob = Arc::new(SecretsSqlxDatabase::new(db_bob)); diff --git a/implementations/rust/ockam/ockam_node/Cargo.toml b/implementations/rust/ockam/ockam_node/Cargo.toml index b7d3f3c4875..8cf5c57bbce 100644 --- a/implementations/rust/ockam/ockam_node/Cargo.toml +++ b/implementations/rust/ockam/ockam_node/Cargo.toml @@ -36,7 +36,6 @@ default = ["std"] # Feature (enabled by default): "std" enables functionality expected to # be available on a standard platform. std = [ - "chrono", "ockam_core/std", "ockam_transport_core/std", "once_cell/std", @@ -69,11 +68,10 @@ metrics = [] # message flows within Ockam apps. debugger = ["ockam_core/debugger"] -storage = ["std", "time", "serde_json", "sqlx", "tokio-retry", "regex"] +storage = ["std", "time", "serde_json", "sqlx", "tokio-retry", "regex", "tempfile"] [dependencies] cfg-if = "1.0.0" -chrono = { version = "0.4", optional = true } fs2 = { version = "0.4.3", optional = true } futures = { version = "0.3.30", default-features = false } heapless = { version = "0.8", features = ["mpmc_large"], optional = true } @@ -87,7 +85,8 @@ opentelemetry = { version = "0.23.0", features = ["logs", "metrics", "trace"], o regex = { version = "1.10.5", default-features = false, optional = true } serde = { version = "1.0", default-features = false, features = ["derive"] } serde_json = { version = "1", optional = true } -sqlx = { version = "0.7.4", optional = true, features = ["sqlite", "migrate", "runtime-tokio"] } +sqlx = { git = "https://github.com/etorreborre/sqlx", rev = "5fec648d2de0cbeed738dcf1c6f5bc9194fc439b", optional = true, features = ["postgres", "sqlite", "any", "migrate", "runtime-tokio"] } +tempfile = { version = "3.10.1", optional = true } time = { version = "0.3.36", default-features = false, optional = true } tokio = { version = "1.38", default-features = false, optional = true, features = ["sync", "time", "rt", "rt-multi-thread", "macros"] } tokio-retry = { version = "0.3.0", optional = true } @@ -98,7 +97,6 @@ tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"], option [dev-dependencies] hex = { version = "0.4", default-features = false } -tempfile = { version = "3.10.1" } [package.metadata.cargo-machete] ignored = ["fs2", "serde_json", "tracing-opentelemetry"] diff --git a/implementations/rust/ockam/ockam_node/src/lib.rs b/implementations/rust/ockam/ockam_node/src/lib.rs index aba14ea5386..ff0399f5121 100644 --- a/implementations/rust/ockam/ockam_node/src/lib.rs +++ b/implementations/rust/ockam/ockam_node/src/lib.rs @@ -65,6 +65,7 @@ mod router; /// Support for storing persistent values pub mod storage; + mod worker_builder; /// Singleton for the runtime executor diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/database_configuration.rs b/implementations/rust/ockam/ockam_node/src/storage/database/database_configuration.rs new file mode 100644 index 00000000000..187dba88967 --- /dev/null +++ b/implementations/rust/ockam/ockam_node/src/storage/database/database_configuration.rs @@ -0,0 +1,198 @@ +use ockam_core::compat::rand::random_string; +use ockam_core::env::get_env; +use ockam_core::errcode::{Kind, Origin}; +use ockam_core::{Error, Result}; +use std::fs::create_dir_all; +use std::path::{Path, PathBuf}; + +/// Database host environment variable +pub const OCKAM_POSTGRES_HOST: &str = "OCKAM_POSTGRES_HOST"; +/// Database port environment variable +pub const OCKAM_POSTGRES_PORT: &str = "OCKAM_POSTGRES_PORT"; +/// Database name environment variable +pub const OCKAM_POSTGRES_DATABASE_NAME: &str = "OCKAM_POSTGRES_DATABASE_NAME"; +/// Database user environment variable +pub const OCKAM_POSTGRES_USER: &str = "OCKAM_POSTGRES_USER"; +/// Database password environment variable +pub const OCKAM_POSTGRES_PASSWORD: &str = "OCKAM_POSTGRES_PASSWORD"; + +/// Configuration for the database. +/// We either use Sqlite or Postgres +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum DatabaseConfiguration { + /// Configuration for a SQLite database + Sqlite { + /// Database file path if the database is stored on disk + path: Option, + }, + /// Configuration for a Postgres database + Postgres { + /// Database host name + host: String, + /// Database host port + port: u16, + /// Database name + database_name: String, + /// Database user + user: Option, + }, +} + +/// Type of database +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum DatabaseType { + /// Type for SQLite + Sqlite, + /// Type for Postgres + Postgres, +} + +/// User of the Postgres database +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct DatabaseUser { + /// Database user + user_name: String, + /// Database password + password: String, +} + +impl DatabaseUser { + /// Create a new database user + pub fn new(user_name: impl Into, password: impl Into) -> Self { + Self { + user_name: user_name.into(), + password: password.into(), + } + } + /// Return the user name + pub fn user_name(&self) -> String { + self.user_name.clone() + } + /// Return the password + pub fn password(&self) -> String { + self.password.clone() + } +} + +impl DatabaseConfiguration { + /// Create a postgres database configuration from environment variables. + /// + /// At minima, the database host and port must be provided. + pub fn postgres() -> Result> { + let host: Option = get_env(OCKAM_POSTGRES_HOST)?; + let port: Option = get_env(OCKAM_POSTGRES_PORT)?; + let database_name: String = + get_env(OCKAM_POSTGRES_DATABASE_NAME)?.unwrap_or("postgres".to_string()); + let user: Option = get_env(OCKAM_POSTGRES_USER)?; + let password: Option = get_env(OCKAM_POSTGRES_PASSWORD)?; + match (host, port) { + (Some(host), Some(port)) => match (user, password) { + (Some(user), Some(password)) => Ok(Some(DatabaseConfiguration::Postgres { + host, + port, + database_name, + user: Some(DatabaseUser::new(user, password)), + })), + _ => Ok(Some(DatabaseConfiguration::Postgres { + host, + port, + database_name, + user: None, + })), + }, + _ => Ok(None), + } + } + + /// Create a local sqlite configuration + pub fn sqlite(path: &Path) -> DatabaseConfiguration { + DatabaseConfiguration::Sqlite { + path: Some(path.to_path_buf()), + } + } + + /// Create an in-memory sqlite configuration + pub fn sqlite_in_memory() -> DatabaseConfiguration { + DatabaseConfiguration::Sqlite { path: None } + } + + /// Return the type of database that has been configured + pub fn database_type(&self) -> DatabaseType { + match self { + DatabaseConfiguration::Sqlite { .. } => DatabaseType::Sqlite, + DatabaseConfiguration::Postgres { .. } => DatabaseType::Postgres, + } + } + + /// Return the type of database that has been configured + pub fn connection_string(&self) -> String { + match self { + DatabaseConfiguration::Sqlite { path: None } => { + Self::create_sqlite_in_memory_connection_string() + } + DatabaseConfiguration::Sqlite { path: Some(path) } => { + Self::create_sqlite_on_disk_connection_string(path) + } + DatabaseConfiguration::Postgres { + host, + port, + database_name, + user, + } => Self::create_postgres_connection_string( + host.clone(), + *port, + database_name.clone(), + user.clone(), + ), + } + } + + /// Create a directory for the SQLite database file if necessary + pub fn create_directory_if_necessary(&self) -> Result<()> { + if let DatabaseConfiguration::Sqlite { path: Some(path) } = self { + if let Some(parent) = path.parent() { + if !parent.exists() { + create_dir_all(parent) + .map_err(|e| Error::new(Origin::Api, Kind::Io, e.to_string()))? + } + } + } + Ok(()) + } + + /// Return true if the path for a SQLite database exists + pub fn exists(&self) -> bool { + self.path().map(|p| p.exists()).unwrap_or(false) + } + + /// Return the database path if the database is a SQLite file. + pub fn path(&self) -> Option { + match self { + DatabaseConfiguration::Sqlite { path } => path.clone(), + _ => None, + } + } + + fn create_sqlite_in_memory_connection_string() -> String { + let file_name = random_string(); + format!("sqlite:file:{file_name}?mode=memory&cache=shared") + } + + fn create_sqlite_on_disk_connection_string(path: &Path) -> String { + let url_string = &path.to_string_lossy().to_string(); + format!("sqlite:file://{url_string}?mode=rwc") + } + + fn create_postgres_connection_string( + host: String, + port: u16, + database_name: String, + user: Option, + ) -> String { + let user_password = match user { + Some(user) => format!("{}:{}@", user.user_name(), user.password()), + None => "".to_string(), + }; + format!("postgres://{user_password}{host}:{port}/{database_name}") + } +} diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/application_migration_set.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/application_migration_set.rs index 0c0503814b9..b267975b4d5 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/application_migration_set.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/application_migration_set.rs @@ -1,21 +1,37 @@ use crate::database::migrations::migration_set::MigrationSet; -use crate::database::Migrator; +use crate::database::{DatabaseType, Migrator}; use crate::migrate; use ockam_core::Result; /// This struct defines the migration to apply to the persistent database -pub struct ApplicationMigrationSet; +pub struct ApplicationMigrationSet { + database_type: DatabaseType, +} + +impl ApplicationMigrationSet { + /// Create a new migration set + pub fn new(database_type: DatabaseType) -> Self { + Self { database_type } + } +} impl MigrationSet for ApplicationMigrationSet { fn create_migrator(&self) -> Result { - migrate!("./src/storage/database/migrations/application_migrations/sql") + match self.database_type { + DatabaseType::Sqlite => { + migrate!("./src/storage/database/migrations/application_migrations/sql/sqlite") + } + DatabaseType::Postgres => { + migrate!("./src/storage/database/migrations/application_migrations/sql/postgres") + } + } } } #[cfg(test)] mod tests { use crate::database::application_migration_set::ApplicationMigrationSet; - use crate::database::{MigrationSet, SqlxDatabase}; + use crate::database::{DatabaseConfiguration, DatabaseType, MigrationSet, SqlxDatabase}; use ockam_core::Result; use tempfile::NamedTempFile; @@ -23,9 +39,10 @@ mod tests { async fn test() -> Result<()> { let db_file = NamedTempFile::new().unwrap(); - let db = SqlxDatabase::create_no_migration(db_file.path()).await?; + let db = SqlxDatabase::create_no_migration(&DatabaseConfiguration::sqlite(db_file.path())) + .await?; - ApplicationMigrationSet + ApplicationMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate(&db.pool) .await?; diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/postgres/20240613110000_project_journey.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/postgres/20240613110000_project_journey.sql new file mode 100644 index 00000000000..ba3704d56b9 --- /dev/null +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/postgres/20240613110000_project_journey.sql @@ -0,0 +1,12 @@ +CREATE TABLE project_journey ( + project_id TEXT NOT NULL, + opentelemetry_context TEXT NOT NULL UNIQUE, + start_datetime TEXT NOT NULL, + previous_opentelemetry_context TEXT +); + +CREATE TABLE host_journey ( + opentelemetry_context TEXT NOT NULL UNIQUE, + start_datetime TEXT NOT NULL, + previous_opentelemetry_context TEXT +); diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/20241701150000_project_journey.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/sqlite/20241701150000_project_journey.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/20241701150000_project_journey.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/sqlite/20241701150000_project_journey.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/20242102180000_time_limited_journey.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/sqlite/20242102180000_time_limited_journey.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/20242102180000_time_limited_journey.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/application_migrations/sql/sqlite/20242102180000_time_limited_journey.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/migration_support/migrator.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/migration_support/migrator.rs index 4876178402c..dab03d369e6 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/migration_support/migrator.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/migration_support/migrator.rs @@ -1,14 +1,14 @@ use ockam_core::compat::collections::HashSet; use ockam_core::compat::time::now; use ockam_core::errcode::{Kind, Origin}; +use sqlx::any::AnyRow; use sqlx::migrate::{AppliedMigration, Migrate, Migration as SqlxMigration}; -use sqlx::sqlite::SqliteRow; -use sqlx::{query, Row, SqliteConnection, SqlitePool}; +use sqlx::{query, Any, AnyConnection, Pool, Row}; use std::cmp::Ordering; use time::OffsetDateTime; use crate::database::migrations::migration_support::rust_migration::RustMigration; -use crate::database::{FromSqlxError, ToSqlxType, ToVoid}; +use crate::database::{FromSqlxError, ToVoid}; use ockam_core::Result; /// Migrator is responsible for running Sql and Rust migrations side by side in the correct order, @@ -67,11 +67,7 @@ impl Migrator { } impl Migrator { - async fn run_migrations( - &self, - connection: &mut SqliteConnection, - up_to: Version, - ) -> Result<()> { + async fn run_migrations(&self, connection: &mut AnyConnection, up_to: Version) -> Result<()> { connection.ensure_migrations_table().await.into_core()?; let version = connection.dirty_version().await.into_core()?; @@ -133,13 +129,12 @@ impl Migrator { impl Migrator { pub(crate) async fn has_migrated( - connection: &mut SqliteConnection, + connection: &mut AnyConnection, migration_name: &str, ) -> Result { - let query = query("SELECT COUNT(*) FROM _rust_migrations WHERE name=?") - .bind(migration_name.to_sql()); - let count_raw: Option = - query.fetch_optional(&mut *connection).await.into_core()?; + let query = + query("SELECT COUNT(*) FROM _rust_migrations WHERE name = $1").bind(migration_name); + let count_raw: Option = query.fetch_optional(&mut *connection).await.into_core()?; if let Some(count_raw) = count_raw { let count: i64 = count_raw.get(0); @@ -150,16 +145,22 @@ impl Migrator { } pub(crate) async fn mark_as_migrated( - connection: &mut SqliteConnection, + connection: &mut AnyConnection, migration_name: &str, ) -> Result<()> { let now = now()?; let now = OffsetDateTime::from_unix_timestamp(now as i64).map_err(|_| { ockam_core::Error::new(Origin::Node, Kind::Internal, "Can't convert timestamp") })?; - let query = query("INSERT OR REPLACE INTO _rust_migrations (name, run_on) VALUES (?, ?)") - .bind(migration_name.to_sql()) - .bind(now.to_sql()); + let query = query( + r#" + INSERT INTO _rust_migrations (name, run_on) + VALUES ($1, $2) + ON CONFLICT (name) + DO UPDATE SET run_on = $2"#, + ) + .bind(migration_name) + .bind(now.unix_timestamp()); query.execute(&mut *connection).await.void()?; Ok(()) @@ -168,23 +169,22 @@ impl Migrator { impl Migrator { /// Run migrations up to the specified version (inclusive) - pub(crate) async fn migrate_up_to(&self, pool: &SqlitePool, up_to: Version) -> Result<()> { + pub(crate) async fn migrate_up_to(&self, pool: &Pool, up_to: Version) -> Result<()> { let mut connection = pool.acquire().await.into_core()?; - // Apparently does nothing for sqlite... + // This lock is only effective for Postgres connection.lock().await.into_core()?; let res = self.run_migrations(&mut connection, up_to).await; connection.unlock().await.into_core()?; - res?; Ok(()) } /// Run all migrations - pub async fn migrate(&self, pool: &SqlitePool) -> Result<()> { + pub async fn migrate(&self, pool: &Pool) -> Result<()> { self.migrate_up_to(pool, i64::MAX).await } } @@ -194,7 +194,7 @@ impl Migrator { /// Run migrations up to the specified version (inclusive) but skip the last rust migration pub(crate) async fn migrate_up_to_skip_last_rust_migration( mut self, - pool: &SqlitePool, + pool: &Pool, up_to: Version, ) -> Result<()> { self.rust_migrations.retain(|m| m.version() < up_to); @@ -224,7 +224,7 @@ impl NextMigration<'_> { async fn apply_sql_migration<'a>( migration: &'a SqlxMigration, - connection: &mut SqliteConnection, + connection: &mut AnyConnection, applied_migrations: &[AppliedMigration], ) -> Result<()> { if migration.migration_type.is_down_migration() { @@ -240,22 +240,32 @@ impl NextMigration<'_> { Origin::Node, Kind::Conflict, format!( - "Checksum mismatch for sql migration for version {}", - migration.version + "Checksum mismatch for sql migration '{}' for version {}", + migration.description, migration.version, ), )); } } - None => { - connection.apply(migration).await.into_core()?; - } + None => match connection.apply(migration).await.into_core() { + Ok(_) => (), + Err(e) => { + return Err(ockam_core::Error::new( + Origin::Node, + Kind::Conflict, + format!( + "Failed to run the migration {}: {e:?}", + migration.description + ), + )) + } + }, } Ok(()) } async fn apply_rust_migration( migration: &dyn RustMigration, - connection: &mut SqliteConnection, + connection: &mut AnyConnection, ) -> Result<()> { if Migrator::has_migrated(connection, migration.name()).await? { return Ok(()); @@ -369,7 +379,7 @@ mod tests { self.version } - async fn migrate(&self, _connection: &mut SqliteConnection) -> Result { + async fn migrate(&self, _connection: &mut AnyConnection) -> Result { Ok(true) } } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/migration_support/rust_migration.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/migration_support/rust_migration.rs index e5e6b116e58..9736ee39fac 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/migration_support/rust_migration.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/migration_support/rust_migration.rs @@ -1,5 +1,5 @@ use core::fmt::Debug; -use sqlx::SqliteConnection; +use sqlx::AnyConnection; use ockam_core::{async_trait, Result}; @@ -13,5 +13,5 @@ pub trait RustMigration: Debug + Send + Sync { fn version(&self) -> i64; /// Execute the migration - async fn migrate(&self, connection: &mut SqliteConnection) -> Result; + async fn migrate(&self, connection: &mut AnyConnection) -> Result; } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/mod.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/mod.rs index 7ddbade1a5b..9a6c7af7db3 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/mod.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/mod.rs @@ -1,5 +1,4 @@ +/// This module defines the migrations to apply to the application database +pub mod node_migration_set; mod rust; pub use rust::*; - -/// This module defines the migrations to apply to the appliaction database -pub mod node_migration_set; diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/node_migration_set.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/node_migration_set.rs index 211cd504eb1..32f9f7894b4 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/node_migration_set.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/node_migration_set.rs @@ -1,9 +1,10 @@ -use crate::database::migrations::migration_20240111100001_add_authority_tables::AuthorityAttributes; -use crate::database::migrations::migration_20240111100002_delete_trust_context::PolicyTrustContextId; -use crate::database::migrations::migration_20240212100000_split_policies::SplitPolicies; -use crate::database::migrations::node_migrations::migration_20231231100000_node_name_identity_attributes::NodeNameIdentityAttributes; -use crate::database::migration_20240313100000_remove_orphan_resources::RemoveOrphanResources; -use crate::database::migration_20240503100000_update_policy_expressions::UpdatePolicyExpressions; +use crate::database::migrations::sqlite::migration_20231231100000_node_name_identity_attributes::NodeNameIdentityAttributes; +use crate::database::migrations::sqlite::migration_20240111100001_add_authority_tables::AuthorityAttributes; +use crate::database::migrations::sqlite::migration_20240111100002_delete_trust_context::PolicyTrustContextId; +use crate::database::migrations::sqlite::migration_20240212100000_split_policies::SplitPolicies; +use crate::database::migrations::sqlite::migration_20240313100000_remove_orphan_resources::RemoveOrphanResources; +use crate::database::migrations::sqlite::migration_20240503100000_update_policy_expressions::UpdatePolicyExpressions; +use crate::database::DatabaseType; use ockam_core::Result; use crate::database::migrations::migration_set::MigrationSet; @@ -11,19 +12,38 @@ use crate::database::migrations::{Migrator, RustMigration}; use crate::migrate; /// This struct defines the migration to apply to the nodes database -pub struct NodeMigrationSet; +pub struct NodeMigrationSet { + database_type: DatabaseType, +} + +impl NodeMigrationSet { + /// Create a new migration set for a node + pub fn new(database_type: DatabaseType) -> Self { + Self { database_type } + } +} impl MigrationSet for NodeMigrationSet { fn create_migrator(&self) -> Result { - let rust_migrations: Vec> = vec![ - Box::new(NodeNameIdentityAttributes), - Box::new(AuthorityAttributes), - Box::new(PolicyTrustContextId), - Box::new(SplitPolicies), - Box::new(RemoveOrphanResources), - Box::new(UpdatePolicyExpressions), - ]; - let mut migrator = migrate!("./src/storage/database/migrations/node_migrations/sql")?; + let rust_migrations: Vec> = match self.database_type { + DatabaseType::Sqlite => vec![ + Box::new(NodeNameIdentityAttributes), + Box::new(AuthorityAttributes), + Box::new(PolicyTrustContextId), + Box::new(SplitPolicies), + Box::new(RemoveOrphanResources), + Box::new(UpdatePolicyExpressions), + ], + DatabaseType::Postgres => vec![], + }; + let mut migrator = match self.database_type { + DatabaseType::Sqlite => { + migrate!("./src/storage/database/migrations/node_migrations/sql/sqlite")? + } + DatabaseType::Postgres => { + migrate!("./src/storage/database/migrations/node_migrations/sql/postgres")? + } + }; migrator.set_rust_migrations(rust_migrations)?; Ok(migrator) @@ -33,7 +53,7 @@ impl MigrationSet for NodeMigrationSet { #[cfg(test)] mod tests { use crate::database::migrations::node_migration_set::NodeMigrationSet; - use crate::database::{MigrationSet, SqlxDatabase}; + use crate::database::{DatabaseConfiguration, DatabaseType, MigrationSet, SqlxDatabase}; use ockam_core::Result; use tempfile::NamedTempFile; @@ -41,9 +61,10 @@ mod tests { async fn test() -> Result<()> { let db_file = NamedTempFile::new().unwrap(); - let db = SqlxDatabase::create_no_migration(db_file.path()).await?; + let db = SqlxDatabase::create_no_migration(&DatabaseConfiguration::sqlite(db_file.path())) + .await?; - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate(&db.pool) .await?; diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust.rs index 5a137b11843..24ee27f6b5a 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust.rs @@ -1,14 +1,2 @@ -/// This migration adds a node name column to the identity attributes table -pub mod migration_20231231100000_node_name_identity_attributes; -/// This migration moves attributes from identity_attributes to the authority_member table for authority nodes -pub mod migration_20240111100001_add_authority_tables; -/// This migration updates policies to not rely on trust_context_id, -/// also introduces `node_name` and replicates policy for each existing node -pub mod migration_20240111100002_delete_trust_context; -/// This migration moves policies attached to resource types from -/// table "resource_policy" to "resource_type_policy" -pub mod migration_20240212100000_split_policies; -/// This migration removes orphan resources -pub mod migration_20240313100000_remove_orphan_resources; -/// This migration updates the policy expressions so that they start with an operator -pub mod migration_20240503100000_update_policy_expressions; +/// SQLite rust migrations +pub mod sqlite; diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20231231100000_node_name_identity_attributes.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20231231100000_node_name_identity_attributes.rs similarity index 72% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20231231100000_node_name_identity_attributes.rs rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20231231100000_node_name_identity_attributes.rs index dd96a2cdb37..183fa23ac0f 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20231231100000_node_name_identity_attributes.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20231231100000_node_name_identity_attributes.rs @@ -1,6 +1,6 @@ -use crate::database::{FromSqlxError, RustMigration, ToSqlxType, ToVoid}; +use crate::database::{Boolean, FromSqlxError, Nullable, RustMigration, ToVoid}; use ockam_core::{async_trait, Result}; -use sqlx::sqlite::SqliteRow; +use sqlx::any::AnyRow; use sqlx::*; /// This struct adds a node name column to the identity attributes table @@ -17,7 +17,7 @@ impl RustMigration for NodeNameIdentityAttributes { Self::version() } - async fn migrate(&self, connection: &mut SqliteConnection) -> Result { + async fn migrate(&self, connection: &mut AnyConnection) -> Result { Self::migrate_attributes_node_name(connection).await } } @@ -39,15 +39,17 @@ impl NodeNameIdentityAttributes { /// Duplicate all attributes entry for every known node pub(crate) async fn migrate_attributes_node_name( - connection: &mut SqliteConnection, + connection: &mut AnyConnection, ) -> Result { // don't run the migration twice - let data_migration_needed: Option = + let data_migration_needed: Option = query(&Self::table_exists("identity_attributes_old")) .fetch_optional(&mut *connection) .await .into_core()?; - let data_migration_needed = data_migration_needed.map(|r| r.get(0)).unwrap_or(false); + let data_migration_needed = data_migration_needed + .map(|r| r.get::(0).to_bool()) + .unwrap_or(false); if !data_migration_needed { // Trigger marking as migrated @@ -68,13 +70,13 @@ impl NodeNameIdentityAttributes { for row in rows { for node_name in &node_names { - let insert = query("INSERT INTO identity_attributes (identifier, attributes, added, expires, attested_by, node_name) VALUES (?, ?, ?, ?, ?, ?)") - .bind(row.identifier.to_sql()) - .bind(row.attributes.to_sql()) - .bind((row.added as u64).to_sql()) - .bind(row.expires.map(|e| (e as u64).to_sql())) - .bind(row.attested_by.clone().map(|e| e.to_sql())) - .bind(node_name.name.to_sql()); + let insert = query("INSERT INTO identity_attributes (identifier, attributes, added, expires, attested_by, node_name) VALUES ($1, $2, $3, $4, $5, $6)") + .bind(&row.identifier) + .bind(&row.attributes) + .bind(row.added) + .bind(row.expires.to_option()) + .bind(row.attested_by.to_option()) + .bind(&node_name.name); insert.execute(&mut *transaction).await.void()?; } @@ -98,8 +100,8 @@ struct IdentityAttributesRow { identifier: String, attributes: Vec, added: i64, - expires: Option, - attested_by: Option, + expires: Nullable, + attested_by: Nullable, } #[derive(FromRow)] @@ -110,9 +112,9 @@ struct NodeNameRow { #[cfg(test)] mod test { use crate::database::migrations::node_migration_set::NodeMigrationSet; - use crate::database::{MigrationSet, SqlxDatabase}; + use crate::database::{DatabaseType, MigrationSet, SqlxDatabase}; + use sqlx::any::AnyArguments; use sqlx::query::Query; - use sqlx::sqlite::SqliteArguments; use std::collections::BTreeMap; use tempfile::NamedTempFile; @@ -122,9 +124,10 @@ mod test { async fn test_migration() -> Result<()> { // create the database pool and migrate the tables let db_file = NamedTempFile::new().unwrap(); - let pool = SqlxDatabase::create_connection_pool(db_file.path()).await?; + let db_file = db_file.path(); + let pool = SqlxDatabase::create_sqlite_connection_pool(db_file).await?; let mut connection = pool.acquire().await.into_core()?; - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to_skip_last_rust_migration(&pool, NodeNameIdentityAttributes::version()) .await?; @@ -141,21 +144,21 @@ mod test { insert_node2.execute(&mut *connection).await.void()?; // apply migrations - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to(&pool, NodeNameIdentityAttributes::version()) .await?; // check data let rows1: Vec = - query_as("SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE node_name = ?") - .bind("node1".to_string().to_sql()) + query_as("SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE node_name = $1") + .bind("node1".to_string()) .fetch_all(&mut *connection) .await .into_core()?; let rows2: Vec = - query_as("SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE node_name = ?") - .bind("node2".to_string().to_sql()) + query_as("SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE node_name = $1") + .bind("node2".to_string()) .fetch_all(&mut *connection) .await .into_core()?; @@ -173,8 +176,8 @@ mod test { assert_eq!(row1.identifier, "identifier1"); assert_eq!(row1.attributes, attributes); assert_eq!(row1.added, 1); - assert_eq!(row1.expires, Some(2)); - assert_eq!(row1.attested_by, Some("authority".to_string())); + assert_eq!(row1.expires.to_option(), Some(2)); + assert_eq!(row1.attested_by.to_option(), Some("authority".to_string())); Ok(()) } @@ -187,21 +190,21 @@ mod test { ]))?) } - fn insert_query(identifier: &str, attributes: Vec) -> Query { - query("INSERT INTO identity_attributes_old VALUES (?, ?, ?, ?, ?)") - .bind(identifier.to_sql()) - .bind(attributes.to_sql()) - .bind(1.to_sql()) - .bind(Some(2).map(|e| e.to_sql())) - .bind(Some("authority").map(|e| e.to_sql())) + fn insert_query(identifier: &str, attributes: Vec) -> Query { + query("INSERT INTO identity_attributes_old VALUES ($1, $2, $3, $4, $5)") + .bind(identifier) + .bind(attributes) + .bind(1) + .bind(Some(2)) + .bind(Some("authority")) } - fn insert_node(name: String) -> Query<'static, Sqlite, SqliteArguments<'static>> { - query("INSERT INTO node (name, identifier, verbosity, is_default, is_authority) VALUES (?, ?, ?, ?, ?)") - .bind(name.to_sql()) - .bind("I_TEST".to_string().to_sql()) - .bind(1.to_sql()) - .bind(0.to_sql()) - .bind(0.to_sql()) + fn insert_node(name: String) -> Query<'static, Any, AnyArguments<'static>> { + query("INSERT INTO node (name, identifier, verbosity, is_default, is_authority) VALUES ($1, $2, $3, $4, $5)") + .bind(name) + .bind("I_TEST".to_string()) + .bind(1) + .bind(0) + .bind(0) } } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240111100001_add_authority_tables.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240111100001_add_authority_tables.rs similarity index 73% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240111100001_add_authority_tables.rs rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240111100001_add_authority_tables.rs index 0e6f411dd26..61bdbd69bc6 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240111100001_add_authority_tables.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240111100001_add_authority_tables.rs @@ -1,5 +1,5 @@ use crate::database::migrations::RustMigration; -use crate::database::{FromSqlxError, ToSqlxType, ToVoid}; +use crate::database::{Boolean, FromSqlxError, Nullable, ToVoid}; use ockam_core::{async_trait, Result}; use sqlx::*; @@ -17,7 +17,7 @@ impl RustMigration for AuthorityAttributes { Self::version() } - async fn migrate(&self, connection: &mut SqliteConnection) -> Result { + async fn migrate(&self, connection: &mut AnyConnection) -> Result { Self::migrate_authority_attributes_to_members(connection).await } } @@ -36,7 +36,7 @@ impl AuthorityAttributes { /// Duplicate all attributes entry for every known node pub(crate) async fn migrate_authority_attributes_to_members( - connection: &mut SqliteConnection, + connection: &mut AnyConnection, ) -> Result { let mut transaction = Connection::begin(&mut *connection).await.into_core()?; @@ -46,27 +46,27 @@ impl AuthorityAttributes { .await .into_core()?; - for node_name in node_names.into_iter().filter(|n| n.is_authority) { + for node_name in node_names.into_iter().filter(|n| n.is_authority.to_bool()) { let rows: Vec = - query_as("SELECT identifier, attributes, added, attested_by FROM identity_attributes WHERE node_name=?") - .bind(node_name.name.to_sql()) + query_as("SELECT identifier, attributes, added, attested_by FROM identity_attributes WHERE node_name = $1") + .bind(node_name.name.clone()) .fetch_all(&mut *transaction) .await .into_core()?; for row in rows { - let insert = query("INSERT INTO authority_member (identifier, added_by, added_at, is_pre_trusted, attributes) VALUES (?, ?, ?, ?, ?)") - .bind(row.identifier.to_sql()) - .bind(row.attested_by.clone().map(|e| e.to_sql())) - .bind((row.added as u64).to_sql()) - .bind(0.to_sql()) - .bind(row.attributes.to_sql()); + let insert = query("INSERT INTO authority_member (identifier, added_by, added_at, is_pre_trusted, attributes) VALUES ($1, $2, $3, $4, $5)") + .bind(row.identifier) + .bind(row.attested_by.to_option()) + .bind(row.added) + .bind(0) + .bind(row.attributes); insert.execute(&mut *transaction).await.void()?; } - query("DELETE FROM identity_attributes WHERE node_name=?") - .bind(node_name.name.to_sql()) + query("DELETE FROM identity_attributes WHERE node_name = $1") + .bind(node_name.name.clone()) .execute(&mut *transaction) .await .void()?; @@ -84,21 +84,21 @@ struct IdentityAttributesRow { identifier: String, attributes: Vec, added: i64, - attested_by: Option, + attested_by: Nullable, } #[derive(FromRow)] struct NodeNameRow { name: String, - is_authority: bool, + is_authority: Boolean, } #[cfg(test)] mod test { use crate::database::migrations::node_migration_set::NodeMigrationSet; - use crate::database::{MigrationSet, SqlxDatabase}; + use crate::database::{DatabaseType, MigrationSet, SqlxDatabase}; + use sqlx::any::AnyArguments; use sqlx::query::Query; - use sqlx::sqlite::SqliteArguments; use std::collections::BTreeMap; use tempfile::NamedTempFile; @@ -108,19 +108,19 @@ mod test { struct MemberRow { identifier: String, attributes: Vec, - added_by: Option, + added_by: Nullable, added_at: i64, - is_pre_trusted: bool, + is_pre_trusted: Boolean, } #[tokio::test] async fn test_migration() -> Result<()> { let db_file = NamedTempFile::new().unwrap(); - let pool = SqlxDatabase::create_connection_pool(db_file.path()).await?; + let pool = SqlxDatabase::create_sqlite_connection_pool(db_file.path()).await?; let mut connection = pool.acquire().await.into_core()?; - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to_skip_last_rust_migration(&pool, AuthorityAttributes::version()) .await?; @@ -154,15 +154,15 @@ mod test { insert.execute(&mut *connection).await.void()?; // apply migrations - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to(&pool, AuthorityAttributes::version()) .await?; // check data let rows1: Vec = - query_as("SELECT identifier, attributes, added, attested_by FROM identity_attributes WHERE node_name = ?") - .bind(regular_node_name.to_sql()) + query_as("SELECT identifier, attributes, added, attested_by FROM identity_attributes WHERE node_name = $1") + .bind(regular_node_name) .fetch_all(&mut *connection) .await .into_core()?; @@ -170,8 +170,8 @@ mod test { assert_eq!(rows1[0].attributes, attributes1); let rows2: Vec = - query_as("SELECT identifier, attributes, added, attested_by FROM identity_attributes WHERE node_name = ?") - .bind(authority_node_name.to_sql()) + query_as("SELECT identifier, attributes, added, attested_by FROM identity_attributes WHERE node_name = $1") + .bind(authority_node_name) .fetch_all(&mut *connection) .await .into_core()?; @@ -185,9 +185,12 @@ mod test { let member = &rows3[0]; assert_eq!(member.identifier, "identifier1".to_string()); - assert_eq!(member.added_by, Some("authority_id".to_string())); + assert_eq!( + member.added_by.to_option(), + Some("authority_id".to_string()) + ); assert_eq!(member.added_at, 1); - assert!(!member.is_pre_trusted); + assert!(!member.is_pre_trusted.to_bool()); assert_eq!(member.attributes, attributes2); Ok(()) @@ -203,25 +206,22 @@ mod test { identifier: &str, attributes: Vec, node_name: String, - ) -> Query { - query("INSERT INTO identity_attributes (identifier, attributes, added, expires, attested_by, node_name) VALUES (?, ?, ?, ?, ?, ?)") - .bind(identifier.to_sql()) - .bind(attributes.to_sql()) - .bind(1.to_sql()) - .bind(Some(2).map(|e| e.to_sql())) - .bind(Some("authority_id").map(|e| e.to_sql())) - .bind(node_name.to_sql()) + ) -> Query { + query("INSERT INTO identity_attributes (identifier, attributes, added, expires, attested_by, node_name) VALUES ($1, $2, $3, $4, $5, $6)") + .bind(identifier) + .bind(attributes) + .bind(1) + .bind(Some(2)) + .bind(Some("authority_id")) + .bind(node_name) } - fn insert_node( - name: String, - is_authority: bool, - ) -> Query<'static, Sqlite, SqliteArguments<'static>> { - query("INSERT INTO node (name, identifier, verbosity, is_default, is_authority) VALUES (?, ?, ?, ?, ?)") - .bind(name.to_sql()) - .bind("I_TEST".to_string().to_sql()) - .bind(1.to_sql()) - .bind(0.to_sql()) - .bind(is_authority.to_sql()) + fn insert_node(name: String, is_authority: bool) -> Query<'static, Any, AnyArguments<'static>> { + query("INSERT INTO node (name, identifier, verbosity, is_default, is_authority) VALUES ($1, $2, $3, $4, $5)") + .bind(name) + .bind("I_TEST".to_string()) + .bind(1) + .bind(0) + .bind(is_authority) } } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240111100002_delete_trust_context.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240111100002_delete_trust_context.rs similarity index 89% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240111100002_delete_trust_context.rs rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240111100002_delete_trust_context.rs index d242b999d86..670fec25a6e 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240111100002_delete_trust_context.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240111100002_delete_trust_context.rs @@ -1,5 +1,5 @@ use crate::database::migrations::RustMigration; -use crate::database::{FromSqlxError, ToSqlxType, ToVoid}; +use crate::database::{FromSqlxError, ToVoid}; use core::fmt; use minicbor::{Decode, Encode}; use ockam_core::{async_trait, Result}; @@ -21,7 +21,7 @@ impl RustMigration for PolicyTrustContextId { Self::version() } - async fn migrate(&self, connection: &mut SqliteConnection) -> Result { + async fn migrate(&self, connection: &mut AnyConnection) -> Result { Self::migrate_update_policies(connection).await } } @@ -39,7 +39,7 @@ impl PolicyTrustContextId { /// This migration updates policies to not rely on trust_context_id, /// also introduces `node_name` and replicates policy for each existing node - pub(crate) async fn migrate_update_policies(connection: &mut SqliteConnection) -> Result { + pub(crate) async fn migrate_update_policies(connection: &mut AnyConnection) -> Result { let mut transaction = Connection::begin(&mut *connection).await.into_core()?; let query_node_names = query_as("SELECT name FROM node"); @@ -61,11 +61,11 @@ impl PolicyTrustContextId { Self::update_expression(&expression) }; for node_name in &node_names { - let insert = query("INSERT INTO policy (resource, action, expression, node_name) VALUES (?, ?, ?, ?)") - .bind(row.resource.to_sql()) - .bind(row.action.to_sql()) - .bind(expression.to_sql()) - .bind(node_name.to_sql()); + let insert = query("INSERT INTO policy (resource, action, expression, node_name) VALUES ($1, $2, $3, $4)") + .bind(&row.resource) + .bind(&row.action) + .bind(&expression) + .bind(node_name); insert.execute(&mut *transaction).await.void()?; } @@ -191,9 +191,9 @@ struct PolicyRow { #[cfg(test)] mod test { use crate::database::migrations::node_migration_set::NodeMigrationSet; - use crate::database::{MigrationSet, SqlxDatabase}; + use crate::database::{DatabaseType, MigrationSet, SqlxDatabase}; + use sqlx::any::AnyArguments; use sqlx::query::Query; - use sqlx::sqlite::SqliteArguments; use tempfile::NamedTempFile; use super::*; @@ -228,11 +228,11 @@ mod test { // create the database pool and migrate the tables let db_file = NamedTempFile::new().unwrap(); - let pool = SqlxDatabase::create_connection_pool(db_file.path()).await?; + let pool = SqlxDatabase::create_sqlite_connection_pool(db_file.path()).await?; let mut connection = pool.acquire().await.into_core()?; - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to_skip_last_rust_migration(&pool, PolicyTrustContextId::version()) .await?; @@ -259,16 +259,16 @@ mod test { insert3.execute(&pool).await.void()?; // apply migrations - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to(&pool, PolicyTrustContextId::version()) .await?; for node_name in &["n1", "n2"] { let rows: Vec = query_as( - "SELECT resource, action, expression, node_name FROM policy WHERE node_name = ?", + "SELECT resource, action, expression, node_name FROM policy WHERE node_name = $1", ) - .bind(node_name.to_sql()) + .bind(node_name) .fetch_all(&mut *connection) .await .into_core()?; @@ -302,19 +302,19 @@ mod test { resource: String, action: String, expression: Vec, - ) -> Query<'static, Sqlite, SqliteArguments<'static>> { - query("INSERT INTO policy_old (resource, action, expression) VALUES (?, ?, ?)") - .bind(resource.to_sql()) - .bind(action.to_sql()) - .bind(expression.to_sql()) + ) -> Query<'static, Any, AnyArguments<'static>> { + query("INSERT INTO policy_old (resource, action, expression) VALUES ($1, $2, $3)") + .bind(resource) + .bind(action) + .bind(expression) } - fn insert_node(name: String) -> Query<'static, Sqlite, SqliteArguments<'static>> { - query("INSERT INTO node (name, identifier, verbosity, is_default, is_authority) VALUES (?, ?, ?, ?, ?)") - .bind(name.to_sql()) - .bind("I_TEST".to_string().to_sql()) - .bind(1.to_sql()) - .bind(0.to_sql()) - .bind(false.to_sql()) + fn insert_node(name: String) -> Query<'static, Any, AnyArguments<'static>> { + query("INSERT INTO node (name, identifier, verbosity, is_default, is_authority) VALUES ($1, $2, $3, $4, $5)") + .bind(name) + .bind("I_TEST".to_string()) + .bind(1) + .bind(0) + .bind(false) } } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240212100000_split_policies.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240212100000_split_policies.rs similarity index 83% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240212100000_split_policies.rs rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240212100000_split_policies.rs index 0eba248a3d8..a6ef706f8ca 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240212100000_split_policies.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240212100000_split_policies.rs @@ -1,5 +1,5 @@ use crate::database::migrations::RustMigration; -use crate::database::{FromSqlxError, ToSqlxType, ToVoid}; +use crate::database::{FromSqlxError, ToVoid}; use ockam_core::{async_trait, Result}; use sqlx::*; @@ -18,7 +18,7 @@ impl RustMigration for SplitPolicies { Self::version() } - async fn migrate(&self, connection: &mut SqliteConnection) -> Result { + async fn migrate(&self, connection: &mut AnyConnection) -> Result { Self::migrate_policies(connection).await } } @@ -34,7 +34,7 @@ impl SplitPolicies { "migration_20240212100000_migrate_policies" } - pub(crate) async fn migrate_policies(connection: &mut SqliteConnection) -> Result { + pub(crate) async fn migrate_policies(connection: &mut AnyConnection) -> Result { let mut transaction = Connection::begin(&mut *connection).await.into_core()?; let query_policies = @@ -46,11 +46,11 @@ impl SplitPolicies { // Copy resource type policies to table "resource_type_policy" for row in rows { if row.resource_name == "tcp-outlet" || row.resource_name == "tcp-inlet" { - query("INSERT INTO resource_type_policy (resource_type, action, expression, node_name) VALUES (?, ?, ?, ?)") - .bind(row.resource_name.to_sql()) - .bind(row.action.to_sql()) - .bind(row.expression.to_sql()) - .bind(row.node_name.to_sql()) + query("INSERT INTO resource_type_policy (resource_type, action, expression, node_name) VALUES ($1, $2, $3, $4)") + .bind(row.resource_name) + .bind(row.action) + .bind(row.expression) + .bind(row.node_name) .execute(&mut *transaction) .await .void()?; @@ -82,10 +82,10 @@ struct ResourcePolicyRow { #[cfg(test)] mod test { use crate::database::migrations::node_migration_set::NodeMigrationSet; - use crate::database::{MigrationSet, SqlxDatabase}; + use crate::database::{DatabaseType, MigrationSet, SqlxDatabase}; use ockam_core::compat::rand::random_string; + use sqlx::any::AnyArguments; use sqlx::query::Query; - use sqlx::sqlite::SqliteArguments; use tempfile::NamedTempFile; use super::*; @@ -95,11 +95,11 @@ mod test { // create the database pool and migrate the tables let db_file = NamedTempFile::new().unwrap(); - let pool = SqlxDatabase::create_connection_pool(db_file.path()).await?; + let pool = SqlxDatabase::create_sqlite_connection_pool(db_file.path()).await?; let mut connection = pool.acquire().await.into_core()?; - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to_skip_last_rust_migration(&pool, SplitPolicies::version()) .await?; @@ -118,7 +118,7 @@ mod test { policy5.execute(&mut *connection).await.void()?; // apply migrations - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to(&pool, SplitPolicies::version()) .await?; @@ -168,14 +168,14 @@ mod test { } /// HELPERS - fn insert_policy(resource: &str) -> Query<'static, Sqlite, SqliteArguments<'static>> { + fn insert_policy(resource: &str) -> Query { let action = "handle_message"; let expression = random_string(); let node_name = random_string(); - query("INSERT INTO resource_policy (resource_name, action, expression, node_name) VALUES (?, ?, ?, ?)") - .bind(resource.to_sql()) - .bind(action.to_sql()) - .bind(expression.to_sql()) - .bind(node_name.to_sql()) + query("INSERT INTO resource_policy (resource_name, action, expression, node_name) VALUES ($1, $2, $3, $4)") + .bind(resource) + .bind(action) + .bind(expression) + .bind(node_name) } } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240313100000_remove_orphan_resources.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240313100000_remove_orphan_resources.rs similarity index 78% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240313100000_remove_orphan_resources.rs rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240313100000_remove_orphan_resources.rs index dc04a289c7f..6a491dcbb00 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240313100000_remove_orphan_resources.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240313100000_remove_orphan_resources.rs @@ -1,5 +1,5 @@ use crate::database::migrations::RustMigration; -use crate::database::{FromSqlxError, ToSqlxType, ToVoid}; +use crate::database::{FromSqlxError, ToVoid}; use ockam_core::{async_trait, Result}; use sqlx::*; @@ -17,7 +17,7 @@ impl RustMigration for RemoveOrphanResources { Self::version() } - async fn migrate(&self, connection: &mut SqliteConnection) -> Result { + async fn migrate(&self, connection: &mut AnyConnection) -> Result { Self::migrate(connection).await } } @@ -33,7 +33,7 @@ impl RemoveOrphanResources { "migration_20240313100000_remove_orphan_resources" } - pub(crate) async fn migrate(connection: &mut SqliteConnection) -> Result { + pub(crate) async fn migrate(connection: &mut AnyConnection) -> Result { let mut transaction = Connection::begin(&mut *connection).await.into_core()?; // Get existing node names @@ -52,10 +52,10 @@ impl RemoveOrphanResources { // Remove resources that are not associated with a node for resource in resources { if !node_names.iter().any(|n| n.name == resource.node_name) { - query("DELETE FROM resource WHERE resource_name = ? AND resource_type = ? AND node_name = ?") - .bind(resource.resource_name.to_sql()) - .bind(resource.resource_type.to_sql()) - .bind(resource.node_name.to_sql()) + query("DELETE FROM resource WHERE resource_name = $1 AND resource_type = $2 AND node_name = $3") + .bind(resource.resource_name) + .bind(resource.resource_type) + .bind(resource.node_name) .execute(&mut *transaction) .await .void()?; @@ -84,10 +84,10 @@ struct ResourceRow { #[cfg(test)] mod test { use crate::database::migrations::node_migration_set::NodeMigrationSet; - use crate::database::{MigrationSet, SqlxDatabase}; + use crate::database::{DatabaseType, MigrationSet, SqlxDatabase}; use ockam_core::compat::rand::random_string; + use sqlx::any::AnyArguments; use sqlx::query::Query; - use sqlx::sqlite::SqliteArguments; use tempfile::NamedTempFile; use super::*; @@ -97,22 +97,22 @@ mod test { // create the database pool and migrate the tables let db_file = NamedTempFile::new().unwrap(); - let pool = SqlxDatabase::create_connection_pool(db_file.path()).await?; + let pool = SqlxDatabase::create_sqlite_connection_pool(db_file.path()).await?; let mut connection = pool.acquire().await.into_core()?; - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to_skip_last_rust_migration(&pool, RemoveOrphanResources::version()) .await?; // insert a node - query("INSERT INTO node (name, identifier, verbosity, is_default, is_authority) VALUES (?, ?, ?, ?, ?)") - .bind("n1".to_sql()) - .bind(random_string().to_sql()) - .bind(0.to_sql()) - .bind(false.to_sql()) - .bind(false.to_sql()) + query("INSERT INTO node (name, identifier, verbosity, is_default, is_authority) VALUES ($1, $2, $3, $4, $5)") + .bind("n1") + .bind(random_string()) + .bind(0) + .bind(false) + .bind(false) .execute(&mut *connection) .await .void()?; @@ -131,7 +131,7 @@ mod test { resource5.execute(&mut *connection).await.void()?; // apply migrations - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to(&pool, RemoveOrphanResources::version()) .await?; @@ -157,14 +157,14 @@ mod test { Ok(()) } /// HELPERS - fn insert_resource( - resource: &str, - node_name: &str, - ) -> Query<'static, Sqlite, SqliteArguments<'static>> { + fn insert_resource<'a>( + resource: &'a str, + node_name: &'a str, + ) -> Query<'a, Any, AnyArguments<'a>> { let resource_type = random_string(); - query("INSERT INTO resource (resource_name, resource_type, node_name) VALUES (?, ?, ?)") - .bind(resource.to_sql()) - .bind(resource_type.to_sql()) - .bind(node_name.to_sql()) + query("INSERT INTO resource (resource_name, resource_type, node_name) VALUES ($1, $2, $3)") + .bind(resource) + .bind(resource_type) + .bind(node_name) } } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240503100000_update_policy_expressions.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240503100000_update_policy_expressions.rs similarity index 75% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240503100000_update_policy_expressions.rs rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240503100000_update_policy_expressions.rs index 0b713b0c7fa..1c61bd2acff 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/migration_20240503100000_update_policy_expressions.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/migration_20240503100000_update_policy_expressions.rs @@ -18,7 +18,7 @@ impl RustMigration for UpdatePolicyExpressions { Self::version() } - async fn migrate(&self, connection: &mut SqliteConnection) -> Result { + async fn migrate(&self, connection: &mut AnyConnection) -> Result { Self::migrate_policy_expressions(connection).await } } @@ -34,9 +34,7 @@ impl UpdatePolicyExpressions { "migration_20240503100000_update_policy_expressions" } - pub(crate) async fn migrate_policy_expressions( - connection: &mut SqliteConnection, - ) -> Result { + pub(crate) async fn migrate_policy_expressions(connection: &mut AnyConnection) -> Result { let mut transaction = Connection::begin(&mut *connection).await.into_core()?; query("UPDATE resource_policy SET expression = '(= subject.has_credential \"true\")' WHERE expression = 'subject.has_credential'").execute(&mut *transaction).await.void()?; query("UPDATE resource_type_policy SET expression = '(= subject.has_credential \"true\")' WHERE expression = 'subject.has_credential'").execute(&mut *transaction).await.void()?; @@ -51,11 +49,10 @@ impl UpdatePolicyExpressions { #[cfg(test)] mod test { use crate::database::migrations::node_migration_set::NodeMigrationSet; - use crate::database::{MigrationSet, SqlxDatabase}; - use crate::storage::database::sqlx_types::ToSqlxType; + use crate::database::{DatabaseType, MigrationSet, SqlxDatabase}; use ockam_core::compat::rand::random_string; + use sqlx::any::{AnyArguments, AnyRow}; use sqlx::query::Query; - use sqlx::sqlite::{SqliteArguments, SqliteRow}; use tempfile::NamedTempFile; use super::*; @@ -65,11 +62,11 @@ mod test { // create the database pool and migrate the tables let db_file = NamedTempFile::new().unwrap(); - let pool = SqlxDatabase::create_connection_pool(db_file.path()).await?; + let pool = SqlxDatabase::create_sqlite_connection_pool(db_file.path()).await?; let mut connection = pool.acquire().await.into_core()?; - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to_skip_last_rust_migration(&pool, UpdatePolicyExpressions::version()) .await?; @@ -86,13 +83,13 @@ mod test { policy4.execute(&mut *connection).await.void()?; // apply migrations - NodeMigrationSet + NodeMigrationSet::new(DatabaseType::Sqlite) .create_migrator()? .migrate_up_to(&pool, UpdatePolicyExpressions::version()) .await?; // check that the update was successful for resource policies - let rows: Vec = query("SELECT expression FROM resource_policy") + let rows: Vec = query("SELECT expression FROM resource_policy") .fetch_all(&mut *connection) .await .into_core()?; @@ -104,7 +101,7 @@ mod test { })); // check that the update was successful for resource type policies - let rows: Vec = query("SELECT expression FROM resource_type_policy") + let rows: Vec = query("SELECT expression FROM resource_type_policy") .fetch_all(&mut *connection) .await .into_core()?; @@ -127,27 +124,25 @@ mod test { } /// HELPERS - fn insert_resource_policy(resource: &str) -> Query<'static, Sqlite, SqliteArguments<'static>> { + fn insert_resource_policy(resource: &str) -> Query<'_, Any, AnyArguments<'_>> { let action = "handle_message"; let expression = "subject.has_credential"; let node_name = random_string(); - query("INSERT INTO resource_policy (resource_name, action, expression, node_name) VALUES (?, ?, ?, ?)") - .bind(resource.to_sql()) - .bind(action.to_sql()) - .bind(expression.to_sql()) - .bind(node_name.to_sql()) + query("INSERT INTO resource_policy (resource_name, action, expression, node_name) VALUES ($1, $2, $3, $4)") + .bind(resource) + .bind(action) + .bind(expression) + .bind(node_name) } - fn insert_resource_type_policy( - resource: &str, - ) -> Query<'static, Sqlite, SqliteArguments<'static>> { + fn insert_resource_type_policy(resource: &str) -> Query<'_, Any, AnyArguments<'_>> { let action = "handle_message"; let expression = "subject.has_credential"; let node_name = random_string(); - query("INSERT INTO resource_type_policy (resource_type, action, expression, node_name) VALUES (?, ?, ?, ?)") - .bind(resource.to_sql()) - .bind(action.to_sql()) - .bind(expression.to_sql()) - .bind(node_name.to_sql()) + query("INSERT INTO resource_type_policy (resource_type, action, expression, node_name) VALUES ($1, $2, $3, $4)") + .bind(resource) + .bind(action) + .bind(expression) + .bind(node_name) } } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/mod.rs b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/mod.rs new file mode 100644 index 00000000000..5a137b11843 --- /dev/null +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/rust/sqlite/mod.rs @@ -0,0 +1,14 @@ +/// This migration adds a node name column to the identity attributes table +pub mod migration_20231231100000_node_name_identity_attributes; +/// This migration moves attributes from identity_attributes to the authority_member table for authority nodes +pub mod migration_20240111100001_add_authority_tables; +/// This migration updates policies to not rely on trust_context_id, +/// also introduces `node_name` and replicates policy for each existing node +pub mod migration_20240111100002_delete_trust_context; +/// This migration moves policies attached to resource types from +/// table "resource_policy" to "resource_type_policy" +pub mod migration_20240212100000_split_policies; +/// This migration removes orphan resources +pub mod migration_20240313100000_remove_orphan_resources; +/// This migration updates the policy expressions so that they start with an operator +pub mod migration_20240503100000_update_policy_expressions; diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/postgres/20240613100000_create_database.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/postgres/20240613100000_create_database.sql new file mode 100644 index 00000000000..dbbdced58cd --- /dev/null +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/postgres/20240613100000_create_database.sql @@ -0,0 +1,359 @@ +-------------- +-- MIGRATIONS +-------------- + +-- Create a table to support rust migrations +CREATE TABLE IF NOT EXISTS _rust_migrations +( + name TEXT NOT NULL, + run_on TIMESTAMP NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS name_index ON _rust_migrations (name); + +-------------- +-- IDENTITIES +-------------- + +-- This table stores identities with +-- - the identity identifier (as a hex-encoded string) +-- - the encoded history of all the key rotations for this identity +CREATE TABLE identity +( + identifier TEXT NOT NULL UNIQUE, + change_history TEXT NOT NULL +); + +-- Insert the controller identity +INSERT INTO identity VALUES ('I84502ce0d9a0a91bae29026b84e19be69fb4203a6bdd1424c85a43c812772a00', '81825858830101585385f6820181584104ebf9d78281a04f180029c12a74e994386c7c9fee24903f3bfe351497a9952758ee5f4b57d7ed6236ab5082ed85e1ae8c07d5600e0587f652d36727904b3e310df41a656a365d1a7836395d820181584050bf79071ecaf08a966228c712295a17da53994dc781a22103602afe656276ef83ba83a1004845b1e979e0944abff3cd8c7ceef834a8f5eeeca0e8f720fa38f4'); + +-- This table some local metadata about identities +CREATE TABLE named_identity +( + identifier TEXT NOT NULL UNIQUE, -- Identity identifier + name TEXT UNIQUE, -- user-specified name + vault_name TEXT NOT NULL, -- name of the vault used to store the identity keys + is_default BOOLEAN DEFAULT FALSE -- boolean indicating if this identity is the default one +); + +-- This table stores the time when a given identity was enrolled +-- In the current project +CREATE TABLE identity_enrollment +( + identifier TEXT NOT NULL UNIQUE, -- Identifier of the identity + enrolled_at INTEGER NOT NULL, -- UNIX timestamp in seconds + email TEXT -- Enrollment email +); + +-- This table lists attributes associated to a given identity +CREATE TABLE identity_attributes +( + identifier TEXT PRIMARY KEY, -- identity possessing those attributes + attributes BYTEA NOT NULL, -- serialized list of attribute names and values for the identity + added INTEGER NOT NULL, -- UNIX timestamp in seconds: when those attributes were inserted in the database + expires INTEGER, -- optional UNIX timestamp in seconds: when those attributes expire + attested_by TEXT, -- optional identifier which attested of these attributes + node_name TEXT NOT NULL -- node name to isolate attributes that each node knows +); + +CREATE UNIQUE INDEX identity_attributes_index ON identity_attributes (identifier, node_name); + +CREATE INDEX identity_attributes_identifier_attested_by_node_name_index ON identity_attributes (identifier, attested_by, node_name); + +CREATE INDEX identity_attributes_expires_node_name_index ON identity_attributes (expires, node_name); + +CREATE INDEX identity_identifier_index ON identity_attributes (identifier); + +CREATE INDEX identity_node_name_index ON identity_attributes (node_name); + + +-- This table stores purpose keys that have been created by a given identity +CREATE TABLE purpose_key +( + identifier TEXT NOT NULL, -- Identity identifier + purpose TEXT NOT NULL, -- Purpose of the key: SecureChannels, or Credentials + purpose_key_attestation BYTEA NOT NULL -- Encoded attestation: attestation data and attestation signature +); + +CREATE UNIQUE INDEX purpose_key_index ON purpose_key (identifier, purpose); + +---------- +-- VAULTS +---------- + +-- This table stores vault metadata when several vaults have been created locally +CREATE TABLE vault +( + name TEXT PRIMARY KEY, -- User-specified name for a vault + path TEXT NULL, -- Path where the vault is saved, This path can the current database path. In that case the vault data is stored in the *-secrets table below + is_default BOOLEAN, -- boolean indicating if this vault is the default one (0 means true) + is_kms BOOLEAN -- boolean indicating if this vault is a KMS one (0 means true). In that case only key handles are stored in the database +); + +-- This table stores secrets for signing data +CREATE TABLE signing_secret +( + handle BYTEA PRIMARY KEY, -- Secret handle + secret_type TEXT NOT NULL, -- Secret type (EdDSACurve25519 or ECDSASHA256CurveP256) + secret BYTEA NOT NULL -- Secret binary +); + +-- This table stores secrets for encrypting / decrypting data +CREATE TABLE x25519_secret +( + handle BYTEA PRIMARY KEY, -- Secret handle + secret BYTEA NOT NULL -- Secret binary +); + + +--------------- +-- CREDENTIALS +--------------- + +-- This table stores credentials as received by the application +CREATE TABLE credential +( + subject_identifier TEXT NOT NULL, + issuer_identifier TEXT NOT NULL, + scope TEXT NOT NULL, + credential BYTEA NOT NULL, + expires_at INTEGER, + node_name TEXT NOT NULL -- node name to isolate credential that each node has +); + +CREATE UNIQUE INDEX credential_issuer_subject_scope_index ON credential (issuer_identifier, subject_identifier, scope); +CREATE UNIQUE INDEX credential_issuer_subject_index ON credential(issuer_identifier, subject_identifier); + +------------------ +-- AUTHORITY +------------------ + +CREATE TABLE authority_member +( + identifier TEXT NOT NULL UNIQUE, + added_by TEXT NOT NULL, + added_at INTEGER NOT NULL, + is_pre_trusted BOOLEAN NOT NULL, + attributes BYTEA +); + +CREATE UNIQUE INDEX authority_member_identifier_index ON authority_member(identifier); +CREATE INDEX authority_member_is_pre_trusted_index ON authority_member(is_pre_trusted); + +-- Reference is a random string that uniquely identifies an enrollment token. However, unlike the one_time_code, +-- it's not sensitive so can be logged and used to track a lifecycle of a specific enrollment token. +CREATE TABLE authority_enrollment_token +( + one_time_code TEXT NOT NULL UNIQUE, + issued_by TEXT NOT NULL, + created_at INTEGER NOT NULL, + expires_at INTEGER NOT NULL, + ttl_count INTEGER NOT NULL, + attributes BYTEA, + reference TEXT +); + +CREATE UNIQUE INDEX authority_enrollment_token_one_time_code_index ON authority_enrollment_token(one_time_code); +CREATE INDEX authority_enrollment_token_expires_at_index ON authority_enrollment_token(expires_at); + +-- This table stores policies. A policy is an expression which +-- can be evaluated against an environment (a list of name/value pairs) +-- to assess if a given action can be performed on a given resource +CREATE TABLE resource_policy +( + resource_name TEXT NOT NULL, -- resource name + action TEXT NOT NULL, -- action name + expression TEXT NOT NULL, -- encoded expression to evaluate + node_name TEXT NOT NULL -- node name +); + +CREATE UNIQUE INDEX resource_policy_index ON resource_policy (node_name, resource_name, action); + +-- Create a new table for resource type policies +CREATE TABLE resource_type_policy +( + resource_type TEXT NOT NULL, -- resource type + action TEXT NOT NULL, -- action name + expression TEXT NOT NULL, -- encoded expression to evaluate + node_name TEXT NOT NULL -- node name +); +CREATE UNIQUE INDEX resource_type_policy_index ON resource_type_policy (node_name, resource_type, action); + +-- Create a new table for resource to resource type mapping +CREATE TABLE resource +( + resource_name TEXT NOT NULL, -- resource name + resource_type TEXT NOT NULL, -- resource type + node_name TEXT NOT NULL -- node name +); +CREATE UNIQUE INDEX resource_index ON resource (node_name, resource_name, resource_type); + + +--------- +-- NODES +--------- + +-- This table stores information about local nodes +CREATE TABLE node +( + name TEXT PRIMARY KEY, -- Node name + identifier TEXT NOT NULL, -- Identifier of the default identity associated to the node + verbosity INTEGER NOT NULL, -- Verbosity level used for logging + is_default BOOLEAN NOT NULL, -- boolean indicating if this node is the default one (0 means true) + is_authority BOOLEAN NOT NULL, -- boolean indicating if this node is an authority node (0 means true). This boolean is used to be able to show an authority node as UP even if its TCP listener cannot be accessed. + tcp_listener_address TEXT, -- Socket address for the node default TCP Listener (can be NULL if the node has not been started) + pid INTEGER, -- Current process id of the node if it has been started + http_server_address TEXT -- Address of the server supporting the HTTP status endpoint for the node +); + +-- This table stores the project name to use for a given node +CREATE TABLE node_project +( + node_name TEXT PRIMARY KEY, -- Node name + project_name TEXT NOT NULL -- Project name +); + +--------------------------- +-- PROJECTS, SPACES, USERS +--------------------------- + +-- This table store data about projects as returned by the Controller +CREATE TABLE project +( + project_id TEXT PRIMARY KEY, -- Identifier of the project + project_name TEXT NOT NULL, -- Name of the project + is_default BOOLEAN NOT NULL, -- boolean indicating if this project is the default one (0 means true) + space_id TEXT NOT NULL, -- Identifier of the space associated to the project + space_name TEXT NOT NULL, -- Name of the space associated to the project + project_identifier TEXT, -- optional: identifier of the project identity + access_route TEXT NOT NULL, -- Route used to create a secure channel to the project + authority_change_history TEXT, -- Change history for the authority identity + authority_access_route TEXT, -- Route te the authority associated to the project + version TEXT, -- Orchestrator software version + running BOOLEAN, -- boolean indicating if this project is currently accessible + operation_id TEXT, -- optional id of the operation currently creating the project on the Controller side + project_change_history TEXT -- Change history for the project identity +); + +-- This table provides the list of users associated to a given project +CREATE TABLE user_project +( + user_email TEXT NOT NULL, -- User email + project_id TEXT NOT NULL -- Project id +); + +-- This table provides additional information for users associated to a project or a space +CREATE TABLE user_role +( + user_id INTEGER NOT NULL, -- User id + project_id TEXT NOT NULL, -- Project id + user_email TEXT NOT NULL, -- User email + role TEXT NOT NULL, -- Role of the user: admin or member + scope TEXT NOT NULL -- Scope of the role: space, project, or service +); + +-- This table stores data about spaces as returned by the controller +CREATE TABLE space +( + space_id TEXT PRIMARY KEY, -- Identifier of the space + space_name TEXT NOT NULL, -- Name of the space + is_default BOOLEAN NOT NULL -- boolean indicating if this project is the default one (0 means true) +); + +-- This table provides the list of users associated to a given project +CREATE TABLE user_space +( + user_email TEXT NOT NULL, -- User email + space_id TEXT NOT NULL -- Space id +); + +-- This table provides additional information for users after they have been authenticated +CREATE TABLE "user" +( + email TEXT PRIMARY KEY, -- User email + sub TEXT NOT NULL, -- (Sub)ject: unique identifier for the user + nickname TEXT NOT NULL, -- User nickname (or handle) + name TEXT NOT NULL, -- User name + picture TEXT NOT NULL, -- Link to a user picture + updated_at TEXT NOT NULL, -- ISO-8601 date: when this user information was last update + email_verified BOOLEAN NOT NULL, -- boolean indicating if the user email has been verified (0 means true) + is_default BOOLEAN NOT NULL -- boolean indicating if this user is the default user locally (0 means true) +); + +------------------- +-- SECURE CHANNELS +------------------- + +-- This table stores secure channels in order to restore them on a restart +CREATE TABLE secure_channel +( + role TEXT NOT NULL, + my_identifier TEXT NOT NULL, + their_identifier TEXT NOT NULL, + decryptor_remote_address TEXT PRIMARY KEY, + decryptor_api_address TEXT NOT NULL, + decryption_key_handle BYTEA NOT NULL + -- TODO: Add date? +); + +CREATE UNIQUE INDEX secure_channel_decryptor_api_address_index ON secure_channel(decryptor_remote_address); + +-- This table stores aead secrets +CREATE TABLE aead_secret +( + handle BYTEA PRIMARY KEY, -- Secret handle + type TEXT NOT NULL, -- Secret type + secret BYTEA NOT NULL -- Secret binary +); + +--------------- +-- APPLICATION +--------------- + +-- This table stores the current state of an outlet created to expose a service with the desktop application +CREATE TABLE tcp_outlet_status +( + node_name TEXT NOT NULL, -- Node where that tcp outlet has been created + socket_addr TEXT NOT NULL, -- Socket address that the outlet connects to + worker_addr TEXT NOT NULL, -- Worker address for the outlet itself + payload TEXT -- Optional status payload +); + +-- This table stores the current state of an inlet created to expose a service with the desktop application +CREATE TABLE tcp_inlet +( + node_name TEXT NOT NULL, -- Node where that tcp inlet has been created + bind_addr TEXT NOT NULL, -- Input address to connect to + outlet_addr TEXT NOT NULL, -- MultiAddress to the outlet + alias TEXT NOT NULL -- Alias for that inlet +); + +-- This table stores the list of services that a user has been invited to connect to +-- via the desktop application +CREATE TABLE incoming_service +( + invitation_id TEXT PRIMARY KEY, -- Invitation id + enabled BOOLEAN NOT NULL, -- boolean indicating if the user wants to service to be accessible (0 means true) + name TEXT NULL -- Optional user-defined name for the service +); + +---------- +-- ADDONS +---------- + +-- This table stores the data necessary to configure the Okta addon +CREATE TABLE okta_config +( + project_id TEXT NOT NULL, -- Project id of the project using the addon + tenant_base_url TEXT NOT NULL, -- Base URL of the tenant + client_id TEXT NOT NULL, -- Client id + certificate TEXT NOT NULL, -- Certificate + attributes TEXT -- Comma-separated list of attribute names +); + +-- This table stores the data necessary to configure the Kafka addons +CREATE TABLE kafka_config +( + project_id TEXT NOT NULL, -- Project id of the project using the addon + bootstrap_server TEXT NOT NULL -- URL of the bootstrap server +); diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20231006100000_create_database.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20231006100000_create_database.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20231006100000_create_database.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20231006100000_create_database.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20231230100000_add_rust_migrations.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20231230100000_add_rust_migrations.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20231230100000_add_rust_migrations.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20231230100000_add_rust_migrations.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20231231100000_node_name_identity_attributes.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20231231100000_node_name_identity_attributes.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20231231100000_node_name_identity_attributes.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20231231100000_node_name_identity_attributes.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240108100000_rename_confluent_config_table.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240108100000_rename_confluent_config_table.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240108100000_rename_confluent_config_table.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240108100000_rename_confluent_config_table.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240111100001_add_authority_tables.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240111100001_add_authority_tables.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240111100001_add_authority_tables.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240111100001_add_authority_tables.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240111100002_delete_trust_context.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240111100002_delete_trust_context.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240111100002_delete_trust_context.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240111100002_delete_trust_context.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240111100003_add_credential.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240111100003_add_credential.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240111100003_add_credential.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240111100003_add_credential.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240212100000_split_policies.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240212100000_split_policies.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240212100000_split_policies.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240212100000_split_policies.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240212100001_outlet_remove_alias.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240212100001_outlet_remove_alias.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240212100001_outlet_remove_alias.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240212100001_outlet_remove_alias.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240213100000_add_controller_history.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240213100000_add_controller_history.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240213100000_add_controller_history.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240213100000_add_controller_history.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240213100001_add_enrollment_token_reference.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240213100001_add_enrollment_token_reference.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240213100001_add_enrollment_token_reference.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240213100001_add_enrollment_token_reference.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240214100000_extend_project.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240214100000_extend_project.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240214100000_extend_project.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240214100000_extend_project.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240307100000_credential_add_scope.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240307100000_credential_add_scope.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240307100000_credential_add_scope.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240307100000_credential_add_scope.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240314150000_tcp_portals.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240314150000_tcp_portals.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240314150000_tcp_portals.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240314150000_tcp_portals.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240321100000_add_enrollment_email.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240321100000_add_enrollment_email.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240321100000_add_enrollment_email.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240321100000_add_enrollment_email.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240507100000_node_add_http_column.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240507100000_node_add_http_column.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240507100000_node_add_http_column.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240507100000_node_add_http_column.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240527100000_add_sc_persistence.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240527100000_add_sc_persistence.sql similarity index 100% rename from implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/20240527100000_add_sc_persistence.sql rename to implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240527100000_add_sc_persistence.sql diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240619100000_database_vault.sql b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240619100000_database_vault.sql new file mode 100644 index 00000000000..391d25278bc --- /dev/null +++ b/implementations/rust/ockam/ockam_node/src/storage/database/migrations/node_migrations/sql/sqlite/20240619100000_database_vault.sql @@ -0,0 +1,29 @@ +-- This migration allows the path column to be NULL +-- When the path is NULL then the vault 'name' has all its keys stored in the current database +CREATE TABLE new_vault +( + name TEXT PRIMARY KEY, -- User-specified name for a vault + path TEXT NULL, -- Path where the vault is saved + is_default INTEGER, -- boolean indicating if this vault is the default one (1 means true) + is_kms INTEGER -- boolean indicating if this signing keys are stored in an AWS KMS (1 means true) +); + +INSERT INTO new_vault (name, path, is_default, is_kms) +SELECT + name, + -- set the path to NULL when the vault is stored in the database + CASE + WHEN path LIKE '%database.sqlite3%' THEN NULL + ELSE path + END as path, + -- fix the setting of the is_default flag which could occur more than once + CASE + WHEN name = 'default' THEN 1 + ELSE 0 + END as is_default, + is_kms +FROM vault; + +DROP TABLE vault; + +ALTER TABLE new_vault RENAME TO vault; diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/mod.rs b/implementations/rust/ockam/ockam_node/src/storage/database/mod.rs index 5b29c8deb90..aa19faf6447 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/mod.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/mod.rs @@ -1,7 +1,9 @@ +mod database_configuration; mod migrations; mod sqlx_database; -mod sqlx_types; +mod sqlx_from_row_types; +pub use database_configuration::*; pub use migrations::*; pub use sqlx_database::*; -pub use sqlx_types::*; +pub use sqlx_from_row_types::*; diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_database.rs b/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_database.rs index 1ffc247c7b7..fe4f8390521 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_database.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_database.rs @@ -1,20 +1,26 @@ use core::fmt::{Debug, Formatter}; -use sqlx::pool::PoolOptions; -use sqlx::sqlite::SqliteConnectOptions; +use core::str::FromStr; +use std::future::Future; use std::ops::Deref; use std::path::{Path, PathBuf}; use std::time::Duration; use ockam_core::errcode::{Kind, Origin}; -use sqlx::{ConnectOptions, SqlitePool}; +use sqlx::any::{install_default_drivers, AnyConnectOptions}; +use sqlx::pool::PoolOptions; +use sqlx::{Any, ConnectOptions, Pool}; +use tempfile::NamedTempFile; use tokio_retry::strategy::{jitter, FixedInterval}; use tokio_retry::Retry; use tracing::debug; use tracing::log::LevelFilter; +use crate::database::database_configuration::DatabaseConfiguration; use crate::database::migrations::application_migration_set::ApplicationMigrationSet; use crate::database::migrations::node_migration_set::NodeMigrationSet; use crate::database::migrations::MigrationSet; +use crate::database::DatabaseType; +use ockam_core::compat::rand::random_string; use ockam_core::compat::sync::Arc; use ockam_core::{Error, Result}; @@ -27,9 +33,8 @@ use ockam_core::{Error, Result}; #[derive(Clone)] pub struct SqlxDatabase { /// Pool of connections to the database - pub pool: Arc, - - path: Option, + pub pool: Arc>, + configuration: DatabaseConfiguration, } impl Debug for SqlxDatabase { @@ -39,7 +44,7 @@ impl Debug for SqlxDatabase { } impl Deref for SqlxDatabase { - type Target = SqlitePool; + type Target = Pool; fn deref(&self) -> &Self::Target { &self.pool @@ -47,33 +52,85 @@ impl Deref for SqlxDatabase { } impl SqlxDatabase { - /// Constructor for a database persisted on disk - pub async fn create(path: impl AsRef) -> Result { - Self::create_impl(path, Some(NodeMigrationSet)).await + /// Constructor for a database + pub async fn create(configuration: &DatabaseConfiguration) -> Result { + Self::create_impl( + configuration, + Some(NodeMigrationSet::new(configuration.database_type())), + ) + .await + } + + /// Constructor for an application database + pub async fn create_application_database( + configuration: &DatabaseConfiguration, + ) -> Result { + Self::create_impl( + configuration, + Some(ApplicationMigrationSet::new(configuration.database_type())), + ) + .await + } + + /// Constructor for a sqlite database + pub async fn create_sqlite(path: &Path) -> Result { + Self::create(&DatabaseConfiguration::sqlite(path)).await + } + + /// Constructor for a sqlite application database + pub async fn create_application_sqlite(path: &Path) -> Result { + Self::create_application_database(&DatabaseConfiguration::sqlite(path)).await + } + + /// Constructor for a local postgres database with no data + pub async fn create_new_postgres() -> Result { + match DatabaseConfiguration::postgres()? { + Some(configuration) => { + let db = Self::create_no_migration(&configuration).await?; + db.drop_all_postgres_tables().await?; + SqlxDatabase::create(&configuration).await + }, + None => Err(Error::new(Origin::Core, Kind::NotFound, "There is no postgres database configuration, or it is incomplete. Please run ockam environment to check the database environment variables".to_string())), + } + } + + /// Constructor for a local application postgres database with no data + pub async fn create_new_application_postgres() -> Result { + match DatabaseConfiguration::postgres()? { + Some(configuration) => { + let db = Self::create_application_no_migration(&configuration).await?; + db.drop_all_postgres_tables().await?; + SqlxDatabase::create_application_database(&configuration).await + }, + None => Err(Error::new(Origin::Core, Kind::NotFound, "There is no postgres database configuration, or it is incomplete. Please run ockam environment to check the database environment variables".to_string())), + } } /// Constructor for a database persisted on disk, with a specific schema / migration pub async fn create_with_migration( - path: impl AsRef, + configuration: &DatabaseConfiguration, migration_set: impl MigrationSet, ) -> Result { - Self::create_impl(path, Some(migration_set)).await + Self::create_impl(configuration, Some(migration_set)).await } /// Constructor for a database persisted on disk without migration - pub async fn create_no_migration(path: impl AsRef) -> Result { - Self::create_impl(path, None::).await + pub async fn create_no_migration(configuration: &DatabaseConfiguration) -> Result { + Self::create_impl(configuration, None::).await + } + + /// Constructor for an application database persisted on disk without migration + pub async fn create_application_no_migration( + configuration: &DatabaseConfiguration, + ) -> Result { + Self::create_impl(configuration, None::).await } async fn create_impl( - path: impl AsRef, + configuration: &DatabaseConfiguration, migration_set: Option, ) -> Result { - path.as_ref() - .parent() - .map(std::fs::create_dir_all) - .transpose() - .map_err(|e| Error::new(Origin::Api, Kind::Io, e.to_string()))?; + configuration.create_directory_if_necessary()?; // creating a new database might be failing a few times // if the files are currently being held by another pod which is shutting down. @@ -83,7 +140,13 @@ impl SqlxDatabase { .take(10); // limit to 10 retries let db = Retry::spawn(retry_strategy, || async { - Self::create_at(path.as_ref()).await + match Self::create_at(configuration).await { + Ok(db) => Ok(db), + Err(e) => { + println!("{e:?}"); + Err(e) + } + } }) .await?; @@ -98,14 +161,15 @@ impl SqlxDatabase { /// Create a nodes database in memory /// => this database is deleted on an `ockam reset` command! (contrary to the application database below) pub async fn in_memory(usage: &str) -> Result { - Self::in_memory_with_migration(usage, NodeMigrationSet).await + Self::in_memory_with_migration(usage, NodeMigrationSet::new(DatabaseType::Sqlite)).await } /// Create an application database in memory /// The application database which contains the application configurations /// => this database is NOT deleted on an `ockam reset` command! pub async fn application_in_memory(usage: &str) -> Result { - Self::in_memory_with_migration(usage, ApplicationMigrationSet).await + Self::in_memory_with_migration(usage, ApplicationMigrationSet::new(DatabaseType::Sqlite)) + .await } /// Create an in-memory database with a specific migration @@ -114,54 +178,65 @@ impl SqlxDatabase { migration_set: impl MigrationSet, ) -> Result { debug!("create an in memory database for {usage}"); + let configuration = DatabaseConfiguration::sqlite_in_memory(); let pool = Self::create_in_memory_connection_pool().await?; let migrator = migration_set.create_migrator()?; migrator.migrate(&pool).await?; // FIXME: We should be careful if we run multiple nodes in one process let db = SqlxDatabase { pool: Arc::new(pool), - path: None, + configuration, }; Ok(db) } - async fn create_at(path: &Path) -> Result { - let path = path.to_path_buf(); + async fn create_at(configuration: &DatabaseConfiguration) -> Result { // Creates database file if it doesn't exist - let pool = Self::create_connection_pool(path.as_path()).await?; + let pool = Self::create_connection_pool(configuration).await?; Ok(SqlxDatabase { pool: Arc::new(pool), - path: Some(path), + configuration: configuration.clone(), }) } - pub(crate) async fn create_connection_pool(path: &Path) -> Result { - let options = SqliteConnectOptions::new() - .filename(path) - .create_if_missing(true) + pub(crate) async fn create_connection_pool( + configuration: &DatabaseConfiguration, + ) -> Result> { + install_default_drivers(); + let connection_string = configuration.connection_string(); + debug!("connecting to {connection_string}"); + let options = AnyConnectOptions::from_str(&connection_string) + .map_err(Self::map_sql_err)? .log_statements(LevelFilter::Trace) .log_slow_statements(LevelFilter::Trace, Duration::from_secs(1)); - let pool = SqlitePool::connect_with(options) + let pool = Pool::connect_with(options) .await .map_err(Self::map_sql_err)?; Ok(pool) } - pub(crate) async fn create_in_memory_connection_pool() -> Result { + /// Create a connection for a SQLite database + pub async fn create_sqlite_connection_pool(path: &Path) -> Result> { + Self::create_connection_pool(&DatabaseConfiguration::sqlite(path)).await + } + + pub(crate) async fn create_in_memory_connection_pool() -> Result> { + install_default_drivers(); // SQLite in-memory DB get wiped if there is no connection to it. // The below setting tries to ensure there is always an open connection let pool_options = PoolOptions::new().idle_timeout(None).max_lifetime(None); + let file_name = random_string(); let pool = pool_options - .connect("sqlite::memory:") + .connect(format!("sqlite:file:{file_name}?mode=memory&cache=shared").as_str()) .await .map_err(Self::map_sql_err)?; Ok(pool) } - /// Path to the db file - pub fn path(&self) -> Option<&Path> { - self.path.as_deref() + /// Path to the db file if there is one + pub fn path(&self) -> Option { + self.configuration.path() } /// Map a sqlx error into an ockam error @@ -175,6 +250,121 @@ impl SqlxDatabase { pub fn map_decode_err(err: minicbor::decode::Error) -> Error { Error::new(Origin::Application, Kind::Io, err) } + + /// Drop all the postgres database tables + pub async fn drop_all_postgres_tables(&self) -> Result<()> { + self.clean_postgres_node_tables(Clean::Drop, None).await + } + + /// Truncate all the postgres database tables + pub async fn truncate_all_postgres_tables(&self) -> Result<()> { + self.clean_postgres_node_tables(Clean::Truncate, None).await + } + + /// Drop all the database tables _except_ for the journey tables + pub async fn drop_postgres_node_tables(&self) -> Result<()> { + self.clean_postgres_node_tables(Clean::Drop, Some("AND tablename NOT LIKE '%journey%'")) + .await + } + + /// Truncate all the database tables _except_ for the journey tables + pub async fn truncate_postgres_node_tables(&self) -> Result<()> { + self.clean_postgres_node_tables(Clean::Truncate, Some("AND tablename NOT LIKE '%journey%'")) + .await + } + + /// Truncate all the database tables _except_ for the journey tables + async fn clean_postgres_node_tables(&self, clean: Clean, filter: Option<&str>) -> Result<()> { + match self.configuration.database_type() { + DatabaseType::Sqlite => Ok(()), + DatabaseType::Postgres => { + sqlx::query( + format!(r#"DO $$ + DECLARE + r RECORD; + BEGIN + FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public' {}) LOOP + EXECUTE '{} TABLE ' || quote_ident(r.tablename) || ' CASCADE'; + END LOOP; + END $$;"#, filter.unwrap_or(""), clean.as_str(), + ).as_str()) + .execute(&*self.pool) + .await + .void() + } + } + } +} + +enum Clean { + Drop, + Truncate, +} + +impl Clean { + fn as_str(&self) -> &str { + match self { + Clean::Drop => "DROP", + Clean::Truncate => "TRUNCATE", + } + } +} + +/// This function can be used to run some test code with the 3 different databases implementations +pub async fn with_dbs(f: F) -> Result<()> +where + F: Fn(SqlxDatabase) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ + let db = SqlxDatabase::in_memory("test").await?; + rethrow("SQLite in memory", f(db)).await?; + + let db_file = NamedTempFile::new().unwrap(); + let db = SqlxDatabase::create_sqlite(db_file.path()).await?; + rethrow("SQLite on disk", f(db)).await?; + + // only run the postgres tests if the OCKAM_POSTGRES_* environment variables are set + if let Ok(db) = SqlxDatabase::create_new_postgres().await { + rethrow("Postgres local", f(db.clone())).await?; + db.drop_all_postgres_tables().await?; + }; + Ok(()) +} + +/// This function can be used to run some test code with the 3 different databases implementations +/// of the application database +pub async fn with_application_dbs(f: F) -> Result<()> +where + F: Fn(SqlxDatabase) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ + let db = SqlxDatabase::application_in_memory("test").await?; + rethrow("SQLite in memory", f(db)).await?; + + let db_file = NamedTempFile::new().unwrap(); + let db = SqlxDatabase::create_application_sqlite(db_file.path()).await?; + rethrow("SQLite on disk", f(db)).await?; + + // only run the postgres tests if the OCKAM_POSTGRES_* environment variables are set + if let Ok(db) = SqlxDatabase::create_new_application_postgres().await { + rethrow("Postgres local", f(db.clone())).await?; + db.drop_all_postgres_tables().await?; + } + Ok(()) +} + +/// Specify which database was used to run a test +async fn rethrow(database_type: &str, f: Fut) -> Result<()> +where + Fut: Future> + Send + 'static, +{ + f.await.map_err(|e| { + Error::new( + Origin::Core, + Kind::Invalid, + format!("{database_type}: {e:?}"), + ) + }) } /// This trait provides some syntax for transforming sqlx errors into ockam errors @@ -223,22 +413,28 @@ impl ToVoid for core::result::Result { } } -#[cfg(test)] -mod tests { - use sqlx::sqlite::SqliteQueryResult; - use sqlx::FromRow; - use tempfile::NamedTempFile; - - use crate::database::ToSqlxType; +/// Create a temporary database file that won't be cleaned-up automatically +pub fn create_temp_db_file() -> Result { + let (_, path) = NamedTempFile::new() + .map_err(|e| Error::new(Origin::Core, Kind::Io, format!("{e:?}")))? + .keep() + .map_err(|e| Error::new(Origin::Core, Kind::Io, format!("{e:?}")))?; + Ok(path) +} +#[cfg(test)] +pub mod tests { use super::*; + use crate::database::Boolean; + use sqlx::any::AnyQueryResult; + use sqlx::FromRow; /// This is a sanity check to test that the database can be created with a file path /// and that migrations are running ok, at least for one table #[tokio::test] - async fn test_create_identity_table() -> Result<()> { + async fn test_create_sqlite_database() -> Result<()> { let db_file = NamedTempFile::new().unwrap(); - let db = SqlxDatabase::create(db_file.path()).await?; + let db = SqlxDatabase::create_sqlite(db_file.path()).await?; let inserted = insert_identity(&db).await.unwrap(); @@ -246,31 +442,51 @@ mod tests { Ok(()) } + /// This is a sanity check to test that we can use Postgres as a database + #[tokio::test] + async fn test_create_postgres_database() -> Result<()> { + if let Some(configuration) = DatabaseConfiguration::postgres()? { + let db = SqlxDatabase::create_no_migration(&configuration).await?; + db.drop_all_postgres_tables().await?; + + let db = SqlxDatabase::create(&configuration).await?; + let inserted = insert_identity(&db).await.unwrap(); + assert_eq!(inserted.rows_affected(), 1); + } + Ok(()) + } + /// This test checks that we can run a query and return an entity #[tokio::test] async fn test_query() -> Result<()> { let db_file = NamedTempFile::new().unwrap(); - let db = SqlxDatabase::create(db_file.path()).await?; + let db_file = db_file.path(); + let db = SqlxDatabase::create_sqlite(db_file).await?; insert_identity(&db).await.unwrap(); // successful query let result: Option = - sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1") + sqlx::query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE identifier = $1") .bind("Ifa804b7fca12a19eed206ae180b5b576860ae651") .fetch_optional(&*db.pool) .await .unwrap(); assert_eq!( result, - Some(IdentifierRow( - "Ifa804b7fca12a19eed206ae180b5b576860ae651".into() - )) + Some(IdentifierRow { + identifier: "Ifa804b7fca12a19eed206ae180b5b576860ae651".into(), + name: "identity-1".to_string(), + vault_name: "vault-1".to_string(), + // This line tests the proper deserialization of a Boolean + // in SQLite where a Boolean maps to an INTEGER + is_default: Boolean::new(true), + }) ); // failed query let result: Option = - sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1") + sqlx::query_as("SELECT identifier FROM named_identity WHERE identifier = $1") .bind("x") .fetch_optional(&*db.pool) .await @@ -280,15 +496,22 @@ mod tests { } /// HELPERS - async fn insert_identity(db: &SqlxDatabase) -> Result { - sqlx::query("INSERT INTO identity VALUES (?1, ?2)") + async fn insert_identity(db: &SqlxDatabase) -> Result { + sqlx::query("INSERT INTO named_identity (identifier, name, vault_name, is_default) VALUES ($1, $2, $3, $4)") .bind("Ifa804b7fca12a19eed206ae180b5b576860ae651") - .bind("123".to_sql()) + .bind("identity-1") + .bind("vault-1") + .bind(true) .execute(&*db.pool) .await .into_core() } #[derive(FromRow, PartialEq, Eq, Debug)] - struct IdentifierRow(String); + struct IdentifierRow { + identifier: String, + name: String, + vault_name: String, + is_default: Boolean, + } } diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_from_row_types.rs b/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_from_row_types.rs new file mode 100644 index 00000000000..ce7e5c79a03 --- /dev/null +++ b/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_from_row_types.rs @@ -0,0 +1,104 @@ +use sqlx::database::HasValueRef; +use sqlx::error::BoxDynError; +use sqlx::postgres::any::AnyTypeInfoKind; +use sqlx::{Any, Database, Decode, Type, Value, ValueRef}; + +/// This type is used to map Option fields for the types deriving `FromRow` +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Nullable(Option); + +impl From> for Nullable { + fn from(value: Option) -> Self { + Nullable(value) + } +} + +impl From> for Option { + fn from(value: Nullable) -> Self { + value.0 + } +} + +impl Nullable { + /// Create a new Nullable value + pub fn new(t: Option) -> Self { + Nullable(t) + } + + /// Return the Option corresponding to this value in the database + pub fn to_option(&self) -> Option { + self.0.clone() + } +} + +impl<'d, T: Decode<'d, Any>> Decode<'d, Any> for Nullable { + fn decode(value: >::ValueRef) -> Result { + match value.type_info().kind() { + AnyTypeInfoKind::Null => Ok(Nullable(None)), + _ => Ok(Nullable(Some(T::decode(value)?))), + } + } +} + +impl> Type for Nullable { + fn type_info() -> ::TypeInfo { + >::type_info() + } + + fn compatible(ty: &::TypeInfo) -> bool { + >::compatible(ty) || ty.kind() == AnyTypeInfoKind::Null + } +} + +/// This type is used to map boolean fields for the types deriving `FrowRow`. +/// Postgres provides a proper boolean type but SQLite maps them as integers. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Boolean(bool); + +impl From for Boolean { + fn from(value: bool) -> Self { + Boolean(value) + } +} + +impl From for bool { + fn from(value: Boolean) -> Self { + value.0 + } +} + +impl Boolean { + /// Create a new Boolean value + pub fn new(b: bool) -> Self { + Boolean(b) + } + + /// Return the bool value + pub fn to_bool(&self) -> bool { + self.0 + } +} + +impl<'d> Decode<'d, Any> for Boolean { + fn decode(value: >::ValueRef) -> Result { + match value.type_info().kind() { + AnyTypeInfoKind::Bool => Ok(Boolean(ValueRef::to_owned(&value).decode())), + AnyTypeInfoKind::Integer => { + let v: i64 = ValueRef::to_owned(&value).decode(); + Ok(Boolean(v == 1)) + } + other => Err(format!("expected BOOLEAN or INTEGER, got {:?}", other).into()), + } + } +} + +impl Type for Boolean { + fn type_info() -> ::TypeInfo { + >::type_info() + } + + fn compatible(ty: &::TypeInfo) -> bool { + >::type_info().kind() == ty.kind() + || ty.kind() == AnyTypeInfoKind::Integer + } +} diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_types.rs b/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_types.rs deleted file mode 100644 index c02e5fcf476..00000000000 --- a/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_types.rs +++ /dev/null @@ -1,206 +0,0 @@ -use chrono::{DateTime, Utc}; -use std::net::SocketAddr; -use std::path::{Path, PathBuf}; - -use ockam_core::Address; -use sqlx::database::HasArguments; -use sqlx::encode::IsNull; -use sqlx::{Database, Encode, Sqlite, Type}; -use time::OffsetDateTime; - -/// This enum represents the set of types that we currently support in our database -/// Since we support only Sqlite at the moment, those types are close to what is supported by Sqlite: -/// https://www.sqlite.org/datatype3.html -/// -/// The purpose of this type is to ease the serialization of data types in Ockam into data types in -/// our database. For example, if we describe how to translate an `Identifier` into some `Text` then -/// we can use the `Text` as a parameter in a sqlx query. -/// -/// Note: see the `ToSqlxType` trait and its instances for how the conversion is done -/// -pub enum SqlxType { - /// This type represents text in the database - Text(String), - /// This type represents arbitrary bytes in the database - Blob(Vec), - /// This type represents ints, signed or unsigned - Integer(i64), - /// This type represents floats - #[allow(unused)] - Real(f64), -} - -/// The SqlxType implements the Type trait from sqlx to allow its values to be serialized -/// to an Sqlite database -impl Type for SqlxType { - fn type_info() -> ::TypeInfo { - as Type>::type_info() - } -} - -/// The SqlType implements the Encode trait from sqlx to allow its values to be serialized -/// to an Sqlite database. There is a 1 to 1 mapping with the database native types -impl Encode<'_, Sqlite> for SqlxType { - fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { - match self { - SqlxType::Text(v) => >::encode_by_ref(v, buf), - SqlxType::Blob(v) => as Encode<'_, Sqlite>>::encode_by_ref(v, buf), - SqlxType::Integer(v) => >::encode_by_ref(v, buf), - SqlxType::Real(v) => >::encode_by_ref(v, buf), - } - } - - fn produces(&self) -> Option<::TypeInfo> { - Some(match self { - SqlxType::Text(_) => >::type_info(), - SqlxType::Blob(_) => as Type>::type_info(), - SqlxType::Integer(_) => >::type_info(), - SqlxType::Real(_) => >::type_info(), - }) - } -} - -/// This trait can be implemented by any type that can be converted to a database type -/// Typically an `Identifier` (to a `Text`), a `TimestampInSeconds` (to an `Integer`) etc... -/// -/// This allows a value to be used as a bind parameters in a sqlx query for example: -/// -/// use std::str::FromStr; -/// use sqlx::query_as; -/// use ockam_node::database::{SqlxType, ToSqlxType}; -/// -/// // newtype for a UNIX-like timestamp -/// struct TimestampInSeconds(u64); -/// -/// // this implementation maps the TimestampInSecond type to one of the types that Sqlx -/// // can serialize for sqlite -/// impl ToSqlxType for TimestampInSeconds { -/// fn to_sql(&self) -> SqlxType { -/// self.0.to_sql() -/// } -/// } -/// -/// let timestamp = TimestampInSeconds(10000000); -/// let query = query_as("SELECT identifier, change_history FROM identity WHERE created_at >= $1").bind(timestamp.as_sql()); -/// -/// -pub trait ToSqlxType { - /// Return the appropriate sql type - fn to_sql(&self) -> SqlxType; -} - -impl ToSqlxType for String { - fn to_sql(&self) -> SqlxType { - SqlxType::Text(self.clone()) - } -} - -impl ToSqlxType for &str { - fn to_sql(&self) -> SqlxType { - self.to_string().to_sql() - } -} - -impl ToSqlxType for bool { - fn to_sql(&self) -> SqlxType { - if *self { - 1.to_sql() - } else { - 0.to_sql() - } - } -} - -impl ToSqlxType for u64 { - fn to_sql(&self) -> SqlxType { - SqlxType::Integer(*self as i64) - } -} - -impl ToSqlxType for u32 { - fn to_sql(&self) -> SqlxType { - SqlxType::Integer(*self as i64) - } -} - -impl ToSqlxType for u16 { - fn to_sql(&self) -> SqlxType { - SqlxType::Integer(*self as i64) - } -} - -impl ToSqlxType for u8 { - fn to_sql(&self) -> SqlxType { - SqlxType::Integer(*self as i64) - } -} - -impl ToSqlxType for i32 { - fn to_sql(&self) -> SqlxType { - SqlxType::Integer(*self as i64) - } -} - -impl ToSqlxType for i16 { - fn to_sql(&self) -> SqlxType { - SqlxType::Integer(*self as i64) - } -} - -impl ToSqlxType for i8 { - fn to_sql(&self) -> SqlxType { - SqlxType::Integer(*self as i64) - } -} - -impl ToSqlxType for OffsetDateTime { - fn to_sql(&self) -> SqlxType { - SqlxType::Integer(self.unix_timestamp()) - } -} - -impl ToSqlxType for DateTime { - fn to_sql(&self) -> SqlxType { - self.to_rfc3339().to_sql() - } -} - -impl ToSqlxType for Vec { - fn to_sql(&self) -> SqlxType { - SqlxType::Blob(self.clone()) - } -} - -impl ToSqlxType for &[u8; 32] { - fn to_sql(&self) -> SqlxType { - SqlxType::Blob(self.to_vec().clone()) - } -} - -impl ToSqlxType for SocketAddr { - fn to_sql(&self) -> SqlxType { - SqlxType::Text(self.to_string()) - } -} - -impl ToSqlxType for Address { - fn to_sql(&self) -> SqlxType { - SqlxType::Text(self.to_string()) - } -} - -impl ToSqlxType for PathBuf { - fn to_sql(&self) -> SqlxType { - self.as_path().to_sql() - } -} - -impl ToSqlxType for &Path { - fn to_sql(&self) -> SqlxType { - SqlxType::Text( - self.to_str() - .unwrap_or("a path should be a valid string") - .into(), - ) - } -} diff --git a/implementations/rust/ockam/ockam_vault/Cargo.toml b/implementations/rust/ockam/ockam_vault/Cargo.toml index f9a08f767a6..a5586de0d66 100644 --- a/implementations/rust/ockam/ockam_vault/Cargo.toml +++ b/implementations/rust/ockam/ockam_vault/Cargo.toml @@ -89,7 +89,7 @@ rand = { version = "0.8", default-features = false } rand_pcg = { version = "0.3.1", default-features = false, optional = true } serde = { version = "1", default-features = false, features = ["derive"] } sha2 = { version = "0.10", default-features = false } -sqlx = { version = "0.7.4", optional = true } +sqlx = { git = "https://github.com/etorreborre/sqlx", rev = "5fec648d2de0cbeed738dcf1c6f5bc9194fc439b", optional = true } static_assertions = "1.1.0" tracing = { version = "0.1", default-features = false } x25519-dalek = { version = "2.0.1", default_features = false, features = ["precomputed-tables", "static_secrets", "zeroize"] } diff --git a/implementations/rust/ockam/ockam_vault/src/software/vault_for_signing/types.rs b/implementations/rust/ockam/ockam_vault/src/software/vault_for_signing/types.rs index c348caffb95..e4b4015e914 100644 --- a/implementations/rust/ockam/ockam_vault/src/software/vault_for_signing/types.rs +++ b/implementations/rust/ockam/ockam_vault/src/software/vault_for_signing/types.rs @@ -47,6 +47,16 @@ pub enum SigningSecret { ECDSASHA256CurveP256(ECDSASHA256CurveP256SecretKey), } +impl SigningSecret { + /// Return the secret key + pub fn key(&self) -> &[u8; 32] { + match self { + SigningSecret::EdDSACurve25519(k) => k.key(), + SigningSecret::ECDSASHA256CurveP256(k) => k.key(), + } + } +} + const_assert_eq!( ed25519_dalek::SECRET_KEY_LENGTH, EDDSA_CURVE25519_SECRET_KEY_LENGTH diff --git a/implementations/rust/ockam/ockam_vault/src/storage/secrets_repository_sql.rs b/implementations/rust/ockam/ockam_vault/src/storage/secrets_repository_sql.rs index b16a1ada90e..7ed77bece38 100644 --- a/implementations/rust/ockam/ockam_vault/src/storage/secrets_repository_sql.rs +++ b/implementations/rust/ockam/ockam_vault/src/storage/secrets_repository_sql.rs @@ -1,3 +1,5 @@ +use sqlx::database::HasArguments; +use sqlx::encode::IsNull; use sqlx::*; use tracing::debug; use zeroize::{Zeroize, ZeroizeOnDrop}; @@ -6,7 +8,7 @@ use ockam_core::async_trait; use ockam_core::compat::vec::Vec; use ockam_core::errcode::{Kind, Origin}; use ockam_core::Result; -use ockam_node::database::{FromSqlxError, SqlxDatabase, SqlxType, ToSqlxType, ToVoid}; +use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid}; use crate::storage::secrets_repository::SecretsRepository; @@ -50,15 +52,21 @@ impl SecretsRepository for SecretsSqlxDatabase { SigningSecretKeyHandle::ECDSASHA256CurveP256(_) => EC_DSA_SHA256_CURVE_P256.into(), }; - let query = query("INSERT OR REPLACE INTO signing_secret VALUES (?, ?, ?)") - .bind(handle.to_sql()) - .bind(secret_type.to_sql()) - .bind(secret.to_sql()); + let query = query( + r#" + INSERT INTO signing_secret (handle, secret_type, secret) + VALUES ($1, $2, $3) + ON CONFLICT (handle) + DO UPDATE SET secret_type = $2, secret = $3"#, + ) + .bind(handle) + .bind(secret_type) + .bind(secret); query.execute(&*self.database.pool).await.void() } async fn delete_signing_secret(&self, handle: &SigningSecretKeyHandle) -> Result { - let query = query("DELETE FROM signing_secret WHERE handle = ?").bind(handle.to_sql()); + let query = query("DELETE FROM signing_secret WHERE handle = $1").bind(handle); let res = query.execute(&*self.database.pool).await.into_core()?; Ok(res.rows_affected() != 0) @@ -69,8 +77,8 @@ impl SecretsRepository for SecretsSqlxDatabase { handle: &SigningSecretKeyHandle, ) -> Result> { let query = - query_as("SELECT handle, secret_type, secret FROM signing_secret WHERE handle=?") - .bind(handle.to_sql()); + query_as("SELECT handle, secret_type, secret FROM signing_secret WHERE handle = $1") + .bind(handle); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -93,14 +101,20 @@ impl SecretsRepository for SecretsSqlxDatabase { handle: &X25519SecretKeyHandle, secret: X25519SecretKey, ) -> Result<()> { - let query = query("INSERT OR REPLACE INTO x25519_secret VALUES (?, ?)") - .bind(handle.to_sql()) - .bind(secret.to_sql()); + let query = query( + r#" + INSERT INTO x25519_secret (handle, secret) + VALUES ($1, $2) + ON CONFLICT (handle) + DO UPDATE SET secret = $2"#, + ) + .bind(handle) + .bind(secret); query.execute(&*self.database.pool).await.void() } async fn delete_x25519_secret(&self, handle: &X25519SecretKeyHandle) -> Result { - let query = query("DELETE FROM x25519_secret WHERE handle = ?").bind(handle.to_sql()); + let query = query("DELETE FROM x25519_secret WHERE handle = $1").bind(handle); let res = query.execute(&*self.database.pool).await.into_core()?; Ok(res.rows_affected() != 0) @@ -110,8 +124,8 @@ impl SecretsRepository for SecretsSqlxDatabase { &self, handle: &X25519SecretKeyHandle, ) -> Result> { - let query = query_as("SELECT handle, secret FROM x25519_secret WHERE handle=?") - .bind(handle.to_sql()); + let query = + query_as("SELECT handle, secret FROM x25519_secret WHERE handle = $1").bind(handle); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -133,25 +147,30 @@ impl SecretsRepository for SecretsSqlxDatabase { handle: &AeadSecretKeyHandle, secret: AeadSecret, ) -> Result<()> { - let query = - query("INSERT OR REPLACE INTO aead_secret(handle, type, secret) VALUES (?, ?, ?)") - .bind(handle.to_sql()) - .bind(AEAD_TYPE.to_sql()) - .bind(secret.to_sql()); + let query = query( + r#" + INSERT INTO aead_secret (handle, type, secret) + VALUES ($1, $2, $3) + ON CONFLICT (handle) + DO UPDATE SET type = $2, secret = $3"#, + ) + .bind(handle) + .bind(AEAD_TYPE) + .bind(secret); query.execute(&*self.database.pool).await.void() } async fn delete_aead_secret(&self, handle: &AeadSecretKeyHandle) -> Result { - let query = query("DELETE FROM aead_secret WHERE handle = ?").bind(handle.to_sql()); + let query = query("DELETE FROM aead_secret WHERE handle = $1").bind(handle); let res = query.execute(&*self.database.pool).await.into_core()?; Ok(res.rows_affected() != 0) } async fn get_aead_secret(&self, handle: &AeadSecretKeyHandle) -> Result> { - let query = query_as("SELECT secret FROM aead_secret WHERE handle=? AND type=?") - .bind(handle.to_sql()) - .bind(AEAD_TYPE.to_sql()); + let query = query_as("SELECT secret FROM aead_secret WHERE handle = $1 AND type = $2") + .bind(handle) + .bind(AEAD_TYPE); let row: Option = query .fetch_optional(&*self.database.pool) .await @@ -174,48 +193,87 @@ impl SecretsRepository for SecretsSqlxDatabase { } } -impl ToSqlxType for SigningSecret { - fn to_sql(&self) -> SqlxType { - match self { - SigningSecret::EdDSACurve25519(k) => k.key().to_sql(), - SigningSecret::ECDSASHA256CurveP256(k) => k.key().to_sql(), - } +impl Type for SigningSecret { + fn type_info() -> ::TypeInfo { + as Type>::type_info() + } +} + +impl Encode<'_, Any> for SigningSecret { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + as Encode<'_, Any>>::encode_by_ref(&self.key().to_vec(), buf) + } +} + +impl Type for SigningSecretKeyHandle { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for SigningSecretKeyHandle { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(self.handle(), buf) + } +} + +impl Type for HandleToSecret { + fn type_info() -> ::TypeInfo { + as Type>::type_info() + } +} + +impl Encode<'_, Any> for HandleToSecret { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + as Encode<'_, Any>>::encode_by_ref(self.value(), buf) + } +} + +impl Type for X25519SecretKeyHandle { + fn type_info() -> ::TypeInfo { + >::type_info() + } +} + +impl Encode<'_, Any> for X25519SecretKeyHandle { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.0, buf) } } -impl ToSqlxType for SigningSecretKeyHandle { - fn to_sql(&self) -> SqlxType { - self.handle().to_sql() +impl Type for AeadSecretKeyHandle { + fn type_info() -> ::TypeInfo { + >::type_info() } } -impl ToSqlxType for X25519SecretKeyHandle { - fn to_sql(&self) -> SqlxType { - self.0.value().to_sql() +impl Encode<'_, Any> for AeadSecretKeyHandle { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + >::encode_by_ref(&self.0 .0, buf) } } -impl ToSqlxType for AeadSecretKeyHandle { - fn to_sql(&self) -> SqlxType { - self.0 .0.to_sql() +impl Type for X25519SecretKey { + fn type_info() -> ::TypeInfo { + as Type>::type_info() } } -impl ToSqlxType for HandleToSecret { - fn to_sql(&self) -> SqlxType { - self.value().to_sql() +impl Encode<'_, Any> for X25519SecretKey { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + as Encode<'_, Any>>::encode_by_ref(&self.key().to_vec(), buf) } } -impl ToSqlxType for X25519SecretKey { - fn to_sql(&self) -> SqlxType { - self.key().to_sql() +impl Type for AeadSecret { + fn type_info() -> ::TypeInfo { + as Type>::type_info() } } -impl ToSqlxType for AeadSecret { - fn to_sql(&self) -> SqlxType { - self.0.to_vec().to_sql() +impl Encode<'_, Any> for AeadSecret { + fn encode_by_ref(&self, buf: &mut ::ArgumentBuffer) -> IsNull { + as Encode<'_, Any>>::encode_by_ref(&self.0.to_vec(), buf) } }