-
Notifications
You must be signed in to change notification settings - Fork 1
/
parse.go
165 lines (137 loc) · 3.75 KB
/
parse.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package xgbshap
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
)
// XXX Some of this code is similar to parse.go in xgb2code
// (https://github.com/maxmind/xgb2code).
// XGBModel corresponds to an XGBoost JSON model.
type XGBModel struct {
Learner Learner `json:"learner"`
}
// Learner is the top level part of an XGBoost model.
type Learner struct {
Attributes Attributes `json:"attributes"`
GradientBooster GradientBooster `json:"gradient_booster"`
}
// Attributes holds attributes from an XGBoost model.
type Attributes struct {
BestNtreeLimit json.Number `json:"best_ntree_limit"`
}
// GradientBooster holds the XGBoost model.
type GradientBooster struct {
Model Model `json:"model"`
}
// Model is the XGBoost model.
type Model struct {
Trees []XGBTree `json:"trees"`
}
// XGBTree is one tree in an XGBoost model as decoded from JSON.
type XGBTree struct {
BaseWeights []float32 `json:"base_weights"`
DefaultLeft []int `json:"default_left"`
LeftChildren []int `json:"left_children"`
RightChildren []int `json:"right_children"`
SplitConditions []float32 `json:"split_conditions"`
SplitIndices []int `json:"split_indices"`
SumHessian []float32 `json:"sum_hessian"`
TreeParam TreeParam `json:"tree_param"`
}
// TreeParam holds tree parameters.
type TreeParam struct {
NumNodes json.Number `json:"num_nodes"`
}
// Tree is one tree in an XGBoost model. It's the representation we process
// XGBTree into.
type Tree struct {
Nodes []Node // Index 0 is the root.
NumNodes int
}
// Node is a node in the Tree.
type Node struct {
Left *Node
Right *Node
Data NodeData
}
// NodeData is a Node's data.
type NodeData struct {
ID int
SplitIndex int
SplitCondition float32
SumHessian float32
BaseWeight float32
DefaultLeft bool
}
// IsLeaf returns whether the Node is a leaf.
//
// This is equivalent to IsLeaf() in xgboost (tree_model.h).
func (n *Node) IsLeaf() bool { return n.Left == nil }
// LeafValue returns the leaf's value.
//
// This is equivalent to LeafValue() in xgboost (tree_model.h).
func (n *Node) LeafValue() float32 { return n.Data.BaseWeight }
// MaxDepth returns the tree's max depth at this node.
//
// This is equivalent to MaxDepth() in xgboost (tree_model.h).
func (n *Node) MaxDepth() int {
if n.IsLeaf() {
return 0
}
leftDepth := n.Left.MaxDepth() + 1
rightDepth := n.Right.MaxDepth() + 1
return max(leftDepth, rightDepth)
}
func parseModel(
file string,
) (*XGBModel, []*Tree, error) {
buf, err := os.ReadFile(filepath.Clean(file))
if err != nil {
return nil, nil, fmt.Errorf("reading file: %w", err)
}
var xm XGBModel
if err := json.Unmarshal(buf, &xm); err != nil {
return nil, nil, fmt.Errorf("unmarshaling: %w", err)
}
var trees []*Tree
//nolint:gocritic // Copies inefficiently, but should only be done once.
for _, t := range xm.Learner.GradientBooster.Model.Trees {
tree, err := parseTree(t)
if err != nil {
return nil, nil, err
}
trees = append(trees, tree)
}
return &xm, trees, nil
}
func parseTree(
xt XGBTree,
) (*Tree, error) {
numNodes, err := xt.TreeParam.NumNodes.Int64()
if err != nil {
return nil, fmt.Errorf("getting num nodes as int64: %w", err)
}
nodes := make([]Node, numNodes)
for i := 0; i < int(numNodes); i++ {
nodes[i].Data = NodeData{
BaseWeight: xt.BaseWeights[i],
DefaultLeft: xt.DefaultLeft[i] == 1,
ID: i,
SplitCondition: xt.SplitConditions[i],
SplitIndex: xt.SplitIndices[i],
SumHessian: xt.SumHessian[i],
}
left := xt.LeftChildren[i]
right := xt.RightChildren[i]
if left == -1 { // No child
continue
}
nodes[i].Left = &nodes[left]
nodes[i].Right = &nodes[right]
}
return &Tree{
Nodes: nodes,
NumNodes: int(numNodes),
}, nil
}