Skip to content

Commit

Permalink
[feat] write query result into multi-parties via "select ... into out…
Browse files Browse the repository at this point in the history
…file" statement (#391)
  • Loading branch information
ancongxue authored Nov 15, 2024
1 parent 94a3b93 commit 7d53bf5
Show file tree
Hide file tree
Showing 10 changed files with 4,664 additions and 4,225 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ docs/_build
# go
bin/*
tool-bin/*
/vendor

# cloudide
.cloudide/
Expand All @@ -27,3 +28,8 @@ logs/scdbserver.log
*.pyc

.venv

# 排除 ide 文件
.idea/
# 排除mac本地文件
*.DS_Store
6 changes: 3 additions & 3 deletions pkg/interpreter/graph/graph_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,9 @@ const (
_quotingAllValid int64 = 2
)

func (plan *GraphBuilder) AddDumpFileNode(name string, in []*Tensor, out []*Tensor, intoOpt *ast.SelectIntoOption) error {
func (plan *GraphBuilder) AddDumpFileNodeForParty(name string, in []*Tensor, out []*Tensor, intoOpt *ast.SelectIntoOption, partyFile *ast.PartyFile) error {
fp := &Attribute{}
fp.SetString(intoOpt.FileName)
fp.SetString(partyFile.FileName)
terminator := &Attribute{}
terminator.SetString(intoOpt.LinesInfo.Terminated)
del := &Attribute{}
Expand Down Expand Up @@ -418,7 +418,7 @@ func (plan *GraphBuilder) AddDumpFileNode(name string, in []*Tensor, out []*Tens
operator.FieldDeliminatorAttr: del,
operator.QuotingStyleAttr: qs,
operator.LineTerminatorAttr: terminator,
}, []string{intoOpt.PartyCode})
}, []string{partyFile.PartyCode})
if err != nil {
return fmt.Errorf("AddDumpFileNode: %v", err)
}
Expand Down
28 changes: 22 additions & 6 deletions pkg/interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"crypto/sha256"
"fmt"
"golang.org/x/exp/slices"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -68,14 +69,18 @@ func (intr *Interpreter) Compile(ctx context.Context, req *pb.CompileQueryReques
if err != nil {
return nil, err
}
selectIntoIssuer := false
issuerAsParticipant := req.GetIssuerAsParticipant()
var intoPartyCodes []string
if lp.IntoOpt() != nil {
if lp.IntoOpt().PartyCode != req.GetIssuer().GetCode() {
return nil, fmt.Errorf("expect select into issuer party code %s but got %s", req.GetIssuer().GetCode(), lp.IntoOpt().PartyCode)
if len(lp.IntoOpt().PartyFiles) == 1 && lp.IntoOpt().PartyFiles[0].PartyCode == "" {
lp.IntoOpt().PartyFiles[0].PartyCode = req.GetIssuer().GetCode()
}
selectIntoIssuer = true
for _, partyFile := range lp.IntoOpt().PartyFiles {
intoPartyCodes = append(intoPartyCodes, partyFile.PartyCode)
}
issuerAsParticipant = true
}
enginesInfo, err := buildEngineInfo(lp, req.GetCatalog(), req.GetDbName(), req.GetIssuer().GetCode(), req.GetIssuerAsParticipant() || selectIntoIssuer)
enginesInfo, err := buildEngineInfo(lp, req.GetCatalog(), req.GetDbName(), req.GetIssuer().GetCode(), issuerAsParticipant, intoPartyCodes)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -290,7 +295,7 @@ func collectDataSourceNode(lp core.LogicalPlan) []*core.DataSource {
return nil
}

func buildEngineInfo(lp core.LogicalPlan, catalog *pb.Catalog, currentDb string, queryIssuer string, issuerAsParticipant bool) (*graph.EnginesInfo, error) {
func buildEngineInfo(lp core.LogicalPlan, catalog *pb.Catalog, currentDb string, queryIssuer string, issuerAsParticipant bool, intoPartyCodes []string) (*graph.EnginesInfo, error) {
// construct catalog map
catalogMap := make(map[string]*pb.TableEntry)
for _, table := range catalog.GetTables() {
Expand Down Expand Up @@ -363,10 +368,21 @@ func buildEngineInfo(lp core.LogicalPlan, catalog *pb.Catalog, currentDb string,
}
}

if len(intoPartyCodes) > 0 {
for _, partyCode := range intoPartyCodes {
parties = append(parties, &graph.Participant{
PartyCode: partyCode,
})
}
}

// sort parties by party code for deterministic in p2p
sort.Slice(parties, func(i, j int) bool {
return parties[i].PartyCode < parties[j].PartyCode
})
parties = slices.CompactFunc(parties, func(i, j *graph.Participant) bool {
return i.PartyCode == j.PartyCode
})

partyInfo := graph.NewPartyInfo(parties)

Expand Down
204 changes: 204 additions & 0 deletions pkg/interpreter/interpreter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,108 @@ var testCases = []compileTestCase{
ok: true,
workPartyNum: 2,
},
{
req: &proto.CompileQueryRequest{
Query: "SELECT tb.data as avg_amount FROM tb into outfile '/data/output11.txt' fields terminated BY ','",
DbName: "",
Issuer: &proto.PartyId{
Code: "alice",
},
IssuerAsParticipant: false,
SecurityConf: &proto.SecurityConfig{
ColumnControlList: []*proto.SecurityConfig_ColumnControl{
{
PartyCode: "alice",
Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT,
DatabaseName: "",
TableName: "tb",
ColumnName: "data",
},
{
PartyCode: "bob",
Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT,
DatabaseName: "",
TableName: "tb",
ColumnName: "data",
},
},
},
Catalog: &proto.Catalog{
Tables: []*proto.TableEntry{
{
TableName: "tb",
Columns: []*proto.TableEntry_Column{
{
Name: "data",
Type: "string",
},
},
IsView: false,
RefTable: "bob.user_stats",
DbType: "mysql",
Owner: &proto.PartyId{
Code: "bob",
},
},
},
},
CompileOpts: &proto.CompileOptions{SecurityCompromise: &proto.SecurityCompromiseConfig{GroupByThreshold: 4}},
// TODO: add RuntimeConfig
},
ok: true,
workPartyNum: 2,
},
{
req: &proto.CompileQueryRequest{
Query: "SELECT tb.data as avg_amount FROM tb into outfile party_code 'bob' '/data/output11.txt' fields terminated BY ','",
DbName: "",
Issuer: &proto.PartyId{
Code: "alice",
},
IssuerAsParticipant: false,
SecurityConf: &proto.SecurityConfig{
ColumnControlList: []*proto.SecurityConfig_ColumnControl{
{
PartyCode: "alice",
Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT,
DatabaseName: "",
TableName: "tb",
ColumnName: "data",
},
{
PartyCode: "bob",
Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT,
DatabaseName: "",
TableName: "tb",
ColumnName: "data",
},
},
},
Catalog: &proto.Catalog{
Tables: []*proto.TableEntry{
{
TableName: "tb",
Columns: []*proto.TableEntry_Column{
{
Name: "data",
Type: "string",
},
},
IsView: false,
RefTable: "bob.user_stats",
DbType: "mysql",
Owner: &proto.PartyId{
Code: "bob",
},
},
},
},
CompileOpts: &proto.CompileOptions{SecurityCompromise: &proto.SecurityCompromiseConfig{GroupByThreshold: 4}},
// TODO: add RuntimeConfig
},
ok: true,
workPartyNum: 2,
},
{
req: &proto.CompileQueryRequest{
Query: "SELECT tb.data as avg_amount FROM tb into outfile party_code 'bob' '/data/output11.txt' fields terminated BY ','",
Expand Down Expand Up @@ -878,6 +980,108 @@ var testCases = []compileTestCase{
ok: true,
workPartyNum: 1,
},
{
req: &proto.CompileQueryRequest{
Query: "SELECT tb.data as avg_amount FROM tb into outfile party_code 'alice' '/data/output11.txt' fields terminated BY ','",
DbName: "",
Issuer: &proto.PartyId{
Code: "bob",
},
IssuerAsParticipant: false,
SecurityConf: &proto.SecurityConfig{
ColumnControlList: []*proto.SecurityConfig_ColumnControl{
{
PartyCode: "alice",
Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT,
DatabaseName: "",
TableName: "tb",
ColumnName: "data",
},
{
PartyCode: "bob",
Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT,
DatabaseName: "",
TableName: "tb",
ColumnName: "data",
},
},
},
Catalog: &proto.Catalog{
Tables: []*proto.TableEntry{
{
TableName: "tb",
Columns: []*proto.TableEntry_Column{
{
Name: "data",
Type: "string",
},
},
IsView: false,
RefTable: "bob.user_stats",
DbType: "mysql",
Owner: &proto.PartyId{
Code: "bob",
},
},
},
},
CompileOpts: &proto.CompileOptions{SecurityCompromise: &proto.SecurityCompromiseConfig{GroupByThreshold: 4}},
// TODO: add RuntimeConfig
},
ok: true,
workPartyNum: 2,
},
{
req: &proto.CompileQueryRequest{
Query: "SELECT tb.data as avg_amount FROM tb into outfile party_code 'alice' '/data/output11.txt' party_code 'bob' '/data/output11.txt' fields terminated BY ','",
DbName: "",
Issuer: &proto.PartyId{
Code: "bob",
},
IssuerAsParticipant: false,
SecurityConf: &proto.SecurityConfig{
ColumnControlList: []*proto.SecurityConfig_ColumnControl{
{
PartyCode: "alice",
Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT,
DatabaseName: "",
TableName: "tb",
ColumnName: "data",
},
{
PartyCode: "bob",
Visibility: proto.SecurityConfig_ColumnControl_PLAINTEXT,
DatabaseName: "",
TableName: "tb",
ColumnName: "data",
},
},
},
Catalog: &proto.Catalog{
Tables: []*proto.TableEntry{
{
TableName: "tb",
Columns: []*proto.TableEntry_Column{
{
Name: "data",
Type: "string",
},
},
IsView: false,
RefTable: "bob.user_stats",
DbType: "mysql",
Owner: &proto.PartyId{
Code: "bob",
},
},
},
},
CompileOpts: &proto.CompileOptions{SecurityCompromise: &proto.SecurityCompromiseConfig{GroupByThreshold: 4}},
// TODO: add RuntimeConfig
},
ok: true,
workPartyNum: 2,
},
{
req: &proto.CompileQueryRequest{
Query: "SELECT ta.credit_rank, COUNT(*) as cnt, AVG(ta.income) as avg_income, AVG(tb.order_amount) as avg_amount FROM ta INNER JOIN tb ON ta.ID = tb.ID WHERE ta.age >= 20 AND ta.age <= 30 AND tb.is_active = 1 GROUP BY ta.credit_rank",
Expand Down
57 changes: 38 additions & 19 deletions pkg/interpreter/translator/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,26 @@ func (t *translator) Translate(lp core.LogicalPlan) (*graph.Graph, error) {
if err != nil {
return nil, err
}
// Check if the result is visible to the issuerPartyCode
for i, col := range ln.Schema().Columns {
cc := ln.CCL()[col.UniqueID]
if !cc.IsVisibleFor(t.issuerPartyCode) {
return nil, status.New(
proto.Code_CCL_CHECK_FAILED,
fmt.Sprintf("ccl check failed: the %dth column %s in the result is not visibile (%s) to party %s", i+1, col.OrigName, cc.LevelFor(t.issuerPartyCode).String(), t.issuerPartyCode))
// Check if the result is visible to the intoOpt PartyCode
cclCheckPartyList := []string{t.issuerPartyCode}
if lp.IntoOpt() != nil {
if len(lp.IntoOpt().PartyFiles) == 1 && lp.IntoOpt().PartyFiles[0].PartyCode == "" {
lp.IntoOpt().PartyFiles[0].PartyCode = t.issuerPartyCode
}
for _, partyFile := range lp.IntoOpt().PartyFiles {
if partyFile.PartyCode != t.issuerPartyCode {
cclCheckPartyList = append(cclCheckPartyList, partyFile.PartyCode)
}
}
}
for _, partyCode := range cclCheckPartyList {
for i, col := range ln.Schema().Columns {
cc := ln.CCL()[col.UniqueID]
if !cc.IsVisibleFor(partyCode) {
return nil, status.New(
proto.Code_CCL_CHECK_FAILED,
fmt.Sprintf("ccl check failed: the %dth column %s in the result is not visibile (%s) to party %s", i+1, col.OrigName, cc.LevelFor(partyCode).String(), partyCode))
}
}
}
// find one of the qualified computation parties to act as the query issuer
Expand Down Expand Up @@ -268,26 +281,32 @@ func (t *translator) addPublishNode(ln logicalNode) error {

func (t *translator) addDumpFileNode(ln logicalNode) error {
intoOpt := ln.IntoOpt()
// issuer party code can see all outputs
if intoOpt.PartyCode == "" {
intoOpt.PartyCode = t.issuerPartyCode
}
// if into party code is not equal to issuer, refuse this query
if intoOpt.PartyCode != t.issuerPartyCode {
return fmt.Errorf("failed to check select into party code (%s) which is not equal to (%s)", intoOpt.PartyCode, t.issuerPartyCode)

if len(intoOpt.PartyFiles) == 1 && intoOpt.PartyFiles[0].PartyCode == "" {
intoOpt.PartyFiles[0].PartyCode = t.issuerPartyCode
}

input, output, err := t.prepareResultNodeIo(ln)
if err != nil {
return fmt.Errorf("addDumpFileNode: prepare io failed: %v", err)
for _, partyFile := range intoOpt.PartyFiles {
input, output, err := t.prepareResultNodeIoForParty(ln, partyFile.PartyCode)
if err != nil {
return fmt.Errorf("addDumpFileNode: prepare io failed: %v", err)
}
err = t.ep.AddDumpFileNodeForParty("dump_file", input, output, intoOpt, partyFile)
if err != nil {
return fmt.Errorf("AddDumpFileNode: %v", err)
}
}
return t.ep.AddDumpFileNode("dump_file", input, output, intoOpt)
return nil
}

func (t *translator) prepareResultNodeIo(ln logicalNode) (input, output []*graph.Tensor, err error) {
return t.prepareResultNodeIoForParty(ln, t.issuerPartyCode)
}

func (t *translator) prepareResultNodeIoForParty(ln logicalNode, partyCode string) (input, output []*graph.Tensor, err error) {
for i, it := range ln.ResultTable() {
// Reveal tensor to into party code
it, err = t.converter.convertTo(it, &privatePlacement{partyCode: t.issuerPartyCode})
it, err = t.converter.convertTo(it, &privatePlacement{partyCode: partyCode})
if err != nil {
return
}
Expand Down
Loading

0 comments on commit 7d53bf5

Please sign in to comment.