Skip to content

Commit

Permalink
repo-sync-2025-02-06T20:02:27+0800 (#988)
Browse files Browse the repository at this point in the history
  • Loading branch information
w-gc authored Feb 10, 2025
1 parent 93c27d4 commit ab4dbad
Show file tree
Hide file tree
Showing 180 changed files with 2,591 additions and 1,266 deletions.
8 changes: 7 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,11 @@
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"mlir.server_path": "bazel-bin/libspu/compiler/tools/spu-lsp"
"mlir.server_path": "bazel-bin/libspu/compiler/tools/spu-lsp",
"files.exclude": {
// "**/bazel-*/**": true,
"external":true,
".cache":true,
"**/__pycache__":true
}
}
10 changes: 2 additions & 8 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

module(
name = "spu",
version = "0.9.4.dev20250123",
version = "0.9.4.dev20250209",
compatibility_level = 1,
)

Expand All @@ -32,13 +32,7 @@ local_path_override(
path = "src",
)

bazel_dep(name = "psi")
git_override(
module_name = "psi",
commit = "8ead92f1bb10329c7e7e56d541fecb3dcd47ee03",
remote = "https://github.com/secretflow/psi.git",
)

bazel_dep(name = "psi", version = "0.6.0.dev250123")
bazel_dep(name = "yacl", version = "20241212.0-871832a")
bazel_dep(name = "grpc", version = "1.66.0.bcr.3")
single_version_override(
Expand Down
76 changes: 14 additions & 62 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

10 changes: 4 additions & 6 deletions examples/cpp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

#include "libspu/core/config.h"

#include "libspu/spu.pb.h"

llvm::cl::opt<std::string> Parties(
"parties", llvm::cl::init("127.0.0.1:61530,127.0.0.1:61531"),
llvm::cl::desc("server list, format: host1:port1[,host2:port2, ...]"));
Expand Down Expand Up @@ -52,13 +50,13 @@ std::unique_ptr<spu::SPUContext> MakeSPUContext() {
auto lctx = MakeLink(Parties.getValue(), Rank.getValue());

spu::RuntimeConfig config;
config.set_protocol(static_cast<spu::ProtocolKind>(ProtocolKind.getValue()));
config.set_field(static_cast<spu::FieldType>(Field.getValue()));
config.protocol = static_cast<spu::ProtocolKind>(ProtocolKind.getValue());
config.field = static_cast<spu::FieldType>(Field.getValue());

populateRuntimeConfig(config);

config.set_enable_action_trace(EngineTrace.getValue());
config.set_enable_type_checker(EngineTrace.getValue());
config.enable_action_trace = EngineTrace.getValue();
config.enable_type_checker = EngineTrace.getValue();

return std::make_unique<spu::SPUContext>(config, lctx);
}
4 changes: 2 additions & 2 deletions examples/python/ir_dump/ir_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import jax.numpy as jnp
import numpy as np

import spu.spu_pb2 as spu_pb2
import spu.libspu as libspu
import spu.utils.distributed as ppd

logging.basicConfig(level=logging.INFO)
Expand All @@ -48,7 +48,7 @@
dump_path = os.path.join(os.path.expanduser("~"), args.dir)
logging.info(f"Dump path: {dump_path}")
# refer to spu.proto for more detailed configuration
copts = spu_pb2.CompilerOptions()
copts = libspu.CompilerOptions()
copts.enable_pretty_print = True
copts.pretty_print_dump_dir = dump_path
copts.xla_pp_kind = 2
Expand Down
4 changes: 2 additions & 2 deletions examples/python/ml/flax_llama7b/flax_llama7b.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from transformers import LlamaTokenizer

import spu.intrinsic as intrinsic
import spu.spu_pb2 as spu_pb2
import spu.libspu as libspu
import spu.utils.distributed as ppd

parser = argparse.ArgumentParser(description='distributed driver.')
Expand All @@ -46,7 +46,7 @@

ppd.init(conf["nodes"], conf["devices"])

copts = spu_pb2.CompilerOptions()
copts = libspu.CompilerOptions()
copts.enable_pretty_print = False
copts.xla_pp_kind = 2
# enable x / broadcast(y) -> x * broadcast(1/y)
Expand Down
4 changes: 2 additions & 2 deletions examples/python/ml/flax_llama7b_split/flax_llama7b_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from flax.linen.linear import Array
from transformers import LlamaTokenizer

import spu.spu_pb2 as spu_pb2
import spu.libspu as libspu
import spu.utils.distributed as ppd

parser = argparse.ArgumentParser(description='distributed driver.')
Expand All @@ -59,7 +59,7 @@

ppd.init(conf["nodes"], conf["devices"])

copts = spu_pb2.CompilerOptions()
copts = libspu.CompilerOptions()
copts.enable_pretty_print = False
copts.xla_pp_kind = 2
# enable x / broadcast(y) -> x * broadcast(1/y)
Expand Down
4 changes: 2 additions & 2 deletions examples/python/ml/flax_whisper/flax_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers import FlaxWhisperForConditionalGeneration, WhisperProcessor

import spu.utils.distributed as ppd
from spu import spu_pb2
from spu import libspu

parser = argparse.ArgumentParser(description='distributed driver.')
parser.add_argument(
Expand Down Expand Up @@ -68,7 +68,7 @@ def run_on_spu():
inputs_ids = processor(ds[0]["audio"]["array"], return_tensors="np")

# Enable rewrite for better performance
copts = spu_pb2.CompilerOptions()
copts = libspu.CompilerOptions()
copts.enable_optimize_denominator_with_broadcast = True

input_ids = ppd.device("P1")(lambda x: x)(inputs_ids.input_features)
Expand Down
6 changes: 3 additions & 3 deletions experimental/squirrel/objectives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,11 @@ namespace {
spu::SPUContext* ctx, const spu::Value& _x, float threshold,
spu::FieldType working_ft) {
namespace sk = spu::kernel;
auto src_field = ctx->config().field();
auto src_field = ctx->config().field;

spu::Value x(CastRing(_x.data(), working_ft), _x.dtype());
// FIXME(lwj): dirty hack
const_cast<spu::RuntimeConfig*>(&ctx->config())->set_field(working_ft);
const_cast<spu::RuntimeConfig*>(&ctx->config())->field = working_ft;
ctx->getState<spu::mpc::Z2kState>()->setField(working_ft);

const auto ONE = sk::hal::_constant(ctx, 1, x.shape());
Expand All @@ -314,7 +314,7 @@ namespace {
auto is_too_large = sk::hal::_xor(
ctx, True, sk::hal::_or(ctx, is_neg, is_inside_range)); // x > t
// FIXME(lwj): dirty hack
const_cast<spu::RuntimeConfig*>(&ctx->config())->set_field(src_field);
const_cast<spu::RuntimeConfig*>(&ctx->config())->field = src_field;
ctx->getState<spu::mpc::Z2kState>()->setField(src_field);

is_neg = spu::Value(CastRing(is_neg.data(), src_field), spu::DT_I1);
Expand Down
36 changes: 17 additions & 19 deletions experimental/squirrel/objectives_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ TEST_P(ObjectivesTest, MaxGain) {

spu::mpc::utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> lctx) {
spu::RuntimeConfig rt_config;
rt_config.set_protocol(ProtocolKind::REF2K);
rt_config.set_field(field);
rt_config.set_fxp_fraction_bits(16);
rt_config.protocol = ProtocolKind::REF2K;
rt_config.field = field;
rt_config.fxp_fraction_bits = 16;

auto _ctx = std::make_unique<spu::SPUContext>(rt_config, lctx);
auto ctx = _ctx.get();
Expand Down Expand Up @@ -170,13 +170,12 @@ TEST_P(ObjectivesTest, Logistic) {

spu::mpc::utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> lctx) {
spu::RuntimeConfig rt_config;
rt_config.set_protocol(ProtocolKind::CHEETAH);
rt_config.mutable_cheetah_2pc_config()->set_ot_kind(
CheetahOtKind::YACL_Softspoken);
rt_config.set_field(field);
rt_config.set_fxp_fraction_bits(17);
rt_config.set_enable_hal_profile(true);
rt_config.set_enable_pphlo_profile(true);
rt_config.protocol = ProtocolKind::CHEETAH;
rt_config.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken;
rt_config.field = field;
rt_config.fxp_fraction_bits = 17;
rt_config.enable_hal_profile = true;
rt_config.enable_pphlo_profile = true;

auto _ctx = std::make_unique<spu::SPUContext>(rt_config, lctx);
auto ctx = _ctx.get();
Expand All @@ -198,7 +197,7 @@ TEST_P(ObjectivesTest, Logistic) {
return;
}

double fxp = std::pow(2., rt_config.fxp_fraction_bits());
double fxp = std::pow(2., rt_config.fxp_fraction_bits);
double max_err = 0.;

for (int64_t i = 0; i < logistic.numel(); ++i) {
Expand Down Expand Up @@ -241,13 +240,12 @@ TEST_P(ObjectivesTest, Sigmoid) {

spu::mpc::utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> lctx) {
spu::RuntimeConfig rt_config;
rt_config.set_protocol(ProtocolKind::CHEETAH);
rt_config.mutable_cheetah_2pc_config()->set_ot_kind(
CheetahOtKind::YACL_Softspoken);
rt_config.set_field(field);
rt_config.set_fxp_fraction_bits(17);
rt_config.set_enable_hal_profile(true);
rt_config.set_enable_pphlo_profile(true);
rt_config.protocol = ProtocolKind::CHEETAH;
rt_config.cheetah_2pc_config.ot_kind = CheetahOtKind::YACL_Softspoken;
rt_config.field = field;
rt_config.fxp_fraction_bits = 17;
rt_config.enable_hal_profile = true;
rt_config.enable_pphlo_profile = true;

auto _ctx = std::make_unique<spu::SPUContext>(rt_config, lctx);
auto ctx = _ctx.get();
Expand All @@ -269,7 +267,7 @@ TEST_P(ObjectivesTest, Sigmoid) {
return;
}

double fxp = std::pow(2., rt_config.fxp_fraction_bits());
double fxp = std::pow(2., rt_config.fxp_fraction_bits);
double max_err = 0.;

for (int64_t i = 0; i < logistic.numel(); ++i) {
Expand Down
16 changes: 8 additions & 8 deletions experimental/squirrel/squirrel_demo_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,11 @@ std::unique_ptr<spu::SPUContext> MakeSPUContext() {
auto lctx = MakeLink(Parties.getValue(), Rank.getValue());

spu::RuntimeConfig config;
config.set_protocol(spu::ProtocolKind::CHEETAH);
config.set_field(static_cast<spu::FieldType>(Field.getValue()));
config.set_fxp_fraction_bits(18);
config.set_fxp_div_goldschmidt_iters(1);
config.set_enable_hal_profile(EngineTrace.getValue());
config.protocol = spu::ProtocolKind::CHEETAH;
config.field = static_cast<spu::FieldType>(Field.getValue());
config.fxp_fraction_bits = 18;
config.fxp_div_goldschmidt_iters = 1;
config.enable_hal_profile = EngineTrace.getValue();
auto hctx = std::make_unique<spu::SPUContext>(config, lctx);
spu::mpc::Factory::RegisterProtocol(hctx.get(), lctx);
return hctx;
Expand Down Expand Up @@ -218,7 +218,7 @@ void RunTest(spu::SPUContext* hctx, squirrel::XGBTreeBuilder& builder,
}

const int64_t nsamples = dframe.shape(0);
const double fxp = std::pow(2., hctx->config().fxp_fraction_bits());
const double fxp = std::pow(2., hctx->config().fxp_fraction_bits);

SPDLOG_DEBUG("Computing inference on testing set ...");

Expand Down Expand Up @@ -293,7 +293,7 @@ int main(int argc, char** argv) {
bucket_size, nfeatures, peer_nfeatures);

worker->BuildMap(dframe);
worker->Setup(8 * spu::SizeOf(hctx->config().field()), hctx->lctx());
worker->Setup(8 * spu::SizeOf(hctx->config().field), hctx->lctx());

std::string act = Activation.getValue();
SPU_ENFORCE(act == "log" or act == "sig", "invalid activation type={}", act);
Expand Down Expand Up @@ -335,7 +335,7 @@ int main(int argc, char** argv) {
}

// Test on train set
double fxp = std::pow(2., hctx->config().fxp_fraction_bits());
double fxp = std::pow(2., hctx->config().fxp_fraction_bits);
int32_t correct = 0;
SPDLOG_DEBUG("Computing inference on training set ...");
for (int64_t i = 0; i < (int64_t)nsamples; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions experimental/squirrel/tree_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ double XGBTreeBuilder::DEBUG_OpenObjects(
Gsum = hal::reveal(ctx, Gsum);
Hsum = hal::reveal(ctx, Hsum);
weights = hal::reveal(ctx, weights);
const double fxp = std::pow(2., ctx->config().fxp_fraction_bits());
const double fxp = std::pow(2., ctx->config().fxp_fraction_bits);
double object = 0.0;
for (int64_t i = 0; i < weights.numel(); ++i) {
double G = Gsum.data().at<int64_t>(i) / fxp;
Expand All @@ -606,7 +606,7 @@ double XGBTreeBuilder::DEBUG_OpenLoss(spu::SPUContext* ctx,
using namespace spu::kernel;
auto _pred = hal::reveal(ctx, pred);
auto _label = hal::reveal(ctx, label);
double fxp = std::pow(2., ctx->config().fxp_fraction_bits());
double fxp = std::pow(2., ctx->config().fxp_fraction_bits);
double loss = 0.;
for (int64_t i = 0; i < _pred.numel(); ++i) {
double y = _label.data().at<int64_t>(i) / fxp;
Expand Down
16 changes: 8 additions & 8 deletions experimental/squirrel/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ spu::Value ArgMax(spu::SPUContext* ctx, const spu::Value& x, int axis,

spu::Value MulArithShareWithPrivateBoolean(spu::SPUContext* ctx,
const spu::Value& ashr) {
SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH);
SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH);
SPU_ENFORCE(ashr.isSecret());

spu::KernelEvalContext kctx(ctx);
Expand All @@ -143,7 +143,7 @@ spu::Value MulArithShareWithPrivateBoolean(spu::SPUContext* ctx,
spu::Value MulArithShareWithPrivateBoolean(
spu::SPUContext* ctx, const spu::Value& ashr,
absl::Span<const uint8_t> prv_boolean) {
SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH);
SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH);
SPU_ENFORCE(ashr.isSecret());
SPU_ENFORCE_EQ(ashr.numel(), (int64_t)prv_boolean.size());

Expand All @@ -160,9 +160,9 @@ spu::Value MulArithShareWithPrivateBoolean(
spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx,
const spu::Value& arith) {
using namespace spu;
SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH);
SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH);
spu::KernelEvalContext kctx(ctx);
auto ft = ctx->config().field();
auto ft = ctx->config().field;
auto out = mpc::cheetah::TiledDispatchOTFunc(
&kctx, arith.data(),
[&](const NdArrayRef& input,
Expand All @@ -184,11 +184,11 @@ spu::Value MulPrivateArithWithPrivateBoolean(spu::SPUContext* ctx,
const spu::DataType dtype,
const spu::Shape& shape) {
using namespace spu;
SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH);
SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH);
SPU_ENFORCE_EQ(boolean.size(), (size_t)shape.numel());

spu::KernelEvalContext kctx(ctx);
auto ft = ctx->config().field();
auto ft = ctx->config().field;
auto out = mpc::cheetah::TiledDispatchOTFunc(
&kctx, boolean,
[&](absl::Span<const uint8_t> input,
Expand All @@ -210,10 +210,10 @@ spu::Value MulArithShareWithANDBoolShare(spu::SPUContext* ctx,
SPU_ENFORCE(ashr.isSecret());
SPU_ENFORCE_EQ(ashr.numel(), (int64_t)bshr.size());

SPU_ENFORCE(ctx->config().protocol() == spu::ProtocolKind::CHEETAH);
SPU_ENFORCE(ctx->config().protocol == spu::ProtocolKind::CHEETAH);

spu::KernelEvalContext kctx(ctx);
auto ft = ctx->config().field();
auto ft = ctx->config().field;
int rank = ctx->lctx()->Rank();

auto out = mpc::cheetah::TiledDispatchOTFunc(
Expand Down
Loading

0 comments on commit ab4dbad

Please sign in to comment.