Skip to content

Commit

Permalink
Reduce parseModel() allocs by 29.8%
Browse files Browse the repository at this point in the history
  • Loading branch information
marselester committed Jul 29, 2024
1 parent aa464ae commit 7d5765e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion contributions.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func treeShap(
isMissing := features[splitIndex] == nil // nil means missing.
hotIndex := getNextNode(
hasMissing,
node,
&node,
nodeIndex,
features[splitIndex],
isMissing,
Expand Down
12 changes: 4 additions & 8 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type TreeParam struct {
// Tree is one tree in an XGBoost model. It's the representation we process
// XGBTree into.
type Tree struct {
Nodes []*Node // Index 0 is the root.
Nodes []Node // Index 0 is the root.
NumNodes int
}

Expand Down Expand Up @@ -136,11 +136,7 @@ func parseTree(
return nil, fmt.Errorf("getting num nodes as int64: %w", err)
}

var nodes []*Node
for i := 0; i < int(numNodes); i++ {
nodes = append(nodes, &Node{})
}

nodes := make([]Node, numNodes)
for i := 0; i < int(numNodes); i++ {
nodes[i].Data = NodeData{
BaseWeight: xt.BaseWeights[i],
Expand All @@ -158,8 +154,8 @@ func parseTree(
continue
}

nodes[i].Left = nodes[left]
nodes[i].Right = nodes[right]
nodes[i].Left = &nodes[left]
nodes[i].Right = &nodes[right]
}

return &Tree{
Expand Down
14 changes: 14 additions & 0 deletions parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package xgbshap

import (
"testing"

"github.com/stretchr/testify/require"
)

func BenchmarkParseModel(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _, err := parseModel("testdata/small-model/model.json")
require.NoError(b, err)
}
}

0 comments on commit 7d5765e

Please sign in to comment.