diff --git a/contributions.go b/contributions.go index ff86c01..b7010a4 100644 --- a/contributions.go +++ b/contributions.go @@ -280,7 +280,7 @@ func treeShap( isMissing := features[splitIndex] == nil // nil means missing. hotIndex := getNextNode( hasMissing, - node, + &node, nodeIndex, features[splitIndex], isMissing, diff --git a/parse.go b/parse.go index 1135336..626065f 100644 --- a/parse.go +++ b/parse.go @@ -56,25 +56,25 @@ type TreeParam struct { // Tree is one tree in an XGBoost model. It's the representation we process // XGBTree into. type Tree struct { - Nodes []*Node // Index 0 is the root. + Nodes []Node // Index 0 is the root. NumNodes int } // Node is a node in the Tree. type Node struct { - Data NodeData Left *Node Right *Node + Data NodeData } // NodeData is a Node's data. type NodeData struct { - BaseWeight float32 - DefaultLeft bool ID int - SplitCondition float32 SplitIndex int + SplitCondition float32 SumHessian float32 + BaseWeight float32 + DefaultLeft bool } // IsLeaf returns whether the Node is a leaf. @@ -136,11 +136,7 @@ func parseTree( return nil, fmt.Errorf("getting num nodes as int64: %w", err) } - var nodes []*Node - for i := 0; i < int(numNodes); i++ { - nodes = append(nodes, &Node{}) - } - + nodes := make([]Node, numNodes) for i := 0; i < int(numNodes); i++ { nodes[i].Data = NodeData{ BaseWeight: xt.BaseWeights[i], @@ -158,8 +154,8 @@ func parseTree( continue } - nodes[i].Left = nodes[left] - nodes[i].Right = nodes[right] + nodes[i].Left = &nodes[left] + nodes[i].Right = &nodes[right] } return &Tree{ diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..69f2ae8 --- /dev/null +++ b/parse_test.go @@ -0,0 +1,14 @@ +package xgbshap + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func BenchmarkParseModel(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _, err := parseModel("testdata/small-model/model.json") + require.NoError(b, err) + } +}