From ab4dbad3ed04a9bb7e818bfb6b25c210541fbf8f Mon Sep 17 00:00:00 2001 From: w-gc <25614556+w-gc@users.noreply.github.com> Date: Mon, 10 Feb 2025 10:19:44 +0800 Subject: [PATCH] repo-sync-2025-02-06T20:02:27+0800 (#988) --- .vscode/settings.json | 8 +- MODULE.bazel | 10 +- MODULE.bazel.lock | 76 +-- examples/cpp/utils.cc | 10 +- examples/python/ir_dump/ir_dump.py | 4 +- .../python/ml/flax_llama7b/flax_llama7b.py | 4 +- .../flax_llama7b_split/flax_llama7b_split.py | 4 +- .../python/ml/flax_whisper/flax_whisper.py | 4 +- experimental/squirrel/objectives.cc | 6 +- experimental/squirrel/objectives_test.cc | 36 +- experimental/squirrel/squirrel_demo_main.cc | 16 +- experimental/squirrel/tree_builder.cc | 4 +- experimental/squirrel/utils.cc | 16 +- experimental/squirrel/utils_test.cc | 59 +- pyrightconfig.json | 5 + sml/cluster/tests/BUILD.bazel | 3 +- sml/cluster/tests/kmeans_test.py | 18 +- sml/decomposition/nmf.py | 6 +- sml/decomposition/pca.py | 6 +- sml/decomposition/tests/BUILD.bazel | 5 +- sml/decomposition/tests/nmf_test.py | 8 +- sml/decomposition/tests/pca_test.py | 10 +- sml/ensemble/tests/BUILD.bazel | 5 +- sml/ensemble/tests/adaboost_test.py | 6 +- sml/ensemble/tests/forest_test.py | 6 +- sml/faq.md | 6 +- sml/feature_selection/tests/BUILD.bazel | 3 +- sml/feature_selection/tests/chi2_test.py | 7 +- sml/gaussian_process/tests/BUILD.bazel | 3 +- sml/gaussian_process/tests/gpc_test.py | 4 +- sml/linear_model/tests/BUILD.bazel | 26 +- sml/linear_model/tests/glm_test.py | 4 +- sml/linear_model/tests/logistic_test.py | 6 +- sml/linear_model/tests/pla_test.py | 6 +- sml/linear_model/tests/quantile_test.py | 6 +- sml/linear_model/tests/ridge_test.py | 6 +- sml/linear_model/tests/sgd_classifier_test.py | 6 +- sml/metrics/classification/BUILD.bazel | 3 +- .../classification/classification_test.py | 10 +- sml/metrics/regression/BUILD.bazel | 3 +- sml/metrics/regression/regression_test.py | 18 +- sml/naive_bayes/tests/BUILD.bazel | 3 +- sml/naive_bayes/tests/gnb_test.py | 10 +- sml/neighbors/tests/BUILD.bazel | 3 +- sml/neighbors/tests/knn_test.py | 6 +- sml/preprocessing/tests/BUILD.bazel | 3 +- sml/preprocessing/tests/preprocessing_test.py | 98 +-- sml/svm/emulations/svm_emul.py | 2 +- sml/svm/tests/BUILD.bazel | 3 +- sml/svm/tests/svm_test.py | 6 +- sml/tree/tests/BUILD.bazel | 3 +- sml/tree/tests/tree_test.py | 6 +- sml/utils/BUILD.bazel | 2 +- sml/utils/emulation.py | 4 +- sml/utils/tests/extmath_test.py | 14 +- spu/BUILD.bazel | 4 +- spu/__init__.py | 10 +- spu/api.py | 49 +- spu/libspu.cc | 451 +++++++++++--- spu/libspu.pyi | 291 +++++++++ spu/ops/groupby/BUILD.bazel | 3 +- spu/ops/groupby/groupby_test.py | 8 +- spu/spu_pb2.py | 15 - spu/tests/BUILD.bazel | 21 +- spu/tests/distributed_test.py | 42 +- spu/tests/frontend_test.py | 14 +- spu/tests/jax_compile_test.py | 6 +- spu/tests/jax_sanity_test.py | 16 +- spu/tests/jnp_aby3_r128_test.py | 4 +- spu/tests/jnp_aby3_r64_test.py | 4 +- spu/tests/jnp_cheetah_r64_test.py | 4 +- spu/tests/jnp_debug.py | 6 +- spu/tests/jnp_ref2k_r64_test.py | 4 +- spu/tests/jnp_semi2k_r128_test.py | 10 +- spu/tests/jnp_semi2k_r64_test.py | 4 +- spu/tests/spu_compiler_test.py | 1 - spu/tests/spu_io_test.py | 158 +++-- spu/tests/spu_runtime_test.py | 26 +- spu/utils/BUILD.bazel | 2 +- spu/utils/distributed_impl.py | 69 +-- spu/utils/frontend.py | 21 +- spu/utils/simulation.py | 28 +- src/MODULE.bazel | 3 +- src/MODULE.bazel.lock | 5 +- src/libspu/BUILD.bazel | 12 +- src/libspu/compiler/common/BUILD.bazel | 2 +- .../compiler/common/compilation_context.cc | 4 +- .../compiler/common/compilation_context.h | 6 +- src/libspu/compiler/compile.h | 2 +- src/libspu/compiler/core/core.cc | 18 +- src/libspu/compiler/front_end/BUILD.bazel | 1 + src/libspu/compiler/front_end/fe.cc | 16 +- src/libspu/compiler/front_end/fe.h | 2 +- src/libspu/compiler/tools/spu-translate.cc | 17 +- src/libspu/core/BUILD.bazel | 8 +- src/libspu/core/config.cc | 65 +- src/libspu/core/config.h | 2 +- src/libspu/core/context.cc | 16 +- src/libspu/core/context.h | 7 +- src/libspu/core/ndarray_ref.cc | 1 + src/libspu/core/prelude.h | 12 +- src/libspu/core/pt_buffer_view.h | 3 +- src/libspu/core/type.cc | 38 ++ src/libspu/core/type.h | 31 +- src/libspu/core/type_util.cc | 7 +- src/libspu/core/type_util.h | 3 +- src/libspu/core/value.cc | 27 +- src/libspu/core/value.h | 8 +- src/libspu/device/BUILD.bazel | 5 +- src/libspu/device/api.cc | 35 +- src/libspu/device/api.h | 5 +- src/libspu/device/io.cc | 8 +- src/libspu/device/io_test.cc | 18 +- .../device/pphlo/pphlo_executor_test.cc | 6 +- .../device/pphlo/pphlo_intrinsic_executor.cc | 2 +- .../device/pphlo/pphlo_verifier_test.cc | 8 +- src/libspu/device/test_utils.h | 3 +- .../utils/pphlo_executor_debug_runner.cc | 17 +- .../utils/pphlo_executor_test_runner.cc | 22 +- .../device/utils/pphlo_executor_test_runner.h | 5 +- src/libspu/kernel/hal/fxp_approx.cc | 28 +- src/libspu/kernel/hal/fxp_approx_test.cc | 28 +- src/libspu/kernel/hal/fxp_base.cc | 2 +- src/libspu/kernel/hal/permute.cc | 22 +- src/libspu/kernel/hal/polymorphic_test.cc | 6 +- src/libspu/kernel/hal/ring.cc | 2 +- src/libspu/kernel/hlo/BUILD.bazel | 1 + src/libspu/kernel/hlo/const_test.cc | 2 +- src/libspu/kernel/hlo/sort_test.cc | 23 +- src/libspu/kernel/test_util.cc | 6 +- src/libspu/mpc/BUILD.bazel | 2 +- src/libspu/mpc/ab_api_test.cc | 76 +-- src/libspu/mpc/aby3/BUILD.bazel | 1 + src/libspu/mpc/aby3/protocol.cc | 2 +- src/libspu/mpc/aby3/protocol_test.cc | 12 +- src/libspu/mpc/aby3/type.cc | 17 +- src/libspu/mpc/aby3/type.h | 14 +- src/libspu/mpc/api_test.cc | 12 +- src/libspu/mpc/api_test.h | 3 +- src/libspu/mpc/cheetah/ot/basic_ot_prot.h | 3 +- src/libspu/mpc/cheetah/protocol.cc | 8 +- src/libspu/mpc/cheetah/protocol_ab_test.cc | 13 +- src/libspu/mpc/cheetah/protocol_api_test.cc | 9 +- src/libspu/mpc/cheetah/state.h | 3 +- src/libspu/mpc/cheetah/type.cc | 16 + src/libspu/mpc/cheetah/type.h | 14 +- src/libspu/mpc/common/BUILD.bazel | 1 + src/libspu/mpc/common/pv2k.cc | 18 +- src/libspu/mpc/common/pv2k.h | 14 +- src/libspu/mpc/factory.cc | 18 +- src/libspu/mpc/factory.h | 3 +- src/libspu/mpc/ref2k/ref2k.cc | 2 +- src/libspu/mpc/ref2k/ref2k_test.cc | 6 +- src/libspu/mpc/securenn/protocol.cc | 2 +- src/libspu/mpc/securenn/protocol_test.cc | 12 +- src/libspu/mpc/securenn/type.cc | 16 + src/libspu/mpc/securenn/type.h | 14 +- src/libspu/mpc/semi2k/arithmetic.cc | 4 +- .../mpc/semi2k/beaver/beaver_interface.h | 3 +- src/libspu/mpc/semi2k/exp.cc | 3 +- src/libspu/mpc/semi2k/protocol.cc | 6 +- src/libspu/mpc/semi2k/protocol_test.cc | 46 +- src/libspu/mpc/semi2k/state.h | 35 +- src/libspu/mpc/semi2k/type.cc | 16 + src/libspu/mpc/semi2k/type.h | 14 +- .../mpc/spdz2k/abprotocol_spdz2k_test.cc | 32 +- src/libspu/mpc/spdz2k/protocol.cc | 2 +- src/libspu/mpc/spdz2k/protocol_ab_test.cc | 12 +- src/libspu/mpc/spdz2k/protocol_api_test.cc | 6 +- src/libspu/mpc/spdz2k/state.h | 8 +- src/libspu/mpc/spdz2k/type.cc | 22 + src/libspu/mpc/spdz2k/type.h | 18 +- src/libspu/mpc/tools/benchmark.h | 2 +- src/libspu/mpc/tools/complexity.cc | 4 +- src/libspu/mpc/utils/tiling_util.h | 12 +- src/libspu/spu.cc | 279 +++++++++ src/libspu/spu.h | 560 ++++++++++++++++++ src/libspu/spu.proto | 2 +- src/libspu/version.h | 2 +- version.bzl | 2 +- 180 files changed, 2591 insertions(+), 1266 deletions(-) create mode 100644 spu/libspu.pyi delete mode 100644 spu/spu_pb2.py create mode 100644 src/libspu/spu.cc create mode 100644 src/libspu/spu.h diff --git a/.vscode/settings.json b/.vscode/settings.json index 02434d073..ee97546d4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -41,5 +41,11 @@ "[python]": { "editor.defaultFormatter": "ms-python.black-formatter" }, - "mlir.server_path": "bazel-bin/libspu/compiler/tools/spu-lsp" + "mlir.server_path": "bazel-bin/libspu/compiler/tools/spu-lsp", + "files.exclude": { + // "**/bazel-*/**": true, + "external":true, + ".cache":true, + "**/__pycache__":true + } } \ No newline at end of file diff --git a/MODULE.bazel b/MODULE.bazel index 8a60b7d76..10cd6feff 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -21,7 +21,7 @@ module( name = "spu", - version = "0.9.4.dev20250123", + version = "0.9.4.dev20250209", compatibility_level = 1, ) @@ -32,13 +32,7 @@ local_path_override( path = "src", ) -bazel_dep(name = "psi") -git_override( - module_name = "psi", - commit = "8ead92f1bb10329c7e7e56d541fecb3dcd47ee03", - remote = "https://github.com/secretflow/psi.git", -) - +bazel_dep(name = "psi", version = "0.6.0.dev250123") bazel_dep(name = "yacl", version = "20241212.0-871832a") bazel_dep(name = "grpc", version = "1.66.0.bcr.3") single_version_override( diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index a20103508..39fe53a2f 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -76,9 +76,8 @@ "https://bcr.bazel.build/modules/boost.detail/1.83.0.bcr.2/MODULE.bazel": "0016cbb9f265ebd4efcbb5f1aff14ccce5754bd98c3d9ff375d3a39f4173f8de", "https://bcr.bazel.build/modules/boost.detail/1.83.0.bcr.2/source.json": "0044cf130a00ccde75322449d2a525e1ca95af2bd61b49873125943c4d3ff7ad", "https://bcr.bazel.build/modules/boost.detail/1.83.0/MODULE.bazel": "9bbd10b5c30d0c17908fd0c94e9a26f4899540d1c1a97088911d1195feb98da9", - "https://bcr.bazel.build/modules/boost.dynamic_bitset/1.83.0.bcr.1/MODULE.bazel": "2093a9d09159b97ada67ed370492f61c775ae7df5018c8ba9c69053ed7813dab", - "https://bcr.bazel.build/modules/boost.dynamic_bitset/1.83.0.bcr.1/source.json": "c8470feff0d94ef0c950dfa6b214cc119bbaae95355e27838b9d295c7b222b21", "https://bcr.bazel.build/modules/boost.dynamic_bitset/1.83.0/MODULE.bazel": "8d2f0d7289219e8fd6391d1c64b1beefdd39ccffdf750ce63b83cccfe81ce9a1", + "https://bcr.bazel.build/modules/boost.dynamic_bitset/1.83.0/source.json": "54e15d39881a0b7217b0480953c158c02e026ba78d329ee6e07fbefb9c31b0b0", "https://bcr.bazel.build/modules/boost.exception/1.83.0.bcr.1/MODULE.bazel": "60d6418bbe6246f71c1ff211c08f187d975fe12a0f6845c64c6127977c51375b", "https://bcr.bazel.build/modules/boost.exception/1.83.0.bcr.1/source.json": "54a9e9984fccf6f49385b28c301ff3af0e7485ac7e5281c5d249f1b5b1225c9c", "https://bcr.bazel.build/modules/boost.function/1.83.0.bcr.1/MODULE.bazel": "b555cccb955cc4bd8d548f81ec29fdd33284ae04afd9ed92b254abe6357a5ed7", @@ -119,9 +118,8 @@ "https://bcr.bazel.build/modules/boost.mpl/1.83.0/MODULE.bazel": "3110db5c9c5000140076797d61592f8abffd50e91e04a153ee24bf47d30e38ec", "https://bcr.bazel.build/modules/boost.multiprecision/1.83.0/MODULE.bazel": "4746848bf172eec5c9f9fc02dd812178b9d6d05d5f4ffa40859759553a326472", "https://bcr.bazel.build/modules/boost.multiprecision/1.83.0/source.json": "24d1a97beaa7eed8e2cf5c688f8c7e531dbf65911c498075b32d89dee4ea4c16", - "https://bcr.bazel.build/modules/boost.numeric_conversion/1.83.0.bcr.1/MODULE.bazel": "1920d71f656a0433947579a86c22e0dd897816eadee3ae20d5c51dab9505deb3", - "https://bcr.bazel.build/modules/boost.numeric_conversion/1.83.0.bcr.1/source.json": "fc35a1dbb8810812f05fa54667f0edeac142f5fc144f9426236694a1d01e875c", "https://bcr.bazel.build/modules/boost.numeric_conversion/1.83.0/MODULE.bazel": "34fbcfe4eab607de5285b543f63142042f98c0895d765476f251c7bb45acf33d", + "https://bcr.bazel.build/modules/boost.numeric_conversion/1.83.0/source.json": "1bd2a8a7fd4595a0c9050ca5d7a4eedb231d2dfdd95f3255c80d564393c8efa0", "https://bcr.bazel.build/modules/boost.optional/1.83.0.bcr.1/MODULE.bazel": "971e786a19e31dfbd8d311dff3bbffc7c1e16f348e019dfca50620fd9c475092", "https://bcr.bazel.build/modules/boost.optional/1.83.0.bcr.1/source.json": "486e1173f8b24120078f072f434a79fb3aa8c6b0dafbc19c56b7780ea2336595", "https://bcr.bazel.build/modules/boost.optional/1.83.0/MODULE.bazel": "d2d9c9afa139aa075212e451570aee269d73d5ccda31ab0ab33a9a4b406c0cf7", @@ -131,9 +129,8 @@ "https://bcr.bazel.build/modules/boost.preprocessor/1.83.0.bcr.1/MODULE.bazel": "edb9fb7900ea7002cbefffd97302b071d7cbd8f948b51c7b1a75043bd2985eba", "https://bcr.bazel.build/modules/boost.preprocessor/1.83.0.bcr.1/source.json": "69dc4f6fc76305c21c4a651c94ccfdc8a76d8fbae1151e7c1d1a4599dffc0f03", "https://bcr.bazel.build/modules/boost.preprocessor/1.83.0/MODULE.bazel": "5d1096729ebd16d2679c798110c0896720be23959c59afa36547d04815e255c8", - "https://bcr.bazel.build/modules/boost.random/1.83.0.bcr.1/MODULE.bazel": "cabe3ba820c9588a9ca22548b88ad8c4b307be085691b286ff0dae8a46b25fa0", - "https://bcr.bazel.build/modules/boost.random/1.83.0.bcr.1/source.json": "5cc5c8c13e525c5b1c12ea5b31359a16bfaf25b750c0601c0624b9342872e05d", "https://bcr.bazel.build/modules/boost.random/1.83.0/MODULE.bazel": "e7b05549fc0fad578911a63c05ffe5c6102cb7bc399d8cc762aa3d5ad73ba30c", + "https://bcr.bazel.build/modules/boost.random/1.83.0/source.json": "be53b27e24ea411296b92b29cab9c93aa1acfdd09d2cac8f5e0416beb0e12d57", "https://bcr.bazel.build/modules/boost.range/1.83.0.bcr.1/MODULE.bazel": "136d623462d1d5c7cf79df83b5ce17a8582a92abb116da9d88c5e5594e5a7d92", "https://bcr.bazel.build/modules/boost.range/1.83.0.bcr.1/source.json": "f99062101034f19d9bf2bef3e07cc99bf192640b72560663227e5375efd1e144", "https://bcr.bazel.build/modules/boost.range/1.83.0/MODULE.bazel": "ceba0feb376949eecb77944bc37ac434261dede457fc449958164fca5d1430db", @@ -146,14 +143,11 @@ "https://bcr.bazel.build/modules/boost.static_assert/1.83.0.bcr.1/MODULE.bazel": "2b605adc483c6241865f1e862437331bc6f56c0d376769908b70ba18d3da1f07", "https://bcr.bazel.build/modules/boost.static_assert/1.83.0.bcr.1/source.json": "a0eac8de976fff7efdf498933d7494df30eff471c51c1edfc822007069697ed7", "https://bcr.bazel.build/modules/boost.static_assert/1.83.0/MODULE.bazel": "680325e3252ae8306555bcf0539d16dcf9ccf9656d8781dfa3449a554d8da016", - "https://bcr.bazel.build/modules/boost.system/1.83.0.bcr.1/MODULE.bazel": "5f905d0fbb1ce99231f3fa278b2e5999aa7395c6393ac42d479ae21824adf03f", - "https://bcr.bazel.build/modules/boost.system/1.83.0.bcr.1/source.json": "0676ab63c01c5ddf731a5cf54667ffc6560e9fb52401a2a9ac6a10c5a9909019", "https://bcr.bazel.build/modules/boost.system/1.83.0/MODULE.bazel": "76354b72be5998bb286e9a4d8e439a04c5b29ebd0e51bfe669258dcd5c3a8c4f", + "https://bcr.bazel.build/modules/boost.system/1.83.0/source.json": "3fbc84c45cbf732cde3913371a077e22ecd3f0096cffb003c24a9ac304988ee4", "https://bcr.bazel.build/modules/boost.throw_exception/1.83.0.bcr.1/MODULE.bazel": "b757c832f5f5f818d87c9eaa993d3eb211554197321c3edf641e2c8821cf19c2", "https://bcr.bazel.build/modules/boost.throw_exception/1.83.0.bcr.1/source.json": "c752d584840e9183141f9d53f07f2051016c16771a973cdd1487f9585980c2e5", "https://bcr.bazel.build/modules/boost.throw_exception/1.83.0/MODULE.bazel": "5df92502378293277ca48837e41f33805ede9e6165acefbf83d96b861919e56e", - "https://bcr.bazel.build/modules/boost.tti/1.83.0.bcr.1/MODULE.bazel": "86dd0d443379e67bb41e9b8c9097d652699ddfc0986bd2fb0462f6f5294ee84d", - "https://bcr.bazel.build/modules/boost.tti/1.83.0.bcr.1/source.json": "660900b6e3615af5b222cf699ff220787d0550e380e28408e475a6a0d354d794", "https://bcr.bazel.build/modules/boost.tuple/1.83.0.bcr.1/MODULE.bazel": "1d540b5efd3b65eeabd3621e5187a799e21bfa9ffc6afd7d4ad307cc4a27a6d4", "https://bcr.bazel.build/modules/boost.tuple/1.83.0.bcr.1/source.json": "7aa33ec2aaae45605049ea0ec1c1de9517a1e1278a22d0a521521d4023f9ad87", "https://bcr.bazel.build/modules/boost.tuple/1.83.0/MODULE.bazel": "96640d5e7abec507a5dd63a9f8a6f90e45bc02d7fed6a0fb51e5e869bba9fecd", @@ -168,14 +162,10 @@ "https://bcr.bazel.build/modules/boost.utility/1.83.0.bcr.1/MODULE.bazel": "1346dc27d6c8b7ced10896224ed3e406adac3fd79c8450d78c291228f1b9075d", "https://bcr.bazel.build/modules/boost.utility/1.83.0.bcr.1/source.json": "15636369b5452784e7bd04f7ae52c751e591dd4ebb852688fccc66234d452929", "https://bcr.bazel.build/modules/boost.utility/1.83.0/MODULE.bazel": "e122ee2a63d4e76dec8d2f81b13f95b7638fcbcd15f752610a3343f13bdb97fd", - "https://bcr.bazel.build/modules/boost.uuid/1.83.0.bcr.1/MODULE.bazel": "0ec51572f062cfb4795cf57c84a16b1b61890954dead42dd55e8616e09159c37", - "https://bcr.bazel.build/modules/boost.uuid/1.83.0.bcr.1/source.json": "d53f1653d8c2062f42884311de23220573e287b1eb224c465845357ded8c5a89", - "https://bcr.bazel.build/modules/boost.variant2/1.83.0.bcr.1/MODULE.bazel": "c60baa3b8923712a156197ffaf5cf9972bf35e44d00a90f7019a06761f391d3e", - "https://bcr.bazel.build/modules/boost.variant2/1.83.0.bcr.1/source.json": "041f94707bd509bc1f93782ccb6c38d491ac0dca25d26011f2047639d3a2bcd1", "https://bcr.bazel.build/modules/boost.variant2/1.83.0/MODULE.bazel": "483c9fd260c0ff24177d8c54324dcff8784783ca37a710f90db8e7664b20fa8d", - "https://bcr.bazel.build/modules/boost.winapi/1.83.0.bcr.1/MODULE.bazel": "faf78b50dae672a38b77db545a460428cfe47a8d79466455ef397d76037e9e40", - "https://bcr.bazel.build/modules/boost.winapi/1.83.0.bcr.1/source.json": "0957b4dabe425e7f9d8d02db63969b808a280c11388ad7814e50b0779ae592cc", + "https://bcr.bazel.build/modules/boost.variant2/1.83.0/source.json": "7898383e5af63e9813d19acdc7557a9ee0f3385fa53cbf7239882dd800dccbc2", "https://bcr.bazel.build/modules/boost.winapi/1.83.0/MODULE.bazel": "2521094214a0182fea27fb2551be0447611b0422fba98ab5cf66fe6fd0d00809", + "https://bcr.bazel.build/modules/boost.winapi/1.83.0/source.json": "9ea22a6d3bd3e6512b4b469d0805add0a6df5d2e83da13becf4b08b5b108a548", "https://bcr.bazel.build/modules/boringssl/0.0.0-20230215-5c22014/MODULE.bazel": "4b03dc0d04375fa0271174badcd202ed249870c8e895b26664fd7298abea7282", "https://bcr.bazel.build/modules/boringssl/0.0.0-20230215-5c22014/source.json": "f90873cd3d891bb63ece55a527d97366da650f84c79c2109bea29c17629bee20", "https://bcr.bazel.build/modules/brotli/1.1.0/MODULE.bazel": "3b5b90488995183419c4b5c9b063a164f6c0bc4d0d6b40550a612a5e860cc0fe", @@ -227,6 +217,8 @@ "https://bcr.bazel.build/modules/libpfm/4.11.0/source.json": "caaffb3ac2b59b8aac456917a4ecf3167d40478ee79f15ab7a877ec9273937c9", "https://bcr.bazel.build/modules/lz4/1.9.4/MODULE.bazel": "e3d307b1d354d70f6c809167eafecf5d622c3f27e3971ab7273410f429c7f83a", "https://bcr.bazel.build/modules/lz4/1.9.4/source.json": "233f0bdfc21f254e3dda14683ddc487ca68c6a3a83b7d5db904c503f85bd089b", + "https://bcr.bazel.build/modules/magic_enum/0.9.6/MODULE.bazel": "2b8db5bbd5d456dfb1e05cafd4a572374d461ffd2e0bd6970b9060dca2200618", + "https://bcr.bazel.build/modules/magic_enum/0.9.6/source.json": "abac9e9c84a47db89960a4c5a585d607bdfe51a60d7e3285dfe1a3dca50d1107", "https://bcr.bazel.build/modules/msgpack-c/6.1.0/MODULE.bazel": "2822ba864146468b3128216ad416f8b39b511395e88d896d472c9c6b30b1ceb2", "https://bcr.bazel.build/modules/msgpack-c/6.1.0/source.json": "b412dd4c8290ea0cce122616076e62ffe1b0799cebd6422608c407608193c1c9", "https://bcr.bazel.build/modules/nlohmann_json/3.11.3/MODULE.bazel": "87023db2f55fc3a9949c7b08dc711fae4d4be339a80a99d04453c4bb3998eefc", @@ -414,7 +406,6 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.describe/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.detail/1.83.0.bcr.2/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.detail/1.83.0/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.dynamic_bitset/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.dynamic_bitset/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.exception/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.function/1.83.0.bcr.1/MODULE.bazel": "not found", @@ -441,7 +432,6 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.mpl/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.mpl/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.multiprecision/1.83.0/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.numeric_conversion/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.numeric_conversion/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.optional/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.optional/1.83.0/MODULE.bazel": "not found", @@ -449,7 +439,6 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.predef/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.preprocessor/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.preprocessor/1.83.0/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.random/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.random/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.range/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.range/1.83.0/MODULE.bazel": "not found", @@ -459,11 +448,9 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.smart_ptr/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.static_assert/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.static_assert/1.83.0/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.system/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.system/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.throw_exception/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.throw_exception/1.83.0/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.tti/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.tuple/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.tuple/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.type_traits/1.83.0.bcr.1/MODULE.bazel": "not found", @@ -473,10 +460,7 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.unordered/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.utility/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.utility/1.83.0/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.uuid/1.83.0.bcr.1/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.variant2/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.variant2/1.83.0/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.winapi/1.83.0.bcr.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boost.winapi/1.83.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/boringssl/0.0.0-20230215-5c22014/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/brotli/1.1.0/MODULE.bazel": "not found", @@ -540,6 +524,7 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/log4cplus/2.1.1/MODULE.bazel": "76862ff5868200f79b0d5ba641fe9eab2dba35324b4ec21559572568148348cc", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/log4cplus/2.1.1/source.json": "67bb02b692db0c21fb73ee30b69f06a481cca7d1eab084a003db272b735a09d5", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/lz4/1.9.4/MODULE.bazel": "not found", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/magic_enum/0.9.6/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/mcl/1.99/MODULE.bazel": "e2bf3654186853610a74833e398fc3b6de6d9ccbe8fa67eaa3ae58d3344940ef", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/mcl/1.99/source.json": "d38d4c7dbd9fb31bcabcc55c0336d82044828b27548c928389cdb6fba05029bd", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/msgpack-c/6.1.0/MODULE.bazel": "not found", @@ -563,6 +548,8 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.9/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/prometheus-cpp/1.2.4/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/protobuf/27.3/MODULE.bazel": "not found", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/psi/0.6.0.dev250123/MODULE.bazel": "50a6b02c8227fdd1c555fb578759e796b277930a0be8a81bd060c3681b5d2d97", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/psi/0.6.0.dev250123/source.json": "7d1c793f60df15d3f65e640faceb933f7dcd4a686104b862e7442e77be1a01f9", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pybind11_bazel/2.11.1.bzl.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pybind11_bazel/2.11.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pybind11_bazel/2.12.0/MODULE.bazel": "not found", @@ -960,41 +947,6 @@ ] } }, - "@@grpc~//bazel:grpc_python_deps.bzl%grpc_python_deps_ext": { - "general": { - "bzlTransitiveDigest": "I1aLu6/WXl6aoKVzpM9MA+NKV6ciLTR8aaO7bH7eQmM=", - "usagesDigest": "mC5Q6fSQ6w2MmjAgOJ9ywsgwtICfbdjQHtboeGbueQ4=", - "recordedFileInputs": {}, - "recordedDirentsInputs": {}, - "envVariables": {}, - "generatedRepoSpecs": { - "cython": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "build_file": "@@grpc~//third_party:cython.BUILD", - "sha256": "a2da56cc22be823acf49741b9aa3aa116d4f07fa8e8b35a3cb08b8447b37c607", - "strip_prefix": "cython-0.29.35", - "urls": [ - "https://github.com/cython/cython/archive/0.29.35.tar.gz" - ] - } - } - }, - "recordedRepoMappingEntries": [ - [ - "grpc~", - "bazel_tools", - "bazel_tools" - ], - [ - "grpc~", - "com_github_grpc_grpc", - "grpc~" - ] - ] - } - }, "@@platforms//host:extension.bzl%host_platform": { "general": { "bzlTransitiveDigest": "xelQcPZH8+tmuOHVjL9vDxMnnQNMlwj0SlvgoqBkm4U=", @@ -1014,8 +966,8 @@ }, "@@psi~//bazel:defs.bzl%non_module_dependencies": { "general": { - "bzlTransitiveDigest": "iZicro3ric+m7Yx6rZ4AMTZFmI5E1fKjWlTYtccq0c0=", - "usagesDigest": "YazK0WcOM6F8vKHNAJ4v0sbDZtdcCet+BI71Sqswmg8=", + "bzlTransitiveDigest": "Lof+AoGR1XS+eNcPLWJjUy4ZzoavywKWa/MBlWLYpK8=", + "usagesDigest": "142x+mCv0x6bYMxZuDGK5hwEWv5b5QOMfJWGct/Ctls=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, @@ -1684,7 +1636,7 @@ "@@spulib~//bazel:defs.bzl%non_module_dependencies": { "general": { "bzlTransitiveDigest": "JT8ZLEUdrYXN19gijrHtztFq/cEAhJlRlNjhtQUlDIE=", - "usagesDigest": "rAVxfb9Rb1soDTUGU4hQlLVbAbaoneI/pDfev4cmTz4=", + "usagesDigest": "H/unVLwVf9lkXyD6DNR/nT2UydRB+nHhr+iO46XfbSk=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, diff --git a/examples/cpp/utils.cc b/examples/cpp/utils.cc index 4ff45b94d..fe8fb787e 100644 --- a/examples/cpp/utils.cc +++ b/examples/cpp/utils.cc @@ -19,8 +19,6 @@ #include "libspu/core/config.h" -#include "libspu/spu.pb.h" - llvm::cl::opt Parties( "parties", llvm::cl::init("127.0.0.1:61530,127.0.0.1:61531"), llvm::cl::desc("server list, format: host1:port1[,host2:port2, ...]")); @@ -52,13 +50,13 @@ std::unique_ptr MakeSPUContext() { auto lctx = MakeLink(Parties.getValue(), Rank.getValue()); spu::RuntimeConfig config; - config.set_protocol(static_cast(ProtocolKind.getValue())); - config.set_field(static_cast(Field.getValue())); + config.protocol = static_cast(ProtocolKind.getValue()); + config.field = static_cast(Field.getValue()); populateRuntimeConfig(config); - config.set_enable_action_trace(EngineTrace.getValue()); - config.set_enable_type_checker(EngineTrace.getValue()); + config.enable_action_trace = EngineTrace.getValue(); + config.enable_type_checker = EngineTrace.getValue(); return std::make_unique(config, lctx); } diff --git a/examples/python/ir_dump/ir_dump.py b/examples/python/ir_dump/ir_dump.py index 4e85eacec..98230ea1b 100644 --- a/examples/python/ir_dump/ir_dump.py +++ b/examples/python/ir_dump/ir_dump.py @@ -28,7 +28,7 @@ import jax.numpy as jnp import numpy as np -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.distributed as ppd logging.basicConfig(level=logging.INFO) @@ -48,7 +48,7 @@ dump_path = os.path.join(os.path.expanduser("~"), args.dir) logging.info(f"Dump path: {dump_path}") # refer to spu.proto for more detailed configuration -copts = spu_pb2.CompilerOptions() +copts = libspu.CompilerOptions() copts.enable_pretty_print = True copts.pretty_print_dump_dir = dump_path copts.xla_pp_kind = 2 diff --git a/examples/python/ml/flax_llama7b/flax_llama7b.py b/examples/python/ml/flax_llama7b/flax_llama7b.py index 1feb28a3d..a81a48b7e 100644 --- a/examples/python/ml/flax_llama7b/flax_llama7b.py +++ b/examples/python/ml/flax_llama7b/flax_llama7b.py @@ -31,7 +31,7 @@ from transformers import LlamaTokenizer import spu.intrinsic as intrinsic -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.distributed as ppd parser = argparse.ArgumentParser(description='distributed driver.') @@ -46,7 +46,7 @@ ppd.init(conf["nodes"], conf["devices"]) -copts = spu_pb2.CompilerOptions() +copts = libspu.CompilerOptions() copts.enable_pretty_print = False copts.xla_pp_kind = 2 # enable x / broadcast(y) -> x * broadcast(1/y) diff --git a/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py b/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py index 1249d3740..0349fd217 100644 --- a/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py +++ b/examples/python/ml/flax_llama7b_split/flax_llama7b_split.py @@ -39,7 +39,7 @@ from flax.linen.linear import Array from transformers import LlamaTokenizer -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.distributed as ppd parser = argparse.ArgumentParser(description='distributed driver.') @@ -59,7 +59,7 @@ ppd.init(conf["nodes"], conf["devices"]) -copts = spu_pb2.CompilerOptions() +copts = libspu.CompilerOptions() copts.enable_pretty_print = False copts.xla_pp_kind = 2 # enable x / broadcast(y) -> x * broadcast(1/y) diff --git a/examples/python/ml/flax_whisper/flax_whisper.py b/examples/python/ml/flax_whisper/flax_whisper.py index b6c429ef9..cf9b221b5 100644 --- a/examples/python/ml/flax_whisper/flax_whisper.py +++ b/examples/python/ml/flax_whisper/flax_whisper.py @@ -27,7 +27,7 @@ from transformers import FlaxWhisperForConditionalGeneration, WhisperProcessor import spu.utils.distributed as ppd -from spu import spu_pb2 +from spu import libspu parser = argparse.ArgumentParser(description='distributed driver.') parser.add_argument( @@ -68,7 +68,7 @@ def run_on_spu(): inputs_ids = processor(ds[0]["audio"]["array"], return_tensors="np") # Enable rewrite for better performance - copts = spu_pb2.CompilerOptions() + copts = libspu.CompilerOptions() copts.enable_optimize_denominator_with_broadcast = True input_ids = ppd.device("P1")(lambda x: x)(inputs_ids.input_features) diff --git a/experimental/squirrel/objectives.cc b/experimental/squirrel/objectives.cc index baf5f08dd..46567597e 100644 --- a/experimental/squirrel/objectives.cc +++ b/experimental/squirrel/objectives.cc @@ -291,11 +291,11 @@ namespace { spu::SPUContext* ctx, const spu::Value& _x, float threshold, spu::FieldType working_ft) { namespace sk = spu::kernel; - auto src_field = ctx->config().field(); + auto src_field = ctx->config().field; spu::Value x(CastRing(_x.data(), working_ft), _x.dtype()); // FIXME(lwj): dirty hack - const_cast(&ctx->config())->set_field(working_ft); + const_cast(&ctx->config())->field = working_ft; ctx->getState()->setField(working_ft); const auto ONE = sk::hal::_constant(ctx, 1, x.shape()); @@ -314,7 +314,7 @@ namespace { auto is_too_large = sk::hal::_xor( ctx, True, sk::hal::_or(ctx, is_neg, is_inside_range)); // x > t // FIXME(lwj): dirty hack - const_cast(&ctx->config())->set_field(src_field); + const_cast(&ctx->config())->field = src_field; ctx->getState()->setField(src_field); is_neg = spu::Value(CastRing(is_neg.data(), src_field), spu::DT_I1); diff --git a/experimental/squirrel/objectives_test.cc b/experimental/squirrel/objectives_test.cc index b2d042b88..46742217c 100644 --- a/experimental/squirrel/objectives_test.cc +++ b/experimental/squirrel/objectives_test.cc @@ -110,9 +110,9 @@ TEST_P(ObjectivesTest, MaxGain) { spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { spu::RuntimeConfig rt_config; - rt_config.set_protocol(ProtocolKind::REF2K); - rt_config.set_field(field); - rt_config.set_fxp_fraction_bits(16); + rt_config.protocol = ProtocolKind::REF2K; + rt_config.field = field; + rt_config.fxp_fraction_bits = 16; auto _ctx = std::make_unique(rt_config, lctx); auto ctx = _ctx.get(); @@ -170,13 +170,12 @@ TEST_P(ObjectivesTest, Logistic) { spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { spu::RuntimeConfig rt_config; - rt_config.set_protocol(ProtocolKind::CHEETAH); - rt_config.mutable_cheetah_2pc_config()->set_ot_kind( - CheetahOtKind::YACL_Softspoken); - rt_config.set_field(field); - rt_config.set_fxp_fraction_bits(17); - rt_config.set_enable_hal_profile(true); - rt_config.set_enable_pphlo_profile(true); + rt_config.protocol = ProtocolKind::CHEETAH; + rt_config.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken; + rt_config.field = field; + rt_config.fxp_fraction_bits = 17; + rt_config.enable_hal_profile = true; + rt_config.enable_pphlo_profile = true; auto _ctx = std::make_unique(rt_config, lctx); auto ctx = _ctx.get(); @@ -198,7 +197,7 @@ TEST_P(ObjectivesTest, Logistic) { return; } - double fxp = std::pow(2., rt_config.fxp_fraction_bits()); + double fxp = std::pow(2., rt_config.fxp_fraction_bits); double max_err = 0.; for (int64_t i = 0; i < logistic.numel(); ++i) { @@ -241,13 +240,12 @@ TEST_P(ObjectivesTest, Sigmoid) { spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { spu::RuntimeConfig rt_config; - rt_config.set_protocol(ProtocolKind::CHEETAH); - rt_config.mutable_cheetah_2pc_config()->set_ot_kind( - CheetahOtKind::YACL_Softspoken); - rt_config.set_field(field); - rt_config.set_fxp_fraction_bits(17); - rt_config.set_enable_hal_profile(true); - rt_config.set_enable_pphlo_profile(true); + rt_config.protocol = ProtocolKind::CHEETAH; + rt_config.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken; + rt_config.field = field; + rt_config.fxp_fraction_bits = 17; + rt_config.enable_hal_profile = true; + rt_config.enable_pphlo_profile = true; auto _ctx = std::make_unique(rt_config, lctx); auto ctx = _ctx.get(); @@ -269,7 +267,7 @@ TEST_P(ObjectivesTest, Sigmoid) { return; } - double fxp = std::pow(2., rt_config.fxp_fraction_bits()); + double fxp = std::pow(2., rt_config.fxp_fraction_bits); double max_err = 0.; for (int64_t i = 0; i < logistic.numel(); ++i) { diff --git a/experimental/squirrel/squirrel_demo_main.cc b/experimental/squirrel/squirrel_demo_main.cc index 19276119b..850f143e4 100644 --- a/experimental/squirrel/squirrel_demo_main.cc +++ b/experimental/squirrel/squirrel_demo_main.cc @@ -118,11 +118,11 @@ std::unique_ptr MakeSPUContext() { auto lctx = MakeLink(Parties.getValue(), Rank.getValue()); spu::RuntimeConfig config; - config.set_protocol(spu::ProtocolKind::CHEETAH); - config.set_field(static_cast(Field.getValue())); - config.set_fxp_fraction_bits(18); - config.set_fxp_div_goldschmidt_iters(1); - config.set_enable_hal_profile(EngineTrace.getValue()); + config.protocol = spu::ProtocolKind::CHEETAH; + config.field = static_cast(Field.getValue()); + config.fxp_fraction_bits = 18; + config.fxp_div_goldschmidt_iters = 1; + config.enable_hal_profile = EngineTrace.getValue(); auto hctx = std::make_unique(config, lctx); spu::mpc::Factory::RegisterProtocol(hctx.get(), lctx); return hctx; @@ -218,7 +218,7 @@ void RunTest(spu::SPUContext* hctx, squirrel::XGBTreeBuilder& builder, } const int64_t nsamples = dframe.shape(0); - const double fxp = std::pow(2., hctx->config().fxp_fraction_bits()); + const double fxp = std::pow(2., hctx->config().fxp_fraction_bits); SPDLOG_DEBUG("Computing inference on testing set ..."); @@ -293,7 +293,7 @@ int main(int argc, char** argv) { bucket_size, nfeatures, peer_nfeatures); worker->BuildMap(dframe); - worker->Setup(8 * spu::SizeOf(hctx->config().field()), hctx->lctx()); + worker->Setup(8 * spu::SizeOf(hctx->config().field), hctx->lctx()); std::string act = Activation.getValue(); SPU_ENFORCE(act == "log" or act == "sig", "invalid activation type={}", act); @@ -335,7 +335,7 @@ int main(int argc, char** argv) { } // Test on train set - double fxp = std::pow(2., hctx->config().fxp_fraction_bits()); + double fxp = std::pow(2., hctx->config().fxp_fraction_bits); int32_t correct = 0; SPDLOG_DEBUG("Computing inference on training set ..."); for (int64_t i = 0; i < (int64_t)nsamples; ++i) { diff --git a/experimental/squirrel/tree_builder.cc b/experimental/squirrel/tree_builder.cc index 06abbdd51..e9ea46606 100644 --- a/experimental/squirrel/tree_builder.cc +++ b/experimental/squirrel/tree_builder.cc @@ -589,7 +589,7 @@ double XGBTreeBuilder::DEBUG_OpenObjects( Gsum = hal::reveal(ctx, Gsum); Hsum = hal::reveal(ctx, Hsum); weights = hal::reveal(ctx, weights); - const double fxp = std::pow(2., ctx->config().fxp_fraction_bits()); + const double fxp = std::pow(2., ctx->config().fxp_fraction_bits); double object = 0.0; for (int64_t i = 0; i < weights.numel(); ++i) { double G = Gsum.data().at(i) / fxp; @@ -606,7 +606,7 @@ double XGBTreeBuilder::DEBUG_OpenLoss(spu::SPUContext* ctx, using namespace spu::kernel; auto _pred = hal::reveal(ctx, pred); auto _label = hal::reveal(ctx, label); - double fxp = std::pow(2., ctx->config().fxp_fraction_bits()); + double fxp = std::pow(2., ctx->config().fxp_fraction_bits); double loss = 0.; for (int64_t i = 0; i < _pred.numel(); ++i) { double y = _label.data().at(i) / fxp; diff --git a/experimental/squirrel/utils.cc b/experimental/squirrel/utils.cc index 58a98d70e..8a7de3a38 100644 --- a/experimental/squirrel/utils.cc +++ b/experimental/squirrel/utils.cc @@ -127,7 +127,7 @@ spu::Value ArgMax(spu::SPUContext* ctx, const spu::Value& x, int axis, spu::Value MulArithShareWithPrivateBoolean(spu::SPUContext* ctx, const spu::Value& ashr) { - SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH); + SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH); SPU_ENFORCE(ashr.isSecret()); spu::KernelEvalContext kctx(ctx); @@ -143,7 +143,7 @@ spu::Value MulArithShareWithPrivateBoolean(spu::SPUContext* ctx, spu::Value MulArithShareWithPrivateBoolean( spu::SPUContext* ctx, const spu::Value& ashr, absl::Span prv_boolean) { - SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH); + SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH); SPU_ENFORCE(ashr.isSecret()); SPU_ENFORCE_EQ(ashr.numel(), (int64_t)prv_boolean.size()); @@ -160,9 +160,9 @@ spu::Value MulArithShareWithPrivateBoolean( spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx, const spu::Value& arith) { using namespace spu; - SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH); + SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH); spu::KernelEvalContext kctx(ctx); - auto ft = ctx->config().field(); + auto ft = ctx->config().field; auto out = mpc::cheetah::TiledDispatchOTFunc( &kctx, arith.data(), [&](const NdArrayRef& input, @@ -184,11 +184,11 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx, const spu::DataType dtype, const spu::Shape& shape) { using namespace spu; - SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH); + SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH); SPU_ENFORCE_EQ(boolean.size(), (size_t)shape.numel()); spu::KernelEvalContext kctx(ctx); - auto ft = ctx->config().field(); + auto ft = ctx->config().field; auto out = mpc::cheetah::TiledDispatchOTFunc( &kctx, boolean, [&](absl::Span input, @@ -210,10 +210,10 @@ spu::Value MulArithShareWithANDBoolShare(spu::SPUContext* ctx, SPU_ENFORCE(ashr.isSecret()); SPU_ENFORCE_EQ(ashr.numel(), (int64_t)bshr.size()); - SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH); + SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH); spu::KernelEvalContext kctx(ctx); - auto ft = ctx->config().field(); + auto ft = ctx->config().field; int rank = ctx->lctx()->Rank(); auto out = mpc::cheetah::TiledDispatchOTFunc( diff --git a/experimental/squirrel/utils_test.cc b/experimental/squirrel/utils_test.cc index 8cde97fdb..6e81aeaf1 100644 --- a/experimental/squirrel/utils_test.cc +++ b/experimental/squirrel/utils_test.cc @@ -65,9 +65,9 @@ TEST_F(UtilsTest, ReduceSum) { spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { spu::RuntimeConfig rt_config; - rt_config.set_protocol(ProtocolKind::REF2K); - rt_config.set_field(field); - rt_config.set_fxp_fraction_bits(16); + rt_config.protocol = ProtocolKind::REF2K; + rt_config.field = field; + rt_config.fxp_fraction_bits = 16; auto _ctx = std::make_unique(rt_config, lctx); auto* ctx = _ctx.get(); @@ -82,7 +82,7 @@ TEST_F(UtilsTest, ReduceSum) { ASSERT_EQ(expected.numel(), got.numel()); if (lctx->Rank() == 0) { - const double fxp = std::pow(2., rt_config.fxp_fraction_bits()); + const double fxp = std::pow(2., rt_config.fxp_fraction_bits); auto flatten = got.data().reshape({got.numel()}); DISPATCH_ALL_FIELDS(field, [&]() { @@ -117,9 +117,9 @@ TEST_F(UtilsTest, ArgMax) { spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { spu::RuntimeConfig rt_config; - rt_config.set_protocol(ProtocolKind::REF2K); - rt_config.set_field(field); - rt_config.set_fxp_fraction_bits(16); + rt_config.protocol = ProtocolKind::REF2K; + rt_config.field = field; + rt_config.fxp_fraction_bits = 16; auto _ctx = std::make_unique(rt_config, lctx); auto* ctx = _ctx.get(); @@ -170,13 +170,12 @@ TEST_F(UtilsTest, MulA1BV) { spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { spu::RuntimeConfig rt_config; - rt_config.set_protocol(ProtocolKind::CHEETAH); - rt_config.mutable_cheetah_2pc_config()->set_ot_kind( - CheetahOtKind::YACL_Softspoken); - rt_config.set_field(field); - rt_config.set_fxp_fraction_bits(16); - rt_config.set_experimental_enable_colocated_optimization(true); - rt_config.set_enable_hal_profile(true); + rt_config.protocol = ProtocolKind::CHEETAH; + rt_config.field = field; + rt_config.fxp_fraction_bits = 16; + rt_config.experimental_enable_colocated_optimization = true; + rt_config.enable_hal_profile = true; + rt_config.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken; auto _ctx = std::make_unique(rt_config, lctx); auto* ctx = _ctx.get(); @@ -193,7 +192,7 @@ TEST_F(UtilsTest, MulA1BV) { c = hlo::Reveal(ctx, c); if (lctx->Rank() == 0) { - double scale = std::pow(2., rt_config.fxp_fraction_bits()); + double scale = std::pow(2., rt_config.fxp_fraction_bits); for (int64_t i = 0; i < c.numel(); ++i) { if (ind[i]) { ASSERT_NEAR(c.data().at(i) / scale, _x[i], 2. / scale); @@ -232,13 +231,12 @@ TEST_F(UtilsTest, MulA1B_AND_style) { spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { spu::RuntimeConfig rt_config; - rt_config.set_protocol(ProtocolKind::CHEETAH); - rt_config.set_field(field); - rt_config.mutable_cheetah_2pc_config()->set_ot_kind( - CheetahOtKind::YACL_Softspoken); - rt_config.set_fxp_fraction_bits(16); - rt_config.set_experimental_enable_colocated_optimization(true); - rt_config.set_enable_hal_profile(true); + rt_config.protocol = ProtocolKind::CHEETAH; + rt_config.field = field; + rt_config.fxp_fraction_bits = 16; + rt_config.experimental_enable_colocated_optimization = true; + rt_config.enable_hal_profile = true; + rt_config.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken; auto _ctx = std::make_unique(rt_config, lctx); auto* ctx = _ctx.get(); @@ -253,7 +251,7 @@ TEST_F(UtilsTest, MulA1B_AND_style) { c = hlo::Reveal(ctx, c); if (lctx->Rank() == 0) { - double scale = std::pow(2., rt_config.fxp_fraction_bits()); + double scale = std::pow(2., rt_config.fxp_fraction_bits); for (int64_t i = 0; i < c.numel(); ++i) { if (ind[0][i] & ind[1][i]) { ASSERT_NEAR(c.data().at(i) / scale, _x[i], 2. / scale); @@ -291,13 +289,12 @@ TEST_F(UtilsTest, BatchMulA1B_AND_style) { spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { spu::RuntimeConfig rt_config; - rt_config.set_protocol(ProtocolKind::CHEETAH); - rt_config.mutable_cheetah_2pc_config()->set_ot_kind( - CheetahOtKind::YACL_Softspoken); - rt_config.set_field(field); - rt_config.set_fxp_fraction_bits(16); - rt_config.set_experimental_enable_colocated_optimization(true); - rt_config.set_enable_hal_profile(true); + rt_config.protocol = ProtocolKind::CHEETAH; + rt_config.field = field; + rt_config.fxp_fraction_bits = 16; + rt_config.experimental_enable_colocated_optimization = true; + rt_config.enable_hal_profile = true; + rt_config.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken; auto _ctx = std::make_unique(rt_config, lctx); auto* ctx = _ctx.get(); @@ -309,7 +306,7 @@ TEST_F(UtilsTest, BatchMulA1B_AND_style) { c = hlo::Reveal(ctx, c); if (lctx->Rank() == 0) { - double scale = std::pow(2., rt_config.fxp_fraction_bits()); + double scale = std::pow(2., rt_config.fxp_fraction_bits); for (int64_t k = 0; k < c.numel(); k += batch_size) { for (int64_t i = 0; i < batch_size; ++i) { diff --git a/pyrightconfig.json b/pyrightconfig.json index fbcb11539..244479187 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,5 +1,10 @@ { "executionEnvironments": [ {"root": "."} + ], + "exclude":[ + ".cache", + "external", + "**/bazel-*" ] } diff --git a/sml/cluster/tests/BUILD.bazel b/sml/cluster/tests/BUILD.bazel index 413069c74..764dd21d1 100644 --- a/sml/cluster/tests/BUILD.bazel +++ b/sml/cluster/tests/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@spu_pip_dev//:requirements.bzl", "requirement") load("//bazel:spu.bzl", "spu_py_test") package(default_visibility = ["//visibility:public"]) @@ -23,6 +24,6 @@ spu_py_test( "//sml/cluster:kmeans", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/cluster/tests/kmeans_test.py b/sml/cluster/tests/kmeans_test.py index 0503a3db5..2800e9f94 100644 --- a/sml/cluster/tests/kmeans_test.py +++ b/sml/cluster/tests/kmeans_test.py @@ -19,16 +19,14 @@ import numpy as np from sklearn.datasets import make_blobs -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu # type: ignore import spu.utils.simulation as spsim from sml.cluster.kmeans import KMEANS class UnitTests(unittest.TestCase): def test_kmeans(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def proc(x1, x2): x = jnp.concatenate((x1, x2), axis=1) @@ -57,9 +55,7 @@ def load_data(): print("sklearn:\n", model.fit(X).predict(X)) def test_kmeans_kmeans_plus_plus(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) X = jnp.array([[-4, -3, -2, -1], [-4, -3, -2, -1]]).T @@ -96,9 +92,7 @@ def proc(x): np.testing.assert_allclose(result, sk_result, rtol=0, atol=1e-4) def test_kmeans_init_array(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def proc(x, init): model = KMEANS( @@ -123,9 +117,7 @@ def proc(x, init): np.testing.assert_allclose(result, sk_result, rtol=0, atol=1e-4) def test_kmeans_random(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) X = jnp.array([[-4, -3, -2, -1], [-4, -3, -2, -1]]).T diff --git a/sml/decomposition/nmf.py b/sml/decomposition/nmf.py index fff3c6109..d704c70a9 100644 --- a/sml/decomposition/nmf.py +++ b/sml/decomposition/nmf.py @@ -151,9 +151,9 @@ def fit(self, X): ----- (1) To prevent overflow error when using large data sets or get more accurate results, you can modify the definition of simulator as follows: - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, + config = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.ABY3, + field=libspu.FieldType.FM128, fxp_fraction_bits=30, ) sim_aby = spsim.Simulator(3, config) diff --git a/sml/decomposition/pca.py b/sml/decomposition/pca.py index 6c942ed83..3ef8ea7fb 100644 --- a/sml/decomposition/pca.py +++ b/sml/decomposition/pca.py @@ -117,9 +117,9 @@ def fit(self, X): When use rsvd, there are a large number of continuous matrix multiplies inside, which will make the value expand rapidly and overflow, we can solve it in the following ways. Step 0: Modify the definition of simulator as follows: - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, + config = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.ABY3, + field=libspu.FieldType.FM128, fxp_fraction_bits=30, ) sim_aby = spsim.Simulator(3, config) diff --git a/sml/decomposition/tests/BUILD.bazel b/sml/decomposition/tests/BUILD.bazel index ae5784aa7..dcdc2e06d 100644 --- a/sml/decomposition/tests/BUILD.bazel +++ b/sml/decomposition/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,7 +24,7 @@ py_test( "//sml/decomposition:pca", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) @@ -34,6 +35,6 @@ py_test( "//sml/decomposition:nmf", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/decomposition/tests/nmf_test.py b/sml/decomposition/tests/nmf_test.py index 114bd4435..21f2d4504 100644 --- a/sml/decomposition/tests/nmf_test.py +++ b/sml/decomposition/tests/nmf_test.py @@ -19,7 +19,7 @@ import numpy as np from sklearn.decomposition import NMF as SklearnNMF -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim # Add the sml directory to the path @@ -35,9 +35,9 @@ def setUpClass(cls): cls.random_seed = 0 np.random.seed(cls.random_seed) # NMF must use FM128 now, for heavy use of non-linear & matrix operations - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, + config = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.ABY3, + field=libspu.FieldType.FM128, fxp_fraction_bits=30, ) cls.sim = spsim.Simulator(3, config) diff --git a/sml/decomposition/tests/pca_test.py b/sml/decomposition/tests/pca_test.py index 8251537a3..fa578b121 100644 --- a/sml/decomposition/tests/pca_test.py +++ b/sml/decomposition/tests/pca_test.py @@ -21,7 +21,7 @@ from jax import random from sklearn.decomposition import PCA as SklearnPCA -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim # Add the sml directory to the path @@ -39,11 +39,11 @@ def setUpClass(cls): # 1. init sim cls.sim64 = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64 ) - config128 = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, + config128 = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.ABY3, + field=libspu.FieldType.FM128, fxp_fraction_bits=30, ) cls.sim128 = spsim.Simulator(3, config128) diff --git a/sml/ensemble/tests/BUILD.bazel b/sml/ensemble/tests/BUILD.bazel index 5b068b310..a9d789740 100644 --- a/sml/ensemble/tests/BUILD.bazel +++ b/sml/ensemble/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,7 +24,7 @@ py_test( "//sml/ensemble:adaboost", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) @@ -34,6 +35,6 @@ py_test( "//sml/ensemble:forest", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/ensemble/tests/adaboost_test.py b/sml/ensemble/tests/adaboost_test.py index 71f6b92d3..4921c51f4 100644 --- a/sml/ensemble/tests/adaboost_test.py +++ b/sml/ensemble/tests/adaboost_test.py @@ -18,7 +18,7 @@ from sklearn.ensemble import AdaBoostClassifier from sklearn.tree import DecisionTreeClassifier -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu # type: ignore import spu.utils.simulation as spsim from sml.ensemble.adaboost import AdaBoostClassifier as sml_Adaboost from sml.tree.tree import DecisionTreeClassifier as sml_dtc @@ -66,9 +66,7 @@ def load_data(): X, y = new_features[:, ::3], iris_label[:] return X, y - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) X, y = load_data() n_samples, n_features = X.shape diff --git a/sml/ensemble/tests/forest_test.py b/sml/ensemble/tests/forest_test.py index 6c9d3280c..ccbad71de 100644 --- a/sml/ensemble/tests/forest_test.py +++ b/sml/ensemble/tests/forest_test.py @@ -17,7 +17,7 @@ from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu import spu.utils.simulation as spsim from sml.ensemble.forest import RandomForestClassifier as sml_rfc @@ -71,9 +71,7 @@ def load_data(): return X, y # bandwidth and latency only work for docker mode - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) # load mock data X, y = load_data() diff --git a/sml/faq.md b/sml/faq.md index 97ba95bf8..c8818b959 100644 --- a/sml/faq.md +++ b/sml/faq.md @@ -11,9 +11,9 @@ ```python # for simulator - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, # change filed size here + config = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.ABY3, + field=libspu.FieldType.FM128, # change filed size here fxp_fraction_bits=30, # change fxp here ) sim = spsim.Simulator(3, config) diff --git a/sml/feature_selection/tests/BUILD.bazel b/sml/feature_selection/tests/BUILD.bazel index 4a570e135..26d85c6c6 100644 --- a/sml/feature_selection/tests/BUILD.bazel +++ b/sml/feature_selection/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,6 +24,6 @@ py_test( "//sml/feature_selection:univariate_selection", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/feature_selection/tests/chi2_test.py b/sml/feature_selection/tests/chi2_test.py index 51435eab5..35a86140d 100644 --- a/sml/feature_selection/tests/chi2_test.py +++ b/sml/feature_selection/tests/chi2_test.py @@ -17,11 +17,12 @@ import unittest import numpy as np -import spu.spu_pb2 as spu_pb2 -import spu.utils.simulation as spsim from sklearn.datasets import load_iris from sklearn.feature_selection import chi2 as chi2_sklearn +import spu.libspu as libspu +import spu.utils.simulation as spsim + sys.path.append(os.path.join(os.path.dirname(__file__), '../../../')) from sml.feature_selection.univariate_selection import chi2 @@ -29,7 +30,7 @@ class UnitTests(unittest.TestCase): def test_chi2(self): sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) def proc(x, y, num_class, max_iter, compute_p_value): diff --git a/sml/gaussian_process/tests/BUILD.bazel b/sml/gaussian_process/tests/BUILD.bazel index da0c809df..d6dd443bd 100644 --- a/sml/gaussian_process/tests/BUILD.bazel +++ b/sml/gaussian_process/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,6 +24,6 @@ py_test( "//sml/gaussian_process:_gpc", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/gaussian_process/tests/gpc_test.py b/sml/gaussian_process/tests/gpc_test.py index 9155a5b60..05358161b 100644 --- a/sml/gaussian_process/tests/gpc_test.py +++ b/sml/gaussian_process/tests/gpc_test.py @@ -20,7 +20,7 @@ import jax.numpy as jnp from sklearn.datasets import load_iris -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim # Add the library directory to the path @@ -31,7 +31,7 @@ class UnitTests(unittest.TestCase): def test_gpc(self): sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) # Test GaussianProcessClassifier diff --git a/sml/linear_model/tests/BUILD.bazel b/sml/linear_model/tests/BUILD.bazel index 327be63d9..46ec0f425 100644 --- a/sml/linear_model/tests/BUILD.bazel +++ b/sml/linear_model/tests/BUILD.bazel @@ -13,6 +13,8 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip//:requirements.bzl", "requirement") +load("@spu_pip_dev//:requirements.bzl", dev_requirement = "requirement") package(default_visibility = ["//visibility:public"]) @@ -27,8 +29,8 @@ py_test( "//sml/linear_model:sgd_classifier", "//spu:init", "//spu/utils:simulation", - "@spu_pip//jax:pkg", - "@spu_pip_dev//scikit_learn:pkg", + requirement("jax"), + dev_requirement("scikit-learn"), ], ) @@ -39,8 +41,8 @@ py_test( "//sml/linear_model:logistic", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//pandas:pkg", - "@spu_pip_dev//scikit_learn:pkg", + dev_requirement("pandas"), + dev_requirement("scikit-learn"), ], ) @@ -52,7 +54,7 @@ py_test( "//sml/linear_model:ridge", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + dev_requirement("scikit-learn"), ], ) @@ -63,9 +65,9 @@ py_test( "//sml/linear_model:pla", "//spu:init", "//spu/utils:simulation", - "@spu_pip//jax:pkg", - "@spu_pip_dev//pandas:pkg", - "@spu_pip_dev//scikit_learn:pkg", + requirement("jax"), + dev_requirement("pandas"), + dev_requirement("scikit-learn"), ], ) @@ -76,7 +78,7 @@ py_test( "//sml/linear_model:glm", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + dev_requirement("scikit-learn"), ], ) @@ -87,8 +89,8 @@ py_test( "//sml/linear_model:quantile", "//spu:init", "//spu/utils:simulation", - "@spu_pip//jax:pkg", - "@spu_pip_dev//pandas:pkg", - "@spu_pip_dev//scikit_learn:pkg", + requirement("jax"), + dev_requirement("pandas"), + dev_requirement("scikit-learn"), ], ) diff --git a/sml/linear_model/tests/glm_test.py b/sml/linear_model/tests/glm_test.py index f42b09b6e..81f81312b 100644 --- a/sml/linear_model/tests/glm_test.py +++ b/sml/linear_model/tests/glm_test.py @@ -23,7 +23,7 @@ _GeneralizedLinearRegressor as std__GeneralizedLinearRegressor, ) -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim from sml.linear_model.glm import ( GammaRegressor, @@ -61,7 +61,7 @@ def generate_data(): X, y, coef, sample_weight = generate_data() exp_y = jnp.exp(y) round_exp_y = jnp.round(exp_y) -sim = spsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128) +sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128) def accuracy_test(model, std_model, y, coef, num=5): diff --git a/sml/linear_model/tests/logistic_test.py b/sml/linear_model/tests/logistic_test.py index e440d4b83..12293629f 100644 --- a/sml/linear_model/tests/logistic_test.py +++ b/sml/linear_model/tests/logistic_test.py @@ -21,7 +21,7 @@ from sklearn.metrics import roc_auc_score from sklearn.preprocessing import MinMaxScaler -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim # Add the library directory to the path @@ -32,9 +32,7 @@ class UnitTests(unittest.TestCase): @staticmethod def load_data(multi_class="binary"): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) # Create dataset if multi_class == "binary": X, y = load_breast_cancer(return_X_y=True, as_frame=True) diff --git a/sml/linear_model/tests/pla_test.py b/sml/linear_model/tests/pla_test.py index 1224fd4f0..f8b1f3011 100644 --- a/sml/linear_model/tests/pla_test.py +++ b/sml/linear_model/tests/pla_test.py @@ -20,16 +20,14 @@ import sklearn.linear_model as sk from sklearn.datasets import load_iris -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu # type: ignore import spu.utils.simulation as spsim from sml.linear_model.pla import Perceptron class UnitTests(unittest.TestCase): def test_pla(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def proc(x, y): model = Perceptron( diff --git a/sml/linear_model/tests/quantile_test.py b/sml/linear_model/tests/quantile_test.py index d9d12f1f8..461810748 100644 --- a/sml/linear_model/tests/quantile_test.py +++ b/sml/linear_model/tests/quantile_test.py @@ -17,7 +17,7 @@ import jax.numpy as jnp from sklearn.linear_model import QuantileRegressor as SklearnQuantileRegressor -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu import spu.utils.simulation as spsim from sml.linear_model.quantile import QuantileRegressor as SmlQuantileRegressor @@ -58,9 +58,7 @@ def generate_data(): return X, y # bandwidth and latency only work for docker mode - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) X, y = generate_data() diff --git a/sml/linear_model/tests/ridge_test.py b/sml/linear_model/tests/ridge_test.py index 9fb01a52b..c34f743d9 100644 --- a/sml/linear_model/tests/ridge_test.py +++ b/sml/linear_model/tests/ridge_test.py @@ -17,7 +17,7 @@ from sklearn.linear_model import Ridge as skRidge import examples.python.utils.dataset_utils as dsutil -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu import spu.utils.simulation as spsim from sml.linear_model.ridge import Ridge @@ -27,9 +27,7 @@ def test_ridge(self): solver_list = ['cholesky', 'svd'] print(f"solver_list={solver_list}") - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def proc(x1, x2, y, solver): model = Ridge(alpha=1.0, max_iter=100, solver=solver) diff --git a/sml/linear_model/tests/sgd_classifier_test.py b/sml/linear_model/tests/sgd_classifier_test.py index 44f1f5deb..f7452b369 100644 --- a/sml/linear_model/tests/sgd_classifier_test.py +++ b/sml/linear_model/tests/sgd_classifier_test.py @@ -19,16 +19,14 @@ # TODO: unify this. import examples.python.utils.dataset_utils as dsutil -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu import spu.utils.simulation as spsim from sml.linear_model.sgd_classifier import SGDClassifier class UnitTests(unittest.TestCase): def test_sgd(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def proc(x1, x2, y): model = SGDClassifier( diff --git a/sml/metrics/classification/BUILD.bazel b/sml/metrics/classification/BUILD.bazel index 446294dbf..d97a7fc6d 100644 --- a/sml/metrics/classification/BUILD.bazel +++ b/sml/metrics/classification/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@spu_pip_dev//:requirements.bzl", "requirement") load("//bazel:spu.bzl", "spu_py_library", "spu_py_test") package(default_visibility = ["//visibility:public"]) @@ -39,7 +40,7 @@ spu_py_test( ":classification", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/metrics/classification/classification_test.py b/sml/metrics/classification/classification_test.py index aeda67214..006e6569f 100644 --- a/sml/metrics/classification/classification_test.py +++ b/sml/metrics/classification/classification_test.py @@ -21,7 +21,7 @@ import numpy as np from sklearn import metrics -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim # add ops dir to the path @@ -44,9 +44,7 @@ class UnitTests(unittest.TestCase): def test_auc(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def bin_count(y_true, y_pred, bin_size): thresholds = equal_obs(y_pred, bin_size) @@ -89,7 +87,7 @@ def digitize(y_pred, thresholds): def test_classification(self): sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) def proc( @@ -159,7 +157,7 @@ def check(spu_result, sk_result): def test_average_precision_score(self): sim = spsim.Simulator.simple( - 2, spu_pb2.ProtocolKind.SEMI2K, spu_pb2.FieldType.FM64 + 2, libspu.ProtocolKind.SEMI2K, libspu.FieldType.FM64 ) def proc(y_true, y_score, **kwargs): diff --git a/sml/metrics/regression/BUILD.bazel b/sml/metrics/regression/BUILD.bazel index 7e9ef59c4..5a00aed64 100644 --- a/sml/metrics/regression/BUILD.bazel +++ b/sml/metrics/regression/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@spu_pip_dev//:requirements.bzl", "requirement") load("//bazel:spu.bzl", "spu_py_binary", "spu_py_library", "spu_py_test") package(default_visibility = ["//visibility:public"]) @@ -28,7 +29,7 @@ spu_py_test( ":regression", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/metrics/regression/regression_test.py b/sml/metrics/regression/regression_test.py index dbd1b7479..b1eeb3ea8 100644 --- a/sml/metrics/regression/regression_test.py +++ b/sml/metrics/regression/regression_test.py @@ -20,25 +20,25 @@ import numpy as np from sklearn import metrics -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim # add ops dir to the path sys.path.append(os.path.join(os.path.dirname(__file__), '../../../')) from sml.metrics.regression.regression import ( + d2_tweedie_score, explained_variance_score, - mean_squared_error, - mean_poisson_deviance, mean_gamma_deviance, - d2_tweedie_score, + mean_poisson_deviance, + mean_squared_error, ) class UnitTests(unittest.TestCase): def test_d2_tweedie_score(self): sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) power_list = [-1, 0, 1, 2, 3] @@ -63,7 +63,7 @@ def test_d2_tweedie_score(self): def test_explained_variance_score(self): sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) weight_list = [ @@ -90,7 +90,7 @@ def test_explained_variance_score(self): def test_mean_squared_error(self): sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) weight_list = [ @@ -113,7 +113,7 @@ def test_mean_squared_error(self): def test_mean_poisson_deviance(self): sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) weight_list = [ @@ -136,7 +136,7 @@ def test_mean_poisson_deviance(self): def test_mean_gamma_deviance(self): sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) weight_list = [ diff --git a/sml/naive_bayes/tests/BUILD.bazel b/sml/naive_bayes/tests/BUILD.bazel index 8d4680fc8..0e4efcc75 100644 --- a/sml/naive_bayes/tests/BUILD.bazel +++ b/sml/naive_bayes/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,6 +24,6 @@ py_test( "//sml/naive_bayes:gnb", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/naive_bayes/tests/gnb_test.py b/sml/naive_bayes/tests/gnb_test.py index 3af454008..8b93bd004 100644 --- a/sml/naive_bayes/tests/gnb_test.py +++ b/sml/naive_bayes/tests/gnb_test.py @@ -21,7 +21,7 @@ from sklearn import datasets from sklearn.naive_bayes import GaussianNB as SklearnGaussianNB -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim # Add the sml directory to the path @@ -37,11 +37,11 @@ def setUpClass(cls): # 1. init sim cls.sim64 = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64 ) - config128 = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, + config128 = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.ABY3, + field=libspu.FieldType.FM128, fxp_fraction_bits=30, ) cls.sim128 = spsim.Simulator(3, config128) diff --git a/sml/neighbors/tests/BUILD.bazel b/sml/neighbors/tests/BUILD.bazel index c294ab560..93467c547 100644 --- a/sml/neighbors/tests/BUILD.bazel +++ b/sml/neighbors/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,6 +24,6 @@ py_test( "//sml/neighbors:knn", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/neighbors/tests/knn_test.py b/sml/neighbors/tests/knn_test.py index 59ff6dd9a..cf4a5bf85 100644 --- a/sml/neighbors/tests/knn_test.py +++ b/sml/neighbors/tests/knn_test.py @@ -20,7 +20,7 @@ import numpy as np from sklearn.neighbors import KNeighborsClassifier -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim # Add the sml directory to the path @@ -31,9 +31,7 @@ class UnitTests(unittest.TestCase): def test_knn(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) # Test fit_predict def proc_predict( diff --git a/sml/preprocessing/tests/BUILD.bazel b/sml/preprocessing/tests/BUILD.bazel index ffb24ef30..f0a2360a5 100644 --- a/sml/preprocessing/tests/BUILD.bazel +++ b/sml/preprocessing/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,6 +24,6 @@ py_test( "//sml/preprocessing", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/preprocessing/tests/preprocessing_test.py b/sml/preprocessing/tests/preprocessing_test.py index 19fffec44..093d27bf9 100644 --- a/sml/preprocessing/tests/preprocessing_test.py +++ b/sml/preprocessing/tests/preprocessing_test.py @@ -18,7 +18,7 @@ import numpy as np from sklearn import preprocessing -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim from sml.preprocessing.preprocessing import ( Binarizer, @@ -32,9 +32,7 @@ class UnitTests(unittest.TestCase): def test_labelbinarizer(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def labelbinarize(X, Y): transformer = LabelBinarizer(neg_label=-2, pos_label=3) @@ -63,9 +61,7 @@ def labelbinarize(X, Y): ) def test_labelbinarizer_binary(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def labelbinarize(X): transformer = LabelBinarizer() @@ -91,9 +87,7 @@ def labelbinarize(X): ) def test_labelbinarizer_unseen(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def labelbinarize(X, Y): transformer = LabelBinarizer() @@ -114,9 +108,7 @@ def labelbinarize(X, Y): np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) def test_binarizer(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def binarize(X): transformer = Binarizer() @@ -134,9 +126,7 @@ def binarize(X): np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=0) def test_normalizer(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def normalize_l1(X): transformer = Normalizer(norm="l1") @@ -174,9 +164,7 @@ def normalize_max(X): np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) def test_minmaxscaler(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def minmaxscale(X, Y): transformer = MinMaxScaler() @@ -201,9 +189,7 @@ def minmaxscale(X, Y): np.testing.assert_allclose(sk_result_2, spu_result_2, rtol=0, atol=1e-4) def test_minmaxscaler_partial_fit(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def minmaxscale(X): transformer = MinMaxScaler() @@ -241,9 +227,7 @@ def minmaxscale(X): np.testing.assert_allclose(sk_result_max, spu_result_max, rtol=0, atol=1e-4) def test_minmaxscaler_zero_variance(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def minmaxscale(X, X_new): transformer = MinMaxScaler() @@ -281,9 +265,7 @@ def minmaxscale(X, X_new): ) def test_maxabsscaler(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def maxabsscale(X): transformer = MaxAbsScaler() @@ -302,9 +284,7 @@ def maxabsscale(X): np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4) def test_maxabsscaler_zero_maxabs(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def maxabsscale(X, X_new): transformer = MaxAbsScaler() @@ -344,9 +324,7 @@ def maxabsscale(X, X_new): ) def test_kbinsdiscretizer_uniform(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X): transformer = KBinsDiscretizer(n_bins=5, strategy='uniform') @@ -376,9 +354,7 @@ def kbinsdiscretize(X): ) def test_kbinsdiscretizer_uniform_diverse_n_bins(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X, n_bins): transformer = KBinsDiscretizer( @@ -412,9 +388,7 @@ def kbinsdiscretize(X, n_bins): ) def test_kbinsdiscretizer_uniform_diverse_n_bins_no_vectorize(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) # When you set vectorize to False, diverse_n_bins should be public. def kbinsdiscretize(X): @@ -449,9 +423,7 @@ def kbinsdiscretize(X): ) def test_kbinsdiscretizer_quantile(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X): transformer = KBinsDiscretizer(n_bins=5, strategy='quantile') @@ -482,9 +454,7 @@ def kbinsdiscretize(X): ) def test_kbinsdiscretizer_quantile_diverse_n_bins(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X, n_bins): transformer = KBinsDiscretizer( @@ -519,9 +489,7 @@ def kbinsdiscretize(X, n_bins): ) def test_kbinsdiscretizer_quantile_diverse_n_bins2(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X, n_bins): transformer = KBinsDiscretizer( @@ -556,9 +524,7 @@ def kbinsdiscretize(X, n_bins): ) def test_kbinsdiscretizer_quantile_diverse_n_bins_no_vectorize(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X): transformer = KBinsDiscretizer( @@ -593,9 +559,7 @@ def kbinsdiscretize(X): ) def test_kbinsdiscretizer_quantile_eliminate(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X): transformer = KBinsDiscretizer(n_bins=2, strategy='quantile') @@ -630,9 +594,7 @@ def kbinsdiscretize(X): ) def test_kbinsdiscretizer_quantile_sample_weight(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X, sample_weight): transformer = KBinsDiscretizer(n_bins=2, strategy='quantile') @@ -673,9 +635,7 @@ def kbinsdiscretize(X, sample_weight): ) def test_kbinsdiscretizer_quantile_sample_weight_diverse_n_bins(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X, n_bins, sample_weight): transformer = KBinsDiscretizer( @@ -721,9 +681,7 @@ def kbinsdiscretize(X, n_bins, sample_weight): ) def test_kbinsdiscretizer_quantile_sample_weight_diverse_n_bins2(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X, n_bins, sample_weight): transformer = KBinsDiscretizer( @@ -770,9 +728,7 @@ def kbinsdiscretize(X, n_bins, sample_weight): ) def test_kbinsdiscretizer_quantile_sample_weight_diverse_n_bins_no_vectorize(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X, sample_weight): transformer = KBinsDiscretizer( @@ -819,9 +775,7 @@ def kbinsdiscretize(X, sample_weight): ) def test_kbinsdiscretizer_kmeans(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X): transformer = KBinsDiscretizer(n_bins=4, strategy='kmeans') @@ -851,9 +805,7 @@ def kbinsdiscretize(X): ) def test_kbinsdiscretizer_kmeans_diverse_n_bins_no_vectorize(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def kbinsdiscretize(X): transformer = KBinsDiscretizer( diff --git a/sml/svm/emulations/svm_emul.py b/sml/svm/emulations/svm_emul.py index 465b67bf7..47c8127be 100644 --- a/sml/svm/emulations/svm_emul.py +++ b/sml/svm/emulations/svm_emul.py @@ -21,7 +21,7 @@ from sklearn.svm import SVC import sml.utils.emulation as emulation -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu # type: ignore from sml.svm.svm import SVM diff --git a/sml/svm/tests/BUILD.bazel b/sml/svm/tests/BUILD.bazel index 011da2277..acedaf636 100644 --- a/sml/svm/tests/BUILD.bazel +++ b/sml/svm/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,6 +24,6 @@ py_test( "//sml/svm", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/svm/tests/svm_test.py b/sml/svm/tests/svm_test.py index c2b7a5800..68d3d519c 100644 --- a/sml/svm/tests/svm_test.py +++ b/sml/svm/tests/svm_test.py @@ -21,16 +21,14 @@ from sklearn.model_selection import train_test_split from sklearn.svm import SVC -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu # type: ignore import spu.utils.simulation as spsim from sml.svm.svm import SVM class UnitTests(unittest.TestCase): def test_svm(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def proc(x0, x1, y0): rbf_svm = SVM(kernel="rbf", max_iter=102) diff --git a/sml/tree/tests/BUILD.bazel b/sml/tree/tests/BUILD.bazel index e96e9140a..dff1b8cef 100644 --- a/sml/tree/tests/BUILD.bazel +++ b/sml/tree/tests/BUILD.bazel @@ -13,6 +13,7 @@ # limitations under the License. load("@rules_python//python:defs.bzl", "py_test") +load("@spu_pip_dev//:requirements.bzl", "requirement") package(default_visibility = ["//visibility:public"]) @@ -23,6 +24,6 @@ py_test( "//sml/tree", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) diff --git a/sml/tree/tests/tree_test.py b/sml/tree/tests/tree_test.py index 470065f93..ef946b76f 100644 --- a/sml/tree/tests/tree_test.py +++ b/sml/tree/tests/tree_test.py @@ -19,7 +19,7 @@ from sklearn.datasets import load_iris from sklearn.tree import DecisionTreeClassifier -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu # type: ignore import spu.utils.simulation as spsim from sml.tree.tree import DecisionTreeClassifier as sml_dtc @@ -55,9 +55,7 @@ def load_data(): return X, y # bandwidth and latency only work for docker mode - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) # load mock data X, y = load_data() diff --git a/sml/utils/BUILD.bazel b/sml/utils/BUILD.bazel index 8982e378e..9309343b9 100644 --- a/sml/utils/BUILD.bazel +++ b/sml/utils/BUILD.bazel @@ -20,7 +20,7 @@ spu_py_library( name = "emulation", srcs = [ "emulation.py", - "@spulib//libspu:spu_py_proto", + # "@spulib//libspu:spu_py_proto", ], data = [ "//examples/python/conf", # FIXME diff --git a/sml/utils/emulation.py b/sml/utils/emulation.py index 1685af21a..d2e97b0e4 100644 --- a/sml/utils/emulation.py +++ b/sml/utils/emulation.py @@ -28,7 +28,7 @@ import yaml import spu.utils.distributed as ppd -from spu import spu_pb2 +from spu import libspu from spu.utils.polyfill import Process CLUSTER_ABY3_3PC = "examples/python/conf/3pc.json" @@ -158,7 +158,7 @@ def run( self, func: Callable, static_argnums=(), - copts=spu_pb2.CompilerOptions(), + copts=libspu.CompilerOptions(), ): def wrapper(*args, **kwargs): # run the func on SPU. diff --git a/sml/utils/tests/extmath_test.py b/sml/utils/tests/extmath_test.py index b1608fe03..e96fa41f9 100644 --- a/sml/utils/tests/extmath_test.py +++ b/sml/utils/tests/extmath_test.py @@ -16,7 +16,7 @@ import jax.numpy as jnp import numpy as np -import spu.spu_pb2 as spu_pb2 # type: ignore +import spu.libspu as libspu # type: ignore import spu.utils.simulation as spsim from sml.utils.extmath import randomized_svd, svd @@ -55,14 +55,14 @@ def setUpClass(cls): np.random.seed(0) # 2. init simulator - config64 = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM64, + config64 = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.ABY3, + field=libspu.FieldType.FM64, fxp_fraction_bits=18, ) - config128 = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.ABY3, - field=spu_pb2.FieldType.FM128, + config128 = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.ABY3, + field=libspu.FieldType.FM128, fxp_fraction_bits=30, ) sim64 = spsim.Simulator(3, config64) diff --git a/spu/BUILD.bazel b/spu/BUILD.bazel index eb12e261a..47b4295bc 100644 --- a/spu/BUILD.bazel +++ b/spu/BUILD.bazel @@ -76,8 +76,8 @@ spu_py_library( name = "api", srcs = [ "api.py", - "spu_pb2.py", - "@spulib//libspu:spu_py_proto", + # "spu_pb2.py", + # "@spulib//libspu:spu_py_proto", ], data = [ ":libspu.so", diff --git a/spu/__init__.py b/spu/__init__.py index cdc83563f..ad7586c3c 100644 --- a/spu/__init__.py +++ b/spu/__init__.py @@ -16,13 +16,13 @@ from . import experimental, psi from .api import Io, Runtime, check_cpu_feature, compile from .intrinsic import * -from .spu_pb2 import ( # type: ignore +from .libspu import ( # type: ignore; + CompilationSource, CompilerOptions, DataType, ExecutableProto, FieldType, ProtocolKind, - PtType, RuntimeConfig, ShapeProto, Visibility, @@ -32,16 +32,16 @@ __all__ = [ "__version__", - # spu_pb2 + # libspu + "CompilationSource", + "CompilerOptions", "DataType", "Visibility", - "PtType", "ProtocolKind", "FieldType", "ShapeProto", "RuntimeConfig", "ExecutableProto", - "CompilerOptions", # spu_api "Io", "Runtime", diff --git a/spu/api.py b/spu/api.py index c70206d98..47b5cdf31 100644 --- a/spu/api.py +++ b/spu/api.py @@ -19,29 +19,28 @@ from cachetools import LRUCache, cached from . import libspu # type: ignore -from . import spu_pb2 class Runtime(object): """The SPU Virtual Machine Slice.""" - def __init__(self, link: libspu.link.Context, config: spu_pb2.RuntimeConfig): + def __init__(self, link: libspu.link.Context, config: libspu.RuntimeConfig): """Constructor of an SPU Runtime. Args: link (libspu.link.Context): Link context. - config (spu_pb2.RuntimeConfig): SPU Runtime Config. + config (libspu.RuntimeConfig): SPU Runtime Config. """ - self._vm = libspu.RuntimeWrapper(link, config.SerializeToString()) + self._vm = libspu.RuntimeWrapper(link, config) - def run(self, executable: spu_pb2.ExecutableProto) -> None: + def run(self, executable: libspu.ExecutableProto) -> None: """Run an SPU executable. Args: - executable (spu_pb2.ExecutableProto): executable. + executable (libspu.ExecutableProto): executable. """ - return self._vm.Run(executable.SerializeToString()) + return self._vm.Run(executable) def set_var(self, name: str, value: libspu.Share) -> None: """Set an SPU value. @@ -75,18 +74,16 @@ def get_var_chunk_count(self, name: str) -> int: """ return self._vm.GetVarChunksCount(name) - def get_var_meta(self, name: str) -> spu_pb2.ValueMetaProto: + def get_var_meta(self, name: str) -> libspu.ValueMetaProto: """Get an SPU value without content. Args: name (str): Id of value. Returns: - spu_pb2.ValueMeta: Data meta with out content. + libspu.ValueMeta: Data meta with out content. """ - ret = spu_pb2.ValueMetaProto() - ret.ParseFromString(self._vm.GetVarMeta(name)) - return ret + return self._vm.GetVarMeta(name) def del_var(self, name: str) -> None: """Delete an SPU value. @@ -104,28 +101,28 @@ def clear(self) -> None: class Io(object): """The SPU IO interface.""" - def __init__(self, world_size: int, config: spu_pb2.RuntimeConfig): + def __init__(self, world_size: int, config: libspu.RuntimeConfig): """Constructor of an SPU Io. Args: world_size (int): # of participants of SPU Device. - config (spu_pb2.RuntimeConfig): SPU Runtime Config. + config (libspu.RuntimeConfig): SPU Runtime Config. """ - self._io = libspu.IoWrapper(world_size, config.SerializeToString()) + self._io = libspu.IoWrapper(world_size, config) def get_share_chunk_count( - self, x: 'np.ndarray', vtype: spu_pb2.Visibility, owner_rank: int = -1 + self, x: 'np.ndarray', vtype: libspu.Visibility, owner_rank: int = -1 ) -> int: return self._io.GetShareChunkCount(x, vtype, owner_rank) def make_shares( - self, x: 'np.ndarray', vtype: spu_pb2.Visibility, owner_rank: int = -1 + self, x: 'np.ndarray', vtype: libspu.Visibility, owner_rank: int = -1 ) -> List[libspu.Share]: """Convert from NumPy array to list of SPU value(s). Args: x (np.ndarray): input. - vtype (spu_pb2.Visibility): visibility. + vtype (libspu.Visibility): visibility. owner_rank (int): the index of the trusted piece. if >= 0, colocation optimization may be applied. Returns: @@ -146,22 +143,24 @@ def reconstruct(self, shares: List[libspu.Share]) -> 'np.ndarray': @cached(cache=LRUCache(maxsize=128)) -def _spu_compilation(source: str, options_str: str): - return libspu.compile(source, options_str) +def _spu_compilation( + source: libspu.CompilationSource, options: libspu.CompilerOptions +) -> bytes: + return libspu.compile(source, options) -def compile(source: spu_pb2.CompilationSource, copts: spu_pb2.CompilerOptions) -> str: +def compile(source: libspu.CompilationSource, copts: libspu.CompilerOptions) -> bytes: """Compile from textual HLO/MHLO IR to SPU bytecode. Args: - source (spu_pb2.CompilationSource): input to compiler. - copts (spu_pb2.CompilerOptions): compiler options. + source (libspu.CompilationSource): input to compiler. + copts (libspu.CompilerOptions): compiler options. Returns: - [spu_pb2.ValueProto]: output. + [libspu.ValueProto]: output. """ - return _spu_compilation(source.SerializeToString(), copts.SerializeToString()) + return _spu_compilation(source, copts) def check_cpu_feature(): diff --git a/spu/libspu.cc b/spu/libspu.cc index 4259442f4..f9e2a7f34 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include "pybind11/iostream.h" @@ -28,7 +29,6 @@ #include "yacl/link/factory.h" #include "libspu/compiler/compile.h" -#include "libspu/core/config.h" #include "libspu/core/context.h" #include "libspu/core/logging.h" #include "libspu/core/value.h" @@ -37,6 +37,7 @@ #include "libspu/device/pphlo/pphlo_executor.h" #include "libspu/device/symbol_table.h" #include "libspu/mpc/factory.h" +#include "libspu/spu.h" #include "libspu/version.h" #ifdef CHECK_AVX @@ -310,11 +311,11 @@ struct PyBindShare { static spu::Value ValueFromPyBindShare(const PyBindShare& py_share) { spu::ValueProto value; - spu::ValueMetaProto meta; + pb::ValueMetaProto meta; SPU_ENFORCE(meta.ParseFromString(py_share.meta)); value.meta.Swap(&meta); for (const auto& s : py_share.share_chunks) { - spu::ValueChunkProto chunk; + pb::ValueChunkProto chunk; SPU_ENFORCE(chunk.ParseFromString(s)); value.chunks.emplace_back(std::move(chunk)); } @@ -345,25 +346,19 @@ class RuntimeWrapper { public: explicit RuntimeWrapper(const std::shared_ptr& lctx, - const std::string& config_pb) { - spu::RuntimeConfig config; - SPU_ENFORCE(config.ParseFromString(config_pb)); - + const RuntimeConfig& config) { // first, fill protobuf default value with implementation defined value. - populateRuntimeConfig(config); + // populateRuntimeConfig(config); sctx_ = std::make_unique(config, lctx); mpc::Factory::RegisterProtocol(sctx_.get(), lctx); - max_chunk_size_ = config.share_max_chunk_size(); + max_chunk_size_ = config.share_max_chunk_size; if (max_chunk_size_ == 0) { max_chunk_size_ = 128UL * 1024 * 1024; } } - void Run(const py::bytes& exec_pb) { - spu::ExecutableProto exec; - SPU_ENFORCE(exec.ParseFromString(exec_pb)); - + void Run(const spu::ExecutableProto& exec) { spu::device::pphlo::PPHloExecutor executor; spu::device::execute(&executor, sctx_.get(), exec, &env_); } @@ -380,8 +375,8 @@ class RuntimeWrapper { return env_.getVar(name).chunksCount(max_chunk_size_); } - py::bytes GetVarMeta(const std::string& name) const { - return env_.getVar(name).toMetaProto().SerializeAsString(); + pb::ValueMetaProto GetVarMeta(const std::string& name) const { + return env_.getVar(name).toMetaProto(); } void DelVar(const std::string& name) { env_.delVar(name); } @@ -460,12 +455,12 @@ class IoWrapper { size_t max_chunk_size_; public: - IoWrapper(size_t world_size, const std::string& config_pb) { - spu::RuntimeConfig config; - SPU_ENFORCE(config.ParseFromString(config_pb)); + IoWrapper(size_t world_size, const spu::RuntimeConfig& config) { + // spu::RuntimeConfig config; + // SPU_ENFORCE(config.ParseFromString(config_pb)); ptr_ = std::make_unique(world_size, config); - max_chunk_size_ = config.share_max_chunk_size(); + max_chunk_size_ = config.share_max_chunk_size; if (max_chunk_size_ == 0) { max_chunk_size_ = 128UL * 1024 * 1024; } @@ -548,6 +543,367 @@ class IoWrapper { } }; +void BindSPU(py::module& m) { + m.doc() = R"pbdoc( + SPU Library + )pbdoc"; + + // bind enum + py::enum_(m, "DataType") + .value("DT_INVALID", DataType::DT_INVALID) + .value("DT_I1", DataType::DT_I1) + .value("DT_I8", DataType::DT_I8) + .value("DT_U8", DataType::DT_U8) + .value("DT_I16", DataType::DT_I16) + .value("DT_U16", DataType::DT_U16) + .value("DT_I32", DataType::DT_I32) + .value("DT_U32", DataType::DT_U32) + .value("DT_I64", DataType::DT_I64) + .value("DT_U64", DataType::DT_U64) + .value("DT_F16", DataType::DT_F16) + .value("DT_F32", DataType::DT_F32) + .value("DT_F64", DataType::DT_F64) + .export_values(); + + py::enum_(m, "Visibility") + .value("VIS_INVALID", Visibility::VIS_INVALID) + .value("VIS_PUBLIC", Visibility::VIS_PUBLIC) + .value("VIS_SECRET", Visibility::VIS_SECRET) + .value("VIS_PRIVATE", Visibility::VIS_PRIVATE) + .export_values(); + + py::enum_(m, "FieldType") + .value("FT_INVALID", FieldType::FT_INVALID) + .value("FM32", FieldType::FM32) + .value("FM64", FieldType::FM64) + .value("FM128", FieldType::FM128) + .export_values(); + + py::enum_(m, "ProtocolKind") + .value("PROT_INVALID", ProtocolKind::PROT_INVALID) + .value("REF2K", ProtocolKind::REF2K) + .value("SEMI2K", ProtocolKind::SEMI2K) + .value("ABY3", ProtocolKind::ABY3) + .value("CHEETAH", ProtocolKind::CHEETAH) + .value("SECURENN", ProtocolKind::SECURENN) + .export_values(); + + // bind RuntimeConfig + py::class_(m, "ClientSSLConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("certificate") = "", py::arg("private_key") = "", + py::arg("ca_file_path") = "", py::arg("verify_depth") = 0) + .def_readwrite("certificate", &ClientSSLConfig::certificate) + .def_readwrite("private_key", &ClientSSLConfig::private_key) + .def_readwrite("ca_file_path", &ClientSSLConfig::ca_file_path) + .def_readwrite("verify_depth", &ClientSSLConfig::verify_depth); + + py::class_(m, "TTPBeaverConfig") + .def(py::init<>()) + .def(py::init>(), + py::arg("server_host") = "", py::arg("adjust_rank") = 0, + py::arg("asym_crypto_schema") = "", + py::arg("server_public_key") = "", + py::arg("transport_protocol") = "", py::arg("ssl_config") = nullptr) + .def_readwrite("server_host", &TTPBeaverConfig::server_host) + .def_readwrite("adjust_rank", &TTPBeaverConfig::adjust_rank) + .def_readwrite("asym_crypto_schema", &TTPBeaverConfig::asym_crypto_schema) + .def_readwrite("server_public_key", &TTPBeaverConfig::server_public_key) + .def_readwrite("transport_protocol", &TTPBeaverConfig::transport_protocol) + .def_readwrite("ssl_config", &TTPBeaverConfig::ssl_config); + + py::class_(m, "CheetahConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("disable_matmul_pack") = false, + py::arg("enable_mul_lsb_error") = false, py::arg("ot_kind") = 0) + .def_readwrite("disable_matmul_pack", &CheetahConfig::disable_matmul_pack) + .def_readwrite("enable_mul_lsb_error", + &CheetahConfig::enable_mul_lsb_error) + .def_readwrite("ot_kind", &CheetahConfig::ot_kind); + + py::class_ rt_cls(m, "RuntimeConfig"); + + py::enum_(rt_cls, "SortMethod") + .value("SORT_DEFAULT", RuntimeConfig::SORT_DEFAULT) + .value("SORT_RADIX", RuntimeConfig::SORT_RADIX) + .value("SORT_QUICK", RuntimeConfig::SORT_QUICK) + .value("SORT_NETWORK", RuntimeConfig::SORT_NETWORK) + .export_values(); + + py::enum_(rt_cls, "ExpMode") + .value("EXP_DEFAULT", RuntimeConfig::EXP_DEFAULT) + .value("EXP_PADE", RuntimeConfig::EXP_PADE) + .value("EXP_TAYLOR", RuntimeConfig::EXP_TAYLOR) + .value("EXP_PRIME", RuntimeConfig::EXP_PRIME) + .export_values(); + + py::enum_(rt_cls, "LogMode") + .value("LOG_DEFAULT", RuntimeConfig::LOG_DEFAULT) + .value("LOG_PADE", RuntimeConfig::LOG_PADE) + .value("LOG_NEWTON", RuntimeConfig::LOG_NEWTON) + .value("LOG_MINMAX", RuntimeConfig::LOG_MINMAX) + .export_values(); + + py::enum_(rt_cls, "SigmoidMode") + .value("SIGMOID_DEFAULT", RuntimeConfig::SIGMOID_DEFAULT) + .value("SIGMOID_MM1", RuntimeConfig::SIGMOID_MM1) + .value("SIGMOID_SEG3", RuntimeConfig::SIGMOID_SEG3) + .value("SIGMOID_REAL", RuntimeConfig::SIGMOID_REAL) + .export_values(); + + py::enum_(rt_cls, "BeaverType") + .value("TrustedFirstParty", RuntimeConfig::TrustedFirstParty) + .value("TrustedThirdParty", RuntimeConfig::TrustedThirdParty) + .value("MultiParty", RuntimeConfig::MultiParty) + .export_values(); + + rt_cls.def(py::init<>()) + .def(py::init()) + .def(py::init(), py::arg("protocol"), + py::arg("field"), py::arg("fxp_fraction_bits") = 0) + .def(py::init()) + .def("ParseFromJsonString", &RuntimeConfig::ParseFromJsonString) + .def("ParseFromString", &RuntimeConfig::ParseFromString) + .def("SerializeToString", + [](const RuntimeConfig& self) { + return py::bytes(self.SerializeAsString()); + }) + .def_readwrite("protocol", &RuntimeConfig::protocol) + .def_readwrite("field", &RuntimeConfig::field) + .def_readwrite("fxp_fraction_bits", &RuntimeConfig::fxp_fraction_bits) + .def_readwrite("max_concurrency", &RuntimeConfig::max_concurrency) + .def_readwrite("enable_action_trace", &RuntimeConfig::enable_action_trace) + .def_readwrite("enable_type_checker", &RuntimeConfig::enable_type_checker) + .def_readwrite("enable_pphlo_trace", &RuntimeConfig::enable_pphlo_trace) + .def_readwrite("enable_runtime_snapshot", + &RuntimeConfig::enable_runtime_snapshot) + .def_readwrite("snapshot_dump_dir", &RuntimeConfig::snapshot_dump_dir) + .def_readwrite("enable_pphlo_profile", + &RuntimeConfig::enable_pphlo_profile) + .def_readwrite("enable_hal_profile", &RuntimeConfig::enable_hal_profile) + .def_readwrite("public_random_seed", &RuntimeConfig::public_random_seed) + .def_readwrite("share_max_chunk_size", + &RuntimeConfig::share_max_chunk_size) + .def_readwrite("sort_method", &RuntimeConfig::sort_method) + .def_readwrite("quick_sort_threshold", + &RuntimeConfig::quick_sort_threshold) + .def_readwrite("fxp_div_goldschmidt_iters", + &RuntimeConfig::fxp_div_goldschmidt_iters) + .def_readwrite("fxp_exp_mode", &RuntimeConfig::fxp_exp_mode) + .def_readwrite("fxp_exp_iters", &RuntimeConfig::fxp_exp_iters) + .def_readwrite("fxp_log_mode", &RuntimeConfig::fxp_log_mode) + .def_readwrite("fxp_log_iters", &RuntimeConfig::fxp_log_iters) + .def_readwrite("fxp_log_orders", &RuntimeConfig::fxp_log_orders) + .def_readwrite("sigmoid_mode", &RuntimeConfig::sigmoid_mode) + .def_readwrite("enable_lower_accuracy_rsqrt", + &RuntimeConfig::enable_lower_accuracy_rsqrt) + .def_readwrite("sine_cosine_iters", &RuntimeConfig::sine_cosine_iters) + .def_readwrite("beaver_type", &RuntimeConfig::beaver_type) + .def_readwrite("ttp_beaver_config", &RuntimeConfig::ttp_beaver_config) + .def_readwrite("cheetah_2pc_config", &RuntimeConfig::cheetah_2pc_config) + .def_readwrite("trunc_allow_msb_error", + &RuntimeConfig::trunc_allow_msb_error) + .def_readwrite("experimental_disable_mmul_split", + &RuntimeConfig::experimental_disable_mmul_split) + .def_readwrite("experimental_enable_inter_op_par", + &RuntimeConfig::experimental_enable_inter_op_par) + .def_readwrite("experimental_enable_intra_op_par", + &RuntimeConfig::experimental_enable_intra_op_par) + .def_readwrite("experimental_disable_vectorization", + &RuntimeConfig::experimental_disable_vectorization) + .def_readwrite("experimental_inter_op_concurrency", + &RuntimeConfig::experimental_inter_op_concurrency) + .def_readwrite("experimental_enable_colocated_optimization", + &RuntimeConfig::experimental_enable_colocated_optimization) + .def_readwrite("experimental_enable_exp_prime", + &RuntimeConfig::experimental_enable_exp_prime) + .def_readwrite("experimental_exp_prime_offset", + &RuntimeConfig::experimental_exp_prime_offset) + .def_readwrite("experimental_exp_prime_disable_lower_bound", + &RuntimeConfig::experimental_exp_prime_disable_lower_bound) + .def_readwrite("experimental_exp_prime_enable_upper_bound", + &RuntimeConfig::experimental_exp_prime_enable_upper_bound); + + // Compiler + py::enum_(m, "SourceIRType") + .value("XLA", SourceIRType::XLA) + .value("STABLEHLO", SourceIRType::STABLEHLO) + .export_values(); + + py::class_(m, "CompilationSource") + .def(py::init<>()) + .def(py::init>(), + py::arg("ir_type") = SourceIRType::XLA, py::arg("ir_txt") = "", + py::arg("input_visibility") = std::vector{}) + .def("__hash__", + [](const CompilationSource& self) { + return std::hash{}(self); + }) + .def("__eq__", + [](const CompilationSource& self, const CompilationSource& other) { + return self == other; + }) + .def_readwrite("ir_type", &CompilationSource::ir_type) + .def_property( + "ir_txt", + [](const CompilationSource& self) { return py::bytes(self.ir_txt); }, + [](CompilationSource& self, const py::bytes& bytes) { + self.ir_txt = std::string(bytes); + }) + .def_readwrite("input_visibility", &CompilationSource::input_visibility); + + py::enum_(m, "XLAPrettyPrintKind") + .value("TEXT", XLAPrettyPrintKind::TEXT) + .value("DOT", XLAPrettyPrintKind::DOT) + .value("HTML", XLAPrettyPrintKind::HTML) + .export_values(); + + py::class_(m, "CompilerOptions") + .def(py::init<>()) + .def(py::init(), + py::arg("enable_pretty_print") = false, + py::arg("pretty_print_dump_dir") = "", + py::arg("xla_pp_kind") = XLAPrettyPrintKind::TEXT, + py::arg("disable_sqrt_plus_epsilon_rewrite") = false, + py::arg("disable_div_sqrt_rewrite") = false, + py::arg("disable_reduce_truncation_optimization") = false, + py::arg("disable_maxpooling_optimization") = false, + py::arg("disallow_mix_types_opts") = false, + py::arg("disable_select_optimization") = false, + py::arg("enable_optimize_denominator_with_broadcast") = false, + py::arg("disable_deallocation_insertion") = false, + py::arg("disable_partial_sort_optimization") = false) + .def("__hash__", + [](const CompilerOptions& self) { + return std::hash{}(self); + }) + .def("__eq__", [](const CompilerOptions& self, + const CompilerOptions& other) { return self == other; }) + .def_readwrite("enable_pretty_print", + &CompilerOptions::enable_pretty_print) + .def_readwrite("pretty_print_dump_dir", + &CompilerOptions::pretty_print_dump_dir) + .def_readwrite("xla_pp_kind", &CompilerOptions::xla_pp_kind) + .def_readwrite("disable_sqrt_plus_epsilon_rewrite", + &CompilerOptions::disable_sqrt_plus_epsilon_rewrite) + .def_readwrite("disable_div_sqrt_rewrite", + &CompilerOptions::disable_div_sqrt_rewrite) + .def_readwrite("disable_reduce_truncation_optimization", + &CompilerOptions::disable_reduce_truncation_optimization) + .def_readwrite("disable_maxpooling_optimization", + &CompilerOptions::disable_maxpooling_optimization) + .def_readwrite("disallow_mix_types_opts", + &CompilerOptions::disallow_mix_types_opts) + .def_readwrite("disable_select_optimization", + &CompilerOptions::disable_select_optimization) + .def_readwrite( + "enable_optimize_denominator_with_broadcast", + &CompilerOptions::enable_optimize_denominator_with_broadcast) + .def_readwrite("disable_deallocation_insertion", + &CompilerOptions::disable_deallocation_insertion) + .def_readwrite("disable_partial_sort_optimization", + &CompilerOptions::disable_partial_sort_optimization); + + py::class_(m, "ExecutableProto") + .def(py::init<>()) + .def(py::init, + std::vector, std::string>(), + py::arg("name") = "", + py::arg("input_names") = std::vector{}, + py::arg("output_names") = std::vector{}, + py::arg("code") = "") + .def("ParseFromString", &ExecutableProto::ParseFromString) + .def("SerializeToString", + [](const ExecutableProto& self) { + return py::bytes(self.SerializeAsString()); + }) + .def_readwrite("name", &ExecutableProto::name) + .def_readwrite("input_names", &ExecutableProto::input_names) + .def_readwrite("output_names", &ExecutableProto::output_names) + .def_property( + "code", + [](const ExecutableProto& self) { return py::bytes(self.code); }, + [](ExecutableProto& self, const py::bytes& bytes) { + self.code = std::string(bytes); + }); + + py::class_(m, "ShapeProto") + .def(py::init<>()) + .def_property_readonly("dims", [](const pb::ShapeProto& self) { + return std::vector(self.dims().begin(), self.dims().end()); + }); + + py::class_(m, "ValueMetaProto") + .def(py::init<>()) + .def("ParseFromString", &pb::ValueMetaProto::ParseFromString) + .def_property_readonly("data_type", + [](const pb::ValueMetaProto& self) { + return DataType(self.data_type()); + }) + .def_property_readonly("is_complex", &pb::ValueMetaProto::is_complex) + .def_property_readonly("visibility", + [](const pb::ValueMetaProto& self) { + return Visibility(self.visibility()); + }) + .def_property_readonly("shape", &pb::ValueMetaProto::shape) + .def_property_readonly("storage_type", &pb::ValueMetaProto::storage_type); + + py::class_(m, "Share", "Share in python runtime") + .def(py::init<>()) + .def_readwrite("share_chunks", &PyBindShare::share_chunks, "share chunks") + .def_readwrite("meta", &PyBindShare::meta, "meta of share") + .def(py::pickle( + [](const PyBindShare& s) { // dump + return py::make_tuple(s.meta, s.share_chunks); + }, + [](const py::tuple& t) { // load + return PyBindShare{t[0].cast(), + t[1].cast>()}; + })); + + // bind spu virtual machine. + py::class_(m, "RuntimeWrapper", "SPU virtual device") + .def(py::init, + const spu::RuntimeConfig&>(), + NO_GIL) + .def("Run", &RuntimeWrapper::Run, NO_GIL) + .def("SetVar", + &RuntimeWrapper:: + SetVar) // https://github.com/pybind/pybind11/issues/1782 + // SetVar & GetVar are using + // py::byte, so they must acquire gil... + .def("GetVar", &RuntimeWrapper::GetVar) + .def("GetVarChunksCount", &RuntimeWrapper::GetVarChunksCount) + .def("GetVarMeta", &RuntimeWrapper::GetVarMeta) + .def("DelVar", &RuntimeWrapper::DelVar); + + // bind spu io suite. + py::class_(m, "IoWrapper", "SPU VM IO") + .def(py::init()) + .def("MakeShares", &IoWrapper::MakeShares, "Create secret shares", + py::arg("arr"), py::arg("visibility"), py::arg("owner_rank") = -1) + .def("GetShareChunkCount", &IoWrapper::GetShareChunkCount, py::arg("arr"), + py::arg("visibility"), py::arg("owner_rank") = -1) + .def("Reconstruct", &IoWrapper::Reconstruct); + + // bind compiler. + m.def( + "compile", + [](const spu::CompilationSource& source, + const spu::CompilerOptions& copts) { + py::scoped_ostream_redirect stream( + std::cout, // std::ostream& + py::module_::import("sys").attr("stdout") // Python output + ); + return py::bytes(spu::compiler::compile(source, copts)); + }, + "spu compile.", py::arg("source"), py::arg("copts")); +} + void BindLogging(py::module& m) { m.doc() = R"pbdoc( SPU Logging Library @@ -624,62 +980,7 @@ PYBIND11_MODULE(libspu, m) { } }); - py::class_(m, "Share", "Share in python runtime") - .def(py::init<>()) - .def_readwrite("share_chunks", &PyBindShare::share_chunks, "share chunks") - .def_readwrite("meta", &PyBindShare::meta, "meta of share") - .def(py::pickle( - [](const PyBindShare& s) { // dump - return py::make_tuple(s.meta, s.share_chunks); - }, - [](const py::tuple& t) { // load - return PyBindShare{t[0].cast(), - t[1].cast>()}; - })); - - // bind spu virtual machine. - py::class_(m, "RuntimeWrapper", "SPU virtual device") - .def(py::init, std::string>(), - NO_GIL) - .def("Run", &RuntimeWrapper::Run, NO_GIL) - .def("SetVar", - &RuntimeWrapper:: - SetVar) // https://github.com/pybind/pybind11/issues/1782 - // SetVar & GetVar are using - // py::byte, so they must acquire gil... - .def("GetVar", &RuntimeWrapper::GetVar) - .def("GetVarChunksCount", &RuntimeWrapper::GetVarChunksCount) - .def("GetVarMeta", &RuntimeWrapper::GetVarMeta) - .def("DelVar", &RuntimeWrapper::DelVar); - - // bind spu io suite. - py::class_(m, "IoWrapper", "SPU VM IO") - .def(py::init()) - .def("MakeShares", &IoWrapper::MakeShares, "Create secret shares", - py::arg("arr"), py::arg("visibility"), py::arg("owner_rank") = -1) - .def("GetShareChunkCount", &IoWrapper::GetShareChunkCount, py::arg("arr"), - py::arg("visibility"), py::arg("owner_rank") = -1) - .def("Reconstruct", &IoWrapper::Reconstruct); - - // bind compiler. - m.def( - "compile", - [](const py::bytes& serialized_src, const std::string& serialized_copts) { - py::scoped_ostream_redirect stream( - std::cout, // std::ostream& - py::module_::import("sys").attr("stdout") // Python output - ); - - spu::CompilerOptions copts; - SPU_ENFORCE(copts.ParseFromString(serialized_copts), - "Parse compiler options failed"); - - spu::CompilationSource src; - SPU_ENFORCE(src.ParseFromString(serialized_src), "Parse source failed"); - - return py::bytes(spu::compiler::compile(src, copts)); - }, - "spu compile.", py::arg("source"), py::arg("copts")); + BindSPU(m); // bind spu libs. py::module link_m = m.def_submodule("link"); diff --git a/spu/libspu.pyi b/spu/libspu.pyi new file mode 100644 index 000000000..31894a009 --- /dev/null +++ b/spu/libspu.pyi @@ -0,0 +1,291 @@ +# Copyright 2024 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +from typing import overload + +class DataType(enum.IntEnum): + DT_INVALID = 0 + DT_I1 = 1 + DT_I8 = 2 + DT_U8 = 3 + DT_I16 = 4 + DT_U16 = 5 + DT_I32 = 6 + DT_U32 = 7 + DT_I64 = 8 + DT_U64 = 9 + DT_F16 = 10 + DT_F32 = 11 + DT_F64 = 12 + +class Visibility(enum.IntEnum): + VIS_INVALID = 0 + VIS_SECRET = 1 + VIS_PUBLIC = 2 + VIS_PRIVATE = 3 + +class FieldType(enum.IntEnum): + FT_INVALID = 0 + FM32 = 1 + FM64 = 2 + FM128 = 3 + +class ProtocolKind(enum.IntEnum): + PROT_INVALID = 0 + REF2K = 1 + SEMI2K = 2 + ABY3 = 3 + CHEETAH = 4 + SECURENN = 5 + +class ClientSSLConfig: + def __init__( + self, + certificate: str = "", + private_key: str = "", + ca_file_path: str = "", + verify_depth: str = "", + ): + self.certificate = certificate + self.private_key = private_key + self.ca_file_path = ca_file_path + self.verify_depth = verify_depth + +class TTPBeaverConfig: + def __init__( + self, + server_host: str = "", + adjust_rank: int = 0, + asym_crypto_schema: str = "", + server_public_key: str = "", + transport_protocol: str = "", + ssl_config: ClientSSLConfig | None = None, + ): + self.server_host = server_host + self.adjust_rank = adjust_rank + self.asym_crypto_schema = asym_crypto_schema + self.server_public_key = server_public_key + self.transport_protocol = transport_protocol + self.ssl_config = ssl_config + +class CheetahOtKind(enum.IntEnum): + YACL_Ferret = 0 + YACL_Softspoken = 1 + EMP_Ferret = 2 + +class CheetahConfig: + def __init__( + self, + disable_matmul_pack: bool, + enable_mul_lsb_error: bool, + ot_kind: CheetahOtKind = CheetahOtKind.YACL_Ferret, + ): + self.disable_matmul_pack = disable_matmul_pack + self.enable_mul_lsb_error = enable_mul_lsb_error + self.ot_kind = ot_kind + +class RuntimeConfig: + class SortMethod(enum.IntEnum): + SORT_DEFAULT = 0 + SORT_RADIX = 1 + SORT_QUICK = 2 + SORT_NETWORK = 3 + + class ExpMode(enum.IntEnum): + EXP_DEFAULT = 0 + EXP_PADE = 1 + EXP_TAYLOR = 2 + EXP_PRIME = 3 + + class LogMode(enum.IntEnum): + LOG_DEFAULT = 0 + LOG_PADE = 1 + LOG_NEWTON = 2 + LOG_MINMAX = 3 + + class SigmoidMode(enum.IntEnum): + SIGMOID_DEFAULT = 0 + SIGMOID_MM1 = 1 + SIGMOID_SEG3 = 2 + SIGMOID_REAL = 3 + + class BeaverType(enum.IntEnum): + TrustedFirstParty = 0 + TrustedThirdParty = 1 + MultiParty = 2 + + protocol: ProtocolKind + field: FieldType + fxp_fraction_bits: int + max_concurrency: int + enable_action_trace: bool + enable_type_checker: bool + enable_pphlo_trace: bool + enable_runtime_snapshot: bool + snapshot_dump_dir: str + enable_pphlo_profile: bool + enable_hal_profile: bool + public_random_seed: int + share_max_chunk_size: int + sort_method: SortMethod + quick_sort_threshold: int + fxp_div_goldschmidt_iters: int + fxp_exp_mode: ExpMode + fxp_exp_iters: int + fxp_log_mode: LogMode + fxp_log_iters: int + fxp_log_orders: int + sigmoid_mode: SigmoidMode + enable_lower_accuracy_rsqrt: bool + sine_cosine_iters: int + beaver_type: BeaverType + ttp_beaver_config: TTPBeaverConfig + cheetah_2pc_config: CheetahConfig + trunc_allow_msb_error: bool + experimental_disable_mmul_split: bool + experimental_enable_inter_op_par: bool + experimental_enable_intra_op_par: bool + experimental_disable_vectorization: bool + experimental_inter_op_concurrency: int + experimental_enable_colocated_optimization: bool + experimental_enable_exp_prime: bool + experimental_exp_prime_offset: int + experimental_exp_prime_disable_lower_bound: bool + experimental_exp_prime_enable_upper_bound: bool + + # @staticmethod + # def makeFromJson(json: str) -> 'RuntimeConfig': ... + @overload + def __init__(self): ... + @overload + def __init__( + self, + protocol: ProtocolKind = ProtocolKind.PROT_INVALID, + field: FieldType = FieldType.FT_INVALID, + fxp_fraction_bits: int = 0, + ): + self.protocol = protocol + self.field = field + self.fxp_fraction_bits = fxp_fraction_bits + + @overload + def __init__(self, other: 'RuntimeConfig'): ... + def ParseFromJsonString(self, data: str) -> bool: ... + def ParseFromString(self, data: bytes) -> bool: ... + def SerializeToString(self) -> bytes: ... + +class SourceIRType(enum.IntEnum): + XLA = 0 + STABLEHLO = 1 + +class CompilationSource: + def __init__( + self, + ir_type: SourceIRType = SourceIRType.XLA, + ir_txt: bytes = b"", + input_visibility: list[Visibility] = [], + ): + self.ir_type = ir_type + self.ir_txt = ir_txt + self.input_visibility = input_visibility + +class CompilerOptions: + def __init__( + self, + enable_pretty_print: bool = False, + pretty_print_dump_dir: str = "", + xla_pp_kind: RuntimeConfig.SortMethod = RuntimeConfig.SortMethod.SORT_DEFAULT, + disable_sqrt_plus_epsilon_rewrite=False, + disable_div_sqrt_rewrite=False, + disable_reduce_truncation_optimization=False, + disable_maxpooling_optimization=False, + disallow_mix_types_opts=False, + disable_select_optimization=False, + enable_optimize_denominator_with_broadcast=False, + disable_deallocation_insertion=False, + disable_partial_sort_optimization=False, + ): + self.enable_pretty_print = enable_pretty_print + self.pretty_print_dump_dir = pretty_print_dump_dir + self.xla_pp_kind = xla_pp_kind + self.disable_sqrt_plus_epsilon_rewrite = disable_sqrt_plus_epsilon_rewrite + self.disable_div_sqrt_rewrite = disable_div_sqrt_rewrite + self.disable_reduce_truncation_optimization = ( + disable_reduce_truncation_optimization + ) + self.disable_maxpooling_optimization = disable_maxpooling_optimization + self.disallow_mix_types_opts = disallow_mix_types_opts + self.disable_select_optimization = disable_select_optimization + self.enable_optimize_denominator_with_broadcast = ( + enable_optimize_denominator_with_broadcast + ) + self.disable_deallocation_insertion = disable_deallocation_insertion + self.disable_partial_sort_optimization = disable_partial_sort_optimization + +class ExecutableProto: + def __init__( + self, + name: str = "", + input_names: list[str] = [], + output_names: list[str] = [], + code: str | bytes = b"", + ): + self.name = name + self.input_names = input_names + self.output_names = output_names + self.code = code + + def ParseFromString(self, data: bytes) -> bool: ... + def SerializeToString(self) -> bytes: ... + +class Share: + meta: bytes + share_chunks: list[bytes] + +class ShapeProto: + def __init__(self, dims: list[int]): + self.dims = dims + +class ValueMetaProto: + data_type: DataType + is_complex: bool + visibility: Visibility + shape: ShapeProto + storage_type: str + + def ParseFromString(self, data: bytes) -> bool: ... + +class RuntimeWrapper: + def __init__(self, link: link.Context, config: RuntimeConfig): ... + def Run(self, executable: ExecutableProto): ... + def SetVar(self, name: str, value: Share): ... + def GetVar(self, name: str) -> Share: ... + def GetVarChunksCount(self, name: str) -> int: ... + def GetVarMeta(self, name: str) -> ValueMetaProto: ... + def DelVar(self, name: str): ... + def Clear(self): ... + +class IoWrapper: + def __init__(self, link: link.Context, config: RuntimeConfig): ... + def MakeShares( + self, arr: bytes, visibility: int, owner_rank: int = -1 + ) -> list[Share]: ... + def GetShareChunkCount( + self, arr: bytes, visibility: int, owner_rank: int = -1 + ) -> int: ... + def Reconstruct(self, vals: list[Share]) -> bytes: ... + +def _check_cpu_features(): ... +def compile(source: CompilationSource, copts: CompilerOptions) -> bytes: ... diff --git a/spu/ops/groupby/BUILD.bazel b/spu/ops/groupby/BUILD.bazel index b38f10a40..850fb0515 100644 --- a/spu/ops/groupby/BUILD.bazel +++ b/spu/ops/groupby/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@spu_pip_dev//:requirements.bzl", "requirement") load("//bazel:spu.bzl", "spu_py_library", "spu_py_test") package(default_visibility = ["//visibility:public"]) @@ -90,6 +91,6 @@ spu_py_test( ":segmentation", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//pandas:pkg", + requirement("pandas"), ], ) diff --git a/spu/ops/groupby/groupby_test.py b/spu/ops/groupby/groupby_test.py index e1994f08a..284c7eee6 100644 --- a/spu/ops/groupby/groupby_test.py +++ b/spu/ops/groupby/groupby_test.py @@ -18,7 +18,7 @@ import numpy as np import pandas as pd -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as spsim from spu.ops.groupby.aggregation import groupby_count, groupby_count_cleartext from spu.ops.groupby.groupby_via_shuffle import ( @@ -51,7 +51,7 @@ def groupby_agg_fun(agg): def test_fn(agg): - sim = spsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def proc(x1, x2, y): return groupby([x1[:, 2], x2[:, 3]], [y]) @@ -153,9 +153,7 @@ def test_var(self): test_fn('var') def test_count(self): - sim = spsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 - ) + sim = spsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) def proc(x1, x2, y): return groupby([x1[:, 2], x2[:, 3]], [y]) diff --git a/spu/spu_pb2.py b/spu/spu_pb2.py deleted file mode 100644 index a33acd309..000000000 --- a/spu/spu_pb2.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 Ant Group Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from libspu.spu_pb2 import * diff --git a/spu/tests/BUILD.bazel b/spu/tests/BUILD.bazel index 3c0289316..d2c7a6799 100644 --- a/spu/tests/BUILD.bazel +++ b/spu/tests/BUILD.bazel @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +load("@spu_pip_dev//:requirements.bzl", "requirement") load("//bazel:spu.bzl", "spu_py_binary", "spu_py_library", "spu_py_test") package(default_visibility = ["//visibility:public"]) @@ -23,7 +24,7 @@ spu_py_library( "//spu:api", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//absl_py:pkg", + requirement("absl-py"), ], ) @@ -200,7 +201,7 @@ spu_py_test( srcs = ["jax_sanity_test.py"], deps = [ ":jnp_testbase", - "@spu_pip_dev//scikit_learn:pkg", + requirement("scikit-learn"), ], ) @@ -221,7 +222,7 @@ spu_py_test( "//spu:api", "//spu:init", "//spu/utils:simulation", - "@spu_pip_dev//absl_py:pkg", + requirement("absl-py"), ], ) @@ -301,7 +302,7 @@ spu_py_test( ":utils", "//spu:init", "//spu:psi", - "@spu_pip_dev//pandas:pkg", + requirement("pandas"), ], ) @@ -313,10 +314,10 @@ spu_py_test( "//spu/utils:frontend", ] + select({ "@bazel_tools//src/conditions:linux_x86_64": [ - "@spu_pip_dev//tensorflow_cpu:pkg", + requirement("tensorflow-cpu"), ], "//conditions:default": [ - "@spu_pip_dev//tensorflow:pkg", + requirement("tensorflow"), ], }), ) @@ -333,13 +334,13 @@ spu_py_test( ":utils", "//spu:init", "//spu/utils:distributed", - "@spu_pip_dev//grpcio:pkg", + requirement("grpcio"), ] + select({ "@bazel_tools//src/conditions:linux_x86_64": [ - "@spu_pip_dev//tensorflow_cpu:pkg", + requirement("tensorflow-cpu"), ], "//conditions:default": [ - "@spu_pip_dev//tensorflow:pkg", + requirement("tensorflow"), ], }), ) @@ -349,6 +350,6 @@ spu_py_test( srcs = ["jax_compile_test.py"], deps = [ ":jnp_testbase", - "@spu_pip_dev//flax:pkg", + requirement("flax"), ], ) diff --git a/spu/tests/distributed_test.py b/spu/tests/distributed_test.py index 54415b099..0a3eb23c0 100644 --- a/spu/tests/distributed_test.py +++ b/spu/tests/distributed_test.py @@ -22,7 +22,7 @@ import tensorflow as tf import spu.utils.distributed as ppd -from spu import spu_pb2 +from spu import libspu from spu.tests.utils import get_free_port from spu.utils.polyfill import Process @@ -153,15 +153,15 @@ def test_basic_pyu(self): def test_basic_spu_jax(self): a = ppd.device("SPU")(no_in_one_out)() self.assertTrue(isinstance(a, ppd.SPU.Object)) - self.assertEqual(a.vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(a.vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(a), np.array([1, 2])) # no in, two out a, b = ppd.device("SPU")(no_in_two_out)() self.assertTrue(isinstance(a, ppd.SPU.Object)) self.assertTrue(isinstance(b, ppd.SPU.Object)) - self.assertEqual(a.vtype, spu_pb2.VIS_PUBLIC) - self.assertEqual(b.vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(a.vtype, libspu.Visibility.VIS_PUBLIC) + self.assertEqual(b.vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(a), np.array([1, 2])) npt.assert_equal(ppd.get(b), np.array([3.0, 4.0])) @@ -170,8 +170,8 @@ def test_basic_spu_jax(self): self.assertEqual(len(l), 2) self.assertTrue(isinstance(l[0], ppd.SPU.Object)) self.assertTrue(isinstance(l[1], ppd.SPU.Object)) - self.assertEqual(l[0].vtype, spu_pb2.VIS_PUBLIC) - self.assertEqual(l[1].vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(l[0].vtype, libspu.Visibility.VIS_PUBLIC) + self.assertEqual(l[1].vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(l[0]), np.array([1, 2])) npt.assert_equal(ppd.get(l[1]), np.array([3.0, 4.0])) @@ -179,21 +179,21 @@ def test_basic_spu_jax(self): d = ppd.device("SPU")(no_in_dict_out)() self.assertTrue(isinstance(d["first"], ppd.SPU.Object)) self.assertTrue(isinstance(d["second"], ppd.SPU.Object)) - self.assertEqual(d["first"].vtype, spu_pb2.VIS_PUBLIC) - self.assertEqual(d["second"].vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(d["first"].vtype, libspu.Visibility.VIS_PUBLIC) + self.assertEqual(d["second"].vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(d["first"]), np.array([1, 2])) npt.assert_equal(ppd.get(d["second"]), np.array([3.0, 4.0])) # immediate input from driver e = ppd.device("SPU")(jnp.add)(np.array([1, 2]), np.array([3, 4])) self.assertTrue(isinstance(e, ppd.SPU.Object)) - self.assertEqual(e.vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(e.vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(e), np.array([4, 6])) # reuse inputs from SPU c = ppd.device("SPU")(jnp.add)(a, b) self.assertTrue(isinstance(c, ppd.SPU.Object)) - self.assertEqual(c.vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(c.vtype, libspu.Visibility.VIS_PUBLIC) self.assertTrue(c.device is ppd.current().devices["SPU"]) npt.assert_equal(ppd.get(c), np.array([4.0, 6.0])) @@ -201,7 +201,7 @@ def test_basic_spu_jax(self): x = ppd.device("P1")(no_in_one_out)() c = ppd.device("SPU")(jnp.add)(a, x) self.assertTrue(isinstance(c, ppd.SPU.Object)) - self.assertEqual(c.vtype, spu_pb2.VIS_SECRET) + self.assertEqual(c.vtype, libspu.Visibility.VIS_SECRET) self.assertTrue(c.device is ppd.current().devices["SPU"]) npt.assert_equal(ppd.get(c), np.array([2, 4])) @@ -222,15 +222,15 @@ def test_basic_spu_tf(self): ppd.set_framework(ppd.Framework.EXP_TF) a = ppd.device("SPU")(no_in_one_out)() self.assertTrue(isinstance(a, ppd.SPU.Object)) - self.assertEqual(a.vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(a.vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(a), np.array([1, 2])) # no in, two out a, b = ppd.device("SPU")(no_in_two_out)() self.assertTrue(isinstance(a, ppd.SPU.Object)) self.assertTrue(isinstance(b, ppd.SPU.Object)) - self.assertEqual(a.vtype, spu_pb2.VIS_PUBLIC) - self.assertEqual(b.vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(a.vtype, libspu.Visibility.VIS_PUBLIC) + self.assertEqual(b.vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(a), np.array([1, 2])) npt.assert_equal(ppd.get(b), np.array([3.0, 4.0])) @@ -239,8 +239,8 @@ def test_basic_spu_tf(self): self.assertEqual(len(l), 2) self.assertTrue(isinstance(l[0], ppd.SPU.Object)) self.assertTrue(isinstance(l[1], ppd.SPU.Object)) - self.assertEqual(l[0].vtype, spu_pb2.VIS_PUBLIC) - self.assertEqual(l[1].vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(l[0].vtype, libspu.Visibility.VIS_PUBLIC) + self.assertEqual(l[1].vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(l[0]), np.array([1, 2])) npt.assert_equal(ppd.get(l[1]), np.array([3.0, 4.0])) @@ -248,21 +248,21 @@ def test_basic_spu_tf(self): d = ppd.device("SPU")(no_in_dict_out)() self.assertTrue(isinstance(d["first"], ppd.SPU.Object)) self.assertTrue(isinstance(d["second"], ppd.SPU.Object)) - self.assertEqual(d["first"].vtype, spu_pb2.VIS_PUBLIC) - self.assertEqual(d["second"].vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(d["first"].vtype, libspu.Visibility.VIS_PUBLIC) + self.assertEqual(d["second"].vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(d["first"]), np.array([1, 2])) npt.assert_equal(ppd.get(d["second"]), np.array([3.0, 4.0])) # immediate input from driver e = ppd.device("SPU")(tf_fun)(np.array([1, 2]), np.array([3, 4])) self.assertTrue(isinstance(e, ppd.SPU.Object)) - self.assertEqual(e.vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(e.vtype, libspu.Visibility.VIS_PUBLIC) npt.assert_equal(ppd.get(e), np.array([4, 6])) # reuse inputs from SPU c = ppd.device("SPU")(tf_fun)(a, a) self.assertTrue(isinstance(c, ppd.SPU.Object)) - self.assertEqual(c.vtype, spu_pb2.VIS_PUBLIC) + self.assertEqual(c.vtype, libspu.Visibility.VIS_PUBLIC) self.assertTrue(c.device is ppd.current().devices["SPU"]) npt.assert_equal(ppd.get(c), np.array([2, 4])) @@ -270,7 +270,7 @@ def test_basic_spu_tf(self): x = ppd.device("P1")(no_in_one_out)() c = ppd.device("SPU")(tf_fun)(a, x) self.assertTrue(isinstance(c, ppd.SPU.Object)) - self.assertEqual(c.vtype, spu_pb2.VIS_SECRET) + self.assertEqual(c.vtype, libspu.Visibility.VIS_SECRET) self.assertTrue(c.device is ppd.current().devices["SPU"]) npt.assert_equal(ppd.get(c), np.array([2, 4])) diff --git a/spu/tests/frontend_test.py b/spu/tests/frontend_test.py index 86efd97b5..815f5e9e4 100644 --- a/spu/tests/frontend_test.py +++ b/spu/tests/frontend_test.py @@ -19,8 +19,8 @@ import numpy as np import tensorflow as tf +import spu.libspu as libspu import spu.utils.frontend as spu_fe -from spu import spu_pb2 def test_jax_add(*args, **kwargs): @@ -41,10 +41,10 @@ def test_jax_compile_static_args(self): {"in3": 2, "in4": np.array([2, 4])}, ["in1", "in2", "in3", "in4"], [ - spu_pb2.VIS_PUBLIC, - spu_pb2.VIS_PUBLIC, - spu_pb2.VIS_PUBLIC, - spu_pb2.VIS_PUBLIC, + libspu.Visibility.VIS_PUBLIC, + libspu.Visibility.VIS_PUBLIC, + libspu.Visibility.VIS_PUBLIC, + libspu.Visibility.VIS_PUBLIC, ], lambda out_flat: [f'test-out{idx}' for idx in range(len(out_flat))], static_argnums=(0,), @@ -77,7 +77,7 @@ def test_jax_compile(self): (np.array([1, 2]), np.array([2, 4])), {}, ["in1", "in2"], - [spu_pb2.VIS_PUBLIC, spu_pb2.VIS_PUBLIC], + [libspu.Visibility.VIS_PUBLIC, libspu.Visibility.VIS_PUBLIC], lambda out_flat: [f'test-out{idx}' for idx in range(len(out_flat))], ) self.assertEqual(executable.name, "add") @@ -101,7 +101,7 @@ def foo(x, y): (np.array([1, 2]), np.array([2, 4])), {}, ["in1", "in2"], - [spu_pb2.VIS_PUBLIC, spu_pb2.VIS_PUBLIC], + [libspu.Visibility.VIS_PUBLIC, libspu.Visibility.VIS_PUBLIC], lambda out_flat: [f'test-out{idx}' for idx in range(len(out_flat))], ) self.assertEqual(executable.name, "foo") diff --git a/spu/tests/jax_compile_test.py b/spu/tests/jax_compile_test.py index 06da946bc..e4b0bb5a9 100644 --- a/spu/tests/jax_compile_test.py +++ b/spu/tests/jax_compile_test.py @@ -19,7 +19,7 @@ from jax import numpy as jnp from jax import random -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim @@ -31,7 +31,7 @@ class UnitTests(unittest.TestCase): # https://github.com/secretflow/spu/issues/428 def test_cache_with_static_argnums(self): sim = ppsim.Simulator.simple( - 1, spu_pb2.ProtocolKind.REF2K, spu_pb2.FieldType.FM64 + 1, libspu.ProtocolKind.REF2K, libspu.FieldType.FM64 ) power_list = [-1, 0, 1, 2, 3] @@ -52,7 +52,7 @@ def test_cache_with_static_argnums(self): # https://github.com/secretflow/spu/issues/306 def test_compile_nn_layer(self): sim = ppsim.Simulator.simple( - 1, spu_pb2.ProtocolKind.REF2K, spu_pb2.FieldType.FM64 + 1, libspu.ProtocolKind.REF2K, libspu.FieldType.FM64 ) class LinearModel(nn.Module): diff --git a/spu/tests/jax_sanity_test.py b/spu/tests/jax_sanity_test.py index c07f2ebf5..0c6ea012a 100644 --- a/spu/tests/jax_sanity_test.py +++ b/spu/tests/jax_sanity_test.py @@ -23,7 +23,7 @@ from sklearn import metrics from sklearn.datasets import load_breast_cancer -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim @@ -103,18 +103,18 @@ def body_fun(_, loop_carry): @parameterized.product( wsize=(2, 3), prot=( - spu_pb2.ProtocolKind.SEMI2K, - spu_pb2.ProtocolKind.ABY3, - spu_pb2.ProtocolKind.CHEETAH, + libspu.ProtocolKind.SEMI2K, + libspu.ProtocolKind.ABY3, + libspu.ProtocolKind.CHEETAH, ), - field=(spu_pb2.FieldType.FM64, spu_pb2.FieldType.FM128), + field=(libspu.FieldType.FM64, libspu.FieldType.FM128), ) class UnitTests(parameterized.TestCase): def test_sslr(self, wsize, prot, field): - if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + if prot == libspu.ProtocolKind.ABY3 and wsize != 3: return - if prot == spu_pb2.ProtocolKind.CHEETAH and ( - wsize != 2 or field != spu_pb2.FieldType.FM64 + if prot == libspu.ProtocolKind.CHEETAH and ( + wsize != 2 or field != libspu.FieldType.FM64 ): return diff --git a/spu/tests/jnp_aby3_r128_test.py b/spu/tests/jnp_aby3_r128_test.py index f9383905a..3ced4bcc9 100644 --- a/spu/tests/jnp_aby3_r128_test.py +++ b/spu/tests/jnp_aby3_r128_test.py @@ -17,7 +17,7 @@ import numpy as np -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests @@ -25,7 +25,7 @@ class JnpTestAby3FM128(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM128 ) self._rng = np.random.RandomState() diff --git a/spu/tests/jnp_aby3_r64_test.py b/spu/tests/jnp_aby3_r64_test.py index 6595ea3cd..7691ad9ef 100644 --- a/spu/tests/jnp_aby3_r64_test.py +++ b/spu/tests/jnp_aby3_r64_test.py @@ -17,7 +17,7 @@ import numpy as np -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests @@ -25,7 +25,7 @@ class JnpTestAby3FM64(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64 + 3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64 ) self._rng = np.random.RandomState() diff --git a/spu/tests/jnp_cheetah_r64_test.py b/spu/tests/jnp_cheetah_r64_test.py index 10c02a53e..23c4fcffb 100644 --- a/spu/tests/jnp_cheetah_r64_test.py +++ b/spu/tests/jnp_cheetah_r64_test.py @@ -17,7 +17,7 @@ import numpy as np -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests @@ -25,7 +25,7 @@ class JnpTestCheetahFM64(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( - 2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType.FM64 + 2, libspu.ProtocolKind.CHEETAH, libspu.FieldType.FM64 ) self._rng = np.random.RandomState() diff --git a/spu/tests/jnp_debug.py b/spu/tests/jnp_debug.py index 4757c555c..5e6d22860 100644 --- a/spu/tests/jnp_debug.py +++ b/spu/tests/jnp_debug.py @@ -15,7 +15,7 @@ import jax.numpy as jnp import numpy as np -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim if __name__ == "__main__": @@ -24,8 +24,8 @@ Please DONT commit it unless it will cause build break. """ - sim = ppsim.Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64) - copts = spu_pb2.CompilerOptions() + sim = ppsim.Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) + copts = libspu.CompilerOptions() # Tweak compiler options copts.disable_div_sqrt_rewrite = True diff --git a/spu/tests/jnp_ref2k_r64_test.py b/spu/tests/jnp_ref2k_r64_test.py index 3d536a144..95f28730f 100644 --- a/spu/tests/jnp_ref2k_r64_test.py +++ b/spu/tests/jnp_ref2k_r64_test.py @@ -17,7 +17,7 @@ import numpy as np -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests @@ -25,7 +25,7 @@ class JnpTestRef2kFM64(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.REF2K, spu_pb2.FieldType.FM64 + 3, libspu.ProtocolKind.REF2K, libspu.FieldType.FM64 ) self._rng = np.random.RandomState() diff --git a/spu/tests/jnp_semi2k_r128_test.py b/spu/tests/jnp_semi2k_r128_test.py index 7ce5c8251..8e0ed64f0 100644 --- a/spu/tests/jnp_semi2k_r128_test.py +++ b/spu/tests/jnp_semi2k_r128_test.py @@ -17,7 +17,7 @@ import numpy as np -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests @@ -25,21 +25,21 @@ class JnpTestSemi2kFM128(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.SEMI2K, spu_pb2.FieldType.FM128 + 3, libspu.ProtocolKind.SEMI2K, libspu.FieldType.FM128 ) self._rng = np.random.RandomState() class JnpTestSemi2kFM128TwoParty(JnpTests.JnpTestBase): def setUp(self): - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.SEMI2K, field=spu_pb2.FieldType.FM128 + config = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.SEMI2K, field=libspu.FieldType.FM128 ) config.experimental_enable_exp_prime = True config.experimental_exp_prime_enable_upper_bound = True config.experimental_exp_prime_offset = 13 config.experimental_exp_prime_disable_lower_bound = False - config.fxp_exp_mode = spu_pb2.RuntimeConfig.ExpMode.EXP_PRIME + config.fxp_exp_mode = libspu.RuntimeConfig.ExpMode.EXP_PRIME self._sim = ppsim.Simulator(2, config) self._rng = np.random.RandomState() diff --git a/spu/tests/jnp_semi2k_r64_test.py b/spu/tests/jnp_semi2k_r64_test.py index dac8c371b..aed6361bc 100644 --- a/spu/tests/jnp_semi2k_r64_test.py +++ b/spu/tests/jnp_semi2k_r64_test.py @@ -17,7 +17,7 @@ import numpy as np -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu import spu.utils.simulation as ppsim from spu.tests.jnp_testbase import JnpTests @@ -25,7 +25,7 @@ class JnpTestSemi2kFM64(JnpTests.JnpTestBase): def setUp(self): self._sim = ppsim.Simulator.simple( - 3, spu_pb2.ProtocolKind.SEMI2K, spu_pb2.FieldType.FM64 + 3, libspu.ProtocolKind.SEMI2K, libspu.FieldType.FM64 ) self._rng = np.random.RandomState() diff --git a/spu/tests/spu_compiler_test.py b/spu/tests/spu_compiler_test.py index 4e65befa6..688aac5d1 100644 --- a/spu/tests/spu_compiler_test.py +++ b/spu/tests/spu_compiler_test.py @@ -19,7 +19,6 @@ import numpy as np import numpy.testing as npt -import spu.spu_pb2 as spu_pb2 import spu.utils.frontend as spu_fe diff --git a/spu/tests/spu_io_test.py b/spu/tests/spu_io_test.py index d88eefd41..9e02be715 100644 --- a/spu/tests/spu_io_test.py +++ b/spu/tests/spu_io_test.py @@ -19,11 +19,11 @@ from absl.testing import absltest, parameterized import spu.api as ppapi -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu def _bytes_to_pb(msg: bytes): - ret = spu_pb2.ValueMetaProto() + ret = libspu.ValueMetaProto() ret.ParseFromString(msg) return ret @@ -31,36 +31,36 @@ def _bytes_to_pb(msg: bytes): @parameterized.product( wsize=(2, 3, 5), prot=( - spu_pb2.ProtocolKind.REF2K, - spu_pb2.ProtocolKind.SEMI2K, - spu_pb2.ProtocolKind.ABY3, + libspu.ProtocolKind.REF2K, + libspu.ProtocolKind.SEMI2K, + libspu.ProtocolKind.ABY3, ), - field=(spu_pb2.FieldType.FM64, spu_pb2.FieldType.FM128), + field=(libspu.FieldType.FM64, libspu.FieldType.FM128), chunk_size=(4, 11, 33, 67, 127, 65535), ) class UnitTests(parameterized.TestCase): def test_io(self, wsize, prot, field, chunk_size): - if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + if prot == libspu.ProtocolKind.ABY3 and wsize != 3: return - config = spu_pb2.RuntimeConfig( + config = libspu.RuntimeConfig( protocol=prot, field=field, fxp_fraction_bits=18, - share_max_chunk_size=chunk_size, ) + config.share_max_chunk_size = chunk_size io = ppapi.Io(wsize, config) # SINT x = np.random.randint(10, size=(3, 4, 5)) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0].meta).shape, - spu_pb2.ShapeProto(dims=(3, 4, 5)), + _bytes_to_pb(xs[0].meta).shape.dims, + [3, 4, 5], ) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) @@ -69,13 +69,13 @@ def test_io(self, wsize, prot, field, chunk_size): # SFXP x = np.random.rand(3, 4, 5) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0].meta).shape, - spu_pb2.ShapeProto(dims=(3, 4, 5)), + _bytes_to_pb(xs[0].meta).shape.dims, + [3, 4, 5], ) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) @@ -84,13 +84,13 @@ def test_io(self, wsize, prot, field, chunk_size): # PFXP x = np.random.rand(3, 4, 5) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) + xs = io.make_shares(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0].meta).shape, - spu_pb2.ShapeProto(dims=(3, 4, 5)), + _bytes_to_pb(xs[0].meta).shape.dims, + [3, 4, 5], ) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) @@ -99,13 +99,13 @@ def test_io(self, wsize, prot, field, chunk_size): # empty x = np.random.rand(1, 0) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0].meta).shape, - spu_pb2.ShapeProto(dims=(1, 0)), + _bytes_to_pb(xs[0].meta).shape.dims, + [1, 0], ) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) @@ -113,28 +113,24 @@ def test_io(self, wsize, prot, field, chunk_size): npt.assert_almost_equal(x, y, decimal=5) def test_io_strides(self, wsize, prot, field, chunk_size): - if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + if prot == libspu.ProtocolKind.ABY3 and wsize != 3: return - config = spu_pb2.RuntimeConfig( - protocol=prot, - field=field, - fxp_fraction_bits=18, - share_max_chunk_size=chunk_size, - ) + config = libspu.RuntimeConfig(protocol=prot, field=field, fxp_fraction_bits=18) + config.share_max_chunk_size = chunk_size io = ppapi.Io(wsize, config) # SINT x = np.random.randint(10, size=(6, 7, 8)) x = x[0:5:2, 0:7:2, 0:8:2] - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0].meta).shape, - spu_pb2.ShapeProto(dims=(3, 4, 4)), + _bytes_to_pb(xs[0].meta).shape.dims, + [3, 4, 4], ) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), chunk_count) y = io.reconstruct(xs) @@ -144,14 +140,14 @@ def test_io_strides(self, wsize, prot, field, chunk_size): x = np.random.rand(6, 7, 8) x = x[0:5:2, 0:7:2, 0:8:2] - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0].meta).shape, - spu_pb2.ShapeProto(dims=(3, 4, 4)), + _bytes_to_pb(xs[0].meta).shape.dims, + [3, 4, 4], ) y = io.reconstruct(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_almost_equal(x, y, decimal=5) @@ -160,38 +156,34 @@ def test_io_strides(self, wsize, prot, field, chunk_size): x = np.random.rand(6, 7, 8) x = x[0:5:2, 0:7:2, 0:8:2] - xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) + xs = io.make_shares(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs), wsize) self.assertEqual( - _bytes_to_pb(xs[0].meta).shape, - spu_pb2.ShapeProto(dims=(3, 4, 4)), + _bytes_to_pb(xs[0].meta).shape.dims, + [3, 4, 4], ) y = io.reconstruct(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_almost_equal(x, y, decimal=5) def test_io_scalar(self, wsize, prot, field, chunk_size): - if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + if prot == libspu.ProtocolKind.ABY3 and wsize != 3: return - config = spu_pb2.RuntimeConfig( - protocol=prot, - field=field, - fxp_fraction_bits=18, - share_max_chunk_size=chunk_size, - ) + config = libspu.RuntimeConfig(protocol=prot, field=field, fxp_fraction_bits=18) + config.share_max_chunk_size = chunk_size io = ppapi.Io(wsize, config) # SINT x = np.random.randint(10, size=()) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) - self.assertEqual(_bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=())) + self.assertEqual(_bytes_to_pb(xs[0].meta).shape.dims, []) y = io.reconstruct(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_equal(x, y) @@ -199,11 +191,11 @@ def test_io_scalar(self, wsize, prot, field, chunk_size): # SFXP x = np.random.random(size=()) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) - self.assertEqual(_bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=())) + self.assertEqual(_bytes_to_pb(xs[0].meta).shape.dims, []) y = io.reconstruct(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_almost_equal(x, y, decimal=5) @@ -211,100 +203,94 @@ def test_io_scalar(self, wsize, prot, field, chunk_size): # PFXP x = np.random.random(size=()) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) + xs = io.make_shares(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs), wsize) - self.assertEqual(_bytes_to_pb(xs[0].meta).shape, spu_pb2.ShapeProto(dims=())) + self.assertEqual(_bytes_to_pb(xs[0].meta).shape.dims, []) y = io.reconstruct(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs[0].share_chunks), chunk_count) npt.assert_almost_equal(x, y, decimal=5) def test_io_single_complex(self, wsize, prot, field, chunk_size): - if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + if prot == libspu.ProtocolKind.ABY3 and wsize != 3: return - config = spu_pb2.RuntimeConfig( - protocol=prot, - field=field, - share_max_chunk_size=chunk_size, - ) + config = libspu.RuntimeConfig(protocol=prot, field=field) + config.share_max_chunk_size = chunk_size io = ppapi.Io(wsize, config) # SFXP x = np.array([1 + 2j, 3 + 4j, 5 + 6j]).astype('complex64') - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) y = io.reconstruct(xs) print(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), 2 * chunk_count) npt.assert_almost_equal(x, y, decimal=5) # PFXP - xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) + xs = io.make_shares(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs), wsize) y = io.reconstruct(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs[0].share_chunks), 2 * chunk_count) npt.assert_almost_equal(x, y, decimal=5) def test_io_double_complex(self, wsize, prot, field, chunk_size): - if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + if prot == libspu.ProtocolKind.ABY3 and wsize != 3: return - config = spu_pb2.RuntimeConfig( - protocol=prot, - field=field, - share_max_chunk_size=chunk_size, - ) + config = libspu.RuntimeConfig(protocol=prot, field=field) + config.share_max_chunk_size = chunk_size io = ppapi.Io(wsize, config) # SFXP x = np.array([1 + 2j, 3 + 4j, 5 + 6j]) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs), wsize) y = io.reconstruct(xs) print(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_SECRET) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_SECRET) self.assertEqual(len(xs[0].share_chunks), 2 * chunk_count) npt.assert_almost_equal(x, y, decimal=5) # PFXP - xs = io.make_shares(x, spu_pb2.Visibility.VIS_PUBLIC) + xs = io.make_shares(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs), wsize) y = io.reconstruct(xs) - chunk_count = io.get_share_chunk_count(x, spu_pb2.Visibility.VIS_PUBLIC) + chunk_count = io.get_share_chunk_count(x, libspu.Visibility.VIS_PUBLIC) self.assertEqual(len(xs[0].share_chunks), 2 * chunk_count) npt.assert_almost_equal(x, y, decimal=5) def test_colocated_io(self, wsize, prot, field, chunk_size): - if prot == spu_pb2.ProtocolKind.ABY3 and wsize != 3: + if prot == libspu.ProtocolKind.ABY3 and wsize != 3: return - if prot == spu_pb2.ProtocolKind.REF2K: + if prot == libspu.ProtocolKind.REF2K: return - config = spu_pb2.RuntimeConfig( + config = libspu.RuntimeConfig( protocol=prot, field=field, - share_max_chunk_size=chunk_size, - experimental_enable_colocated_optimization=True, ) + config.share_max_chunk_size = chunk_size + config.experimental_enable_colocated_optimization = True io = ppapi.Io(wsize, config) # PrivINT x = np.random.randint(10, size=()) - xs = io.make_shares(x, spu_pb2.Visibility.VIS_SECRET, owner_rank=1) + xs = io.make_shares(x, libspu.Visibility.VIS_SECRET, owner_rank=1) self.assertIn('Priv2k', _bytes_to_pb(xs[0].meta).storage_type) y = io.reconstruct(xs) diff --git a/spu/tests/spu_runtime_test.py b/spu/tests/spu_runtime_test.py index 84d1956b0..4cccca437 100644 --- a/spu/tests/spu_runtime_test.py +++ b/spu/tests/spu_runtime_test.py @@ -18,16 +18,16 @@ import numpy as np import numpy.testing as npt -import spu.spu_pb2 as spu_pb2 +import spu.libspu as libspu from spu.utils.simulation import Simulator class UnitTests(unittest.TestCase): def test_no_io(self): wsize = 3 - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.SEMI2K, - field=spu_pb2.FieldType.FM128, + config = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.SEMI2K, + field=libspu.FieldType.FM128, fxp_fraction_bits=18, ) @@ -42,16 +42,16 @@ def test_no_io(self): pphlo.custom_call @spu.dbg_print (%1) {has_side_effect = true} : (tensor<2x2x!pphlo.secret>)->() return %1 : tensor<2x2x!pphlo.secret> }""" - executable = spu_pb2.ExecutableProto( + executable = libspu.ExecutableProto( name="test", input_names=["in0"], output_names=["out0"], code=code.encode() ) sim(executable, x) def test_raise(self): wsize = 3 - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.SEMI2K, - field=spu_pb2.FieldType.FM128, + config = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.SEMI2K, + field=libspu.FieldType.FM128, fxp_fraction_bits=18, ) @@ -66,7 +66,7 @@ def test_raise(self): %0 = pphlo.dot %arg0, %arg1 : (tensor<2x3x!pphlo.secret>, tensor<12x13x!pphlo.secret>) -> tensor<2x2x!pphlo.secret> return %0 : tensor<2x2x!pphlo.secret> }""" - executable = spu_pb2.ExecutableProto( + executable = libspu.ExecutableProto( name="test", input_names=["in0", "in1"], output_names=["out0"], @@ -78,9 +78,9 @@ def test_raise(self): def test_wrong_version(self): wsize = 1 - config = spu_pb2.RuntimeConfig( - protocol=spu_pb2.ProtocolKind.REF2K, - field=spu_pb2.FieldType.FM64, + config = libspu.RuntimeConfig( + protocol=libspu.ProtocolKind.REF2K, + field=libspu.FieldType.FM64, ) sim = Simulator(wsize, config) @@ -93,7 +93,7 @@ def test_wrong_version(self): return %0 : tensor<1xf32> } }""" - executable = spu_pb2.ExecutableProto( + executable = libspu.ExecutableProto( name="test", code=code.encode(), ) diff --git a/spu/utils/BUILD.bazel b/spu/utils/BUILD.bazel index 2dc766850..5ddadd44c 100644 --- a/spu/utils/BUILD.bazel +++ b/spu/utils/BUILD.bazel @@ -50,7 +50,7 @@ spu_py_library( srcs = [ "distributed_impl.py", ":distributed_py_proto_grpc", - "@spulib//libspu:spu_py_proto", + # "@spulib//libspu:spu_py_proto", ], deps = [ ":frontend", diff --git a/spu/utils/distributed_impl.py b/spu/utils/distributed_impl.py index ebd6a7038..ba5e2511b 100644 --- a/spu/utils/distributed_impl.py +++ b/spu/utils/distributed_impl.py @@ -44,10 +44,10 @@ import numpy as np from termcolor import colored - from .. import api as spu_api from .. import libspu # type: ignore -from .. import spu_pb2 + +# from .. import spu_pb2 from . import distributed_pb2 # type: ignore from . import distributed_pb2_grpc # type: ignore @@ -373,17 +373,17 @@ def shape_spu_to_np(spu_shape): def dtype_spu_to_np(spu_dtype): MAP = { - spu_pb2.DataType.DT_F32: np.float32, - spu_pb2.DataType.DT_F64: np.float64, - spu_pb2.DataType.DT_I1: np.bool_, - spu_pb2.DataType.DT_I8: np.int8, - spu_pb2.DataType.DT_U8: np.uint8, - spu_pb2.DataType.DT_I16: np.int16, - spu_pb2.DataType.DT_U16: np.uint16, - spu_pb2.DataType.DT_I32: np.int32, - spu_pb2.DataType.DT_U32: np.uint32, - spu_pb2.DataType.DT_I64: np.int64, - spu_pb2.DataType.DT_U64: np.uint64, + libspu.DataType.DT_F32: np.float32, + libspu.DataType.DT_F64: np.float64, + libspu.DataType.DT_I1: np.bool_, + libspu.DataType.DT_I8: np.int8, + libspu.DataType.DT_U8: np.uint8, + libspu.DataType.DT_I16: np.int16, + libspu.DataType.DT_U16: np.uint16, + libspu.DataType.DT_I32: np.int32, + libspu.DataType.DT_U32: np.uint32, + libspu.DataType.DT_I64: np.int64, + libspu.DataType.DT_U64: np.uint64, } return MAP.get(spu_dtype) @@ -550,7 +550,7 @@ def builtin_spu_init( for rank, addr in enumerate(addrs): desc.add_party(f"r{rank}", addr) link = libspu.link.create_brpc(desc, my_rank) - spu_config = spu_pb2.RuntimeConfig() + spu_config = libspu.RuntimeConfig() spu_config.ParseFromString(spu_config_str) if my_rank != 0: spu_config.enable_action_trace = False @@ -574,7 +574,7 @@ def builtin_spu_run( rt = server._locals[f"{device_name}-rt"] io = server._locals[f"{device_name}-io"] - spu_exec = spu_pb2.ExecutableProto() + spu_exec = libspu.ExecutableProto() spu_exec.ParseFromString(exec_str) # do infeed. @@ -582,7 +582,7 @@ def builtin_spu_run( if isinstance(arg, ValueWrapper): rt.set_var(spu_exec.input_names[idx], arg.spu_share) else: - fst, *_ = io.make_shares(arg, spu_pb2.Visibility.VIS_PUBLIC) + fst, *_ = io.make_shares(arg, libspu.Visibility.VIS_PUBLIC) rt.set_var(spu_exec.input_names[idx], fst) # run @@ -611,7 +611,7 @@ def builtin_spu_run( return rets -from spu import spu_pb2 +# from spu import spu_pb2 class SPUObjectMetadata: @@ -619,7 +619,7 @@ def __init__( self, shape: Sequence[int], dtype: np.dtype, - vtype: spu_pb2.Visibility, + vtype: libspu.Visibility, ) -> None: self.dtype = dtype self.shape = shape @@ -636,7 +636,7 @@ def __init__( refs: Sequence[ObjectRef], shape: Sequence[int], dtype: np.dtype, - vtype: spu_pb2.Visibility, + vtype: libspu.Visibility, ): super().__init__(device) assert all(isObjectRef(ref) for ref in refs) @@ -728,7 +728,7 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]): except ImportError: import jax.linear_util as lu # fallback from jax._src import api_util as japi_util - from jax.tree_util import tree_map, tree_flatten + from jax.tree_util import tree_flatten, tree_map mock_args, mock_kwargs = tree_map(mock_parameters, (args, kwargs)) @@ -742,7 +742,7 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]): ( arg.vtype if isinstance(arg, SPU.Object) - else spu_pb2.Visibility.VIS_PUBLIC + else libspu.Visibility.VIS_PUBLIC ) for arg in args_flat ] @@ -775,7 +775,7 @@ class TensorFlowFunction(Device.Function): device: SPU def __init__( - self, device: Device, pyfunc: Callable, copts: spu_pb2.CompilerOptions + self, device: Device, pyfunc: Callable, copts: libspu.CompilerOptions ): super().__init__(device, pyfunc) self.copts = copts @@ -845,7 +845,7 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]): ( arg.vtype if isinstance(arg, SPU.Object) - else spu_pb2.Visibility.VIS_PUBLIC + else libspu.Visibility.VIS_PUBLIC ) for arg in args_flat ] @@ -873,7 +873,7 @@ class TorchFunction(Device.Function): device: SPU def __init__( - self, device: Device, pyfunc: Callable, copts: spu_pb2.CompilerOptions + self, device: Device, pyfunc: Callable, copts: libspu.CompilerOptions ): super().__init__(device, pyfunc) self.state_dict = None @@ -957,7 +957,7 @@ def mock_parameters(obj: Union[SPU.Object, np.ndarray]): fn, torch.nn.Module ), "currently only torch.nn.Module is supported" - from jax.tree_util import tree_map, tree_flatten + from jax.tree_util import tree_flatten, tree_map mock_args, mock_kwargs = tree_map(mock_parameters, (args, kwargs)) @@ -992,11 +992,12 @@ def __init__( self.internal_addrs = internal_addrs assert len(internal_addrs) == len(node_clients) - from google.protobuf import json_format - - self.runtime_config = json_format.Parse( - json.dumps(runtime_config), spu_pb2.RuntimeConfig() - ) + # from google.protobuf import json_format + # self.runtime_config = json_format.Parse( + # json.dumps(runtime_config), libspu.RuntimeConfig() + # ) + self.runtime_config = libspu.RuntimeConfig() + self.runtime_config.ParseFromJsonString(json.dumps(runtime_config)) with concurrent.futures.ThreadPoolExecutor() as executor: futures = [ @@ -1025,7 +1026,7 @@ def compile( self, fn: Callable, static_argnums=(), - copts=spu_pb2.CompilerOptions(), + copts=libspu.CompilerOptions(), ) -> Callable: if _FRAMEWORK == Framework.EXP_TF: return SPU.TensorFlowFunction(self, fn, copts) @@ -1282,7 +1283,7 @@ def PYU2PYU(to: PYU, obj: PYU.Object): def SPU2PYU(to: PYU, obj: SPU.Object): # tell PYU the object refs, and run reconstruct on it. def reconstruct(wsize: int, spu_config_str: str, shares: List[ValueWrapper]): - spu_config = spu_pb2.RuntimeConfig() + spu_config = libspu.RuntimeConfig() spu_config.ParseFromString(spu_config_str) spu_io = spu_api.Io(wsize, spu_config) @@ -1295,12 +1296,12 @@ def reconstruct(wsize: int, spu_config_str: str, shares: List[ValueWrapper]): ) -def PYU2SPU(to: SPU, obj: PYU.Object, vtype=spu_pb2.Visibility.VIS_SECRET): +def PYU2SPU(to: SPU, obj: PYU.Object, vtype=libspu.Visibility.VIS_SECRET): # make shares on PYU, and tell SPU the object refs. def make_shares( wsize: int, spu_config_str: str, x: np.ndarray, owner_rank: int = -1 ): - spu_config = spu_pb2.RuntimeConfig() + spu_config = libspu.RuntimeConfig() spu_config.ParseFromString(spu_config_str) spu_io = spu_api.Io(wsize, spu_config) diff --git a/spu/utils/frontend.py b/spu/utils/frontend.py index fc20cd0c8..e67a3e471 100644 --- a/spu/utils/frontend.py +++ b/spu/utils/frontend.py @@ -22,7 +22,7 @@ from numpy import ndarray from .. import api as spu_api -from .. import spu_pb2 +from .. import libspu _jax_lock = Lock() @@ -225,7 +225,7 @@ def compile( outputNameGen: Callable, static_argnums=(), static_argnames=None, - copts=spu_pb2.CompilerOptions(), + copts=libspu.CompilerOptions(), ): if kind == Kind.JAX: import jax @@ -262,13 +262,10 @@ def compile( else: raise NameError(f"Unknown frontend type {kind}") - source = spu_pb2.CompilationSource() - source.ir_txt = ir_text - source.ir_type = spu_pb2.SourceIRType.XLA - source.input_visibility.extend(input_vis) + source = libspu.CompilationSource(libspu.SourceIRType.XLA, ir_text, input_vis) name = fn.func.__name__ if isinstance(fn, functools.partial) else fn.__name__ mlir = spu_api.compile(source, copts) - executable = spu_pb2.ExecutableProto( + executable = libspu.ExecutableProto( name=name, input_names=input_names, output_names=output_names, @@ -282,7 +279,7 @@ def torch_compile( args_flat: List, m_args_flat: List, state_dict: collections.OrderedDict, - copts=spu_pb2.CompilerOptions(), + copts=libspu.CompilerOptions(), ): import os @@ -312,9 +309,7 @@ def torch_compile( ] output_tree = method.meta.output_pytree_spec - source = spu_pb2.CompilationSource() - source.ir_txt = ir_text - source.ir_type = spu_pb2.SourceIRType.STABLEHLO + source = libspu.CompilationSource(libspu.SourceIRType.STABLEHLO, ir_text, []) args_params_flat = [] for loc in method.meta.input_locations: @@ -335,13 +330,13 @@ def torch_compile( ( arg.vtype if isinstance(arg, distributed.SPU.Object) - else spu_pb2.Visibility.VIS_PUBLIC + else libspu.Visibility.VIS_PUBLIC ) for arg in args_params_flat ] ) mlir = spu_api.compile(source, copts) - executable = spu_pb2.ExecutableProto( + executable = libspu.ExecutableProto( name=name, input_names=input_names, output_names=output_names, diff --git a/spu/utils/simulation.py b/spu/utils/simulation.py index 6a6c220ab..f16881859 100644 --- a/spu/utils/simulation.py +++ b/spu/utils/simulation.py @@ -29,7 +29,6 @@ from .. import api as spu_api from .. import libspu # type: ignore -from .. import spu_pb2 from . import frontend as spu_fe @@ -50,30 +49,30 @@ def join(self): class Simulator(object): - def __init__(self, wsize: int, rt_config: spu_pb2.RuntimeConfig): + def __init__(self, wsize: int, rt_config: libspu.RuntimeConfig): self.wsize = wsize self.rt_config = rt_config self.io = spu_api.Io(wsize, rt_config) @classmethod - def simple(cls, wsize: int, prot: spu_pb2.ProtocolKind, field: spu_pb2.FieldType): + def simple(cls, wsize: int, prot: libspu.ProtocolKind, field: libspu.FieldType): """helper method to create an SPU Simulator Args: wsize (int): the world size. - prot (spu_pb2.ProtocolKind): protocol. + prot (libspu.ProtocolKind): protocol. - field (spu_pb2.FieldType): field type. + field (libspu.FieldType): field type. Returns: A SPU Simulator """ - config = spu_pb2.RuntimeConfig(protocol=prot, field=field) + config = libspu.RuntimeConfig(protocol=prot, field=field) - if prot == spu_pb2.ProtocolKind.CHEETAH: + if prot == libspu.ProtocolKind.CHEETAH: # config.cheetah_2pc_config.enable_mul_lsb_error = True - # config.cheetah_2pc_config.ot_kind = spu_pb2.CheetahOtKind.YACL_Softspoken + # config.cheetah_2pc_config.ot_kind = libspu.CheetahOtKind.YACL_Softspoken pass # config.enable_hal_profile = True # config.enable_pphlo_profile = True @@ -82,10 +81,10 @@ def simple(cls, wsize: int, prot: spu_pb2.ProtocolKind, field: spu_pb2.FieldType # config.enable_type_checker = True return cls(wsize, config) - def __call__(self, executable, *flat_args): + def __call__(self, executable: libspu.ExecutableProto, *flat_args): flat_args = [np.array(jnp.array(x)) for x in flat_args] params = [ - self.io.make_shares(x, spu_pb2.Visibility.VIS_SECRET) for x in flat_args + self.io.make_shares(x, libspu.Visibility.VIS_SECRET) for x in flat_args ] lctx_desc = libspu.link.Desc() @@ -94,8 +93,7 @@ def __call__(self, executable, *flat_args): def wrapper(rank): lctx = libspu.link.create_mem(lctx_desc, rank) - rank_config = spu_pb2.RuntimeConfig() - rank_config.CopyFrom(self.rt_config) + rank_config = libspu.RuntimeConfig(self.rt_config) if rank != 0: # rank_config.enable_pphlo_trace = False rank_config.enable_action_trace = False @@ -129,12 +127,12 @@ def sim_jax( sim: Simulator, fun: Callable, static_argnums=(), - copts=spu_pb2.CompilerOptions(), + copts=libspu.CompilerOptions(), ): """ Decorates a jax numpy fn that simulated on SPU. - >>> sim = Simulator.simple(3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM64) + >>> sim = Simulator.simple(3, libspu.ProtocolKind.ABY3, libspu.FieldType.FM64) >>> spu_fn = sim_jax(sim, jnp.add) Then we can call spu_fn like normal jnp fn. @@ -163,7 +161,7 @@ def outputNameGen(out_flat): args, kwargs, in_names, - [spu_pb2.Visibility.VIS_SECRET] * len(args_flat), + [libspu.Visibility.VIS_SECRET] * len(args_flat), outputNameGen, static_argnums=static_argnums, copts=copts, diff --git a/src/MODULE.bazel b/src/MODULE.bazel index 9bb3576f3..aff04b8e8 100644 --- a/src/MODULE.bazel +++ b/src/MODULE.bazel @@ -21,7 +21,7 @@ module( name = "spulib", - version = "0.9.4.dev20250123", + version = "0.9.4.dev20250209", compatibility_level = 1, ) @@ -58,6 +58,7 @@ bazel_dep(name = "spdlog", version = "1.14.1") bazel_dep(name = "fmt", version = "11.0.2") bazel_dep(name = "abseil-cpp", version = "20240722.0") bazel_dep(name = "rules_python", version = "0.29.0") +bazel_dep(name = "magic_enum", version = "0.9.6") python = use_extension("@rules_python//python/extensions:python.bzl", "python") python.toolchain( diff --git a/src/MODULE.bazel.lock b/src/MODULE.bazel.lock index e8f6c363a..173a5fbab 100644 --- a/src/MODULE.bazel.lock +++ b/src/MODULE.bazel.lock @@ -117,6 +117,8 @@ "https://bcr.bazel.build/modules/jsoncpp/1.9.5/source.json": "4108ee5085dd2885a341c7fab149429db457b3169b86eb081fa245eadf69169d", "https://bcr.bazel.build/modules/libpfm/4.11.0/MODULE.bazel": "45061ff025b301940f1e30d2c16bea596c25b176c8b6b3087e92615adbd52902", "https://bcr.bazel.build/modules/libpfm/4.11.0/source.json": "caaffb3ac2b59b8aac456917a4ecf3167d40478ee79f15ab7a877ec9273937c9", + "https://bcr.bazel.build/modules/magic_enum/0.9.6/MODULE.bazel": "2b8db5bbd5d456dfb1e05cafd4a572374d461ffd2e0bd6970b9060dca2200618", + "https://bcr.bazel.build/modules/magic_enum/0.9.6/source.json": "abac9e9c84a47db89960a4c5a585d607bdfe51a60d7e3285dfe1a3dca50d1107", "https://bcr.bazel.build/modules/msgpack-c/6.1.0/MODULE.bazel": "2822ba864146468b3128216ad416f8b39b511395e88d896d472c9c6b30b1ceb2", "https://bcr.bazel.build/modules/msgpack-c/6.1.0/source.json": "b412dd4c8290ea0cce122616076e62ffe1b0799cebd6422608c407608193c1c9", "https://bcr.bazel.build/modules/nlohmann_json/3.11.3/MODULE.bazel": "87023db2f55fc3a9949c7b08dc711fae4d4be339a80a99d04453c4bb3998eefc", @@ -332,6 +334,7 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libtommath/0.0.0-20240407-42b3fb0/source.json": "ffeac1c889d6d8eb4c6ed0d7243bfd298a94f0fb3a3670a31e88045b8c95fc02", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/llvm-raw/20240809.0-35f55f5/MODULE.bazel": "569ffa91a1497f0cbfc2a39b03b3155d2f7e7c360e2fdf5b31418868a1a9f1b8", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/llvm-raw/20240809.0-35f55f5/source.json": "eccb7418d37d4ffc5ccf7c5b737f051d11299b7510fc85895f4cd5f8b948a5a6", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/magic_enum/0.9.6/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/mcl/1.99/MODULE.bazel": "e2bf3654186853610a74833e398fc3b6de6d9ccbe8fa67eaa3ae58d3344940ef", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/mcl/1.99/source.json": "d38d4c7dbd9fb31bcabcc55c0336d82044828b27548c928389cdb6fba05029bd", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/msgpack-c/6.1.0/MODULE.bazel": "not found", @@ -446,7 +449,7 @@ "//bazel:defs.bzl%non_module_dependencies": { "general": { "bzlTransitiveDigest": "JT8ZLEUdrYXN19gijrHtztFq/cEAhJlRlNjhtQUlDIE=", - "usagesDigest": "6n60USwsGB5Z94VzGaJhZHMgjavixB/T5boxrw8IhDU=", + "usagesDigest": "De6xcaUoo/nee6n8unzzT9fZQjzw6vxabjE95MVezUY=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, diff --git a/src/libspu/BUILD.bazel b/src/libspu/BUILD.bazel index 570e8d3c7..dfa0ba495 100644 --- a/src/libspu/BUILD.bazel +++ b/src/libspu/BUILD.bazel @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@protobuf//bazel:py_proto_library.bzl", "py_proto_library") load("@rules_cc//cc:defs.bzl", "cc_proto_library") load("@rules_proto//proto:defs.bzl", "proto_library") load("//bazel:spu.bzl", "spu_cc_library") @@ -29,9 +28,14 @@ cc_proto_library( deps = [":spu_proto"], ) -py_proto_library( - name = "spu_py_proto", - deps = ["//libspu:spu_proto"], +spu_cc_library( + name = "spu", + srcs = ["spu.cc"], + hdrs = ["spu.h"], + deps = [ + ":spu_cc_proto", + "@protobuf", + ], ) spu_cc_library( diff --git a/src/libspu/compiler/common/BUILD.bazel b/src/libspu/compiler/common/BUILD.bazel index 8d64134c3..e2669623d 100644 --- a/src/libspu/compiler/common/BUILD.bazel +++ b/src/libspu/compiler/common/BUILD.bazel @@ -33,7 +33,7 @@ spu_cc_library( hdrs = ["compilation_context.h"], deps = [ ":ir_printer_config", - "//libspu:spu_cc_proto", + "//libspu:spu", "//libspu/core:prelude", ], ) diff --git a/src/libspu/compiler/common/compilation_context.cc b/src/libspu/compiler/common/compilation_context.cc index 553fe9137..8593ddc88 100644 --- a/src/libspu/compiler/common/compilation_context.cc +++ b/src/libspu/compiler/common/compilation_context.cc @@ -32,9 +32,9 @@ namespace spu::compiler { CompilationContext::CompilationContext(CompilerOptions options) : options_(std::move(options)) { - if (options_.enable_pretty_print()) { + if (options_.enable_pretty_print) { pp_config_ = std::make_unique( - options_.pretty_print_dump_dir()); + options_.pretty_print_dump_dir); } // Set an error handler diff --git a/src/libspu/compiler/common/compilation_context.h b/src/libspu/compiler/common/compilation_context.h index 8abfed96b..d52532245 100644 --- a/src/libspu/compiler/common/compilation_context.h +++ b/src/libspu/compiler/common/compilation_context.h @@ -20,7 +20,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace mlir { class PassManager; @@ -41,10 +41,10 @@ class CompilationContext { const CompilerOptions &getCompilerOptions() const { return options_; } - bool hasPrettyPrintEnabled() const { return options_.enable_pretty_print(); } + bool hasPrettyPrintEnabled() const { return options_.enable_pretty_print; } XLAPrettyPrintKind getXlaPrettyPrintKind() const { - return options_.xla_pp_kind(); + return options_.xla_pp_kind; } std::filesystem::path getPrettyPrintDir() const; diff --git a/src/libspu/compiler/compile.h b/src/libspu/compiler/compile.h index d50f3c185..255210686 100644 --- a/src/libspu/compiler/compile.h +++ b/src/libspu/compiler/compile.h @@ -14,7 +14,7 @@ #pragma once -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::compiler { diff --git a/src/libspu/compiler/core/core.cc b/src/libspu/compiler/core/core.cc index b7c46d8f7..3fa91892d 100644 --- a/src/libspu/compiler/core/core.cc +++ b/src/libspu/compiler/core/core.cc @@ -45,7 +45,7 @@ void Core::buildPipeline(mlir::PassManager *pm) { // lowering auto &optPM = pm->nest(); - if (!options.disable_maxpooling_optimization()) { + if (!options.disable_maxpooling_optimization) { // Need a cse before maxpooling optPM.addPass(mlir::createCSEPass()); optPM.addPass(mlir::spu::pphlo::createOptimizeMaxPoolingPass()); @@ -53,7 +53,7 @@ void Core::buildPipeline(mlir::PassManager *pm) { optPM.addPass(mlir::spu::pphlo::createDecomposeOps()); optPM.addPass(mlir::spu::pphlo::createSortLowering()); - if (!options.disable_partial_sort_optimization()) { + if (!options.disable_partial_sort_optimization) { optPM.addPass(mlir::spu::pphlo::createPartialSortToTopK()); } @@ -61,17 +61,17 @@ void Core::buildPipeline(mlir::PassManager *pm) { optPM.addPass(mlir::spu::pphlo::createInlineSecretControlFlow()); - if (!options.disable_sqrt_plus_epsilon_rewrite()) { + if (!options.disable_sqrt_plus_epsilon_rewrite) { optPM.addPass(mlir::spu::pphlo::createOptimizeSqrtPlusEps()); } optPM.addPass(mlir::spu::pphlo::createExpandSecretGatherPass()); - if (!options.disable_div_sqrt_rewrite()) { + if (!options.disable_div_sqrt_rewrite) { optPM.addPass(mlir::spu::pphlo::createRewriteDivSqrtPatterns()); } - if (options.enable_optimize_denominator_with_broadcast()) { + if (options.enable_optimize_denominator_with_broadcast) { optPM.addPass(mlir::spu::pphlo::createOptimizeDenominatorWithBroadcast()); } @@ -79,17 +79,17 @@ void Core::buildPipeline(mlir::PassManager *pm) { optPM.addPass(mlir::spu::pphlo::createConvertPushDownPass()); - if (!options.disable_reduce_truncation_optimization()) { + if (!options.disable_reduce_truncation_optimization) { optPM.addPass(mlir::spu::pphlo::createReduceTruncationPass()); } - if (!options.disallow_mix_types_opts()) { + if (!options.disallow_mix_types_opts) { optPM.addPass(mlir::spu::pphlo::createLowerMixedTypeOpPass()); } optPM.addPass(mlir::createCanonicalizerPass()); - if (!options.disable_select_optimization()) { + if (!options.disable_select_optimization) { optPM.addPass(mlir::spu::pphlo::createOptimizeSelectPass()); } @@ -97,7 +97,7 @@ void Core::buildPipeline(mlir::PassManager *pm) { optPM.addPass(mlir::spu::pphlo::createRegionAccessFixture()); optPM.addPass(mlir::createCSEPass()); - if (!options.disable_deallocation_insertion()) { + if (!options.disable_deallocation_insertion) { optPM.addPass(mlir::spu::pphlo::createInsertDeallocationOp()); } } diff --git a/src/libspu/compiler/front_end/BUILD.bazel b/src/libspu/compiler/front_end/BUILD.bazel index 1cbca11b6..04ec731ce 100644 --- a/src/libspu/compiler/front_end/BUILD.bazel +++ b/src/libspu/compiler/front_end/BUILD.bazel @@ -76,6 +76,7 @@ spu_cc_library( "//libspu/dialect/pphlo/transforms:all_passes", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:Parser", + "@magic_enum", "@xla//xla/mlir_hlo:mhlo_passes", "@xla//xla/translate/mhlo_to_hlo:translate", ], diff --git a/src/libspu/compiler/front_end/fe.cc b/src/libspu/compiler/front_end/fe.cc index 7f9529e5a..20998f4d0 100644 --- a/src/libspu/compiler/front_end/fe.cc +++ b/src/libspu/compiler/front_end/fe.cc @@ -15,6 +15,7 @@ #include "libspu/compiler/front_end/fe.h" #include "fmt/ranges.h" +#include "magic_enum.hpp" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" @@ -33,7 +34,6 @@ #include "libspu/core/prelude.h" #include "libspu/dialect/pphlo/IR/dialect.h" #include "libspu/dialect/pphlo/transforms/passes.h" - namespace spu::compiler { FE::FE(CompilationContext *ctx) : ctx_(ctx) { @@ -50,8 +50,8 @@ mlir::OwningOpRef FE::doit(const CompilationSource &source) { HloImporter importer(ctx_); mlir::OwningOpRef module; - if (source.ir_type() == spu::SourceIRType::STABLEHLO) { - module = mlir::parseSourceString(source.ir_txt(), + if (source.ir_type == spu::SourceIRType::STABLEHLO) { + module = mlir::parseSourceString(source.ir_txt, ctx_->getMLIRContext()); SPU_ENFORCE(module, "MLIR parser failure"); @@ -72,17 +72,17 @@ mlir::OwningOpRef FE::doit(const CompilationSource &source) { out.flush(); module = importer.parseXlaModuleFromString(xla_text); } - } else if (source.ir_type() == spu::SourceIRType::XLA) { - module = importer.parseXlaModuleFromString(source.ir_txt()); + } else if (source.ir_type == spu::SourceIRType::XLA) { + module = importer.parseXlaModuleFromString(source.ir_txt); } else { - SPU_THROW("Unhandled IR type = {}", source.ir_type()); + SPU_THROW("Unhandled IR type = {}", source.ir_type); } std::string input_vis_str; { std::vector input_vis; - for (const auto &v : source.input_visibility()) { - input_vis.emplace_back(Visibility_Name(v)); + for (const auto &v : source.input_visibility) { + input_vis.emplace_back(magic_enum::enum_name(v)); } input_vis_str = fmt::format("input_vis_list={}", fmt::join(input_vis, ",")); } diff --git a/src/libspu/compiler/front_end/fe.h b/src/libspu/compiler/front_end/fe.h index c520a124a..c80f230a4 100644 --- a/src/libspu/compiler/front_end/fe.h +++ b/src/libspu/compiler/front_end/fe.h @@ -18,7 +18,7 @@ #include "mlir/IR/OwningOpRef.h" -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace mlir { diff --git a/src/libspu/compiler/tools/spu-translate.cc b/src/libspu/compiler/tools/spu-translate.cc index 899033405..7be89ea63 100644 --- a/src/libspu/compiler/tools/spu-translate.cc +++ b/src/libspu/compiler/tools/spu-translate.cc @@ -34,8 +34,7 @@ #include "libspu/dialect/utils/utils.h" #include "libspu/kernel/test_util.h" #include "libspu/mpc/utils/simulate.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" #define EXPOSE_PIPELINE_BUILDER #include "libspu/compiler/core/core.h" @@ -217,33 +216,33 @@ void evalModule(ModuleOp module) { runPasses(module); ::spu::RuntimeConfig conf; - conf.set_field(::spu::FM64); - conf.set_enable_type_checker(true); + conf.field = ::spu::FM64; + conf.enable_type_checker = true; int numParties = 1; switch (ProtocolKind.getValue()) { case 1: { - conf.set_protocol(::spu::REF2K); + conf.protocol = ::spu::REF2K; numParties = 1; break; } case 2: { - conf.set_protocol(::spu::SEMI2K); + conf.protocol = ::spu::SEMI2K; numParties = 2; break; } case 3: { - conf.set_protocol(::spu::ABY3); + conf.protocol = ::spu::ABY3; numParties = 3; break; } case 4: { - conf.set_protocol(::spu::CHEETAH); + conf.protocol = ::spu::CHEETAH; numParties = 2; break; } case 5: { - conf.set_protocol(::spu::SECURENN); + conf.protocol = ::spu::SECURENN; numParties = 3; break; } diff --git a/src/libspu/core/BUILD.bazel b/src/libspu/core/BUILD.bazel index b5c95acfd..b6c1907ad 100644 --- a/src/libspu/core/BUILD.bazel +++ b/src/libspu/core/BUILD.bazel @@ -40,7 +40,7 @@ spu_cc_library( srcs = [], hdrs = ["prelude.h"], deps = [ - "//libspu:spu_cc_proto", + "//libspu:spu", "@yacl//yacl/base:exception", "@yacl//yacl/utils:scope_guard", ], @@ -52,8 +52,9 @@ spu_cc_library( hdrs = ["type_util.h"], deps = [ ":half", - "//libspu:spu_cc_proto", + "//libspu:spu", "//libspu/core:prelude", + "@magic_enum", "@yacl//yacl/base:int128", ], ) @@ -75,7 +76,7 @@ spu_cc_library( srcs = ["config.cc"], hdrs = ["config.h"], deps = [ - "//libspu:spu_cc_proto", + "//libspu:spu", "//libspu/core:prelude", "@yacl//yacl/utils:parallel", ], @@ -116,6 +117,7 @@ spu_cc_library( deps = [ ":type_util", "//libspu/core:prelude", + "@magic_enum", ], ) diff --git a/src/libspu/core/config.cc b/src/libspu/core/config.cc index c40c4c3e1..20ff5d957 100644 --- a/src/libspu/core/config.cc +++ b/src/libspu/core/config.cc @@ -41,78 +41,77 @@ size_t defaultFxpBits(FieldType field) { void populateRuntimeConfig(RuntimeConfig& cfg) { // mandatory fields. - SPU_ENFORCE(cfg.protocol() != ProtocolKind::PROT_INVALID); - SPU_ENFORCE(cfg.field() != FieldType::FT_INVALID); + SPU_ENFORCE(cfg.protocol != ProtocolKind::PROT_INVALID); + SPU_ENFORCE(cfg.field != FieldType::FT_INVALID); - if (cfg.max_concurrency() == 0) { - cfg.set_max_concurrency(yacl::get_num_threads()); + if (cfg.max_concurrency == 0) { + cfg.max_concurrency = yacl::get_num_threads(); } // - if (cfg.fxp_fraction_bits() == 0) { - cfg.set_fxp_fraction_bits(defaultFxpBits(cfg.field())); + if (cfg.fxp_fraction_bits == 0) { + cfg.fxp_fraction_bits = defaultFxpBits(cfg.field); } - if (cfg.fxp_div_goldschmidt_iters() == 0) { - cfg.set_fxp_div_goldschmidt_iters(2); + if (cfg.fxp_div_goldschmidt_iters == 0) { + cfg.fxp_div_goldschmidt_iters = 2; } // sort - if (cfg.quick_sort_threshold() == 0) { - cfg.set_quick_sort_threshold(32); + if (cfg.quick_sort_threshold == 0) { + cfg.quick_sort_threshold = 32; } // fxp exponent config { - if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_DEFAULT) { - cfg.set_fxp_exp_mode(RuntimeConfig::EXP_TAYLOR); + if (cfg.fxp_exp_mode == RuntimeConfig::EXP_DEFAULT) { + cfg.fxp_exp_mode = RuntimeConfig::EXP_TAYLOR; } - if (cfg.fxp_exp_mode() == RuntimeConfig::EXP_PRIME) { + if (cfg.fxp_exp_mode == RuntimeConfig::EXP_PRIME) { // 0 offset is not supported - if (cfg.experimental_exp_prime_offset() == 0) { + if (cfg.experimental_exp_prime_offset == 0) { // For FM128 default offset is 13 - if (cfg.field() == FieldType::FM128) { - cfg.set_experimental_exp_prime_offset(13); + if (cfg.field == FieldType::FM128) { + cfg.experimental_exp_prime_offset = 13; } // TODO: set defaults for other fields, currently only FM128 is // supported } } - if (cfg.fxp_exp_iters() == 0) { - cfg.set_fxp_exp_iters(8); + if (cfg.fxp_exp_iters == 0) { + cfg.fxp_exp_iters = 8; } } // fxp log config { - if (cfg.fxp_log_mode() == RuntimeConfig::LOG_DEFAULT) { - cfg.set_fxp_log_mode(RuntimeConfig::LOG_MINMAX); + if (cfg.fxp_log_mode == RuntimeConfig::LOG_DEFAULT) { + cfg.fxp_log_mode = RuntimeConfig::LOG_MINMAX; } - if (cfg.fxp_log_iters() == 0) { - cfg.set_fxp_log_iters(3); + if (cfg.fxp_log_iters == 0) { + cfg.fxp_log_iters = 3; } - if (cfg.fxp_log_orders() == 0) { - cfg.set_fxp_log_orders(8); + if (cfg.fxp_log_orders == 0) { + cfg.fxp_log_orders = 8; } } - if (cfg.sine_cosine_iters() == 0) { - cfg.set_sine_cosine_iters(10); // Default + if (cfg.sine_cosine_iters == 0) { + cfg.sine_cosine_iters = 10; // Default } // inter op concurrency - if (cfg.experimental_enable_inter_op_par()) { - cfg.set_experimental_inter_op_concurrency( - cfg.experimental_inter_op_concurrency() == 0 - ? 8 - : cfg.experimental_inter_op_concurrency()); + if (cfg.experimental_enable_inter_op_par) { + if (cfg.experimental_inter_op_concurrency == 0) { + cfg.experimental_inter_op_concurrency = 8; + } } - if (cfg.sigmoid_mode() == RuntimeConfig::SIGMOID_DEFAULT) { - cfg.set_sigmoid_mode(RuntimeConfig::SIGMOID_REAL); + if (cfg.sigmoid_mode == RuntimeConfig::SIGMOID_DEFAULT) { + cfg.sigmoid_mode = RuntimeConfig::SIGMOID_REAL; } // MPC related configurations diff --git a/src/libspu/core/config.h b/src/libspu/core/config.h index c1d21a7cb..1eb8b94bc 100644 --- a/src/libspu/core/config.h +++ b/src/libspu/core/config.h @@ -14,7 +14,7 @@ #pragma once -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu { diff --git a/src/libspu/core/context.cc b/src/libspu/core/context.cc index 8243fa6b4..abdde9a5e 100644 --- a/src/libspu/core/context.cc +++ b/src/libspu/core/context.cc @@ -17,6 +17,7 @@ #include "yacl/link/algorithm/allgather.h" #include "yacl/utils/parallel.h" +#include "libspu/core/config.h" #include "libspu/core/trace.h" namespace spu { @@ -40,11 +41,12 @@ SPUContext::SPUContext(const RuntimeConfig& config, prot_(std::make_unique(genRootObjectId(lctx))), lctx_(lctx), max_cluster_level_concurrency_(yacl::get_num_threads()) { + populateRuntimeConfig(config_); // Limit number of threads - if (config.max_concurrency() > 0) { - yacl::set_num_threads(config.max_concurrency()); + if (config.max_concurrency > 0) { + yacl::set_num_threads(config.max_concurrency); max_cluster_level_concurrency_ = std::min( - max_cluster_level_concurrency_, config.max_concurrency()); + max_cluster_level_concurrency_, config.max_concurrency); } if (lctx_) { @@ -70,17 +72,17 @@ std::unique_ptr SPUContext::fork() const { void setupTrace(spu::SPUContext* sctx, const spu::RuntimeConfig& rt_config) { int64_t tr_flag = 0; // TODO: Support tracing for parallel op execution - if (rt_config.enable_action_trace() && - !rt_config.experimental_enable_intra_op_par()) { + if (rt_config.enable_action_trace && + !rt_config.experimental_enable_intra_op_par) { tr_flag |= TR_LOG; } - if (rt_config.enable_pphlo_profile()) { + if (rt_config.enable_pphlo_profile) { tr_flag |= TR_HLO; tr_flag |= TR_REC; } - if (rt_config.enable_hal_profile()) { + if (rt_config.enable_hal_profile) { tr_flag |= TR_HAL | TR_MPC; tr_flag |= TR_REC; } diff --git a/src/libspu/core/context.h b/src/libspu/core/context.h index e174a872b..b9aef30aa 100644 --- a/src/libspu/core/context.h +++ b/src/libspu/core/context.h @@ -23,8 +23,7 @@ #include "libspu/core/object.h" #include "libspu/core/prelude.h" #include "libspu/core/value.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu { @@ -57,13 +56,13 @@ class SPUContext final { // Return current working fixed point fractional bits. size_t getFxpBits() const { - const auto fbits = config_.fxp_fraction_bits(); + const auto fbits = config_.fxp_fraction_bits; SPU_ENFORCE(fbits != 0); return fbits; } // Return current working field of MPC engine. - FieldType getField() const { return config_.field(); } + FieldType getField() const { return config_.field; } // Return current working runtime config. const RuntimeConfig& config() const { return config_; } diff --git a/src/libspu/core/ndarray_ref.cc b/src/libspu/core/ndarray_ref.cc index 2a8c6ed58..7fc7bc482 100644 --- a/src/libspu/core/ndarray_ref.cc +++ b/src/libspu/core/ndarray_ref.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include namespace spu { diff --git a/src/libspu/core/prelude.h b/src/libspu/core/prelude.h index cfb6ce1ce..cc7fe8c7e 100644 --- a/src/libspu/core/prelude.h +++ b/src/libspu/core/prelude.h @@ -58,7 +58,7 @@ // Format #include "fmt/ostream.h" -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace fmt { @@ -78,21 +78,21 @@ template <> struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; template <> struct formatter : ostream_formatter {}; template <> -struct formatter : ostream_formatter {}; +struct formatter : ostream_formatter {}; } // namespace fmt diff --git a/src/libspu/core/pt_buffer_view.h b/src/libspu/core/pt_buffer_view.h index cc5df06e6..809249b91 100644 --- a/src/libspu/core/pt_buffer_view.h +++ b/src/libspu/core/pt_buffer_view.h @@ -21,8 +21,7 @@ #include "libspu/core/ndarray_ref.h" #include "libspu/core/prelude.h" #include "libspu/core/shape.h" - -#include "libspu/spu.pb.h" // PtType +#include "libspu/spu.h" // PtType namespace spu { namespace detail { diff --git a/src/libspu/core/type.cc b/src/libspu/core/type.cc index ae2767742..430811974 100644 --- a/src/libspu/core/type.cc +++ b/src/libspu/core/type.cc @@ -16,6 +16,8 @@ #include +#include "magic_enum.hpp" + namespace spu { Type::Type() @@ -103,4 +105,40 @@ Type Type::fromString(std::string_view repr) { return Type(fctor(details.substr(0, details.length() - 1))); } +void PtTy::fromString(std::string_view detail) { + auto pt_type = magic_enum::enum_cast(detail); + SPU_ENFORCE(pt_type.has_value(), "parse failed from={}", detail); + pt_type_ = pt_type.value(); +} + +std::string PtTy::toString() const { + return std::string(magic_enum::enum_name(pt_type_)); +} + +void RingTy::fromString(std::string_view detail) { + auto field = magic_enum::enum_cast(detail); + SPU_ENFORCE(field.has_value(), "parse failed from={}", detail); + field_ = field.value(); +} + +std::string RingTy::toString() const { + return std::string(magic_enum::enum_name(field())); +} + +void GfmpTy::fromString(std::string_view detail) { + auto comma = detail.find_first_of(','); + auto field_str = detail.substr(0, comma); + auto mp_exp_str = detail.substr(comma + 1); + auto field = magic_enum::enum_cast(field_str); + SPU_ENFORCE(field.has_value(), "parse failed from={}", detail); + field_ = field.value(); + mersenne_prime_exp_ = std::stoul(std::string(mp_exp_str)); + prime_ = (static_cast(1) << mersenne_prime_exp_) - 1; +} + +std::string GfmpTy::toString() const { + return fmt::format("{},{}", magic_enum::enum_name(field()), + mersenne_prime_exp_); +} + } // namespace spu diff --git a/src/libspu/core/type.h b/src/libspu/core/type.h index 2bcf68190..b0cb1a1ef 100644 --- a/src/libspu/core/type.h +++ b/src/libspu/core/type.h @@ -327,12 +327,8 @@ class PtTy : public TypeImpl { size_t size() const override { return SizeOf(pt_type_); } - std::string toString() const override { return PtType_Name(pt_type_); } - - void fromString(std::string_view detail) override { - SPU_ENFORCE(PtType_Parse(std::string(detail), &pt_type_), - "parse failed from={}", detail); - } + void fromString(std::string_view detail) override; + std::string toString() const override; }; inline Type makePtType(PtType etype) { return makeType(etype); } @@ -380,12 +376,8 @@ class RingTy : public TypeImpl { return SizeOf(GetStorageType(field_)); } - void fromString(std::string_view detail) override { - SPU_ENFORCE(FieldType_Parse(std::string(detail), &field_), - "parse failed from={}", detail); - }; - - std::string toString() const override { return FieldType_Name(field()); } + void fromString(std::string_view detail) override; + std::string toString() const override; bool equals(TypeObject const* other) const override { auto const* derived_other = dynamic_cast(other); @@ -420,19 +412,8 @@ class GfmpTy : public TypeImpl { size_t mp_exp() const { return mersenne_prime_exp_; } - void fromString(std::string_view detail) override { - auto comma = detail.find_first_of(','); - auto field_str = detail.substr(0, comma); - auto mp_exp_str = detail.substr(comma + 1); - SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), - "parse failed from={}", detail); - mersenne_prime_exp_ = std::stoul(std::string(mp_exp_str)); - prime_ = (static_cast(1) << mersenne_prime_exp_) - 1; - } - - std::string toString() const override { - return fmt::format("{},{}", FieldType_Name(field()), mersenne_prime_exp_); - } + void fromString(std::string_view detail) override; + std::string toString() const override; bool equals(TypeObject const* other) const override { auto const* derived_other = dynamic_cast(other); diff --git a/src/libspu/core/type_util.cc b/src/libspu/core/type_util.cc index 7eac207e9..cda1ced71 100644 --- a/src/libspu/core/type_util.cc +++ b/src/libspu/core/type_util.cc @@ -14,6 +14,7 @@ #include "libspu/core/type_util.h" +#include "magic_enum.hpp" namespace spu { ////////////////////////////////////////////////////////////// @@ -95,7 +96,7 @@ std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // Plaintext types. ////////////////////////////////////////////////////////////// std::ostream& operator<<(std::ostream& os, const PtType& pt_type) { - os << PtType_Name(pt_type); + os << magic_enum::enum_name(pt_type); return os; } @@ -118,7 +119,7 @@ size_t SizeOf(PtType ptt) { // ProtocolKind utils ////////////////////////////////////////////////////////////// std::ostream& operator<<(std::ostream& os, ProtocolKind protocol) { - os << ProtocolKind_Name(protocol); + os << magic_enum::enum_name(protocol); return os; } @@ -142,7 +143,7 @@ size_t GetMersennePrimeExp(FieldType field) { // Field 2k types, TODO(jint) support Zq ////////////////////////////////////////////////////////////// std::ostream& operator<<(std::ostream& os, FieldType field) { - os << FieldType_Name(field); + os << magic_enum::enum_name(field); return os; } diff --git a/src/libspu/core/type_util.h b/src/libspu/core/type_util.h index 84b04f453..9cf2bc270 100644 --- a/src/libspu/core/type_util.h +++ b/src/libspu/core/type_util.h @@ -23,8 +23,7 @@ #include "libspu/core/half.h" #include "libspu/core/prelude.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu { diff --git a/src/libspu/core/value.cc b/src/libspu/core/value.cc index cc327985b..86ce8415b 100644 --- a/src/libspu/core/value.cc +++ b/src/libspu/core/value.cc @@ -85,7 +85,7 @@ ValueProto Value::toProto(size_t max_chunk_size) const { size_t chunk_size = std::min(max_chunk_size, size - i * max_chunk_size); size_t offset = i * max_chunk_size; - ValueChunkProto chunk; + pb::ValueChunkProto chunk; chunk.set_total_bytes(size); chunk.set_chunk_offset(offset); if (chunk_size > 0) { @@ -113,18 +113,19 @@ ValueProto Value::toProto(size_t max_chunk_size) const { if (imag_) { array_to_chunks(*imag_); } + ret.meta.CopyFrom(toMetaProto()); return ret; } -ValueMetaProto Value::toMetaProto() const { +pb::ValueMetaProto Value::toMetaProto() const { SPU_ENFORCE(dtype_ != DT_INVALID && vtype() != VIS_INVALID); - ValueMetaProto proto; - proto.set_data_type(dtype_); + pb::ValueMetaProto proto; + proto.set_data_type(::spu::pb::DataType(dtype_)); proto.set_is_complex(isComplex()); - proto.set_visibility(vtype()); + proto.set_visibility(::spu::pb::Visibility(vtype())); for (const auto& d : shape()) { proto.mutable_shape()->add_dims(d); } @@ -136,7 +137,7 @@ Value Value::fromProto(const ValueProto& value) { const auto& meta = value.meta; if (meta.is_complex()) { // real - ValueMetaProto partial = value.meta; + pb::ValueMetaProto partial = value.meta; partial.set_is_complex(false); ValueProto partial_proto; partial_proto.meta = partial; @@ -154,12 +155,14 @@ Value Value::fromProto(const ValueProto& value) { const auto eltype = Type::fromString(meta.storage_type()); - SPU_ENFORCE(meta.data_type() != DT_INVALID, "invalid data type={}", - meta.data_type()); + auto data_type = DataType(meta.data_type()); + SPU_ENFORCE(data_type != DataType::DT_INVALID, "invalid data type={}", + data_type); // vtype is deduced from storage_type. - SPU_ENFORCE(meta.visibility() == getVisibilityFromType(eltype), - "visibility {} does not match storage_type {}", meta.visibility(), + auto visibility = Visibility(meta.visibility()); + SPU_ENFORCE(visibility == getVisibilityFromType(eltype), + "visibility {} does not match storage_type {}", visibility, eltype); Shape shape(meta.shape().dims().begin(), meta.shape().dims().end()); @@ -167,7 +170,7 @@ Value Value::fromProto(const ValueProto& value) { const auto& chunks = value.chunks; const size_t total_bytes = chunks.empty() ? 0 : chunks[0].total_bytes(); - std::map ordered_chunks; + std::map ordered_chunks; for (const auto& s : chunks) { SPU_ENFORCE(ordered_chunks.insert({s.chunk_offset(), &s}).second, "Repeated chunk_offset {} found", s.chunk_offset()); @@ -187,7 +190,7 @@ Value Value::fromProto(const ValueProto& value) { SPU_ENFORCE(total_bytes == chunk_end_pos); - return Value(data, meta.data_type()); + return Value(data, spu::DataType(meta.data_type())); } Value Value::clone() const { diff --git a/src/libspu/core/value.h b/src/libspu/core/value.h index 1eccd384d..1f7b45876 100644 --- a/src/libspu/core/value.h +++ b/src/libspu/core/value.h @@ -24,6 +24,8 @@ #include "libspu/core/type_util.h" #include "libspu/core/vectorize.h" +#include "libspu/spu.pb.h" + namespace spu { // In order to prevent a single protobuf from being larger than 2gb, a spu @@ -31,8 +33,8 @@ namespace spu { // std::vector is used to organize multiple chunks instead of repeated in // protobuf. struct ValueProto { - ValueMetaProto meta; - std::vector chunks; + pb::ValueMetaProto meta; + std::vector chunks; }; class Value final { @@ -90,7 +92,7 @@ class Value final { // Serialize to protobuf. ValueProto toProto(size_t max_chunk_size) const; size_t chunksCount(size_t max_chunk_size) const; - ValueMetaProto toMetaProto() const; + pb::ValueMetaProto toMetaProto() const; // Deserialize from protobuf. static Value fromProto(const ValueProto& value); diff --git a/src/libspu/device/BUILD.bazel b/src/libspu/device/BUILD.bazel index a821980c9..839cf43d8 100644 --- a/src/libspu/device/BUILD.bazel +++ b/src/libspu/device/BUILD.bazel @@ -30,7 +30,7 @@ spu_cc_library( hdrs = ["io.h"], deps = [ ":symbol_table", - "//libspu:spu_cc_proto", + "//libspu:spu", "//libspu/core:context", "//libspu/core:pt_buffer_view", "//libspu/core:value", @@ -55,7 +55,6 @@ spu_cc_library( deps = [ ":intrinsic_table", ":symbol_table", - "//libspu:spu_cc_proto", "//libspu/core:context", "//libspu/core:value", "//libspu/dialect/pphlo/IR:dialect", @@ -85,7 +84,7 @@ spu_cc_library( deps = [ ":io", ":symbol_table", - "//libspu:spu_cc_proto", + "//libspu:spu", "//libspu/core:ndarray_ref", "//libspu/mpc/utils:simulate", ], diff --git a/src/libspu/device/api.cc b/src/libspu/device/api.cc index bae9b462e..0abcaa92f 100644 --- a/src/libspu/device/api.cc +++ b/src/libspu/device/api.cc @@ -124,12 +124,12 @@ struct ActionStats { void takeSnapshot(size_t rank, const RuntimeConfig &rt_config, const ExecutableProto &executable, const SymbolTable &env) { - const std::string &dump_dir = rt_config.snapshot_dump_dir(); + const std::string &dump_dir = rt_config.snapshot_dump_dir; // Naming convention for dumped files must align with debug runner. std::filesystem::path dump_folder(dump_dir); std::filesystem::create_directories(dump_folder); - // Dump executable + // Dump config { std::ofstream config_file(getConfigFilePath(dump_folder), std::ios::binary | std::ios::out); @@ -270,15 +270,15 @@ void executeImpl(OpExecutor *executor, spu::SPUContext *sctx, std::vector inputs; { TimeitGuard timeit(exec_stats.infeed_time); - inputs.reserve(executable.input_names_size()); - for (int32_t idx = 0; idx < executable.input_names_size(); idx++) { - inputs.emplace_back(env->getVar(executable.input_names(idx))); + inputs.reserve(executable.input_names.size()); + for (size_t idx = 0; idx < executable.input_names.size(); idx++) { + inputs.emplace_back(env->getVar(executable.input_names[idx])); } } const RuntimeConfig rt_config = sctx->config(); - if (rt_config.enable_runtime_snapshot()) { + if (rt_config.enable_runtime_snapshot) { const bool isRefHal = sctx->lctx() == nullptr; const size_t rank = isRefHal ? 0 : sctx->lctx()->Rank(); takeSnapshot(rank, rt_config, executable, *env); @@ -298,7 +298,7 @@ void executeImpl(OpExecutor *executor, spu::SPUContext *sctx, [&](mlir::Diagnostic &diag) { SPDLOG_ERROR(diag.str()); }); auto moduleOpRef = - mlir::parseSourceString(executable.code(), &mlir_ctx); + mlir::parseSourceString(executable.code, &mlir_ctx); SPU_ENFORCE(moduleOpRef, "MLIR parser failure"); @@ -322,11 +322,11 @@ void executeImpl(OpExecutor *executor, spu::SPUContext *sctx, SPU_ENFORCE(entry_function, "main module not found"); ExecutionOptions opts; - opts.do_type_check = rt_config.enable_type_checker(); - opts.do_log_execution = rt_config.enable_pphlo_trace(); - opts.do_parallel = rt_config.experimental_enable_inter_op_par(); + opts.do_type_check = rt_config.enable_type_checker; + opts.do_log_execution = rt_config.enable_pphlo_trace; + opts.do_parallel = rt_config.experimental_enable_inter_op_par; if (opts.do_parallel) { - opts.concurrency = rt_config.experimental_inter_op_concurrency(); + opts.concurrency = rt_config.experimental_inter_op_concurrency; mlir_ctx.enableMultithreading(); mlir_ctx.enterMultiThreadedExecution(); } @@ -341,14 +341,14 @@ void executeImpl(OpExecutor *executor, spu::SPUContext *sctx, // sync output to environment. { TimeitGuard timeit(exec_stats.outfeed_time); - for (int32_t idx = 0; idx < executable.output_names_size(); idx++) { - env->setVar(executable.output_names(idx), outputs[idx]); + for (size_t idx = 0; idx < executable.output_names.size(); idx++) { + env->setVar(executable.output_names[idx], outputs[idx]); } } comm_stats.diff(sctx->lctx()); if ((getGlobalTraceFlag(sctx->id()) & TR_REC) != 0) { - printProfilingData(sctx, executable.name(), exec_stats, comm_stats); + printProfilingData(sctx, executable.name, exec_stats, comm_stats); } } @@ -361,12 +361,7 @@ void execute(OpExecutor *executor, spu::SPUContext *sctx, const std::string &text, const std::vector &input_names, const std::vector &output_names, SymbolTable *env) { - ExecutableProto executable; - executable.set_name("unnamed"); - *executable.mutable_input_names() = {input_names.begin(), input_names.end()}; - *executable.mutable_output_names() = {output_names.begin(), - output_names.end()}; - executable.set_code(text); + ExecutableProto executable("unnamed", input_names, output_names, text); return executeImpl(executor, sctx, executable, env); } diff --git a/src/libspu/device/api.h b/src/libspu/device/api.h index 215390aab..fac39e0c5 100644 --- a/src/libspu/device/api.h +++ b/src/libspu/device/api.h @@ -21,11 +21,9 @@ #include "libspu/core/value.h" #include "libspu/device/executor.h" #include "libspu/device/symbol_table.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::device { - void execute(OpExecutor *executor, SPUContext *sctx, const ExecutableProto &executable, SymbolTable *env); @@ -34,5 +32,4 @@ void execute(OpExecutor *executor, spu::SPUContext *sctx, const std::string &text, const std::vector &input_names, const std::vector &output_names, SymbolTable *env); - } // namespace spu::device diff --git a/src/libspu/device/io.cc b/src/libspu/device/io.cc index 955746d5e..a9f0def1f 100644 --- a/src/libspu/device/io.cc +++ b/src/libspu/device/io.cc @@ -43,7 +43,7 @@ size_t IoClient::getShareSize(const PtBufferView &bv, Visibility vtype, std::vector IoClient::makeShares(const PtBufferView &bv, Visibility vtype, int owner_rank) { - const size_t fxp_bits = config_.fxp_fraction_bits(); + const size_t fxp_bits = config_.fxp_fraction_bits; SPU_ENFORCE(fxp_bits != 0, "fxp should never be zero, please check default"); if (bv.pt_type == PT_I1 && vtype == VIS_SECRET && @@ -86,10 +86,10 @@ std::vector IoClient::makeShares(const PtBufferView &bv, // encode to ring. DataType dtype; - NdArrayRef encoded = encodeToRing(bv, config_.field(), fxp_bits, &dtype); + NdArrayRef encoded = encodeToRing(bv, config_.field, fxp_bits, &dtype); // make shares. - if (!config_.experimental_enable_colocated_optimization()) { + if (!config_.experimental_enable_colocated_optimization) { owner_rank = -1; } std::vector shares = @@ -154,7 +154,7 @@ void IoClient::combineShares(absl::Span values, return; } - const size_t fxp_bits = config_.fxp_fraction_bits(); + const size_t fxp_bits = config_.fxp_fraction_bits; SPU_ENFORCE(fxp_bits != 0, "fxp should never be zero, please check default"); // reconstruct to ring buffer. diff --git a/src/libspu/device/io_test.cc b/src/libspu/device/io_test.cc index b67642e47..0c38e5b78 100644 --- a/src/libspu/device/io_test.cc +++ b/src/libspu/device/io_test.cc @@ -30,8 +30,8 @@ TEST_P(IoClientTest, Float) { const Visibility kVisibility = std::get<3>(GetParam()); RuntimeConfig hconf; - hconf.set_protocol(std::get<1>(GetParam())); - hconf.set_field(std::get<2>(GetParam())); + hconf.protocol = std::get<1>(GetParam()); + hconf.field = std::get<2>(GetParam()); IoClient io(kWorldSize, hconf); xt::xarray in_data({{1, -2, 3, 0}}); @@ -52,8 +52,8 @@ TEST_P(IoClientTest, Int) { const Visibility kVisibility = std::get<3>(GetParam()); RuntimeConfig hconf; - hconf.set_protocol(std::get<1>(GetParam())); - hconf.set_field(std::get<2>(GetParam())); + hconf.protocol = std::get<1>(GetParam()); + hconf.field = std::get<2>(GetParam()); IoClient io(kWorldSize, hconf); xt::xarray in_data({{1, -2, 3, 0}}); @@ -103,8 +103,8 @@ TEST_P(ColocatedIoTest, Works) { const Visibility kVisibility = std::get<3>(GetParam()); RuntimeConfig hconf; - hconf.set_protocol(std::get<1>(GetParam())); - hconf.set_field(std::get<2>(GetParam())); + hconf.protocol = std::get<1>(GetParam()); + hconf.field = std::get<2>(GetParam()); mpc::utils::simulate(kWorldSize, [&](auto lctx) { SPUContext sctx(hconf, lctx); @@ -135,9 +135,9 @@ TEST(ColocatedIoTest, PrivateWorks) { const size_t kWorldSize = 2; RuntimeConfig hconf; - hconf.set_protocol(ProtocolKind::SEMI2K); - hconf.set_field(FieldType::FM64); - hconf.set_experimental_enable_colocated_optimization(true); + hconf.protocol = ProtocolKind::SEMI2K; + hconf.field = FieldType::FM64; + hconf.experimental_enable_colocated_optimization = true; mpc::utils::simulate(kWorldSize, [&](auto lctx) { SPUContext sctx(hconf, lctx); diff --git a/src/libspu/device/pphlo/pphlo_executor_test.cc b/src/libspu/device/pphlo/pphlo_executor_test.cc index ebc837632..50f557270 100644 --- a/src/libspu/device/pphlo/pphlo_executor_test.cc +++ b/src/libspu/device/pphlo/pphlo_executor_test.cc @@ -48,7 +48,7 @@ TEST_P(ExecutorTest, BoolSplatConstant) { Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam())); - r.getConfig().set_enable_type_checker(false); + r.getConfig().enable_type_checker = false; r.run(R"( func.func @main() -> (tensor) { @@ -80,7 +80,7 @@ TEST_P(ExecutorTest, BoolConstant) { Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam())); - r.getConfig().set_enable_type_checker(false); + r.getConfig().enable_type_checker = false; r.run(R"( func.func @main() -> (tensor<2xi32>) { @@ -99,7 +99,7 @@ TEST_P(ExecutorTest, ComplexConstant) { Runner r(std::get<0>(GetParam()), std::get<1>(GetParam()), std::get<2>(GetParam())); - r.getConfig().set_enable_type_checker(false); + r.getConfig().enable_type_checker = false; r.run(R"( func.func @main() -> (tensor<2xcomplex>) { diff --git a/src/libspu/device/pphlo/pphlo_intrinsic_executor.cc b/src/libspu/device/pphlo/pphlo_intrinsic_executor.cc index 7bfafc3b3..687804194 100644 --- a/src/libspu/device/pphlo/pphlo_intrinsic_executor.cc +++ b/src/libspu/device/pphlo/pphlo_intrinsic_executor.cc @@ -136,7 +136,7 @@ std::vector intrinsic_dispatcher(SPUContext* ctx, } if (name == PREFER_A) { - if (ctx->config().protocol() == ProtocolKind::CHEETAH) { + if (ctx->config().protocol == ProtocolKind::CHEETAH) { // NOTE(juhou): For 2PC, MulAB uses COT which is efficient and accurate // than MulAA that needs HE. Thus we just by-pass the PreferAOp for 2PC. return {inputs[0]}; diff --git a/src/libspu/device/pphlo/pphlo_verifier_test.cc b/src/libspu/device/pphlo/pphlo_verifier_test.cc index 8efd4f312..b41c7f09a 100644 --- a/src/libspu/device/pphlo/pphlo_verifier_test.cc +++ b/src/libspu/device/pphlo/pphlo_verifier_test.cc @@ -38,8 +38,8 @@ void runner(const OpFcn &f, absl::Span> inputs, absl::Span> positives, absl::Span> negatives) { RuntimeConfig conf; - conf.set_field(FM64); - conf.set_protocol(SEMI2K); + conf.protocol = SEMI2K; + conf.field = FM64; std::unique_ptr io_ = std::make_unique(2, conf); for (size_t idx = 0; idx < inputs.size(); ++idx) { io_->InFeed(fmt::format("in{}", idx), inputs[idx], VIS_SECRET); @@ -261,8 +261,8 @@ TEST(Verify, Greater) { TEST(Verify, Select) { RuntimeConfig conf; - conf.set_field(FM64); - conf.set_protocol(SEMI2K); + conf.protocol = SEMI2K; + conf.field = FM64; std::unique_ptr io_ = std::make_unique(2, conf); io_->InFeed("in0", xt::xarray{false, true, true, false}, VIS_SECRET); io_->InFeed("in1", xt::xarray{5, 2, 7, 4}, VIS_SECRET); diff --git a/src/libspu/device/test_utils.h b/src/libspu/device/test_utils.h index 2e8e852b4..120eb77fc 100644 --- a/src/libspu/device/test_utils.h +++ b/src/libspu/device/test_utils.h @@ -19,8 +19,7 @@ #include "libspu/core/config.h" #include "libspu/device/io.h" #include "libspu/device/symbol_table.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::device { diff --git a/src/libspu/device/utils/pphlo_executor_debug_runner.cc b/src/libspu/device/utils/pphlo_executor_debug_runner.cc index f06e13afe..d778e430a 100644 --- a/src/libspu/device/utils/pphlo_executor_debug_runner.cc +++ b/src/libspu/device/utils/pphlo_executor_debug_runner.cc @@ -29,6 +29,8 @@ #include "libspu/mpc/factory.h" #include "libspu/mpc/utils/simulate.h" +#include "libspu/spu.pb.h" + llvm::cl::opt SnapshotDir( "snapshot_dir", llvm::cl::desc("folder contains core snapshot files"), llvm::cl::init(".")); @@ -80,10 +82,10 @@ spu::RuntimeConfig parseRuntimeConfig( SPDLOG_INFO("Read config file from {}", config_file.c_str()); std::ifstream stream(config_file, std::ios::binary); - spu::RuntimeConfig config; + spu::pb::RuntimeConfig config; SPU_ENFORCE(config.ParseFromIstream(&stream), "Parse serialized config file {} failed", config_file.c_str()); - return config; + return spu::RuntimeConfig(config); } spu::ExecutableProto parseExecutable( @@ -94,10 +96,15 @@ spu::ExecutableProto parseExecutable( SPDLOG_INFO("Read config file from {}", code_file.c_str()); std::ifstream stream(code_file, std::ios::binary); - spu::ExecutableProto code; + spu::pb::ExecutableProto code; SPU_ENFORCE(code.ParseFromIstream(&stream), "Parse serialized code file {} failed", code_file.c_str()); - return code; + auto input_names = std::vector(code.input_names().begin(), + code.input_names().end()); + auto output_names = std::vector(code.output_names().begin(), + code.output_names().end()); + return spu::ExecutableProto(code.name(), input_names, output_names, + code.code()); } spu::device::SymbolTable parseSymbolTable( @@ -162,7 +169,7 @@ void MemBasedRunner(const std::filesystem::path &snapshot_dir) { SPDLOG_INFO("world size = {}", world_size); auto rt_config = parseRuntimeConfig(snapshot_dir); - rt_config.set_enable_runtime_snapshot(false); + rt_config.enable_runtime_snapshot = false; spu::mpc::utils::simulate( world_size, [&](const std::shared_ptr<::yacl::link::Context> &lctx) { diff --git a/src/libspu/device/utils/pphlo_executor_test_runner.cc b/src/libspu/device/utils/pphlo_executor_test_runner.cc index 2b4880bc8..998ce53d5 100644 --- a/src/libspu/device/utils/pphlo_executor_test_runner.cc +++ b/src/libspu/device/utils/pphlo_executor_test_runner.cc @@ -25,21 +25,16 @@ namespace spu::device::pphlo::test { Runner::Runner(size_t world_size, FieldType field, ProtocolKind protocol) : world_size_(world_size) { - config_.set_field(field); - config_.set_protocol(protocol); - config_.set_enable_type_checker(true); - config_.set_experimental_enable_colocated_optimization(true); + config_.field = field; + config_.protocol = protocol; + config_.enable_type_checker = true; + config_.experimental_enable_colocated_optimization = true; io_ = std::make_unique(world_size_, config_); } std::string Runner::compileMHlo(const std::string &mhlo, const std::vector &vis) { - CompilationSource source; - source.set_ir_type(SourceIRType::STABLEHLO); - source.set_ir_txt(mhlo); - for (const auto v : vis) { - source.add_input_visibility(v); - } + CompilationSource source(SourceIRType::STABLEHLO, mhlo, vis); CompilerOptions copts; return compiler::compile(source, copts); @@ -47,13 +42,12 @@ std::string Runner::compileMHlo(const std::string &mhlo, void Runner::run(const std::string &mlir, size_t num_output) { for (size_t idx = 0; idx < num_output; ++idx) { - executable_.add_output_names(fmt::format("output{}", idx)); + executable_.output_names.emplace_back(fmt::format("output{}", idx)); } - executable_.set_code(mlir); + executable_.code = mlir; ::spu::mpc::utils::simulate( world_size_, [&](const std::shared_ptr &lctx) { - RuntimeConfig conf; - conf.CopyFrom(config_); + RuntimeConfig conf(config_); if (lctx->Rank() == 0) { // conf.set_enable_action_trace(true); } diff --git a/src/libspu/device/utils/pphlo_executor_test_runner.h b/src/libspu/device/utils/pphlo_executor_test_runner.h index 9263cc877..e61850b25 100644 --- a/src/libspu/device/utils/pphlo_executor_test_runner.h +++ b/src/libspu/device/utils/pphlo_executor_test_runner.h @@ -17,8 +17,7 @@ #include "fmt/format.h" // IWYU pragma: keep #include "libspu/device/test_utils.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::device::pphlo::test { @@ -39,7 +38,7 @@ class Runner { int owner_rank = -1) { const std::string name = fmt::format("input{}", input_idx_++); io_->InFeed(name, input, vis, owner_rank); - executable_.add_input_names(name); + executable_.input_names.emplace_back(name); } std::string compileMHlo(const std::string &mhlo, diff --git a/src/libspu/kernel/hal/fxp_approx.cc b/src/libspu/kernel/hal/fxp_approx.cc index 34667e84b..6fa420861 100644 --- a/src/libspu/kernel/hal/fxp_approx.cc +++ b/src/libspu/kernel/hal/fxp_approx.cc @@ -125,7 +125,7 @@ Value log2_pade_normalized(SPUContext* ctx, const Value& x) { Value log2_pade(SPUContext* ctx, const Value& x) { SPU_TRACE_HAL_DISP(ctx, x); - const size_t bit_width = SizeOf(ctx->config().field()) * 8; + const size_t bit_width = SizeOf(ctx->config().field) * 8; auto k = _popcount(ctx, _prefix_or(ctx, x), bit_width); const size_t num_fxp_bits = ctx->getFxpBits(); @@ -165,7 +165,7 @@ Value log_householder(SPUContext* ctx, const Value& x) { Value y = f_add(ctx, f_sub(ctx, term_1, term_2), constant(ctx, 3.0, x.dtype(), x.shape())); - const size_t fxp_log_orders = ctx->config().fxp_log_orders(); + const size_t fxp_log_orders = ctx->config().fxp_log_orders; SPU_ENFORCE(fxp_log_orders != 0, "fxp_log_orders should not be {}", fxp_log_orders); std::vector coeffs{0.0}; @@ -173,7 +173,7 @@ Value log_householder(SPUContext* ctx, const Value& x) { coeffs.emplace_back(1.0 / (1.0 + i)); } - const size_t num_iters = ctx->config().fxp_log_iters(); + const size_t num_iters = ctx->config().fxp_log_iters; SPU_ENFORCE(num_iters != 0, "fxp_log_iters should not be {}", num_iters); for (size_t i = 0; i < num_iters; i++) { Value h = f_sub(ctx, constant(ctx, 1.0, x.dtype(), x.shape()), @@ -187,7 +187,7 @@ Value log_householder(SPUContext* ctx, const Value& x) { // see https://lvdmaaten.github.io/publications/papers/crypten.pdf // exp(x) = (1 + x / n) ^ n, when n is infinite large. Value exp_taylor(SPUContext* ctx, const Value& x) { - const size_t fxp_exp_iters = ctx->config().fxp_exp_iters(); + const size_t fxp_exp_iters = ctx->config().fxp_exp_iters; SPU_ENFORCE(fxp_exp_iters != 0, "fxp_exp_iters should not be {}", fxp_exp_iters); @@ -203,9 +203,9 @@ Value exp_taylor(SPUContext* ctx, const Value& x) { Value exp_prime(SPUContext* ctx, const Value& x) { auto clamped_x = x; - auto offset = ctx->config().experimental_exp_prime_offset(); + auto offset = ctx->config().experimental_exp_prime_offset; auto fxp = ctx->getFxpBits(); - if (!ctx->config().experimental_exp_prime_disable_lower_bound()) { + if (!ctx->config().experimental_exp_prime_disable_lower_bound) { // currently the bound is tied to FM128 SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128); auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E; @@ -213,7 +213,7 @@ Value exp_prime(SPUContext* ctx, const Value& x) { constant(ctx, lower_bound, x.dtype(), x.shape())) .setDtype(x.dtype()); } - if (ctx->config().experimental_exp_prime_enable_upper_bound()) { + if (ctx->config().experimental_exp_prime_enable_upper_bound) { // currently the bound is tied to FM128 SPU_ENFORCE_EQ(ctx->getField(), FieldType::FM128); auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E; @@ -457,7 +457,7 @@ Value f_exp(SPUContext* ctx, const Value& x) { return f_exp_p(ctx, x); } - switch (ctx->config().fxp_exp_mode()) { + switch (ctx->config().fxp_exp_mode) { case RuntimeConfig::EXP_DEFAULT: case RuntimeConfig::EXP_TAYLOR: return detail::exp_taylor(ctx, x); @@ -482,7 +482,7 @@ Value f_exp(SPUContext* ctx, const Value& x) { } default: SPU_THROW("unexpected exp approximation method {}", - ctx->config().fxp_exp_mode()); + ctx->config().fxp_exp_mode); } } @@ -495,7 +495,7 @@ Value f_log(SPUContext* ctx, const Value& x) { return f_log_p(ctx, x); } - switch (ctx->config().fxp_log_mode()) { + switch (ctx->config().fxp_log_mode) { // Note: // Generally, MINMAX approximation is a fast and precise DEFAULT option // which gives very high precision (avg error < 3e-5) when x is between ( 0, @@ -519,7 +519,7 @@ Value f_log(SPUContext* ctx, const Value& x) { return detail::log_householder(ctx, x); default: SPU_THROW("unexpected log approximation method {}", - ctx->config().fxp_log_mode()); + ctx->config().fxp_log_mode); } } @@ -575,7 +575,7 @@ static Value rsqrt_init_guess(SPUContext* ctx, const Value& x, const Value& z) { // let rsqrt(u) = 26.02942339 * u^4 - 49.86605845 * u^3 + 38.4714796 * u^2 // - 15.47994394 * u + 4.14285016 spu::Value r; - if (!ctx->config().enable_lower_accuracy_rsqrt()) { + if (!ctx->config().enable_lower_accuracy_rsqrt) { auto coeffs = {0.0F, -15.47994394F, 38.4714796F, -49.86605845F, 26.02942339F}; r = f_add(ctx, @@ -666,7 +666,7 @@ Value f_rsqrt(SPUContext* ctx, const Value& x) { // TODO: we should avoid fork context in hal layer, it will make global // scheduling harder and also make profiling harder. - if (ctx->config().experimental_enable_intra_op_par()) { + if (ctx->config().experimental_enable_intra_op_par) { auto sub_ctx = ctx->fork(); auto r = std::async(rsqrt_init_guess, dynamic_cast(sub_ctx.get()), x, z); @@ -746,7 +746,7 @@ Value f_sigmoid(SPUContext* ctx, const Value& x) { SPU_ENFORCE(x.isFxp()); - switch (ctx->config().sigmoid_mode()) { + switch (ctx->config().sigmoid_mode) { case RuntimeConfig::SIGMOID_DEFAULT: case RuntimeConfig::SIGMOID_MM1: { return sigmoid_mm1(ctx, x); diff --git a/src/libspu/kernel/hal/fxp_approx_test.cc b/src/libspu/kernel/hal/fxp_approx_test.cc index d540eb2bd..c85db164d 100644 --- a/src/libspu/kernel/hal/fxp_approx_test.cc +++ b/src/libspu/kernel/hal/fxp_approx_test.cc @@ -83,13 +83,13 @@ TEST(FxpTest, ExponentialPrime) { std::cout << "test exp_prime" << std::endl; spu::mpc::utils::simulate(2, [&](std::shared_ptr lctx) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::SEMI2K); - conf.set_field(FieldType::FM128); - conf.set_fxp_fraction_bits(40); - conf.set_experimental_enable_exp_prime(true); + conf.protocol = ProtocolKind::SEMI2K; + conf.field = FieldType::FM128; + conf.fxp_fraction_bits = 40; + conf.experimental_enable_exp_prime = true; SPUContext ctx = test::makeSPUContext(conf, lctx); - auto offset = ctx.config().experimental_exp_prime_offset(); + auto offset = ctx.config().experimental_exp_prime_offset; auto fxp = ctx.getFxpBits(); auto lower_bound = (48.0 - offset - 2.0 * fxp) / M_LOG2E; auto upper_bound = (124.0 - 2.0 * fxp - offset) / M_LOG2E; @@ -262,9 +262,9 @@ TEST(FxpTest, Rsqrt) { // fxp_fraction_bits = 17 { RuntimeConfig config; - config.set_protocol(ProtocolKind::REF2K); - config.set_field(FieldType::FM64); - config.set_fxp_fraction_bits(17); + config.protocol = ProtocolKind::REF2K; + config.field = FieldType::FM64; + config.fxp_fraction_bits = 17; SPUContext ctx = test::makeSPUContext(config, nullptr); Value a = test::makeValue(&ctx, x, VIS_SECRET); @@ -279,9 +279,9 @@ TEST(FxpTest, Rsqrt) { { RuntimeConfig config; - config.set_protocol(ProtocolKind::REF2K); - config.set_field(FieldType::FM64); - config.set_fxp_fraction_bits(16); + config.protocol = ProtocolKind::REF2K; + config.field = FieldType::FM64; + config.fxp_fraction_bits = 16; SPUContext ctx = test::makeSPUContext(config, nullptr); xt::random::seed(0); @@ -320,9 +320,9 @@ TEST(FxpTest, Sqrt) { // fxp_fraction_bits = 17 { RuntimeConfig config; - config.set_protocol(ProtocolKind::REF2K); - config.set_field(FieldType::FM64); - config.set_fxp_fraction_bits(17); + config.protocol = ProtocolKind::REF2K; + config.field = FieldType::FM64; + config.fxp_fraction_bits = 17; SPUContext ctx = test::makeSPUContext(config, nullptr); Value a = test::makeValue(&ctx, x, VIS_SECRET); diff --git a/src/libspu/kernel/hal/fxp_base.cc b/src/libspu/kernel/hal/fxp_base.cc index 486f66b43..63f9bd239 100644 --- a/src/libspu/kernel/hal/fxp_base.cc +++ b/src/libspu/kernel/hal/fxp_base.cc @@ -118,7 +118,7 @@ Value reciprocal_goldschmidt_normalized_approx(SPUContext* ctx, auto r = w; auto e = f_sub(ctx, k1_, f_mul(ctx, c, w, SignType::Positive)); - size_t num_iters = ctx->config().fxp_div_goldschmidt_iters(); + size_t num_iters = ctx->config().fxp_div_goldschmidt_iters; if (ctx->getFxpBits() >= 30) { // default 2 iters of goldschmidt can only get precision about 14 bits. // so if fxp>=30, we use 3 iters by default, which get about 28 bits diff --git a/src/libspu/kernel/hal/permute.cc b/src/libspu/kernel/hal/permute.cc index d40115766..81f1b8a31 100644 --- a/src/libspu/kernel/hal/permute.cc +++ b/src/libspu/kernel/hal/permute.cc @@ -29,8 +29,7 @@ #include "libspu/kernel/hal/shape_ops.h" #include "libspu/kernel/hal/type_cast.h" #include "libspu/kernel/hal/utils.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::kernel::hal { @@ -396,7 +395,7 @@ bool Partition(SPUContext *ctx, const int64_t num_keys, return false; } - int64_t quick_sort_thres = ctx->config().quick_sort_threshold(); + int64_t quick_sort_thres = ctx->config().quick_sort_threshold; int64_t lo; // left end of current interval int64_t hi; // right end of current interval @@ -751,7 +750,7 @@ std::vector PrepareInput(SPUContext *ctx, const Value &input, if (!config.value_only) { auto dt = - ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; + ctx->config().field == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; // shuffle index with the same permutation as values inp.push_back( _perm_ss(ctx, _p2s(ctx, hal::iota(ctx, dt, input.numel())), rand_perm) @@ -1010,8 +1009,7 @@ spu::Value _gen_inv_perm_s(SPUContext *ctx, absl::Span keys, SPU_ENFORCE_GT(bv.size(), 0U); // 2. generate natural permutation for initialization - auto dt = - ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; + auto dt = ctx->config().field == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; auto init_perm = iota(ctx, dt, keys[0].numel()); auto shared_perm = _p2s(ctx, init_perm); @@ -1315,8 +1313,7 @@ spu::Value _apply_inv_perm(SPUContext *ctx, const spu::Value &x, // Given a permutation, generate its inverse permutation // ret[perm[i]] = i spu::Value _inverse(SPUContext *ctx, const Value &perm) { - auto dt = - ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; + auto dt = ctx->config().field == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; auto iota_perm = iota(ctx, dt, perm.numel()); return _apply_inv_perm(ctx, iota_perm, perm); } @@ -1394,8 +1391,7 @@ spu::Value _merge_pub_pri_keys(SPUContext *ctx, auto cur_inv_perm = _gen_inv_perm(ctx, cur_key_hat, is_ascending); inv_perm = _compose_perm(ctx, inv_perm, cur_inv_perm); } - auto dt = - ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; + auto dt = ctx->config().field == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; std::vector permed_keys; for (const auto &key : keys) { permed_keys.emplace_back(_apply_inv_perm(ctx, key, inv_perm)); @@ -1577,7 +1573,7 @@ std::vector simple_sort1d(SPUContext *ctx, "num_keys {} is not valid", num_keys); std::vector ret; - const auto sort_method = ctx->config().sort_method(); + const auto sort_method = ctx->config().sort_method; // There are multiple sort methods supported by SPU, we will try to seek the // best method in the following order if the user does not specify the method @@ -1768,7 +1764,7 @@ std::vector topk_1d(SPUContext *ctx, const spu::Value &input, ret.push_back(internal::_permute_1d(ctx, input, topk_indices)); if (!config.value_only) { auto dt = - ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; + ctx->config().field == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; ret.push_back(constant(ctx, topk_indices, dt, {static_cast(topk_indices.size())})); } @@ -1805,7 +1801,7 @@ std::vector topk_1d(SPUContext *ctx, const spu::Value &input, "kernels are not supported"); auto dt = - ctx->config().field() == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; + ctx->config().field == FieldType::FM32 ? spu::DT_I32 : spu::DT_I64; std::vector inp; inp.push_back(input); diff --git a/src/libspu/kernel/hal/polymorphic_test.cc b/src/libspu/kernel/hal/polymorphic_test.cc index 45c7e8777..e59c99eb5 100644 --- a/src/libspu/kernel/hal/polymorphic_test.cc +++ b/src/libspu/kernel/hal/polymorphic_test.cc @@ -682,9 +682,9 @@ class LogisticTest TEST_P(LogisticTest, Logistic) { // GIVEN RuntimeConfig config; - config.set_protocol(ProtocolKind::REF2K); - config.set_field(FieldType::FM64); - config.set_sigmoid_mode(GetParam()); + config.protocol = ProtocolKind::REF2K; + config.field = FieldType::FM64; + config.sigmoid_mode = GetParam(); SPUContext ctx = test::makeSPUContext(config, nullptr); xt::xarray x{{1.0, 2.0}, {0.5, 1.8}}; diff --git a/src/libspu/kernel/hal/ring.cc b/src/libspu/kernel/hal/ring.cc index 725fd498f..1c5895bab 100644 --- a/src/libspu/kernel/hal/ring.cc +++ b/src/libspu/kernel/hal/ring.cc @@ -321,7 +321,7 @@ Value _mmul(SPUContext* ctx, const Value& x, const Value& y) { auto [m_step, n_step, k_step] = calcMmulTilingSize(m, n, k, x.elsize(), 256UL * 1024 * 1024); - if (ctx->config().experimental_disable_mmul_split() || + if (ctx->config().experimental_disable_mmul_split || (m_step == m && n_step == n && k_step == k)) { // no split return _mmul_impl(ctx, x, y); diff --git a/src/libspu/kernel/hlo/BUILD.bazel b/src/libspu/kernel/hlo/BUILD.bazel index 80c1c9bde..c6140fee2 100644 --- a/src/libspu/kernel/hlo/BUILD.bazel +++ b/src/libspu/kernel/hlo/BUILD.bazel @@ -306,6 +306,7 @@ spu_cc_test( "//libspu/kernel:test_util", "//libspu/kernel/hal:polymorphic", "//libspu/mpc/utils:simulate", + "@magic_enum", ], ) diff --git a/src/libspu/kernel/hlo/const_test.cc b/src/libspu/kernel/hlo/const_test.cc index cbcf1181b..cae008f9e 100644 --- a/src/libspu/kernel/hlo/const_test.cc +++ b/src/libspu/kernel/hlo/const_test.cc @@ -41,7 +41,7 @@ TEST(ConstTest, Epsilon) { auto v = hal::dump_public_as(&sctx, eps); - EXPECT_FLOAT_EQ(v[0], 1 / (std::pow(2, sctx.config().fxp_fraction_bits()))); + EXPECT_FLOAT_EQ(v[0], 1 / (std::pow(2, sctx.config().fxp_fraction_bits))); } } // namespace spu::kernel::hlo diff --git a/src/libspu/kernel/hlo/sort_test.cc b/src/libspu/kernel/hlo/sort_test.cc index 89519b9e6..85f6f87df 100644 --- a/src/libspu/kernel/hlo/sort_test.cc +++ b/src/libspu/kernel/hlo/sort_test.cc @@ -20,6 +20,7 @@ #include #include "gtest/gtest.h" +#include "magic_enum.hpp" #include "xtensor/xio.hpp" #include "libspu/kernel/hal/polymorphic.h" @@ -28,11 +29,10 @@ #include "libspu/kernel/hlo/const.h" #include "libspu/kernel/test_util.h" #include "libspu/mpc/utils/simulate.h" - // to print method name std::ostream &operator<<(std::ostream &os, - spu::RuntimeConfig_SortMethod method) { - os << spu::RuntimeConfig::SortMethod_Name(method); + spu::RuntimeConfig::SortMethod method) { + os << magic_enum::enum_name(method); return os; } namespace spu::kernel::hlo { @@ -248,10 +248,11 @@ TEST_P(SimpleSortTest, MultiOperands) { mpc::utils::simulate( npc, [&](const std::shared_ptr &lctx) { RuntimeConfig cfg; - cfg.set_protocol(prot); - cfg.set_field(field); - cfg.set_enable_action_trace(false); - cfg.set_sort_method(method); + cfg.protocol = prot; + cfg.field = field; + cfg.enable_action_trace = false; + cfg.sort_method = method; + SPUContext ctx = test::makeSPUContext(cfg, lctx); xt::xarray k1 = {7, 6, 5, 5, 4, 4, 4, 1, 3, 3}; @@ -292,10 +293,10 @@ TEST_P(SimpleSortTest, SingleKeyWithPayload) { mpc::utils::simulate( npc, [&](const std::shared_ptr &lctx) { RuntimeConfig cfg; - cfg.set_protocol(prot); - cfg.set_field(field); - cfg.set_enable_action_trace(false); - cfg.set_sort_method(method); + cfg.protocol = prot; + cfg.field = field; + cfg.enable_action_trace = false; + cfg.sort_method = method; SPUContext ctx = test::makeSPUContext(cfg, lctx); xt::xarray k1 = {7, 6, 5, 4, 1, 3, 2}; diff --git a/src/libspu/kernel/test_util.cc b/src/libspu/kernel/test_util.cc index 3375eada7..c55408976 100644 --- a/src/libspu/kernel/test_util.cc +++ b/src/libspu/kernel/test_util.cc @@ -34,9 +34,9 @@ SPUContext makeSPUContext(RuntimeConfig config, SPUContext makeSPUContext(ProtocolKind prot_kind, FieldType field, const std::shared_ptr& lctx) { RuntimeConfig cfg; - cfg.set_protocol(prot_kind); - cfg.set_field(field); - cfg.set_enable_action_trace(false); + cfg.protocol = prot_kind; + cfg.field = field; + cfg.enable_action_trace = false; return makeSPUContext(cfg, lctx); } diff --git a/src/libspu/mpc/BUILD.bazel b/src/libspu/mpc/BUILD.bazel index 712284906..5e2145c40 100644 --- a/src/libspu/mpc/BUILD.bazel +++ b/src/libspu/mpc/BUILD.bazel @@ -45,7 +45,7 @@ spu_cc_library( srcs = ["factory.cc"], hdrs = ["factory.h"], deps = [ - "//libspu:spu_cc_proto", + "//libspu:spu", "//libspu/mpc/aby3", "//libspu/mpc/cheetah", "//libspu/mpc/ref2k", diff --git a/src/libspu/mpc/ab_api_test.cc b/src/libspu/mpc/ab_api_test.cc index 82d802b3b..0fe930af9 100644 --- a/src/libspu/mpc/ab_api_test.cc +++ b/src/libspu/mpc/ab_api_test.cc @@ -111,7 +111,7 @@ bool verifyCost(Kernel* kernel, std::string_view name, FieldType field, /* THEN */ \ EXPECT_VALUE_EQ(re, rp); \ EXPECT_TRUE(verifyCost(obj->prot()->getKernel(#OP "_aa"), #OP "_aa", \ - conf.field(), kShape, npc, cost)); \ + conf.field, kShape, npc, cost)); \ }); \ } @@ -141,7 +141,7 @@ bool verifyCost(Kernel* kernel, std::string_view name, FieldType field, /* THEN */ \ EXPECT_VALUE_EQ(re, rp); \ EXPECT_TRUE(verifyCost(obj->prot()->getKernel(#OP "_ap"), #OP "_ap", \ - conf.field(), kShape, npc, cost)); \ + conf.field, kShape, npc, cost)); \ }); \ } @@ -179,7 +179,7 @@ TEST_P(ArithmeticTest, SquareA) { /* THEN */ EXPECT_VALUE_EQ(r_aa, r_pp); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("square_a"), "square_a", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); }); } @@ -195,13 +195,13 @@ TEST_P(ArithmeticTest, MulA1B) { return; } - const int64_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field) * 8; /* GIVEN */ - auto p0 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH + auto p0 = rand_p(obj.get(), conf.protocol == ProtocolKind::CHEETAH ? Shape({200, 26}) : kShape); - auto p1 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH + auto p1 = rand_p(obj.get(), conf.protocol == ProtocolKind::CHEETAH ? Shape({200, 26}) : kShape); auto a0 = p2a(obj.get(), p0); @@ -223,7 +223,7 @@ TEST_P(ArithmeticTest, MulA1B) { /* THEN */ EXPECT_VALUE_EQ(r_aa, r_pp); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mul_a1b"), "mul_a1b", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); }); } @@ -239,7 +239,7 @@ TEST_P(ArithmeticTest, MulAV) { return; } - const int64_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); @@ -259,7 +259,7 @@ TEST_P(ArithmeticTest, MulAV) { /* THEN */ EXPECT_VALUE_EQ(r_aa, r_pp); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mul_av"), "mul_av", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); }); } @@ -276,7 +276,7 @@ TEST_P(ArithmeticTest, MulA1BV) { return; } - const int64_t K = spu::SizeOf(conf.field()) * 8; + const int64_t K = spu::SizeOf(conf.field) * 8; /* GIVEN */ auto p0 = rand_p(obj.get(), kShape); @@ -300,7 +300,7 @@ TEST_P(ArithmeticTest, MulA1BV) { /* THEN */ EXPECT_VALUE_EQ(r_aa, r_pp); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("mul_a1bv"), "mul_a1bv", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); }); } @@ -335,7 +335,7 @@ TEST_P(ArithmeticTest, MatMulAP) { /* THEN */ EXPECT_VALUE_EQ(r_aa, r_pp); - ce::Params params = {{"K", SizeOf(conf.field()) * 8}, + ce::Params params = {{"K", SizeOf(conf.field) * 8}, {"N", npc}, {"m", M}, {"n", N}, @@ -376,7 +376,7 @@ TEST_P(ArithmeticTest, MatMulAA) { /* THEN */ EXPECT_VALUE_EQ(r_aa, r_pp); - ce::Params params = {{"K", SizeOf(conf.field()) * 8}, + ce::Params params = {{"K", SizeOf(conf.field) * 8}, {"N", npc}, {"m", M}, {"n", N}, @@ -427,7 +427,7 @@ TEST_P(ArithmeticTest, MatMulAV) { /* THEN */ EXPECT_VALUE_EQ(r0_aa, r_pp); EXPECT_VALUE_EQ(r1_aa, r_pp); - ce::Params params = {{"K", SizeOf(conf.field()) * 8}, + ce::Params params = {{"K", SizeOf(conf.field) * 8}, {"N", npc}, {"m", M}, {"n", N}, @@ -460,7 +460,7 @@ TEST_P(ArithmeticTest, NegateA) { /* THEN */ EXPECT_VALUE_EQ(r_p, r_pp); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("negate_a"), "negate_a", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); }); } @@ -491,7 +491,7 @@ TEST_P(ArithmeticTest, LShiftA) { /* THEN */ EXPECT_VALUE_EQ(r_b, r_p); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("lshift_a"), "lshift_a", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); } }); } @@ -518,7 +518,7 @@ TEST_P(ArithmeticTest, TruncA) { } else { // has msb error, only use lowest 10 bits. p0 = arshift_p(obj.get(), p0, - {static_cast(SizeOf(conf.field()) * 8 - 10)}); + {static_cast(SizeOf(conf.field) * 8 - 10)}); } /* GIVEN */ @@ -536,7 +536,7 @@ TEST_P(ArithmeticTest, TruncA) { /* THEN */ EXPECT_VALUE_ALMOST_EQ(r_a, r_p, npc); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("trunc_a"), "trunc_a", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); }); } @@ -559,7 +559,7 @@ TEST_P(ArithmeticTest, P2A) { /* THEN */ EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("p2a"), "p2a", conf.field(), + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("p2a"), "p2a", conf.field, kShape, npc, cost)); }); } @@ -583,7 +583,7 @@ TEST_P(ArithmeticTest, A2P) { /* THEN */ EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("a2p"), "a2p", conf.field(), + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("a2p"), "a2p", conf.field, kShape, npc, cost)); }); } @@ -615,7 +615,7 @@ TEST_P(ArithmeticTest, A2P) { /* THEN */ \ EXPECT_VALUE_EQ(re, rp); \ EXPECT_TRUE(verifyCost(obj->prot()->getKernel(#OP "_bb"), #OP "_bb", \ - conf.field(), kShape, npc, cost)); \ + conf.field, kShape, npc, cost)); \ }); \ } @@ -645,7 +645,7 @@ TEST_P(ArithmeticTest, A2P) { /* THEN */ \ EXPECT_VALUE_EQ(re, rp); \ EXPECT_TRUE(verifyCost(obj->prot()->getKernel(#OP "_bp"), #OP "_bp", \ - conf.field(), kShape, npc, cost)); \ + conf.field, kShape, npc, cost)); \ }); \ } @@ -685,7 +685,7 @@ TEST_BOOLEAN_BINARY_OP(xor) /* THEN */ \ EXPECT_VALUE_EQ(r_b, r_p); \ EXPECT_TRUE(verifyCost(obj->prot()->getKernel(#OP "_b"), #OP "_b", \ - conf.field(), kShape, npc, cost)); \ + conf.field, kShape, npc, cost)); \ } \ }); \ } @@ -713,7 +713,7 @@ TEST_P(BooleanTest, P2B) { /* THEN */ EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("p2b"), "p2b", conf.field(), + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("p2b"), "p2b", conf.field, kShape, npc, cost)); }); } @@ -737,7 +737,7 @@ TEST_P(BooleanTest, B2P) { /* THEN */ EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("b2p"), "b2p", conf.field(), + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("b2p"), "b2p", conf.field, kShape, npc, cost)); }); } @@ -756,8 +756,8 @@ TEST_P(BooleanTest, BitrevB) { /* WHEN */ auto b0 = p2b(obj.get(), p0); - for (size_t i = 0; i < SizeOf(conf.field()); i++) { - for (size_t j = i; j < SizeOf(conf.field()); j++) { + for (size_t i = 0; i < SizeOf(conf.field); i++) { + for (size_t j = i; j < SizeOf(conf.field); j++) { auto prev = obj->prot()->getState()->getStats(); auto b1 = bitrev_b(obj.get(), b0, i, j); auto cost = obj->prot()->getState()->getStats() - prev; @@ -767,7 +767,7 @@ TEST_P(BooleanTest, BitrevB) { EXPECT_VALUE_EQ(p1, pp1); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("bitrev_b"), "bitrev_b", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); } } }); @@ -791,7 +791,7 @@ TEST_P(ConversionTest, A2B) { auto cost = obj->prot()->getState()->getStats() - prev; /* THEN */ - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("a2b"), "a2b", conf.field(), + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("a2b"), "a2b", conf.field, kShape, npc, cost)); EXPECT_VALUE_EQ(p0, b2p(obj.get(), b1)); }); @@ -816,7 +816,7 @@ TEST_P(ConversionTest, B2A) { auto cost = obj->prot()->getState()->getStats() - prev; /* THEN */ - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("b2a"), "b2a", conf.field(), + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("b2a"), "b2a", conf.field, kShape, npc, cost)); EXPECT_VALUE_EQ(p0, a2p(obj.get(), a1)); }); @@ -838,7 +838,7 @@ TEST_P(ConversionTest, MSB) { auto p0 = rand_p(obj.get(), kShape); // SECURENN has an msb input range here - if (conf.protocol() == ProtocolKind::SECURENN) { + if (conf.protocol == ProtocolKind::SECURENN) { p0 = arshift_p(obj.get(), p0, {1}); } @@ -851,10 +851,10 @@ TEST_P(ConversionTest, MSB) { /* THEN */ EXPECT_TRUE(verifyCost(obj->prot()->getKernel("msb_a2b"), "msb_a2b", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); EXPECT_VALUE_EQ( rshift_p(obj.get(), p0, - {static_cast(SizeOf(conf.field()) * 8 - 1)}), + {static_cast(SizeOf(conf.field) * 8 - 1)}), b2p(obj.get(), b1)); }); } @@ -872,13 +872,13 @@ TEST_P(ConversionTest, EqualAA) { } /* GIVEN */ // NOTE(lwj) for Cheetah, set a lager case to test the tield dispatch - auto r0 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH + auto r0 = rand_p(obj.get(), conf.protocol == ProtocolKind::CHEETAH ? Shape({10, 20, 30}) : kShape); - auto r1 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH + auto r1 = rand_p(obj.get(), conf.protocol == ProtocolKind::CHEETAH ? Shape({10, 20, 30}) : kShape); - auto r2 = rand_p(obj.get(), conf.protocol() == ProtocolKind::CHEETAH + auto r2 = rand_p(obj.get(), conf.protocol == ProtocolKind::CHEETAH ? Shape({10, 20, 30}) : kShape); std::memcpy(r2.data().data(), r0.data().data(), 16); @@ -896,7 +896,7 @@ TEST_P(ConversionTest, EqualAA) { /* THEN */ EXPECT_VALUE_EQ(out_value, t_value); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("equal_aa"), "equal_aa", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); } }); } @@ -931,7 +931,7 @@ TEST_P(ConversionTest, EqualAP) { /* THEN */ EXPECT_VALUE_EQ(out_value, t_value); EXPECT_TRUE(verifyCost(obj->prot()->getKernel("equal_ap"), "equal_ap", - conf.field(), kShape, npc, cost)); + conf.field, kShape, npc, cost)); } }); } diff --git a/src/libspu/mpc/aby3/BUILD.bazel b/src/libspu/mpc/aby3/BUILD.bazel index 29507046b..0b834df0c 100644 --- a/src/libspu/mpc/aby3/BUILD.bazel +++ b/src/libspu/mpc/aby3/BUILD.bazel @@ -191,6 +191,7 @@ spu_cc_library( deps = [ "//libspu/core:type", "//libspu/mpc/common:pv2k", + "@magic_enum", ], ) diff --git a/src/libspu/mpc/aby3/protocol.cc b/src/libspu/mpc/aby3/protocol.cc index e36a32732..f6a173465 100644 --- a/src/libspu/mpc/aby3/protocol.cc +++ b/src/libspu/mpc/aby3/protocol.cc @@ -33,7 +33,7 @@ void regAby3Protocol(SPUContext* ctx, const std::shared_ptr& lctx) { aby3::registerTypes(); - ctx->prot()->addState(ctx->config().field()); + ctx->prot()->addState(ctx->config().field); // add communicator ctx->prot()->addState(lctx); diff --git a/src/libspu/mpc/aby3/protocol_test.cc b/src/libspu/mpc/aby3/protocol_test.cc index ab65344a8..14c744f6d 100644 --- a/src/libspu/mpc/aby3/protocol_test.cc +++ b/src/libspu/mpc/aby3/protocol_test.cc @@ -24,8 +24,8 @@ namespace { RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::ABY3); - conf.set_field(field); + conf.protocol = ProtocolKind::ABY3; + conf.field = field; return conf; } @@ -39,7 +39,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM128)), // testing::Values(3)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); @@ -51,7 +51,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM128)), // testing::Values(3)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); @@ -63,7 +63,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM128)), // testing::Values(3)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); @@ -75,7 +75,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM128)), // testing::Values(3)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); diff --git a/src/libspu/mpc/aby3/type.cc b/src/libspu/mpc/aby3/type.cc index 45676fd4a..fe83c8265 100644 --- a/src/libspu/mpc/aby3/type.cc +++ b/src/libspu/mpc/aby3/type.cc @@ -16,8 +16,9 @@ #include -#include "libspu/mpc/common/pv2k.h" +#include "magic_enum.hpp" +#include "libspu/mpc/common/pv2k.h" namespace spu::mpc::aby3 { void registerTypes() { @@ -30,4 +31,18 @@ void registerTypes() { }); } +void BShrTy::fromString(std::string_view detail) { + auto comma = detail.find_first_of(','); + auto back_type_str = detail.substr(0, comma); + auto nbits_str = detail.substr(comma + 1); + auto back_type = magic_enum::enum_cast(back_type_str); + SPU_ENFORCE(back_type.has_value(), "parse failed from={}", detail); + back_type_ = back_type.value(); + nbits_ = std::stoul(std::string(nbits_str)); +} + +std::string BShrTy::toString() const { + return fmt::format("{},{}", magic_enum::enum_name(back_type_), nbits_); +} + } // namespace spu::mpc::aby3 diff --git a/src/libspu/mpc/aby3/type.h b/src/libspu/mpc/aby3/type.h index b5a00de9b..f2b541ed7 100644 --- a/src/libspu/mpc/aby3/type.h +++ b/src/libspu/mpc/aby3/type.h @@ -92,18 +92,8 @@ class BShrTy : public TypeImpl { static std::string_view getStaticId() { return "aby3.BShr"; } - void fromString(std::string_view detail) override { - auto comma = detail.find_first_of(','); - auto back_type_str = detail.substr(0, comma); - auto nbits_str = detail.substr(comma + 1); - SPU_ENFORCE(PtType_Parse(std::string(back_type_str), &back_type_), - "parse failed from={}", detail); - nbits_ = std::stoul(std::string(nbits_str)); - } - - std::string toString() const override { - return fmt::format("{},{}", PtType_Name(back_type_), nbits_); - } + void fromString(std::string_view detail) override; + std::string toString() const override; size_t size() const override { return SizeOf(back_type_) * 2; } diff --git a/src/libspu/mpc/api_test.cc b/src/libspu/mpc/api_test.cc index d890258f6..e0dff5e9a 100644 --- a/src/libspu/mpc/api_test.cc +++ b/src/libspu/mpc/api_test.cc @@ -261,7 +261,7 @@ TEST_P(ApiTest, MsbS) { auto p0 = rand_p(sctx.get(), kShape); // SECURENN has an msb input range requirement here - if (conf.protocol() == ProtocolKind::SECURENN) { + if (conf.protocol == ProtocolKind::SECURENN) { p0 = arshift_p(sctx.get(), p0, {1}); } @@ -288,7 +288,7 @@ TEST_P(ApiTest, MsbS) { auto x_s = p2s(sctx.get(), x_p); \ \ for (auto bits : kShiftBits) { \ - if (bits >= SizeOf(conf.field()) * 8) { \ + if (bits >= SizeOf(conf.field) * 8) { \ continue; \ } \ /* WHEN */ \ @@ -318,7 +318,7 @@ TEST_P(ApiTest, MsbS) { auto x_v = p2v(sctx.get(), x_p, rank); \ \ for (auto bits : kShiftBits) { \ - if (bits >= SizeOf(conf.field()) * 8) { \ + if (bits >= SizeOf(conf.field) * 8) { \ continue; \ } \ /* WHEN */ \ @@ -349,7 +349,7 @@ TEST_P(ApiTest, MsbS) { auto p0 = rand_p(sctx.get(), kShape); \ \ for (auto bits : kShiftBits) { /* WHEN */ \ - if (bits >= SizeOf(conf.field()) * 8) { \ + if (bits >= SizeOf(conf.field) * 8) { \ continue; \ } \ auto r_p = OP##_p(sctx.get(), p0, {static_cast(bits)}); \ @@ -379,13 +379,13 @@ TEST_P(ApiTest, TruncS) { auto sctx = factory(conf, lctx); // NOTE(lwj): test Cheetah's TiledDispatch using larger shape - auto p0 = rand_p(sctx.get(), conf.protocol() == ProtocolKind::CHEETAH + auto p0 = rand_p(sctx.get(), conf.protocol == ProtocolKind::CHEETAH ? Shape({300, 20}) : kShape); // TODO: here we assume has msb error, only use lowest 10 bits. p0 = arshift_p(sctx.get(), p0, - {static_cast(SizeOf(conf.field()) * 8 - 10)}); + {static_cast(SizeOf(conf.field) * 8 - 10)}); const size_t bits = 2; auto r_s = s2p(sctx.get(), trunc_s(sctx.get(), p2s(sctx.get(), p0), bits, diff --git a/src/libspu/mpc/api_test.h b/src/libspu/mpc/api_test.h index 3b205d7c6..6ee809628 100644 --- a/src/libspu/mpc/api_test.h +++ b/src/libspu/mpc/api_test.h @@ -18,8 +18,7 @@ #include "yacl/link/link.h" #include "libspu/mpc/api_test_params.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::mpc::test { diff --git a/src/libspu/mpc/cheetah/ot/basic_ot_prot.h b/src/libspu/mpc/cheetah/ot/basic_ot_prot.h index 90ee661c8..2c4aa89a8 100644 --- a/src/libspu/mpc/cheetah/ot/basic_ot_prot.h +++ b/src/libspu/mpc/cheetah/ot/basic_ot_prot.h @@ -17,8 +17,7 @@ #include "libspu/core/ndarray_ref.h" #include "libspu/mpc/cheetah/ot/ferret_ot_interface.h" #include "libspu/mpc/common/communicator.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::mpc::cheetah { diff --git a/src/libspu/mpc/cheetah/protocol.cc b/src/libspu/mpc/cheetah/protocol.cc index 407e6ca2d..b2e81db8c 100644 --- a/src/libspu/mpc/cheetah/protocol.cc +++ b/src/libspu/mpc/cheetah/protocol.cc @@ -42,16 +42,16 @@ void regCheetahProtocol(SPUContext* ctx, ctx->prot()->addState(lctx); // add Z2k state. - ctx->prot()->addState(ctx->config().field()); + ctx->prot()->addState(ctx->config().field); // add Cheetah states ctx->prot()->addState( - lctx, ctx->config().cheetah_2pc_config().enable_mul_lsb_error()); + lctx, ctx->config().cheetah_2pc_config.enable_mul_lsb_error); ctx->prot()->addState( - lctx, ctx->config().cheetah_2pc_config().disable_matmul_pack()); + lctx, ctx->config().cheetah_2pc_config.disable_matmul_pack); ctx->prot()->addState( ctx->getClusterLevelMaxConcurrency(), - ctx->config().cheetah_2pc_config().ot_kind()); + ctx->config().cheetah_2pc_config.ot_kind); // register public kernels. regPV2kKernels(ctx->prot()); diff --git a/src/libspu/mpc/cheetah/protocol_ab_test.cc b/src/libspu/mpc/cheetah/protocol_ab_test.cc index 0b57918c4..f2c3d4e9b 100644 --- a/src/libspu/mpc/cheetah/protocol_ab_test.cc +++ b/src/libspu/mpc/cheetah/protocol_ab_test.cc @@ -20,10 +20,9 @@ namespace { RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::CHEETAH); - conf.set_field(field); - conf.mutable_cheetah_2pc_config()->set_ot_kind( - CheetahOtKind::YACL_Softspoken); + conf.protocol = ProtocolKind::CHEETAH; + conf.field = field; + conf.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken; return conf; } @@ -36,7 +35,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM64)), // testing::Values(2)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); @@ -48,7 +47,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM128)), // testing::Values(2)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); @@ -59,7 +58,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM64)), // testing::Values(2)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); diff --git a/src/libspu/mpc/cheetah/protocol_api_test.cc b/src/libspu/mpc/cheetah/protocol_api_test.cc index 5fcfca621..527506668 100644 --- a/src/libspu/mpc/cheetah/protocol_api_test.cc +++ b/src/libspu/mpc/cheetah/protocol_api_test.cc @@ -20,10 +20,9 @@ namespace { RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::CHEETAH); - conf.set_field(field); - conf.mutable_cheetah_2pc_config()->set_ot_kind( - CheetahOtKind::YACL_Softspoken); + conf.protocol = ProtocolKind::CHEETAH; + conf.field = field; + conf.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken; return conf; } @@ -37,7 +36,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM128)), // testing::Values(2)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); diff --git a/src/libspu/mpc/cheetah/state.h b/src/libspu/mpc/cheetah/state.h index 0ee3e575e..96ea437ed 100644 --- a/src/libspu/mpc/cheetah/state.h +++ b/src/libspu/mpc/cheetah/state.h @@ -24,8 +24,7 @@ #include "libspu/mpc/cheetah/arith/cheetah_mul.h" #include "libspu/mpc/cheetah/ot/basic_ot_prot.h" #include "libspu/mpc/cheetah/rlwe/utils.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::mpc::cheetah { diff --git a/src/libspu/mpc/cheetah/type.cc b/src/libspu/mpc/cheetah/type.cc index 1497189ea..ede298f7f 100644 --- a/src/libspu/mpc/cheetah/type.cc +++ b/src/libspu/mpc/cheetah/type.cc @@ -16,6 +16,8 @@ #include +#include "magic_enum.hpp" + #include "libspu/mpc/common/pv2k.h" namespace spu::mpc::cheetah { @@ -29,4 +31,18 @@ void registerTypes() { }); } +void BShrTy::fromString(std::string_view detail) { + auto comma = detail.find_first_of(','); + auto field_str = detail.substr(0, comma); + auto nbits_str = detail.substr(comma + 1); + auto field = magic_enum::enum_cast(field_str); + SPU_ENFORCE(field.has_value(), "parse failed from={}", detail); + field_ = field.value(); + nbits_ = std::stoul(std::string(nbits_str)); +}; + +std::string BShrTy::toString() const { + return fmt::format("{},{}", magic_enum::enum_name(field()), nbits_); +} + } // namespace spu::mpc::cheetah \ No newline at end of file diff --git a/src/libspu/mpc/cheetah/type.h b/src/libspu/mpc/cheetah/type.h index 607c20783..bd8cb02a7 100644 --- a/src/libspu/mpc/cheetah/type.h +++ b/src/libspu/mpc/cheetah/type.h @@ -42,18 +42,8 @@ class BShrTy : public TypeImpl { static std::string_view getStaticId() { return "cheetah.BShr"; } - void fromString(std::string_view detail) override { - auto comma = detail.find_first_of(','); - auto field_str = detail.substr(0, comma); - auto nbits_str = detail.substr(comma + 1); - SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), - "parse failed from={}", detail); - nbits_ = std::stoul(std::string(nbits_str)); - }; - - std::string toString() const override { - return fmt::format("{},{}", FieldType_Name(field()), nbits_); - } + void fromString(std::string_view detail) override; + std::string toString() const override; bool equals(TypeObject const* other) const override { auto const* derived_other = dynamic_cast(other); diff --git a/src/libspu/mpc/common/BUILD.bazel b/src/libspu/mpc/common/BUILD.bazel index f22bc908d..d6805502a 100644 --- a/src/libspu/mpc/common/BUILD.bazel +++ b/src/libspu/mpc/common/BUILD.bazel @@ -25,6 +25,7 @@ spu_cc_library( "//libspu/mpc/common:communicator", "//libspu/mpc/common:prg_state", "//libspu/mpc/utils:ring_ops", + "@magic_enum", ], ) diff --git a/src/libspu/mpc/common/pv2k.cc b/src/libspu/mpc/common/pv2k.cc index ea38bf9db..fac9deddc 100644 --- a/src/libspu/mpc/common/pv2k.cc +++ b/src/libspu/mpc/common/pv2k.cc @@ -17,15 +17,15 @@ #include #include +#include "magic_enum.hpp" + #include "libspu/core/ndarray_ref.h" #include "libspu/mpc/common/communicator.h" #include "libspu/mpc/common/prg_state.h" #include "libspu/mpc/kernel.h" #include "libspu/mpc/utils/ring_ops.h" - namespace spu::mpc { namespace { - inline bool isOwner(KernelEvalContext* ctx, const Type& type) { auto* comm = ctx->getState(); return type.as()->owner() == static_cast(comm->getRank()); @@ -944,6 +944,20 @@ class MergeKeysV : public MergeKeysKernel { } // namespace +void Priv2kTy::fromString(std::string_view str) { + auto comma = str.find_first_of(','); + auto field_str = str.substr(0, comma); + auto owner_str = str.substr(comma + 1); + auto field = magic_enum::enum_cast(field_str); + SPU_ENFORCE(field.has_value(), "parse failed from={}", str); + field_ = field.value(); + owner_ = std::stoll(std::string(owner_str)); +} + +std::string Priv2kTy::toString() const { + return fmt::format("{},{}", magic_enum::enum_name(field()), owner_); +} + void regPV2kTypes() { static std::once_flag flag; std::call_once(flag, []() { diff --git a/src/libspu/mpc/common/pv2k.h b/src/libspu/mpc/common/pv2k.h index 334e04a26..e9f28d20b 100644 --- a/src/libspu/mpc/common/pv2k.h +++ b/src/libspu/mpc/common/pv2k.h @@ -51,18 +51,8 @@ class Priv2kTy : public TypeImpl { return SizeOf(GetStorageType(field_)); } - std::string toString() const override { - return fmt::format("{},{}", FieldType_Name(field()), owner_); - } - - void fromString(std::string_view str) override { - auto comma = str.find_first_of(','); - auto field_str = str.substr(0, comma); - auto owner_str = str.substr(comma + 1); - SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), - "parse failed from={}", str); - owner_ = std::stoll(std::string(owner_str)); - } + std::string toString() const override; + void fromString(std::string_view str) override; bool equals(TypeObject const* other) const override { auto const* derived_other = dynamic_cast(other); diff --git a/src/libspu/mpc/factory.cc b/src/libspu/mpc/factory.cc index 5ee0c6901..09f40c95d 100644 --- a/src/libspu/mpc/factory.cc +++ b/src/libspu/mpc/factory.cc @@ -32,7 +32,7 @@ namespace spu::mpc { void Factory::RegisterProtocol( SPUContext* ctx, const std::shared_ptr& lctx) { // TODO: support multi-protocols. - switch (ctx->config().protocol()) { + switch (ctx->config().protocol) { case ProtocolKind::REF2K: { return regRef2kProtocol(ctx, lctx); } @@ -49,31 +49,31 @@ void Factory::RegisterProtocol( return regSecurennProtocol(ctx, lctx); } default: { - SPU_THROW("Invalid protocol kind {}", ctx->config().protocol()); + SPU_THROW("Invalid protocol kind {}", ctx->config().protocol); } } } std::unique_ptr Factory::CreateIO(const RuntimeConfig& conf, size_t npc) { - switch (conf.protocol()) { + switch (conf.protocol) { case ProtocolKind::REF2K: { - return makeRef2kIo(conf.field(), npc); + return makeRef2kIo(conf.field, npc); } case ProtocolKind::SEMI2K: { - return semi2k::makeSemi2kIo(conf.field(), npc); + return semi2k::makeSemi2kIo(conf.field, npc); } case ProtocolKind::ABY3: { - return aby3::makeAby3Io(conf.field(), npc); + return aby3::makeAby3Io(conf.field, npc); } case ProtocolKind::CHEETAH: { - return cheetah::makeCheetahIo(conf.field(), npc); + return cheetah::makeCheetahIo(conf.field, npc); } case ProtocolKind::SECURENN: { - return securenn::makeSecurennIo(conf.field(), npc); + return securenn::makeSecurennIo(conf.field, npc); } default: { - SPU_THROW("Invalid protocol kind {}", conf.protocol()); + SPU_THROW("Invalid protocol kind {}", conf.protocol); } } return nullptr; diff --git a/src/libspu/mpc/factory.h b/src/libspu/mpc/factory.h index 110a35d44..0f4722222 100644 --- a/src/libspu/mpc/factory.h +++ b/src/libspu/mpc/factory.h @@ -20,8 +20,7 @@ #include "libspu/core/context.h" #include "libspu/mpc/io_interface.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::mpc { diff --git a/src/libspu/mpc/ref2k/ref2k.cc b/src/libspu/mpc/ref2k/ref2k.cc index 28b8c1762..17c66a92b 100644 --- a/src/libspu/mpc/ref2k/ref2k.cc +++ b/src/libspu/mpc/ref2k/ref2k.cc @@ -477,7 +477,7 @@ void regRef2kProtocol(SPUContext* ctx, ctx->prot()->addState(lctx); // add Z2k state. - ctx->prot()->addState(ctx->config().field()); + ctx->prot()->addState(ctx->config().field); // register public kernels. regPV2kKernels(ctx->prot()); diff --git a/src/libspu/mpc/ref2k/ref2k_test.cc b/src/libspu/mpc/ref2k/ref2k_test.cc index 56410cb68..8f11f2f14 100644 --- a/src/libspu/mpc/ref2k/ref2k_test.cc +++ b/src/libspu/mpc/ref2k/ref2k_test.cc @@ -22,8 +22,8 @@ namespace { RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::REF2K); - conf.set_field(field); + conf.protocol = ProtocolKind::REF2K; + conf.field = field; return conf; } @@ -37,7 +37,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM128)), // testing::Values(1, 2, 3, 5)), // [](const testing::TestParamInfo& p) { - return fmt::format("{}x{}", std::get<1>(p.param).field(), + return fmt::format("{}x{}", std::get<1>(p.param).field, std::get<2>(p.param)); }); diff --git a/src/libspu/mpc/securenn/protocol.cc b/src/libspu/mpc/securenn/protocol.cc index 3140b06cb..9cc8a444e 100644 --- a/src/libspu/mpc/securenn/protocol.cc +++ b/src/libspu/mpc/securenn/protocol.cc @@ -36,7 +36,7 @@ void regSecurennProtocol(SPUContext* ctx, ctx->prot()->addState(lctx); // add Z2k state. - ctx->prot()->addState(ctx->config().field()); + ctx->prot()->addState(ctx->config().field); // register public kernels. regPV2kKernels(ctx->prot()); diff --git a/src/libspu/mpc/securenn/protocol_test.cc b/src/libspu/mpc/securenn/protocol_test.cc index 354a976f4..d72c9a95e 100644 --- a/src/libspu/mpc/securenn/protocol_test.cc +++ b/src/libspu/mpc/securenn/protocol_test.cc @@ -24,8 +24,8 @@ namespace { RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::SECURENN); - conf.set_field(field); + conf.protocol = ProtocolKind::SECURENN; + conf.field = field; return conf; } @@ -40,7 +40,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(3)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); }); INSTANTIATE_TEST_SUITE_P( @@ -52,7 +52,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(3)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); ; }); @@ -65,7 +65,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(3)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); ; }); @@ -78,7 +78,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(3)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); ; }); diff --git a/src/libspu/mpc/securenn/type.cc b/src/libspu/mpc/securenn/type.cc index 5b0b2bb69..f9bb37eeb 100644 --- a/src/libspu/mpc/securenn/type.cc +++ b/src/libspu/mpc/securenn/type.cc @@ -16,6 +16,8 @@ #include +#include "magic_enum.hpp" + #include "libspu/mpc/common/pv2k.h" namespace spu::mpc::securenn { @@ -29,4 +31,18 @@ void registerTypes() { }); } +void BShrTy::fromString(std::string_view detail) { + auto comma = detail.find_first_of(','); + auto field_str = detail.substr(0, comma); + auto nbits_str = detail.substr(comma + 1); + auto field = magic_enum::enum_cast(field_str); + SPU_ENFORCE(field.has_value(), "parse failed from={}", detail); + field_ = field.value(); + nbits_ = std::stoul(std::string(nbits_str)); +}; + +std::string BShrTy::toString() const { + return fmt::format("{},{}", magic_enum::enum_name(field()), nbits_); +} + } // namespace spu::mpc::securenn diff --git a/src/libspu/mpc/securenn/type.h b/src/libspu/mpc/securenn/type.h index 302dadee0..6cc497190 100644 --- a/src/libspu/mpc/securenn/type.h +++ b/src/libspu/mpc/securenn/type.h @@ -42,18 +42,8 @@ class BShrTy : public TypeImpl { static std::string_view getStaticId() { return "securenn.BShr"; } - void fromString(std::string_view detail) override { - auto comma = detail.find_first_of(','); - auto field_str = detail.substr(0, comma); - auto nbits_str = detail.substr(comma + 1); - SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), - "parse failed from={}", detail); - nbits_ = std::stoul(std::string(nbits_str)); - }; - - std::string toString() const override { - return fmt::format("{},{}", FieldType_Name(field()), nbits_); - } + void fromString(std::string_view detail) override; + std::string toString() const override; bool equals(TypeObject const* other) const override { auto const* derived_other = dynamic_cast(other); diff --git a/src/libspu/mpc/semi2k/arithmetic.cc b/src/libspu/mpc/semi2k/arithmetic.cc index 25c31934a..07b2ed472 100644 --- a/src/libspu/mpc/semi2k/arithmetic.cc +++ b/src/libspu/mpc/semi2k/arithmetic.cc @@ -227,8 +227,8 @@ std::tuple MulOpen( auto x_hit_cache = x_cache.replay_desc.status != Beaver::Init; auto y_hit_cache = y_cache.replay_desc.status != Beaver::Init; - if (ctx->sctx()->config().experimental_disable_vectorization() || - x_hit_cache || y_hit_cache) { + if (ctx->sctx()->config().experimental_disable_vectorization || x_hit_cache || + y_hit_cache) { if (x_hit_cache) { x_a = std::move(x_cache.open_cache); } else { diff --git a/src/libspu/mpc/semi2k/beaver/beaver_interface.h b/src/libspu/mpc/semi2k/beaver/beaver_interface.h index 7a7b7f6a2..f13d60a40 100644 --- a/src/libspu/mpc/semi2k/beaver/beaver_interface.h +++ b/src/libspu/mpc/semi2k/beaver/beaver_interface.h @@ -20,8 +20,7 @@ #include "libspu/core/shape.h" #include "libspu/mpc/common/prg_tensor.h" - -#include "libspu/spu.pb.h" +#include "libspu/spu.h" namespace spu::mpc::semi2k { diff --git a/src/libspu/mpc/semi2k/exp.cc b/src/libspu/mpc/semi2k/exp.cc index 34dba15fa..2b5cf6ed6 100644 --- a/src/libspu/mpc/semi2k/exp.cc +++ b/src/libspu/mpc/semi2k/exp.cc @@ -49,8 +49,7 @@ NdArrayRef ExpA::proc(KernelEvalContext* ctx, const NdArrayRef& in) const { const ring2k_t exp_conv_scale = std::roundf(M_LOG2E * (1L << kExpFxp)); // offset scale should directly encoded to a fixed point with total_fxp - const ring2k_t offset = - ctx->sctx()->config().experimental_exp_prime_offset(); + const ring2k_t offset = ctx->sctx()->config().experimental_exp_prime_offset; const ring2k_t offset_scaled = offset << total_fxp; NdArrayView _x(x); diff --git a/src/libspu/mpc/semi2k/protocol.cc b/src/libspu/mpc/semi2k/protocol.cc index cc7c8f0d6..6f5368500 100644 --- a/src/libspu/mpc/semi2k/protocol.cc +++ b/src/libspu/mpc/semi2k/protocol.cc @@ -40,7 +40,7 @@ void regSemi2kProtocol(SPUContext* ctx, ctx->prot()->addState(lctx); // add Z2k state. - ctx->prot()->addState(ctx->config().field()); + ctx->prot()->addState(ctx->config().field); // register public kernels. regPV2kKernels(ctx->prot()); @@ -72,7 +72,7 @@ void regSemi2kProtocol(SPUContext* ctx, semi2k::EqualAA, semi2k::EqualAP, // semi2k::BeaverCacheKernel>(); - if (ctx->config().trunc_allow_msb_error()) { + if (ctx->config().trunc_allow_msb_error) { ctx->prot()->regKernel(); } else { ctx->prot()->regKernel(); @@ -84,7 +84,7 @@ void regSemi2kProtocol(SPUContext* ctx, // only supports 2pc fm128 for now if (ctx->getField() == FieldType::FM128 && - ctx->config().experimental_enable_exp_prime()) { + ctx->config().experimental_enable_exp_prime) { ctx->prot()->regKernel(); } } diff --git a/src/libspu/mpc/semi2k/protocol_test.cc b/src/libspu/mpc/semi2k/protocol_test.cc index abf75fffc..605975f9d 100644 --- a/src/libspu/mpc/semi2k/protocol_test.cc +++ b/src/libspu/mpc/semi2k/protocol_test.cc @@ -43,14 +43,14 @@ namespace { RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::SEMI2K); - conf.set_field(field); + conf.protocol = ProtocolKind::SEMI2K; + conf.field = field; if (field == FieldType::FM64) { - conf.set_fxp_fraction_bits(17); + conf.fxp_fraction_bits = 17; } else if (field == FieldType::FM128) { - conf.set_fxp_fraction_bits(40); + conf.fxp_fraction_bits = 40; } - conf.set_experimental_enable_exp_prime(true); + conf.experimental_enable_exp_prime = true; return conf; } @@ -74,15 +74,13 @@ void InitBeaverServer() { std::unique_ptr makeTTPSemi2kProtocol( const RuntimeConfig& rt, const std::shared_ptr& lctx) { InitBeaverServer(); - RuntimeConfig ttp_rt = rt; - ttp_rt.set_beaver_type(RuntimeConfig_BeaverType_TrustedThirdParty); - auto* ttp = ttp_rt.mutable_ttp_beaver_config(); - ttp->set_adjust_rank(lctx->WorldSize() - 1); - ttp->set_server_host(server_host); - ttp->set_asym_crypto_schema("SM2"); - ttp->set_server_public_key(key_pair.first.data(), - key_pair.first.size()); + std::string server_public_key(key_pair.first.data(), + key_pair.first.size()); + RuntimeConfig ttp_rt = rt; + ttp_rt.beaver_type = RuntimeConfig::TrustedThirdParty; + ttp_rt.ttp_beaver_config = std::make_shared( + server_host, lctx->WorldSize() - 1, "SM2", server_public_key, ""); return makeSemi2kProtocol(ttp_rt, lctx); } @@ -100,7 +98,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2, 3, 5)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); }); INSTANTIATE_TEST_SUITE_P( @@ -114,7 +112,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2, 3, 5)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); ; }); @@ -129,7 +127,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2, 3, 5)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); ; }); @@ -144,7 +142,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2, 3, 5)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); ; }); @@ -161,7 +159,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2, 3, 5)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); ; }); @@ -431,7 +429,7 @@ TEST_P(BeaverCacheTest, priv_mul_test) { NdArrayRef ring2k_shr[2]; int64_t numel = 1; - FieldType field = conf.field(); + FieldType field = conf.field; std::vector real_vec(numel); for (int64_t i = 0; i < numel; ++i) { @@ -481,7 +479,7 @@ TEST_P(BeaverCacheTest, priv_mul_test) { TEST_P(BeaverCacheTest, exp_mod_test) { const RuntimeConfig& conf = std::get<1>(GetParam()); - FieldType field = conf.field(); + FieldType field = conf.field; DISPATCH_ALL_FIELDS(field, [&]() { // exponents < 32 @@ -506,15 +504,15 @@ TEST_P(BeaverCacheTest, ExpA) { // only supports FM128 for now // note not using ctx->hasKernel("exp_a") because we are testing kernel // registration as well. - if (npc != 2 || conf.field() != FieldType::FM128) { + if (npc != 2 || conf.field != FieldType::FM128) { return; } - auto fxp = conf.fxp_fraction_bits(); + auto fxp = conf.fxp_fraction_bits; NdArrayRef ring2k_shr[2]; int64_t numel = 100; - FieldType field = conf.field(); + FieldType field = conf.field; // how to define and achieve high pricision for e^20 std::uniform_real_distribution dist(-18.0, 15.0); @@ -598,7 +596,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2)), // npc [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param), + std::get<1>(p.param).field, std::get<2>(p.param), std::get<3>(p.param)); ; }); diff --git a/src/libspu/mpc/semi2k/state.h b/src/libspu/mpc/semi2k/state.h index 2e88c9f69..e50207640 100644 --- a/src/libspu/mpc/semi2k/state.h +++ b/src/libspu/mpc/semi2k/state.h @@ -37,38 +37,33 @@ class Semi2kState : public State { explicit Semi2kState(const RuntimeConfig& conf, const std::shared_ptr& lctx) { - if (conf.beaver_type() == RuntimeConfig_BeaverType_TrustedFirstParty) { + if (conf.beaver_type == RuntimeConfig::TrustedFirstParty) { beaver_ = std::make_unique(lctx); - } else if (conf.beaver_type() == - RuntimeConfig_BeaverType_TrustedThirdParty) { + } else if (conf.beaver_type == RuntimeConfig::TrustedThirdParty) { semi2k::BeaverTtp::Options ops; SPU_ENFORCE(conf.has_ttp_beaver_config()); - ops.server_host = conf.ttp_beaver_config().server_host(); - ops.adjust_rank = conf.ttp_beaver_config().adjust_rank(); - ops.asym_crypto_schema = conf.ttp_beaver_config().asym_crypto_schema(); + ops.server_host = conf.ttp_beaver_config->server_host; + ops.adjust_rank = conf.ttp_beaver_config->adjust_rank; + ops.asym_crypto_schema = conf.ttp_beaver_config->asym_crypto_schema; { - const auto& key = conf.ttp_beaver_config().server_public_key(); + const auto& key = conf.ttp_beaver_config->server_public_key; ops.server_public_key = yacl::Buffer(key.data(), key.size()); } - if (!conf.ttp_beaver_config().transport_protocol().empty()) { - ops.brpc_channel_protocol = - conf.ttp_beaver_config().transport_protocol(); + if (!conf.ttp_beaver_config->transport_protocol.empty()) { + ops.brpc_channel_protocol = conf.ttp_beaver_config->transport_protocol; } - if (conf.ttp_beaver_config().has_ssl_config()) { + if (conf.ttp_beaver_config->has_ssl_config()) { + auto& ssl_config = conf.ttp_beaver_config->ssl_config; brpc::ChannelSSLOptions ssl_options; - ssl_options.verify.ca_file_path = - conf.ttp_beaver_config().ssl_config().ca_file_path(); - ssl_options.verify.verify_depth = - conf.ttp_beaver_config().ssl_config().verify_depth(); - ssl_options.client_cert.certificate = - conf.ttp_beaver_config().ssl_config().certificate(); - ssl_options.client_cert.private_key = - conf.ttp_beaver_config().ssl_config().private_key(); + ssl_options.verify.ca_file_path = ssl_config->ca_file_path; + ssl_options.verify.verify_depth = ssl_config->verify_depth; + ssl_options.client_cert.certificate = ssl_config->certificate; + ssl_options.client_cert.private_key = ssl_config->private_key; ops.brpc_ssl_options = std::move(ssl_options); } beaver_ = std::make_unique(lctx, std::move(ops)); } else { - SPU_THROW("unsupported beaver type {}", conf.beaver_type()); + SPU_THROW("unsupported beaver type {}", conf.beaver_type); } beaver_cache_ = std::make_unique(); } diff --git a/src/libspu/mpc/semi2k/type.cc b/src/libspu/mpc/semi2k/type.cc index 3c076e76e..331cf64a5 100644 --- a/src/libspu/mpc/semi2k/type.cc +++ b/src/libspu/mpc/semi2k/type.cc @@ -16,6 +16,8 @@ #include +#include "magic_enum.hpp" + #include "libspu/mpc/common/pv2k.h" namespace spu::mpc::semi2k { @@ -29,4 +31,18 @@ void registerTypes() { }); } +void BShrTy::fromString(std::string_view detail) { + auto comma = detail.find_first_of(','); + auto field_str = detail.substr(0, comma); + auto nbits_str = detail.substr(comma + 1); + auto field = magic_enum::enum_cast(field_str); + SPU_ENFORCE(field.has_value(), "parse failed from={}", detail); + field_ = field.value(); + nbits_ = std::stoul(std::string(nbits_str)); +}; + +std::string BShrTy::toString() const { + return fmt::format("{},{}", magic_enum::enum_name(field()), nbits_); +} + } // namespace spu::mpc::semi2k diff --git a/src/libspu/mpc/semi2k/type.h b/src/libspu/mpc/semi2k/type.h index 712f2865d..2f2fd7e52 100644 --- a/src/libspu/mpc/semi2k/type.h +++ b/src/libspu/mpc/semi2k/type.h @@ -42,18 +42,8 @@ class BShrTy : public TypeImpl { static std::string_view getStaticId() { return "semi2k.BShr"; } - void fromString(std::string_view detail) override { - auto comma = detail.find_first_of(','); - auto field_str = detail.substr(0, comma); - auto nbits_str = detail.substr(comma + 1); - SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), - "parse failed from={}", detail); - nbits_ = std::stoul(std::string(nbits_str)); - }; - - std::string toString() const override { - return fmt::format("{},{}", FieldType_Name(field()), nbits_); - } + void fromString(std::string_view detail) override; + std::string toString() const override; bool equals(TypeObject const* other) const override { auto const* derived_other = dynamic_cast(other); diff --git a/src/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc b/src/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc index 496a7059c..7a1b99d9b 100644 --- a/src/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc +++ b/src/libspu/mpc/spdz2k/abprotocol_spdz2k_test.cc @@ -98,8 +98,8 @@ TEST_P(BooleanTest, NotB) { /* THEN */ EXPECT_VALUE_EQ(r_p, r_pp); - EXPECT_TRUE(verifyCost(obj->prot()->getKernel("not_b"), "not_b", - conf.field(), kShape, npc, cost)); + EXPECT_TRUE(verifyCost(obj->prot()->getKernel("not_b"), "not_b", conf.field, + kShape, npc, cost)); }); } @@ -129,7 +129,7 @@ TEST_P(ConversionTest, AddBB) { /* THEN */ EXPECT_VALUE_EQ(re, rp); - EXPECT_TRUE(verifyCost(obj->getKernel("add_bb"), "add_bb", conf.field(), + EXPECT_TRUE(verifyCost(obj->getKernel("add_bb"), "add_bb", conf.field, kShape, npc, cost)); }); } @@ -160,7 +160,7 @@ TEST_P(ConversionTest, AddBP) { /* THEN */ EXPECT_VALUE_EQ(re, rp); - EXPECT_TRUE(verifyCost(obj->getKernel("add_bp"), "add_bp", conf.field(), + EXPECT_TRUE(verifyCost(obj->getKernel("add_bp"), "add_bp", conf.field, kShape, npc, cost)); }); } @@ -189,8 +189,8 @@ TEST_P(ConversionTest, Bit2A) { auto p1 = a2p(obj.get(), a); /* THEN */ EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(obj->getKernel("bit2a"), "bit2a", conf.field(), - kShape, npc, cost)); + EXPECT_TRUE(verifyCost(obj->getKernel("bit2a"), "bit2a", conf.field, kShape, + npc, cost)); }); } @@ -218,8 +218,8 @@ TEST_P(ConversionTest, A2Bit) { auto p1 = b2p(obj.get(), b); /* THEN */ EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(obj->getKernel("a2bit"), "a2bit", conf.field(), - kShape, npc, cost)); + EXPECT_TRUE(verifyCost(obj->getKernel("a2bit"), "a2bit", conf.field, kShape, + npc, cost)); }); } @@ -267,7 +267,7 @@ TEST_P(ConversionTest, BitLT) { }); /* THEN */ - EXPECT_TRUE(verifyCost(obj->getKernel("bitlt_bb"), "bitlt_bb", conf.field(), + EXPECT_TRUE(verifyCost(obj->getKernel("bitlt_bb"), "bitlt_bb", conf.field, kShape, npc, cost)); }); } @@ -316,7 +316,7 @@ TEST_P(ConversionTest, BitLE) { }); /* THEN */ - EXPECT_TRUE(verifyCost(obj->getKernel("bitle_bb"), "bitle_bb", conf.field(), + EXPECT_TRUE(verifyCost(obj->getKernel("bitle_bb"), "bitle_bb", conf.field, kShape, npc, cost)); }); } @@ -342,8 +342,8 @@ TEST_P(BooleanTest, BitIntl) { auto pp1 = bitintl_b(obj.get(), p0, stride); /* THEN */ EXPECT_VALUE_EQ(p1, pp1); - EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", - conf.field(), kShape, npc, cost)); + EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", conf.field, + kShape, npc, cost)); }); } @@ -368,8 +368,8 @@ TEST_P(BooleanTest, BitDeintl) { auto pp1 = bitdeintl_b(obj.get(), p0, stride); /* THEN */ EXPECT_VALUE_EQ(p1, pp1); - EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", - conf.field(), kShape, npc, cost)); + EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", conf.field, + kShape, npc, cost)); }); } @@ -394,8 +394,8 @@ TEST_P(BooleanTest, BitIntlAndDeintl) { auto p1 = b2p(obj.get(), b1); /* THEN */ EXPECT_VALUE_EQ(p0, p1); - EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", - conf.field(), kShape, npc, cost)); + EXPECT_TRUE(verifyCost(obj->getKernel("bitintl_b"), "bitintl_b", conf.field, + kShape, npc, cost)); }); } diff --git a/src/libspu/mpc/spdz2k/protocol.cc b/src/libspu/mpc/spdz2k/protocol.cc index ffbefd558..eff9d3e1e 100644 --- a/src/libspu/mpc/spdz2k/protocol.cc +++ b/src/libspu/mpc/spdz2k/protocol.cc @@ -38,7 +38,7 @@ void regSpdz2kProtocol(SPUContext* ctx, ctx->prot()->addState(lctx); // add Z2k state. - ctx->prot()->addState(ctx->config().field()); + ctx->prot()->addState(ctx->config().field); // register public kernels. regPV2kKernels(ctx->prot()); diff --git a/src/libspu/mpc/spdz2k/protocol_ab_test.cc b/src/libspu/mpc/spdz2k/protocol_ab_test.cc index ba2ca4960..cf5d7cc93 100644 --- a/src/libspu/mpc/spdz2k/protocol_ab_test.cc +++ b/src/libspu/mpc/spdz2k/protocol_ab_test.cc @@ -20,15 +20,15 @@ namespace { RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::SEMI2K); // FIXME: - conf.set_field(field); + conf.protocol = ProtocolKind::SEMI2K; + conf.field = field; return conf; } std::unique_ptr makeMpcSpdz2kProtocol( const RuntimeConfig& rt, const std::shared_ptr& lctx) { RuntimeConfig mpc_rt = rt; - mpc_rt.set_beaver_type(RuntimeConfig_BeaverType_MultiParty); + mpc_rt.beaver_type = RuntimeConfig::MultiParty; return makeSpdz2kProtocol(mpc_rt, lctx); } @@ -42,7 +42,7 @@ INSTANTIATE_TEST_SUITE_P( makeConfig(FieldType::FM32), 2}), [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); }); // TODO : improve performance of boolean share and conversion in offline phase @@ -54,7 +54,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); }); INSTANTIATE_TEST_SUITE_P( @@ -65,7 +65,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); }); } // namespace spu::mpc::test diff --git a/src/libspu/mpc/spdz2k/protocol_api_test.cc b/src/libspu/mpc/spdz2k/protocol_api_test.cc index 3617b2fc4..add73935b 100644 --- a/src/libspu/mpc/spdz2k/protocol_api_test.cc +++ b/src/libspu/mpc/spdz2k/protocol_api_test.cc @@ -20,8 +20,8 @@ namespace { RuntimeConfig makeConfig(FieldType field) { RuntimeConfig conf; - conf.set_protocol(ProtocolKind::SEMI2K); // FIXME: - conf.set_field(field); + conf.protocol = ProtocolKind::SEMI2K; + conf.field = field; return conf; } @@ -35,7 +35,7 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(2)), // [](const testing::TestParamInfo& p) { return fmt::format("{}x{}x{}", std::get<0>(p.param).name(), - std::get<1>(p.param).field(), std::get<2>(p.param)); + std::get<1>(p.param).field, std::get<2>(p.param)); }); } // namespace spu::mpc::test diff --git a/src/libspu/mpc/spdz2k/state.h b/src/libspu/mpc/spdz2k/state.h index c5d938386..768b498e1 100644 --- a/src/libspu/mpc/spdz2k/state.h +++ b/src/libspu/mpc/spdz2k/state.h @@ -84,13 +84,13 @@ class Spdz2kState : public State { explicit Spdz2kState(const RuntimeConfig& conf, std::shared_ptr lctx) - : data_field_(conf.field()) { - if (conf.beaver_type() == RuntimeConfig_BeaverType_TrustedFirstParty) { + : data_field_(conf.field) { + if (conf.beaver_type == RuntimeConfig::TrustedFirstParty) { beaver_ = std::make_unique(lctx); - } else if (conf.beaver_type() == RuntimeConfig_BeaverType_MultiParty) { + } else if (conf.beaver_type == RuntimeConfig::MultiParty) { beaver_ = std::make_unique(lctx); } else { - SPU_THROW("unsupported beaver type {}", conf.beaver_type()); + SPU_THROW("unsupported beaver type {}", conf.beaver_type); } lctx_ = lctx; runtime_field_ = getRuntimeField(data_field_); diff --git a/src/libspu/mpc/spdz2k/type.cc b/src/libspu/mpc/spdz2k/type.cc index d84b53d99..88363a2cc 100644 --- a/src/libspu/mpc/spdz2k/type.cc +++ b/src/libspu/mpc/spdz2k/type.cc @@ -16,6 +16,8 @@ #include +#include "magic_enum.hpp" + #include "libspu/mpc/common/pv2k.h" namespace spu::mpc::spdz2k { @@ -30,4 +32,24 @@ void registerTypes() { }); } +void BShrTy::fromString(std::string_view detail) { + auto comma = detail.find_first_of(','); + auto last_comma = detail.find_last_of(','); + auto back_type_str = detail.substr(0, comma); + auto nbits_str = detail.substr(comma + 1, last_comma); + auto back_type = magic_enum::enum_cast(back_type_str); + SPU_ENFORCE(back_type.has_value(), "parse failed from={}", detail); + back_type_ = back_type.value(); + nbits_ = std::stoul(std::string(nbits_str)); + auto field_str = detail.substr(last_comma + 1); + auto field = magic_enum::enum_cast(field_str); + SPU_ENFORCE(field.has_value(), "parse failed from={}", field_str); + field_ = field.value(); +}; + +std::string BShrTy::toString() const { + return fmt::format("{},{},{}", magic_enum::enum_name(back_type_), nbits_, + field_); +} + } // namespace spu::mpc::spdz2k diff --git a/src/libspu/mpc/spdz2k/type.h b/src/libspu/mpc/spdz2k/type.h index 8b43301d1..019d507d2 100644 --- a/src/libspu/mpc/spdz2k/type.h +++ b/src/libspu/mpc/spdz2k/type.h @@ -61,22 +61,8 @@ class BShrTy : public TypeImpl { static std::string_view getStaticId() { return "spdz2k.BShr"; } - void fromString(std::string_view detail) override { - auto comma = detail.find_first_of(','); - auto last_comma = detail.find_last_of(','); - auto back_type_str = detail.substr(0, comma); - auto nbits_str = detail.substr(comma + 1, last_comma); - SPU_ENFORCE(PtType_Parse(std::string(back_type_str), &back_type_), - "parse failed from={}", back_type_str); - nbits_ = std::stoul(std::string(nbits_str)); - auto field_str = detail.substr(last_comma + 1); - SPU_ENFORCE(FieldType_Parse(std::string(field_str), &field_), - "parse failed from={}", field_str); - }; - - std::string toString() const override { - return fmt::format("{},{},{}", PtType_Name(back_type_), nbits_, field_); - } + void fromString(std::string_view detail) override; + std::string toString() const override; size_t nbits() const { return nbits_; } diff --git a/src/libspu/mpc/tools/benchmark.h b/src/libspu/mpc/tools/benchmark.h index ce64c17c4..78c96c987 100644 --- a/src/libspu/mpc/tools/benchmark.h +++ b/src/libspu/mpc/tools/benchmark.h @@ -62,7 +62,7 @@ void MPCBenchMark(benchmark::State& state) { const size_t npc = BenchConfig::bench_npc; const auto field = static_cast(state.range(0)); RuntimeConfig conf; - conf.set_field(field); + conf.field = field; auto func = [&](std::shared_ptr lctx) { auto obj = BenchConfig::bench_factory(conf, lctx); if (!obj->hasKernel(OpData::op_name)) { diff --git a/src/libspu/mpc/tools/complexity.cc b/src/libspu/mpc/tools/complexity.cc index e6929d157..57c3ab14f 100644 --- a/src/libspu/mpc/tools/complexity.cc +++ b/src/libspu/mpc/tools/complexity.cc @@ -49,8 +49,8 @@ internal::SingleComplexityReport dumpComplexityReport( fmt::print("{:<15}, {:<20}, {:<20}\n", "name", "latency", "comm"); RuntimeConfig rt_conf; - rt_conf.set_protocol(protocol); - rt_conf.set_field(FM64); + rt_conf.protocol = protocol; + rt_conf.field = FM64; utils::simulate( party_cnt, [&](const std::shared_ptr& lctx) -> void { diff --git a/src/libspu/mpc/utils/tiling_util.h b/src/libspu/mpc/utils/tiling_util.h index 8caaf2af9..53b91a5da 100644 --- a/src/libspu/mpc/utils/tiling_util.h +++ b/src/libspu/mpc/utils/tiling_util.h @@ -30,9 +30,9 @@ namespace spu::mpc { template Value tiled(Fn&& fn, SPUContext* ctx, const Value& x, Args&&... args) { const int64_t kBlockSize = kMinTaskSize; - if (!ctx->config().experimental_enable_intra_op_par() // - || !ctx->prot()->hasLowCostFork() // - || x.numel() <= kBlockSize // + if (!ctx->config().experimental_enable_intra_op_par // + || !ctx->prot()->hasLowCostFork() // + || x.numel() <= kBlockSize // ) { return fn(ctx, x, std::forward(args)...); } @@ -145,9 +145,9 @@ Value tiled(Fn&& fn, SPUContext* ctx, const Value& x, const Value& y, SPU_ENFORCE(x.shape() == y.shape()); const int64_t kBlockSize = kMinTaskSize; - if (!ctx->config().experimental_enable_intra_op_par() // - || !ctx->prot()->hasLowCostFork() // - || x.numel() <= kBlockSize // + if (!ctx->config().experimental_enable_intra_op_par // + || !ctx->prot()->hasLowCostFork() // + || x.numel() <= kBlockSize // ) { return fn(ctx, x, y, std::forward(args)...); } diff --git a/src/libspu/spu.cc b/src/libspu/spu.cc new file mode 100644 index 000000000..3b852130a --- /dev/null +++ b/src/libspu/spu.cc @@ -0,0 +1,279 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "libspu/spu.h" + +#include + +#include "google/protobuf/json/json.h" + +#include "libspu/spu.pb.h" +namespace spu { + +void convertFromPB(const pb::RuntimeConfig& src, RuntimeConfig& dst) { + dst.protocol = ProtocolKind(src.protocol()); + dst.field = FieldType(src.field()); + dst.fxp_fraction_bits = src.fxp_fraction_bits(); + dst.max_concurrency = src.max_concurrency(); + dst.enable_action_trace = src.enable_action_trace(); + dst.enable_type_checker = src.enable_type_checker(); + dst.enable_pphlo_trace = src.enable_pphlo_trace(); + dst.enable_runtime_snapshot = src.enable_runtime_snapshot(); + dst.snapshot_dump_dir = src.snapshot_dump_dir(); + dst.enable_pphlo_profile = src.enable_pphlo_profile(); + dst.enable_hal_profile = src.enable_hal_profile(); + dst.public_random_seed = src.public_random_seed(); + dst.share_max_chunk_size = src.share_max_chunk_size(); + dst.sort_method = RuntimeConfig::SortMethod(src.sort_method()); + dst.quick_sort_threshold = src.quick_sort_threshold(); + dst.fxp_div_goldschmidt_iters = src.fxp_div_goldschmidt_iters(); + dst.fxp_exp_mode = RuntimeConfig::ExpMode(src.fxp_exp_mode()); + dst.fxp_exp_iters = src.fxp_exp_iters(); + dst.fxp_log_mode = RuntimeConfig::LogMode(src.fxp_log_mode()); + dst.fxp_log_iters = src.fxp_log_iters(); + dst.fxp_log_orders = src.fxp_log_orders(); + dst.sigmoid_mode = RuntimeConfig::SigmoidMode(src.sigmoid_mode()); + dst.enable_lower_accuracy_rsqrt = src.enable_lower_accuracy_rsqrt(); + dst.sine_cosine_iters = src.sine_cosine_iters(); + dst.beaver_type = RuntimeConfig::BeaverType(src.beaver_type()); + dst.trunc_allow_msb_error = src.trunc_allow_msb_error(); + dst.experimental_disable_mmul_split = src.experimental_disable_mmul_split(); + dst.experimental_enable_inter_op_par = src.experimental_enable_inter_op_par(); + dst.experimental_enable_intra_op_par = src.experimental_enable_intra_op_par(); + dst.experimental_disable_vectorization = + src.experimental_disable_vectorization(); + dst.experimental_inter_op_concurrency = + src.experimental_inter_op_concurrency(); + dst.experimental_enable_colocated_optimization = + src.experimental_enable_colocated_optimization(); + dst.experimental_enable_exp_prime = src.experimental_enable_exp_prime(); + dst.experimental_exp_prime_offset = src.experimental_exp_prime_offset(); + dst.experimental_exp_prime_disable_lower_bound = + src.experimental_exp_prime_disable_lower_bound(); + dst.experimental_exp_prime_enable_upper_bound = + src.experimental_exp_prime_enable_upper_bound(); + + if (src.has_ttp_beaver_config()) { + auto ttp_conf = src.ttp_beaver_config(); + std::unique_ptr ssl_config; + if (ttp_conf.has_ssl_config()) { + ssl_config = std::make_unique( + ttp_conf.ssl_config().certificate(), + ttp_conf.ssl_config().private_key(), + ttp_conf.ssl_config().ca_file_path(), + ttp_conf.ssl_config().verify_depth()); + } + dst.ttp_beaver_config = std::make_unique( + ttp_conf.server_host(), ttp_conf.adjust_rank(), + ttp_conf.asym_crypto_schema(), ttp_conf.server_public_key(), + ttp_conf.transport_protocol(), std::move(ssl_config)); + } + + if (src.has_cheetah_2pc_config()) { + dst.cheetah_2pc_config = + CheetahConfig(src.cheetah_2pc_config().disable_matmul_pack(), + src.cheetah_2pc_config().enable_mul_lsb_error(), + CheetahOtKind(src.cheetah_2pc_config().ot_kind())); + } +} + +void convertToPB(const RuntimeConfig& src, pb::RuntimeConfig& dst) { + dst.set_protocol(pb::ProtocolKind(src.protocol)); + dst.set_field(pb::FieldType(src.field)); + dst.set_fxp_fraction_bits(src.fxp_fraction_bits); + dst.set_max_concurrency(src.max_concurrency); + dst.set_enable_action_trace(src.enable_action_trace); + dst.set_enable_type_checker(src.enable_type_checker); + dst.set_enable_pphlo_trace(src.enable_pphlo_trace); + dst.set_enable_runtime_snapshot(src.enable_runtime_snapshot); + dst.set_snapshot_dump_dir(src.snapshot_dump_dir); + dst.set_enable_pphlo_profile(src.enable_pphlo_profile); + dst.set_enable_hal_profile(src.enable_hal_profile); + dst.set_public_random_seed(src.public_random_seed); + dst.set_share_max_chunk_size(src.share_max_chunk_size); + dst.set_sort_method(pb::RuntimeConfig::SortMethod(src.sort_method)); + dst.set_quick_sort_threshold(src.quick_sort_threshold); + dst.set_fxp_div_goldschmidt_iters(src.fxp_div_goldschmidt_iters); + dst.set_fxp_exp_mode(pb::RuntimeConfig::ExpMode(src.fxp_exp_mode)); + dst.set_fxp_exp_iters(src.fxp_exp_iters); + dst.set_fxp_log_mode(pb::RuntimeConfig::LogMode(src.fxp_log_mode)); + dst.set_fxp_log_iters(src.fxp_log_iters); + dst.set_fxp_log_orders(src.fxp_log_orders); + dst.set_sigmoid_mode(pb::RuntimeConfig::SigmoidMode(src.sigmoid_mode)); + dst.set_enable_lower_accuracy_rsqrt(src.enable_lower_accuracy_rsqrt); + dst.set_sine_cosine_iters(src.sine_cosine_iters); + dst.set_beaver_type(pb::RuntimeConfig::BeaverType(src.beaver_type)); + if (src.ttp_beaver_config) { + auto ttp_conf = dst.mutable_ttp_beaver_config(); + ttp_conf->set_server_host(src.ttp_beaver_config->server_host); + ttp_conf->set_adjust_rank(src.ttp_beaver_config->adjust_rank); + ttp_conf->set_asym_crypto_schema(src.ttp_beaver_config->asym_crypto_schema); + ttp_conf->set_server_public_key(src.ttp_beaver_config->server_public_key); + ttp_conf->set_transport_protocol(src.ttp_beaver_config->transport_protocol); + if (src.ttp_beaver_config->ssl_config) { + auto ssl_config = ttp_conf->mutable_ssl_config(); + ssl_config->set_certificate( + src.ttp_beaver_config->ssl_config->certificate); + ssl_config->set_private_key( + src.ttp_beaver_config->ssl_config->private_key); + ssl_config->set_ca_file_path( + src.ttp_beaver_config->ssl_config->ca_file_path); + ssl_config->set_verify_depth( + src.ttp_beaver_config->ssl_config->verify_depth); + } + } + if (src.protocol == ProtocolKind::CHEETAH) { + auto cheetah_conf = dst.mutable_cheetah_2pc_config(); + cheetah_conf->set_disable_matmul_pack( + src.cheetah_2pc_config.disable_matmul_pack); + cheetah_conf->set_enable_mul_lsb_error( + src.cheetah_2pc_config.enable_mul_lsb_error); + cheetah_conf->set_ot_kind( + pb::CheetahOtKind(src.cheetah_2pc_config.ot_kind)); + } + dst.set_trunc_allow_msb_error(src.trunc_allow_msb_error); + dst.set_experimental_disable_mmul_split(src.experimental_disable_mmul_split); + dst.set_experimental_enable_inter_op_par( + src.experimental_enable_inter_op_par); + dst.set_experimental_enable_intra_op_par( + src.experimental_enable_intra_op_par); + dst.set_experimental_disable_vectorization( + src.experimental_disable_vectorization); + dst.set_experimental_inter_op_concurrency( + src.experimental_inter_op_concurrency); + dst.set_experimental_enable_colocated_optimization( + src.experimental_enable_colocated_optimization); + dst.set_experimental_enable_exp_prime(src.experimental_enable_exp_prime); + dst.set_experimental_exp_prime_offset(src.experimental_exp_prime_offset); + dst.set_experimental_exp_prime_disable_lower_bound( + src.experimental_exp_prime_disable_lower_bound); + dst.set_experimental_exp_prime_enable_upper_bound( + src.experimental_exp_prime_enable_upper_bound); +} + +RuntimeConfig::RuntimeConfig(const spu::pb::RuntimeConfig& pb_conf) { + convertFromPB(pb_conf, *this); +} + +std::string RuntimeConfig::SerializeAsString() const { + pb::RuntimeConfig pb_conf; + convertToPB(*this, pb_conf); + return pb_conf.SerializeAsString(); +} + +std::string RuntimeConfig::DebugString() const { + pb::RuntimeConfig pb_conf; + convertToPB(*this, pb_conf); + return pb_conf.DebugString(); +} + +bool RuntimeConfig::ParseFromJsonString(std::string_view data) { + pb::RuntimeConfig pb_conf; + auto status = google::protobuf::json::JsonStringToMessage(data, &pb_conf); + if (!status.ok()) return false; + convertFromPB(pb_conf, *this); + return true; +} + +bool RuntimeConfig::ParseFromString(std::string_view data) { + pb::RuntimeConfig pb_conf; + if (!pb_conf.ParseFromString(data)) return false; + convertFromPB(pb_conf, *this); + return true; +} + +bool ExecutableProto::ParseFromString(std::string_view data) { + pb::ExecutableProto pb_exec; + if (!pb_exec.ParseFromString(data)) return false; + name = pb_exec.name(); + input_names = {pb_exec.input_names().begin(), pb_exec.input_names().end()}; + output_names = {pb_exec.output_names().begin(), pb_exec.output_names().end()}; + code = pb_exec.code(); + return true; +} + +std::string ExecutableProto::SerializeAsString() const { + pb::ExecutableProto pb_exec; + pb_exec.set_name(name); + for (const auto& in : input_names) { + pb_exec.add_input_names(in); + } + for (const auto& out : output_names) { + pb_exec.add_output_names(out); + } + pb_exec.set_code(code); + return pb_exec.SerializeAsString(); +} + +#if __cplusplus < 202002L +bool CompilationSource::operator==(const CompilationSource& other) const { + return ir_type == other.ir_type && ir_txt == other.ir_txt && + input_visibility == other.input_visibility; +} + +bool CompilerOptions::operator==(const CompilerOptions& other) const { + return enable_pretty_print == other.enable_pretty_print && + pretty_print_dump_dir == other.pretty_print_dump_dir && + xla_pp_kind == other.xla_pp_kind && + disable_sqrt_plus_epsilon_rewrite == + other.disable_sqrt_plus_epsilon_rewrite && + disable_div_sqrt_rewrite == other.disable_div_sqrt_rewrite && + disable_reduce_truncation_optimization == + other.disable_reduce_truncation_optimization && + disable_maxpooling_optimization == + other.disable_maxpooling_optimization && + disallow_mix_types_opts == other.disallow_mix_types_opts && + disable_select_optimization == other.disable_select_optimization && + enable_optimize_denominator_with_broadcast == + other.enable_optimize_denominator_with_broadcast && + disable_deallocation_insertion == + other.disable_deallocation_insertion && + disable_partial_sort_optimization == + other.disable_partial_sort_optimization; +} +#endif +}; // namespace spu + +namespace std { +template +inline void hash_combine(std::size_t& seed, const T& v, const Rest&... rest) { + seed ^= std::hash{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + (hash_combine(seed, rest), ...); +} + +std::size_t hash::operator()( + const spu::CompilationSource& cs) const { + std::size_t seed = 0; + hash_combine(seed, cs.ir_type, cs.ir_txt); + for (const auto& v : cs.input_visibility) { + hash_combine(seed, v); + } + + return seed; +} + +std::size_t hash::operator()( + const spu::CompilerOptions& co) const { + std::size_t seed = 0; + hash_combine( + seed, co.enable_pretty_print, co.pretty_print_dump_dir, co.xla_pp_kind, + co.disable_sqrt_plus_epsilon_rewrite, co.disable_div_sqrt_rewrite, + co.disable_reduce_truncation_optimization, + co.disable_maxpooling_optimization, co.disallow_mix_types_opts, + co.disable_select_optimization, + co.enable_optimize_denominator_with_broadcast, + co.disable_deallocation_insertion, co.disable_partial_sort_optimization); + return seed; +} +}; // namespace std \ No newline at end of file diff --git a/src/libspu/spu.h b/src/libspu/spu.h new file mode 100644 index 000000000..fd729eea3 --- /dev/null +++ b/src/libspu/spu.h @@ -0,0 +1,560 @@ +// Copyright 2024 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace spu { + +namespace pb { +class RuntimeConfig; +} // namespace pb + +// The SPU datatype +enum DataType { + DT_INVALID = 0, + + DT_I1 = 1, // 1bit integer (bool). + DT_I8 = 2, // int8 + DT_U8 = 3, // uint8 + DT_I16 = 4, // int16 + DT_U16 = 5, // uint16 + DT_I32 = 6, // int32 + DT_U32 = 7, // uint32 + DT_I64 = 8, // int64 + DT_U64 = 9, // uint64 + DT_F16 = 10, // half + DT_F32 = 11, // float + DT_F64 = 12, // double +}; + +// The visibility type. +// +// SPU is a secure evaluation runtime, but not all data are secret, some of them +// are publicly known to all parties, marking them as public will improve +// performance significantly. +enum Visibility { + VIS_INVALID = 0, + VIS_SECRET = 1, // Invisible(unknown) for all or some of the parties. + VIS_PUBLIC = 2, // Visible(public) for all parties. + VIS_PRIVATE = 3, // Visible for only one party +}; + +// Plaintext type +// +// SPU runtime does not process with plaintext directly, plaintext type is +// mainly used for IO purposes, when converting a plaintext buffer to an SPU +// buffer, we have to let spu know which type the plaintext buffer is. +enum PtType { + PT_INVALID = 0, // + PT_I8 = 1, // int8_t + PT_U8 = 2, // uint8_t + PT_I16 = 3, // int16_t + PT_U16 = 4, // uint16_t + PT_I32 = 5, // int32_t + PT_U32 = 6, // uint32_t + PT_I64 = 7, // int64_t + PT_U64 = 8, // uint64_t + PT_I128 = 9, // int128_t + PT_U128 = 10, // uint128_t + PT_I1 = 11, // bool + + PT_F16 = 30, // half + PT_F32 = 31, // float + PT_F64 = 32, // double + + PT_CF32 = 50, // complex float + PT_CF64 = 51, // complex double +}; + +// A security parameter type. +// +// The secure evaluation is based on some algebraic structure (ring or field), +enum FieldType { + FT_INVALID = 0, + + FM32 = 1, // Ring 2^32 + FM64 = 2, // Ring 2^64 + FM128 = 3, // Ring 2^128 +}; + +// The protocol kind. +enum ProtocolKind { + // Invalid protocol. + PROT_INVALID = 0, + + // The reference implementation in `ring^2k`, note: this 'protocol' only + // behave-like a fixed point secure protocol without any security guarantee. + // Hence, it should only be selected for debugging purposes. + REF2K = 1, + + // A semi-honest multi-party protocol. This protocol requires a trusted third + // party to generate the offline correlated randoms. Currently, SecretFlow by + // default ships this protocol with a trusted first party. Hence, it should + // only be used for debugging purposes. + SEMI2K = 2, + + // A honest majority 3PC-protocol. SecretFlow provides the semi-honest + // implementation without Yao. + ABY3 = 3, + + // The famous [Cheetah](https://eprint.iacr.org/2022/207) protocol, a very + // fast 2PC protocol. + CHEETAH = 4, + + // A semi-honest 3PC-protocol for Neural Network, P2 as the helper, + // (https://eprint.iacr.org/2018/442) + SECURENN = 5, +}; + +////////////////////////////////////////////////////////////////////////// +// Runtime configuration +////////////////////////////////////////////////////////////////////////// +struct ClientSSLConfig { + // Certificate in PEM format, supported both file path and raw string + std::string certificate; + // Private key in PEM format, supported both file path and raw string based on + // prefix + std::string private_key; + // The trusted CA file to verify the peer's certificate + // If empty, use the system default CA files + std::string ca_file_path; + // Maximum depth of the certificate chain for verification + // If 0, turn off the verification + int32_t verify_depth; + + ClientSSLConfig() = default; + ClientSSLConfig(std::string certificate, std::string private_key, + std::string ca_file_path, int32_t verify_depth) + : certificate(std::move(certificate)), + private_key(std::move(private_key)), + ca_file_path(std::move(ca_file_path)), + verify_depth(verify_depth) {} +}; + +struct TTPBeaverConfig { + // TrustedThirdParty beaver server's remote ip:port or load-balance uri. + std::string server_host; + // which rank do adjust rpc call, usually choose the rank closer to the + // server. + int32_t adjust_rank = 0; + + // asym_crypto_schema: support ["SM2"] + // Will support 25519 in the future, after yacl supported it. + std::string asym_crypto_schema; + // Server's public key in PEM format + std::string server_public_key; + + // Transport protocol, support ["http", "h2"] + std::string transport_protocol; + + // Configurations related to SSL + std::shared_ptr ssl_config; + + bool has_ssl_config() const { return ssl_config != nullptr; } + + TTPBeaverConfig() = default; + TTPBeaverConfig(std::string server_host, int32_t adjust_rank, + std::string asym_crypto_schema, std::string server_public_key, + std::string transport_protocol, + std::shared_ptr ssl_config = nullptr) + : server_host(std::move(server_host)), + adjust_rank(adjust_rank), + asym_crypto_schema(std::move(asym_crypto_schema)), + server_public_key(std::move(server_public_key)), + transport_protocol(std::move(transport_protocol)), + ssl_config(std::move(ssl_config)) {} +}; + +enum CheetahOtKind { YACL_Ferret = 0, YACL_Softspoken = 1, EMP_Ferret = 2 }; + +struct CheetahConfig { + // disable the ciphertext packing for matmul + bool disable_matmul_pack; + // allow least significant bits error for point-wise mul + bool enable_mul_lsb_error; + // Setup for cheetah ot + CheetahOtKind ot_kind; + + CheetahConfig() = default; + CheetahConfig(bool disable_matmul_pack, bool enable_mul_lsb_error, + CheetahOtKind ot_kind) + : disable_matmul_pack(disable_matmul_pack), + enable_mul_lsb_error(enable_mul_lsb_error), + ot_kind(ot_kind) {} +}; + +// The SPU runtime configuration. +struct RuntimeConfig { + static const uint64_t kDefaultShareMaxChunkSize = 128 * 1024 * 1024; + static const int64_t kDefaultQuickSortThreshold = 32; + static const int64_t kDefaultFxpDivGoldschmidtIters = 2; + static const int64_t kDefaultFxpExpIters = 8; + static const int64_t kDefaultFxpLogIters = 3; + static const int64_t kDefaultFxpLogOrders = 8; + static const int64_t kDefaultSineCosineIters = 10; + static const uint64_t kDefaultExperimentalInterOpConcurrency = 8; + /////////////////////////////////////// + // Basic + /////////////////////////////////////// + + // The protocol kind. + ProtocolKind protocol = PROT_INVALID; + + // The field type. + FieldType field = FT_INVALID; + + // Number of fraction bits of fixed-point number. + // 0(default) indicates implementation defined. + int64_t fxp_fraction_bits = 0; + + // Max number of cores + int32_t max_concurrency = 0; + + /////////////////////////////////////// + // Advanced + /////////////////////////////////////// + + // @exclude + // Runtime related, reserved for [10, 50) + + // When enabled, runtime prints verbose info of the call stack, debug purpose + // only. + bool enable_action_trace = false; + + // When enabled, runtime checks runtime type infos against the + // compile-time ones, exceptions are raised if mismatches happen. Note: + // Runtime outputs prefer runtime type infos even when flag is on. + bool enable_type_checker = false; + + // When enabled, runtime prints executed pphlo list, debug purpose only. + bool enable_pphlo_trace = false; + + // When enabled, runtime dumps executed executables in the dump_dir, debug + // purpose only. + bool enable_runtime_snapshot = false; + std::string snapshot_dump_dir; + + // When enabled, runtime records detailed pphlo timing data, debug purpose + // only. + // WARNING: the `send bytes` information is only accurate when + // `experimental_enable_inter_op_par` and `experimental_enable_intra_op_par` + // options are disabled. + bool enable_pphlo_profile = false; + + // When enabled, runtime records detailed hal timing data, debug purpose only. + // WARNING: the `send bytes` information is only accurate when + // `experimental_enable_inter_op_par` and `experimental_enable_intra_op_par` + // options are disabled. + bool enable_hal_profile = false; + + // The public random variable generated by the runtime, the concrete prg + // function is implementation defined. + // Note: this seed only applies to `public variable` only, it has nothing + // to do with security. + uint64_t public_random_seed = 0; + + // max chunk size for Value::toProto + // default: 128 * 1024 * 1024 + uint64_t share_max_chunk_size = kDefaultShareMaxChunkSize; + + enum SortMethod { + SORT_DEFAULT = 0, // Implementation defined. + SORT_RADIX = 1, // The radix sort (stable sort, need efficient shuffle). + SORT_QUICK = 2, // The quick sort (unstable, need efficient shuffle). + SORT_NETWORK = 3, // The odd-even sorting network (unstable, most general). + }; + + // SPU supports multiple sorting algorithms. + // -for 2pc, only sorting network is supported. + // -for 2.5pc or 3pc, all these algorithms are supported. + // -for stable sort, only radix sort is supported. + SortMethod sort_method = SORT_DEFAULT; + + // threshold for quick sort, when the size of the array is less than this + // value, use merge sort instead + int64_t quick_sort_threshold = kDefaultQuickSortThreshold; + + // @exclude + // Fixed-point arithmetic related, reserved for [50, 100) + + // The iterations use in f_div with Goldschmidt method. + // 0(default) indicates implementation defined. + int64_t fxp_div_goldschmidt_iters = kDefaultFxpDivGoldschmidtIters; + + // The exponential approximation method. + enum ExpMode { + EXP_DEFAULT = 0, // Implementation defined. + EXP_PADE = 1, // The pade approximation. + EXP_TAYLOR = 2, // Taylor series approximation. + EXP_PRIME = 3, // exp prime only available for some implementations + }; + + // The exponent approximation method. + ExpMode fxp_exp_mode = EXP_DEFAULT; + + // Number of iterations of `exp` approximation, 0(default) indicates impl + // defined. + int64_t fxp_exp_iters = kDefaultFxpExpIters; + + // The logarithm approximation method. + enum LogMode { + LOG_DEFAULT = 0, // Implementation defined. + LOG_PADE = 1, // The pade approximation. + LOG_NEWTON = 2, // The newton approximation. + LOG_MINMAX = 3, // The minmax approximation. + }; + + // The logarithm approximation method. + LogMode fxp_log_mode = LOG_DEFAULT; + + // Number of iterations of `log` approximation, 0(default) indicates + // impl-defined. + int64_t fxp_log_iters = kDefaultFxpLogIters; + + // Number of orders of `log` approximation, 0(default) indicates impl defined. + int64_t fxp_log_orders = kDefaultFxpLogOrders; + + // The sigmoid approximation method. + enum SigmoidMode { + // Implementation defined. + SIGMOID_DEFAULT = 0, + // Minmax approximation one order. + // f(x) = 0.5 + 0.125 * x + SIGMOID_MM1 = 1, + // Piece-wise simulation. + // f(x) = 0.5 + 0.125x if -4 <= x <= 4 + // 1 if x > 4 + // 0 if -4 > x + SIGMOID_SEG3 = 2, + // The real definition, which depends on exp's accuracy. + // f(x) = 1 / (1 + exp(-x)) + SIGMOID_REAL = 3, + }; + + // The sigmoid function approximation model. + SigmoidMode sigmoid_mode = SIGMOID_DEFAULT; + + // Enable a simpler rsqrt approximation + bool enable_lower_accuracy_rsqrt = false; + + // Sine/Cosine approximation iterations + int64_t sine_cosine_iters = kDefaultSineCosineIters; + + /// - MPC protocol related definitions. + + enum BeaverType { + // Assume first party (rank0) as trusted party to generate beaver triple. + // WARNING: It is NOT SAFE and SHOULD NOT BE used in production. + TrustedFirstParty = 0, + // Generate beaver triple through an additional trusted third party. + TrustedThirdParty = 1, + // Generate beaver triple through multi-party. + MultiParty = 2, + }; + // beaver config, works for semi2k and spdz2k for now. + BeaverType beaver_type = TrustedFirstParty; + + // TrustedThirdParty configs. + std::shared_ptr ttp_beaver_config; + + // Cheetah 2PC configs. + CheetahConfig cheetah_2pc_config; + + // For protocol like SecureML, the most significant bit may have error with + // low probability, which lead to huge calculation error. + bool trunc_allow_msb_error = false; + + /// System related configurations start. + + // Experimental: DO NOT USE + bool experimental_disable_mmul_split = false; + // Inter op parallel, aka, DAG level parallel. + bool experimental_enable_inter_op_par = false; + // Intra op parallel, aka, hal/mpc level parallel. + bool experimental_enable_intra_op_par = false; + // Disable kernel level vectorization. + bool experimental_disable_vectorization = false; + // Inter op concurrency. + uint64_t experimental_inter_op_concurrency = + kDefaultExperimentalInterOpConcurrency; + // Enable use of private type + bool experimental_enable_colocated_optimization = false; + + // enable experimental exp prime method + bool experimental_enable_exp_prime = false; + + // The offset parameter for exp prime methods. + // control the valid range of exp prime method. + // valid range is: + // ((47 - offset - 2fxp)/log_2(e), (125 - 2fxp - offset)/log_2(e)) + // clamp to value would be + // lower bound: (48 - offset - 2fxp)/log_2(e) + // higher bound: (124 - 2fxp - offset)/log_2(e) + // default offset is 13, 0 offset is not supported. + uint32_t experimental_exp_prime_offset = 0; + // whether to apply the clamping lower bound + // default to enable it + bool experimental_exp_prime_disable_lower_bound = false; + // whether to apply the clamping upper bound + // default to disable it + bool experimental_exp_prime_enable_upper_bound = false; + + // static RuntimeConfig makeFromJson(const std::string& json_str); + + RuntimeConfig() = default; + RuntimeConfig(ProtocolKind protocol, FieldType field, + int64_t fxp_fraction_bits = 0) + : protocol(protocol), + field(field), + fxp_fraction_bits(fxp_fraction_bits){}; + RuntimeConfig(const RuntimeConfig& other) = default; + explicit RuntimeConfig(const pb::RuntimeConfig& pb_conf); + + bool has_ttp_beaver_config() const { return ttp_beaver_config != nullptr; } + + bool ParseFromJsonString(std::string_view data); + bool ParseFromString(std::string_view data); + std::string SerializeAsString() const; + std::string DebugString() const; +}; + +////////////////////////////////////////////////////////////////////////// +// Compiler relate definition +////////////////////////////////////////////////////////////////////////// +enum SourceIRType { XLA = 0, STABLEHLO = 1 }; + +struct CompilationSource { + // Input IR type + SourceIRType ir_type; + + // IR + std::string ir_txt; + + // Input visibilities + std::vector input_visibility; + + CompilationSource() = default; + CompilationSource(SourceIRType ir_type, std::string ir_txt, + std::vector input_visibility) + : ir_type(ir_type), + ir_txt(std::move(ir_txt)), + input_visibility(std::move(input_visibility)) {} + +#if __cplusplus >= 202002L + bool operator==(const CompilationSource& other) const = default; +#else + bool operator==(const CompilationSource& other) const; +#endif +}; + +enum XLAPrettyPrintKind { TEXT = 0, DOT = 1, HTML = 2 }; + +struct CompilerOptions { + // Pretty print + bool enable_pretty_print = false; + std::string pretty_print_dump_dir; + XLAPrettyPrintKind xla_pp_kind = XLAPrettyPrintKind::TEXT; + + // Disable sqrt(x) + eps to sqrt(x+eps) rewrite + bool disable_sqrt_plus_epsilon_rewrite = false; + + // Disable x/sqrt(y) to x*rsqrt(y) rewrite + bool disable_div_sqrt_rewrite = false; + + // Disable reduce truncation optimization + bool disable_reduce_truncation_optimization = false; + + // Disable maxpooling optimization + bool disable_maxpooling_optimization = false; + + // Disallow mix type operations + bool disallow_mix_types_opts = false; + + // Disable SelectOp optimization + bool disable_select_optimization = false; + + // Enable optimize x/bcast(y) -> x * bcast(1/y) + bool enable_optimize_denominator_with_broadcast = false; + + // Disable deallocation insertion pass + bool disable_deallocation_insertion = false; + + // Disable sort->topk rewrite when only partial sort is required + bool disable_partial_sort_optimization = false; + +#if __cplusplus >= 202002L + bool operator==(const CompilerOptions& other) const = default; +#else + bool operator==(const CompilerOptions& other) const; +#endif +}; + +// The executable format accepted by SPU runtime. +// +// - Inputs should be prepared before running executable. +// - Output is maintained after execution, and can be fetched by output name. +// +// ```python +// rt = spu.Runtime(...) # create an spu runtime. +// rt.set_var('x', ...) # set variable to the runtime. +// exe = spu.ExecutableProto( # prepare the executable. +// name = 'balabala', +// input_names = ['x'], +// output_names = ['y'], +// code = ...) +// rt.run(exe) # run the executable. +// y = rt.get_var('y') # get the executable from spu runtime. +// ``` +struct ExecutableProto { + // The name of the executable. + std::string name; + + // The input names. + std::vector input_names; + + // The output names. + std::vector output_names; + + // The bytecode of the program, with format IR_MLIR_SPU. + std::string code; + + ExecutableProto() = default; + ExecutableProto(std::string name, std::vector input_names, + std::vector output_names, std::string code) + : name(std::move(name)), + input_names(std::move(input_names)), + output_names(std::move(output_names)), + code(std::move(code)) {} + + bool ParseFromString(std::string_view data); + std::string SerializeAsString() const; +}; +}; // namespace spu + +namespace std { +template <> +struct hash { + std::size_t operator()(const spu::CompilationSource& cs) const; +}; +template <> +struct hash { + std::size_t operator()(const spu::CompilerOptions& co) const; +}; +}; // namespace std diff --git a/src/libspu/spu.proto b/src/libspu/spu.proto index 5a4adbe2e..00ddc3744 100644 --- a/src/libspu/spu.proto +++ b/src/libspu/spu.proto @@ -22,7 +22,7 @@ syntax = "proto3"; -package spu; +package spu.pb; option java_package = "org.secretflow.spu"; diff --git a/src/libspu/version.h b/src/libspu/version.h index 75a1d7f9d..e8f0c4ae9 100644 --- a/src/libspu/version.h +++ b/src/libspu/version.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define SPU_VERSION "0.9.4.dev20250123" +#define SPU_VERSION "0.9.4.dev20250209" #include diff --git a/version.bzl b/version.bzl index 462608359..3330ed92a 100644 --- a/version.bzl +++ b/version.bzl @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -SPU_VERSION = "0.9.4.dev20250123" +SPU_VERSION = "0.9.4.dev20250209"