Skip to content

Commit

Permalink
Merge pull request #14 from maxmind/parse-model-perf
Browse files Browse the repository at this point in the history
`parseModel()` performance improvements
  • Loading branch information
horgh authored Aug 1, 2024
2 parents fe0dca2 + 7d5765e commit ed3a1c8
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion contributions.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ func treeShap(
isMissing := features[splitIndex] == nil // nil means missing.
hotIndex := getNextNode(
hasMissing,
node,
&node,
nodeIndex,
features[splitIndex],
isMissing,
Expand Down
20 changes: 8 additions & 12 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,25 @@ type TreeParam struct {
// Tree is one tree in an XGBoost model. It's the representation we process
// XGBTree into.
type Tree struct {
Nodes []*Node // Index 0 is the root.
Nodes []Node // Index 0 is the root.
NumNodes int
}

// Node is a node in the Tree.
type Node struct {
Data NodeData
Left *Node
Right *Node
Data NodeData
}

// NodeData is a Node's data.
type NodeData struct {
BaseWeight float32
DefaultLeft bool
ID int
SplitCondition float32
SplitIndex int
SplitCondition float32
SumHessian float32
BaseWeight float32
DefaultLeft bool
}

// IsLeaf returns whether the Node is a leaf.
Expand Down Expand Up @@ -136,11 +136,7 @@ func parseTree(
return nil, fmt.Errorf("getting num nodes as int64: %w", err)
}

var nodes []*Node
for i := 0; i < int(numNodes); i++ {
nodes = append(nodes, &Node{})
}

nodes := make([]Node, numNodes)
for i := 0; i < int(numNodes); i++ {
nodes[i].Data = NodeData{
BaseWeight: xt.BaseWeights[i],
Expand All @@ -158,8 +154,8 @@ func parseTree(
continue
}

nodes[i].Left = nodes[left]
nodes[i].Right = nodes[right]
nodes[i].Left = &nodes[left]
nodes[i].Right = &nodes[right]
}

return &Tree{
Expand Down
14 changes: 14 additions & 0 deletions parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package xgbshap

import (
"testing"

"github.com/stretchr/testify/require"
)

func BenchmarkParseModel(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _, err := parseModel("testdata/small-model/model.json")
require.NoError(b, err)
}
}

0 comments on commit ed3a1c8

Please sign in to comment.