Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# CHANGELOG

## 1.1.0

- Added support for categorical splits (models trained with
`enable_categorical`). Categorical features are passed as their integer
category codes in the `data` slice, the same encoding XGBoost uses internally.

## 1.0.0 (2026-06-15)

- Refactored the codebase to allow using the code generation functionality as a
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ The following XGBoost objectives are supported:
- Regression: `reg:logistic`, `reg:squarederror`, `reg:linear`,
`reg:absoluteerror`, `reg:pseudohubererror`, and `reg:quantileerror`.

Both numeric and categorical splits (models trained with `enable_categorical`)
are supported. Categorical features must be passed to the generated function as
their integer category codes, the same encoding XGBoost uses internally; a
missing feature is represented by a `nil` entry in the `data` slice.
Comment thread
coderabbitai[bot] marked this conversation as resolved.

## Supported Languages

Currently `xgb2code` supports generating Go code.
Expand Down
184 changes: 182 additions & 2 deletions gen/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,19 @@ type xgbTree struct {
RightChildren []int `json:"right_children"`
SplitConditions []float64 `json:"split_conditions"`
SplitIndices []int `json:"split_indices"`
TreeParam struct {
// SplitType marks each node's split kind: 0 = numeric, 1 = categorical. It
// is absent in models trained without categorical features, in which case
// every split is numeric.
SplitType []int `json:"split_type"`
// The fields below describe categorical splits. Categories is the flattened
// list of category values across all categorical nodes; for each entry k in
// CategoriesNodes (a node ID), the values that route to that node's right
// child are Categories[CategoriesSegments[k] : CategoriesSegments[k]+CategoriesSizes[k]].
Categories []int `json:"categories"`
CategoriesNodes []int `json:"categories_nodes"`
CategoriesSegments []int `json:"categories_segments"`
CategoriesSizes []int `json:"categories_sizes"`
TreeParam struct {
NumNodes json.Number `json:"num_nodes"`
} `json:"tree_param"`
}
Expand Down Expand Up @@ -161,6 +173,11 @@ type nodeData struct {
ID int64
SplitCondition float64
SplitIndex int
// Categorical reports whether this is a categorical split. When true,
// Categories holds the category values that route to the right child and
// SplitCondition is unused (XGBoost stores a dummy threshold there).
Categorical bool
Categories []int
}

func parseTreeInfo(xt xgbTree) (*node, error) {
Expand All @@ -172,21 +189,48 @@ func parseTreeInfo(xt xgbTree) (*node, error) {
)
}

if err := checkNodeArrays(xt, numNodes); err != nil {
return nil, err
}

categories, err := categorySets(xt, numNodes)
if err != nil {
return nil, err
}

nodes := make([]node, numNodes)
for i := range numNodes {
cats, categorical := categories[i]
nodes[i].data = nodeData{
DefaultLeft: xt.DefaultLeft[i],
ID: i,
SplitCondition: xt.SplitConditions[i],
SplitIndex: xt.SplitIndices[i],
Categorical: categorical,
Categories: cats,
}

left := xt.LeftChildren[i]
right := xt.RightChildren[i]

if left == -1 { // No child
// A leaf has no children (both -1). Any other combination, including a
// half-wired node with exactly one child, is malformed: the codegen
// treats a node with a nil child as a leaf and would silently drop the
// other subtree, so reject it rather than mispredict.
if left == -1 && right == -1 {
continue
}
if left < 0 || int64(left) >= numNodes ||
right < 0 || int64(right) >= numNodes {
return nil, fmt.Errorf(
"node %d has out-of-range children "+
"(left=%d, right=%d, num_nodes=%d)",
i,
left,
right,
numNodes,
)
}

nodes[i].left = &nodes[left]
nodes[i].right = &nodes[right]
Expand All @@ -195,6 +239,142 @@ func parseTreeInfo(xt xgbTree) (*node, error) {
return &nodes[0], nil // Root node
}

// checkNodeArrays verifies that every per-node array has exactly one entry per
// node, so the indexing in parseTreeInfo cannot panic on a truncated or
// inconsistent model. num_nodes comes from a separate JSON field and so is not
// inherently consistent with the arrays it describes.
func checkNodeArrays(xt xgbTree, numNodes int64) error {
arrays := []struct {
name string
n int
}{
{"default_left", len(xt.DefaultLeft)},
{"left_children", len(xt.LeftChildren)},
{"right_children", len(xt.RightChildren)},
{"split_conditions", len(xt.SplitConditions)},
{"split_indices", len(xt.SplitIndices)},
}
for _, a := range arrays {
if int64(a.n) != numNodes {
return fmt.Errorf(
"%s has %d entries but num_nodes is %d",
a.name,
a.n,
numNodes,
)
}
}
return nil
}

// categorySets maps each categorical node's ID to the category values that
// route to its right child, decoding XGBoost's flattened categories/segments/
// sizes representation. It returns an empty map for models trained without
// categorical features. It validates the arrays rather than trusting them: a
// malformed or inconsistent encoding would otherwise cause a categorical node
// to be silently emitted as a numeric split on its dummy threshold, producing
// wrong predictions.
func categorySets(xt xgbTree, numNodes int64) (map[int64][]int, error) {
n := len(xt.CategoriesNodes)
if len(xt.CategoriesSegments) != n || len(xt.CategoriesSizes) != n {
return nil, fmt.Errorf(
"inconsistent categorical arrays: categories_nodes=%d, "+
"categories_segments=%d, categories_sizes=%d",
n,
len(xt.CategoriesSegments),
len(xt.CategoriesSizes),
)
}

sets := make(map[int64][]int, n)
for k := range n {
start := xt.CategoriesSegments[k]
size := xt.CategoriesSizes[k]
if start < 0 || size < 0 || start > len(xt.Categories)-size {
return nil, fmt.Errorf(
"categorical segment [%d:%d+%d] out of range for "+
"categories of length %d",
start,
start,
size,
len(xt.Categories),
)
}
nodeID := int64(xt.CategoriesNodes[k])
if nodeID < 0 || nodeID >= numNodes {
return nil, fmt.Errorf(
"categories_nodes[%d] = %d out of range for num_nodes %d",
k,
nodeID,
numNodes,
)
}
cats := make([]int, size)
copy(cats, xt.Categories[start:start+size])
sets[nodeID] = cats
}

// split_type is the only independent signal of which nodes are categorical,
// so it is what lets us verify that every categorical node was decoded.
// Without it we cannot make that check, and a categorical node missing from
// categories_nodes would be silently emitted as a numeric split on its dummy
// threshold. Real XGBoost models always include split_type when they have
// categorical data, so reject categorical data that lacks it rather than
// risk a wrong prediction.
if len(xt.SplitType) == 0 {
if len(sets) > 0 {
return nil, errors.New(
"model has categorical data (categories_nodes) but no split_type",
)
}
return sets, nil
}

if int64(len(xt.SplitType)) != numNodes {
return nil, fmt.Errorf(
"split_type length %d does not match num_nodes %d",
len(xt.SplitType),
numNodes,
)
}

// Every node that split_type marks as categorical must have a decoded set,
// and vice versa; a mismatch means we would treat a node as the wrong split
// kind. Any split_type other than 0 (numeric) or 1 (categorical) is an
// encoding we do not understand, so reject it rather than defaulting it to
// numeric.
for i := range numNodes {
switch xt.SplitType[i] {
case 0, 1:
default:
return nil, fmt.Errorf(
"node %d has unsupported split_type %d",
i,
xt.SplitType[i],
)
}
_, hasSet := sets[i]
isCategorical := xt.SplitType[i] == 1
if hasSet != isCategorical {
return nil, fmt.Errorf(
"node %d has split_type %d but %s in categories_nodes",
i,
xt.SplitType[i],
presence(hasSet),
)
}
}

return sets, nil
}

func presence(present bool) string {
if present {
return "is present"
}
return "is absent"
}

func readModelMeta(x *xgbModel) (modelMeta, error) {
obj := objective(x.Learner.Objective.Name)

Expand Down
Loading
Loading