Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
# CHANGELOG

## 1.2.0

- Added support for `-Infinity` values in a tree's `split_conditions`, which
XGBoost's histogram tree method emits for splits that route on missingness
(missing values follow the node's default direction and every present value
goes to the other child). Such models previously failed to parse. Other
non-finite split values (`+Infinity` or `NaN`) and non-finite leaf values now
produce a clear error.

## 1.1.0 (2026-06-17)

- Added support for categorical splits (models trained with
Expand Down
43 changes: 39 additions & 4 deletions gen/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package gen
import (
"fmt"
"go/format"
"math"
"os"
)

Expand Down Expand Up @@ -41,9 +42,12 @@ func generateSource(
return "", err
}

// We run the code through formatting to check for syntax errors. We don't
// return the formatted code since we intend what we generate to already be
// well formatted.
// We run the code through formatting to catch structural/syntax errors. We
// don't return the formatted code since we intend what we generate to already
// be well formatted. Note this only validates syntax, not values: a stray
// non-finite literal like -Inf is a valid Go identifier and would pass here,
// so checkSplitCondition (not this) is what guarantees no such literal is
// ever emitted.
if _, err := format.Source([]byte(code)); err != nil {
return "", fmt.Errorf("error formatting code: %w", err)
}
Expand All @@ -57,6 +61,37 @@ func codegenTree(r *renderer, tree *node, level int) (string, error) {
return r.executeTerminalNode(tree, level)
}

// A -Infinity threshold on a numeric split makes "*data[i] < threshold"
// false for every present value, so the split routes purely on whether the
// feature is missing. Rendering it literally would emit the uncompilable
// "*data[i] < -Inf", so collapse it to the equivalent missingness branch.
if !tree.data.Categorical && math.IsInf(tree.data.SplitCondition, -1) {
return codegenMissingnessSplit(r, tree, level)
}

left, err := codegenTree(r, tree.left, level+1)
if err != nil {
return "", err
}
right, err := codegenTree(r, tree.right, level+1)
if err != nil {
return "", err
}

return r.executeDecisionNode(tree, level, left, right, false)
}

// codegenMissingnessSplit emits code for a numeric node whose -Infinity threshold
// makes it route on missingness alone: missing values go left (if default_left)
// and every present value goes right.
func codegenMissingnessSplit(r *renderer, tree *node, level int) (string, error) {
// With default_left == 0, missing also routes right, so the node reduces to
// its right subtree. The left subtree is unreachable and dropped; that is safe
// because parseTreeInfo has already validated every node before codegen runs.
if tree.data.DefaultLeft == 0 {
return codegenTree(r, tree.right, level)
}

left, err := codegenTree(r, tree.left, level+1)
if err != nil {
return "", err
Expand All @@ -66,7 +101,7 @@ func codegenTree(r *renderer, tree *node, level int) (string, error) {
return "", err
}

return r.executeDecisionNode(tree, level, left, right)
return r.executeDecisionNode(tree, level, left, right, true)
}

// GenerateFile generates a .go file containing a function that implements the XGB model.
Expand Down
94 changes: 94 additions & 0 deletions gen/generate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package gen

import (
"math"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCodegenMissingnessSplit(t *testing.T) {
// leaf builds a terminal node with the given output value.
leaf := func(v float64) *node {
return &node{data: nodeData{SplitCondition: v}}
}

// negInfRoot is a decision node whose -Infinity threshold makes it route on
// missingness alone. defaultLeft controls where missing values go.
negInfRoot := func(defaultLeft int) *node {
return &node{
data: nodeData{
SplitCondition: math.Inf(-1),
SplitIndex: 0,
DefaultLeft: defaultLeft,
},
left: leaf(10),
right: leaf(20),
}
}

t.Run("default_left routes missing left, present right", func(t *testing.T) {
r, err := newRenderer()
require.NoError(t, err)

code, err := codegenTree(r, negInfRoot(1), 0)
require.NoError(t, err)

// Verify branch placement, not just presence: missing (nil) routes to
// the left leaf (10) and every present value to the right leaf (20). An
// inverted condition or swapped leaves would slip past a presence-only
// check but not these ordering assertions.
nilIdx := strings.Index(code, "if data[0] == nil {")
elseIdx := strings.Index(code, "} else {")
leftIdx := strings.Index(code, "sum += 10")
rightIdx := strings.Index(code, "sum += 20")
require.GreaterOrEqual(t, nilIdx, 0, "missingness branch must test data[0] == nil")
require.Greater(t, elseIdx, nilIdx)
assert.True(t, leftIdx > nilIdx && leftIdx < elseIdx, "missing must add the left leaf (10)")
assert.Greater(t, rightIdx, elseIdx, "present must add the right leaf (20)")
// The uncompilable literal comparison must never be emitted.
assert.NotContains(t, code, "Inf")
})

t.Run("no default_left collapses to the right subtree", func(t *testing.T) {
r, err := newRenderer()
require.NoError(t, err)

code, err := codegenTree(r, negInfRoot(0), 0)
require.NoError(t, err)

// Present and missing both route right, so the node disappears entirely.
assert.Equal(t, "sum += 20", strings.TrimSpace(code))
})

t.Run("no default_left collapses into a right decision subtree", func(t *testing.T) {
r, err := newRenderer()
require.NoError(t, err)

// The surviving right child is itself a decision node. The collapse must
// render it at the dropped node's level (not level+1) and omit the
// unreachable left subtree. A leaf-only right child (the case above)
// cannot catch a re-indentation bug because it has no indentation.
root := &node{
data: nodeData{SplitCondition: math.Inf(-1), DefaultLeft: 0, SplitIndex: 0},
left: leaf(10),
right: &node{
data: nodeData{SplitCondition: 1.5, SplitIndex: 1, DefaultLeft: 1},
left: leaf(20),
right: leaf(30),
},
}
code, err := codegenTree(r, root, 0)
require.NoError(t, err)

assert.NotContains(t, code, "sum += 10", "dropped left subtree must not appear")
// Rendered at the parent's level (0) => a single leading tab, not two.
assert.True(
t,
strings.HasPrefix(code, "\tif data[1] == nil || *data[1] < 1.5 {"),
"right subtree must render at the parent level; got:\n%s", code,
)
})
}
Loading
Loading