From df0db8570e6deaa04935faa083244f198ec2e6cc Mon Sep 17 00:00:00 2001 From: Noel Dawe Date: Thu, 25 Jun 2026 22:45:53 -0400 Subject: [PATCH 1/3] Missingness splits with -Infinity thresholds require special handling --- CHANGELOG.md | 9 ++ gen/generate.go | 43 +++++- gen/generate_test.go | 94 ++++++++++++ gen/parse.go | 199 +++++++++++++++++++++++--- gen/parse_test.go | 132 ++++++++++++++++- gen/templates.go | 6 + gen/templates/decision_node.tmpl | 4 +- gen/testdata/neg-inf-split/model.json | 42 ++++++ gen/testdata/neg-inf-split/preds.csv | 3 + gen/testdata/neg-inf-split/xtest.csv | 3 + main_test.go | 1 + version.go | 2 +- 12 files changed, 513 insertions(+), 25 deletions(-) create mode 100644 gen/generate_test.go create mode 100644 gen/testdata/neg-inf-split/model.json create mode 100644 gen/testdata/neg-inf-split/preds.csv create mode 100644 gen/testdata/neg-inf-split/xtest.csv diff --git a/CHANGELOG.md b/CHANGELOG.md index 24510fd..2e6e695 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,14 @@ # CHANGELOG +## 1.2.0 + +- Added support for `-Infinity` values in a tree's `split_conditions`, which + XGBoost's histogram tree method emits for splits that route on missingness + (missing values follow the node's default direction and every present value + goes to the other child). Such models previously failed to parse. Other + non-finite split values (`+Infinity` or `NaN`) and non-finite leaf values now + produce a clear error. + ## 1.1.0 (2026-06-17) - Added support for categorical splits (models trained with diff --git a/gen/generate.go b/gen/generate.go index f704c45..fb905a9 100644 --- a/gen/generate.go +++ b/gen/generate.go @@ -4,6 +4,7 @@ package gen import ( "fmt" "go/format" + "math" "os" ) @@ -41,9 +42,12 @@ func generateSource( return "", err } - // We run the code through formatting to check for syntax errors. We don't - // return the formatted code since we intend what we generate to already be - // well formatted. + // We run the code through formatting to catch structural/syntax errors. We + // don't return the formatted code since we intend what we generate to already + // be well formatted. Note this only validates syntax, not values: a stray + // non-finite literal like -Inf is a valid Go identifier and would pass here, + // so checkSplitCondition (not this) is what guarantees no such literal is + // ever emitted. if _, err := format.Source([]byte(code)); err != nil { return "", fmt.Errorf("error formatting code: %w", err) } @@ -57,6 +61,37 @@ func codegenTree(r *renderer, tree *node, level int) (string, error) { return r.executeTerminalNode(tree, level) } + // A -Infinity threshold on a numeric split makes "*data[i] < threshold" + // false for every present value, so the split routes purely on whether the + // feature is missing. Rendering it literally would emit the uncompilable + // "*data[i] < -Inf", so collapse it to the equivalent missingness branch. + if !tree.data.Categorical && math.IsInf(tree.data.SplitCondition, -1) { + return codegenMissingnessSplit(r, tree, level) + } + + left, err := codegenTree(r, tree.left, level+1) + if err != nil { + return "", err + } + right, err := codegenTree(r, tree.right, level+1) + if err != nil { + return "", err + } + + return r.executeDecisionNode(tree, level, left, right, false) +} + +// codegenMissingnessSplit emits code for a numeric node whose -Infinity threshold +// makes it route on missingness alone: missing values go left (if default_left) +// and every present value goes right. +func codegenMissingnessSplit(r *renderer, tree *node, level int) (string, error) { + // With default_left == 0, missing also routes right, so the node reduces to + // its right subtree. The left subtree is unreachable and dropped; that is safe + // because parseTreeInfo has already validated every node before codegen runs. + if tree.data.DefaultLeft == 0 { + return codegenTree(r, tree.right, level) + } + left, err := codegenTree(r, tree.left, level+1) if err != nil { return "", err @@ -66,7 +101,7 @@ func codegenTree(r *renderer, tree *node, level int) (string, error) { return "", err } - return r.executeDecisionNode(tree, level, left, right) + return r.executeDecisionNode(tree, level, left, right, true) } // GenerateFile generates a .go file containing a function that implements the XGB model. diff --git a/gen/generate_test.go b/gen/generate_test.go new file mode 100644 index 0000000..1577051 --- /dev/null +++ b/gen/generate_test.go @@ -0,0 +1,94 @@ +package gen + +import ( + "math" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCodegenMissingnessSplit(t *testing.T) { + // leaf builds a terminal node with the given output value. + leaf := func(v float64) *node { + return &node{data: nodeData{SplitCondition: v}} + } + + // negInfRoot is a decision node whose -Infinity threshold makes it route on + // missingness alone. defaultLeft controls where missing values go. + negInfRoot := func(defaultLeft int) *node { + return &node{ + data: nodeData{ + SplitCondition: math.Inf(-1), + SplitIndex: 0, + DefaultLeft: defaultLeft, + }, + left: leaf(10), + right: leaf(20), + } + } + + t.Run("default_left routes missing left, present right", func(t *testing.T) { + r, err := newRenderer() + require.NoError(t, err) + + code, err := codegenTree(r, negInfRoot(1), 0) + require.NoError(t, err) + + // Verify branch placement, not just presence: missing (nil) routes to + // the left leaf (10) and every present value to the right leaf (20). An + // inverted condition or swapped leaves would slip past a presence-only + // check but not these ordering assertions. + nilIdx := strings.Index(code, "if data[0] == nil {") + elseIdx := strings.Index(code, "} else {") + leftIdx := strings.Index(code, "sum += 10") + rightIdx := strings.Index(code, "sum += 20") + require.GreaterOrEqual(t, nilIdx, 0, "missingness branch must test data[0] == nil") + require.Greater(t, elseIdx, nilIdx) + assert.True(t, leftIdx > nilIdx && leftIdx < elseIdx, "missing must add the left leaf (10)") + assert.Greater(t, rightIdx, elseIdx, "present must add the right leaf (20)") + // The uncompilable literal comparison must never be emitted. + assert.NotContains(t, code, "Inf") + }) + + t.Run("no default_left collapses to the right subtree", func(t *testing.T) { + r, err := newRenderer() + require.NoError(t, err) + + code, err := codegenTree(r, negInfRoot(0), 0) + require.NoError(t, err) + + // Present and missing both route right, so the node disappears entirely. + assert.Equal(t, "sum += 20", strings.TrimSpace(code)) + }) + + t.Run("no default_left collapses into a right decision subtree", func(t *testing.T) { + r, err := newRenderer() + require.NoError(t, err) + + // The surviving right child is itself a decision node. The collapse must + // render it at the dropped node's level (not level+1) and omit the + // unreachable left subtree. A leaf-only right child (the case above) + // cannot catch a re-indentation bug because it has no indentation. + root := &node{ + data: nodeData{SplitCondition: math.Inf(-1), DefaultLeft: 0, SplitIndex: 0}, + left: leaf(10), + right: &node{ + data: nodeData{SplitCondition: 1.5, SplitIndex: 1, DefaultLeft: 1}, + left: leaf(20), + right: leaf(30), + }, + } + code, err := codegenTree(r, root, 0) + require.NoError(t, err) + + assert.NotContains(t, code, "sum += 10", "dropped left subtree must not appear") + // Rendered at the parent's level (0) => a single leading tab, not two. + assert.True( + t, + strings.HasPrefix(code, "\tif data[1] == nil || *data[1] < 1.5 {"), + "right subtree must render at the parent level; got:\n%s", code, + ) + }) +} diff --git a/gen/parse.go b/gen/parse.go index cfe0cc7..4f46b35 100644 --- a/gen/parse.go +++ b/gen/parse.go @@ -1,6 +1,7 @@ package gen import ( + "bytes" "encoding/json" "errors" "fmt" @@ -61,11 +62,11 @@ type xgbModel struct { } type xgbTree struct { - DefaultLeft []int `json:"default_left"` - LeftChildren []int `json:"left_children"` - RightChildren []int `json:"right_children"` - SplitConditions []float64 `json:"split_conditions"` - SplitIndices []int `json:"split_indices"` + DefaultLeft []int `json:"default_left"` + LeftChildren []int `json:"left_children"` + RightChildren []int `json:"right_children"` + SplitConditions []xgbFloat `json:"split_conditions"` + SplitIndices []int `json:"split_indices"` // SplitType marks each node's split kind: 0 = numeric, 1 = categorical. It // is absent in models trained without categorical features, in which case // every split is numeric. @@ -83,21 +84,121 @@ type xgbTree struct { } `json:"tree_param"` } +// xgbFloat is a float64 decoded from XGBoost's JSON, where a number may appear +// either as a normal JSON number or as one of the non-finite tokens Infinity, +// -Infinity, and NaN that XGBoost emits but standard JSON forbids. readModel +// rewrites those tokens to quoted strings before decoding (see +// sanitizeNonFiniteNumbers), so this unmarshaler accepts a JSON number or any +// quoted float literal (which includes the rewritten tokens). It is currently +// used only for the split_conditions field, the one place these tokens occur. +type xgbFloat float64 + +func (s *xgbFloat) UnmarshalJSON(b []byte) error { + if len(b) > 0 && b[0] == '"' { + var str string + if err := json.Unmarshal(b, &str); err != nil { + return fmt.Errorf("decoding split_condition string: %w", err) + } + f, err := strconv.ParseFloat(str, 64) + if err != nil { + return fmt.Errorf("invalid split_condition %q: %w", str, err) + } + *s = xgbFloat(f) + return nil + } + var f float64 + if err := json.Unmarshal(b, &f); err != nil { + return fmt.Errorf("decoding split_condition number: %w", err) + } + *s = xgbFloat(f) + return nil +} + func readModel(inputJSON string) (*xgbModel, error) { - fh, err := os.Open(filepath.Clean(inputJSON)) + data, err := os.ReadFile(filepath.Clean(inputJSON)) if err != nil { return nil, fmt.Errorf("error opening file: %w", err) } - defer fh.Close() var x xgbModel - if err := json.NewDecoder(fh).Decode(&x); err != nil { + if err := json.Unmarshal(sanitizeNonFiniteNumbers(data), &x); err != nil { return nil, fmt.Errorf("error decoding JSON: %w", err) } return &x, nil } +// sanitizeNonFiniteNumbers rewrites the JSON-incompatible literals XGBoost emits +// for non-finite floats (Infinity, -Infinity, NaN) into quoted strings, anywhere +// they appear outside a JSON string literal (never inside string contents). In +// well-formed XGBoost output these tokens only ever appear as numeric values. +// encoding/json rejects these tokens at the lexer level, before any custom +// unmarshaler can see them, so they must be rewritten in the raw bytes; +// xgbFloat.UnmarshalJSON then accepts the quoted form. The input is returned +// unchanged when no such token is present. +// +// XGBoost writes these exact tokens from PrintSpecialFloat in +// src/common/charconv.cc, reached when JsonWriter serializes each float of a +// tree's split_conditions array; its own (non-standard) JSON parser reads them +// back, so they are a deliberate, round-trippable encoding rather than +// corruption. They occur only in the text .json format; the binary UBJSON +// (.ubj) format stores raw IEEE bytes instead. +func sanitizeNonFiniteNumbers(data []byte) []byte { + // "Infinity" is a substring of "-Infinity", so this also detects the latter. + if !bytes.Contains(data, []byte("Infinity")) && + !bytes.Contains(data, []byte("NaN")) { + return data + } + + // Checked longest-first so -Infinity is matched before Infinity. + tokens := [][]byte{[]byte("-Infinity"), []byte("Infinity"), []byte("NaN")} + + out := make([]byte, 0, len(data)) + inString := false + for i := 0; i < len(data); { + c := data[i] + if inString { + out = append(out, c) + // Skip the escaped character so an escaped quote does not end the + // string prematurely. + if c == '\\' && i+1 < len(data) { + out = append(out, data[i+1]) + i += 2 + continue + } + if c == '"' { + inString = false + } + i++ + continue + } + if c == '"' { + inString = true + out = append(out, c) + i++ + continue + } + matched := false + for _, tok := range tokens { + if !bytes.HasPrefix(data[i:], tok) { + continue + } + out = append(out, '"') + out = append(out, tok...) + out = append(out, '"') + i += len(tok) + matched = true + break + } + if matched { + continue + } + out = append(out, c) + i++ + } + return out +} + func readTrees(x *xgbModel) ([]*node, error) { var trees []*node for i := range x.Learner.GradientBooster.Model.Trees { @@ -169,8 +270,12 @@ type node struct { } type nodeData struct { - DefaultLeft int - ID int64 + DefaultLeft int + ID int64 + // SplitCondition is kept as float64 through parsing but is narrowed back to + // float32 when rendered into the generated comparison. That is lossless + // because XGBoost stores split conditions as float32 in the first place; the + // float64 here is only the widening from the JSON round-trip. SplitCondition float64 SplitIndex int // Categorical reports whether this is a categorical split. When true, @@ -201,23 +306,42 @@ func parseTreeInfo(xt xgbTree) (*node, error) { nodes := make([]node, numNodes) for i := range numNodes { cats, categorical := categories[i] + sc := float64(xt.SplitConditions[i]) + + left := xt.LeftChildren[i] + right := xt.RightChildren[i] + isLeaf := left == -1 && right == -1 + + if err := checkSplitCondition(i, sc, categorical, isLeaf); err != nil { + return nil, err + } + + // default_left is a boolean encoded as 0 or 1. The decision-node template + // and the -Infinity collapse both treat any non-zero as "missing goes + // left", so a corrupt value would be silently normalized; reject it. + defaultLeft := xt.DefaultLeft[i] + if defaultLeft != 0 && defaultLeft != 1 { + return nil, fmt.Errorf( + "node %d has invalid default_left %d (want 0 or 1)", + i, + defaultLeft, + ) + } + nodes[i].data = nodeData{ - DefaultLeft: xt.DefaultLeft[i], + DefaultLeft: defaultLeft, ID: i, - SplitCondition: xt.SplitConditions[i], + SplitCondition: sc, SplitIndex: xt.SplitIndices[i], Categorical: categorical, Categories: cats, } - left := xt.LeftChildren[i] - right := xt.RightChildren[i] - // A leaf has no children (both -1). Any other combination, including a // half-wired node with exactly one child, is malformed: the codegen // treats a node with a nil child as a leaf and would silently drop the // other subtree, so reject it rather than mispredict. - if left == -1 && right == -1 { + if isLeaf { continue } if left < 0 || int64(left) >= numNodes || @@ -239,6 +363,49 @@ func parseTreeInfo(xt xgbTree) (*node, error) { return &nodes[0], nil // Root node } +// checkSplitCondition validates a node's split_conditions value. The value is a +// numeric threshold for a numeric decision node, a leaf's output value for a +// leaf, and a dummy (ignored) value for a categorical node. Only finite values +// are supported, with one exception: a -Infinity threshold on a numeric decision +// node. Such a threshold makes "value < threshold" false for every present +// value, so the node routes all present values to its right child and missing +// values per default_left. That is a "missingness split", which codegenTree +// collapses to a clean branch. +Infinity and NaN are rejected because they +// cannot be rendered as Go float literals and, unlike -Infinity, have no +// unambiguous collapsed form (a present +Infinity feature value, for instance, +// would still route differently). +// +// The -Infinity comes from XGBoost's histogram-based split finder (the hist +// tree method, including when fed a QuantileDMatrix): a split that isolates +// missing values lands at a feature's minimum, and the lower bound of the +// lowest histogram bin is -inf (NumericBinLowerBound in XGBoost's +// src/common/hist_util.h). +func checkSplitCondition(id int64, sc float64, categorical, isLeaf bool) error { + if categorical { + // The value is a dummy for categorical nodes and is never used. + return nil + } + // A -Infinity threshold on a decision node is the supported missingness + // split; everything else non-finite is rejected. + if !isLeaf && math.IsInf(sc, -1) { + return nil + } + if math.IsInf(sc, 0) || math.IsNaN(sc) { + kind := "split threshold" + if isLeaf { + kind = "leaf value" + } + return fmt.Errorf( + "node %d has a non-finite %s (%v); only finite values are "+ + "supported, except a -Infinity threshold on a decision node", + id, + kind, + sc, + ) + } + return nil +} + // checkNodeArrays verifies that every per-node array has exactly one entry per // node, so the indexing in parseTreeInfo cannot panic on a truncated or // inconsistent model. num_nodes comes from a separate JSON field and so is not diff --git a/gen/parse_test.go b/gen/parse_test.go index e6e13ed..4f13b81 100644 --- a/gen/parse_test.go +++ b/gen/parse_test.go @@ -137,7 +137,7 @@ func TestParseTreeInfoCategorical(t *testing.T) { DefaultLeft: []int{1, 0, 0}, LeftChildren: []int{1, -1, -1}, RightChildren: []int{2, -1, -1}, - SplitConditions: []float64{1e-45, 0.5, -0.5}, + SplitConditions: []xgbFloat{1e-45, 0.5, -0.5}, SplitIndices: []int{1, 0, 0}, SplitType: []int{1, 0, 0}, Categories: []int{0, 2}, @@ -299,7 +299,7 @@ func TestParseTreeInfoValidation(t *testing.T) { DefaultLeft: []int{0, 0, 0}, LeftChildren: []int{1, -1, -1}, RightChildren: []int{2, -1, -1}, - SplitConditions: []float64{0.5, 0, 0}, + SplitConditions: []xgbFloat{0.5, 0, 0}, SplitIndices: []int{0, 0, 0}, } xt.TreeParam.NumNodes = "3" @@ -345,6 +345,12 @@ func TestParseTreeInfoValidation(t *testing.T) { xt.RightChildren = []int{-1, -1, -1} }, }, + { + name: "default_left outside {0, 1}", + mutate: func(xt *xgbTree) { + xt.DefaultLeft = []int{2, 0, 0} + }, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -356,6 +362,126 @@ func TestParseTreeInfoValidation(t *testing.T) { } } +func TestSanitizeNonFiniteNumbers(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + { + name: "no tokens is unchanged", + in: `{"split_conditions":[1.5,-2.0]}`, + want: `{"split_conditions":[1.5,-2.0]}`, + }, + { + name: "quotes the non-finite tokens", + in: `[1.5, -Infinity, Infinity, NaN, 2.0]`, + want: `[1.5, "-Infinity", "Infinity", "NaN", 2.0]`, + }, + { + // -Infinity must be matched as a whole so the minus sign is not left + // behind as a separate token. + name: "negative infinity keeps its sign inside the quotes", + in: `[-Infinity]`, + want: `["-Infinity"]`, + }, + { + // A feature literally named with a token must not be rewritten; it is + // inside a JSON string, not a numeric value. + name: "tokens inside strings are left alone", + in: `{"feature_names":["NaN","Infinity"],"split_conditions":[NaN]}`, + want: `{"feature_names":["NaN","Infinity"],"split_conditions":["NaN"]}`, + }, + { + name: "escaped quote does not end the string early", + in: `{"name":"a\"NaN","v":[NaN]}`, + want: `{"name":"a\"NaN","v":["NaN"]}`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assert.Equal( + t, + test.want, + string(sanitizeNonFiniteNumbers([]byte(test.in))), + ) + }) + } +} + +func TestSplitConditionUnmarshalJSON(t *testing.T) { + // After sanitizeNonFiniteNumbers runs, split_conditions is a mix of JSON + // numbers and quoted non-finite tokens; both forms must decode. + var got []xgbFloat + err := json.Unmarshal( + []byte(`[1.5, "-Infinity", "Infinity", "NaN", -2.0]`), + &got, + ) + require.NoError(t, err) + + require.Len(t, got, 5) + assert.InDelta(t, 1.5, float64(got[0]), 1e-10) + assert.True(t, math.IsInf(float64(got[1]), -1)) + assert.True(t, math.IsInf(float64(got[2]), 1)) + assert.True(t, math.IsNaN(float64(got[3]))) + assert.InDelta(t, -2.0, float64(got[4]), 1e-10) + + // Malformed elements must surface a decode error, not silently decode as 0: + // a quoted token strconv cannot parse, and a bare token that is neither a + // number nor (after sanitizeNonFiniteNumbers) a quoted non-finite token. + t.Run("quoted non-numeric string is rejected", func(t *testing.T) { + err := json.Unmarshal([]byte(`["bogus"]`), &[]xgbFloat{}) + require.Error(t, err) + }) + t.Run("non-numeric bare token is rejected", func(t *testing.T) { + err := json.Unmarshal([]byte(`[true]`), &[]xgbFloat{}) + require.Error(t, err) + }) +} + +func TestParseTreeInfoNonFinite(t *testing.T) { + // baseTree is a valid numeric three-node tree (root with two leaves) whose + // root threshold each case overrides. + baseTree := func(rootCond xgbFloat, defaultLeft int) xgbTree { + xt := xgbTree{ + DefaultLeft: []int{defaultLeft, 0, 0}, + LeftChildren: []int{1, -1, -1}, + RightChildren: []int{2, -1, -1}, + SplitConditions: []xgbFloat{rootCond, 10, 20}, + SplitIndices: []int{0, 0, 0}, + } + xt.TreeParam.NumNodes = "3" + return xt + } + + negInf := xgbFloat(math.Inf(-1)) + posInf := xgbFloat(math.Inf(1)) + nan := xgbFloat(math.NaN()) + + t.Run("negative-infinity threshold is accepted", func(t *testing.T) { + _, err := parseTreeInfo(baseTree(negInf, 1)) + require.NoError(t, err) + }) + + t.Run("positive-infinity threshold is rejected", func(t *testing.T) { + _, err := parseTreeInfo(baseTree(posInf, 1)) + require.Error(t, err) + }) + + t.Run("NaN threshold is rejected", func(t *testing.T) { + _, err := parseTreeInfo(baseTree(nan, 1)) + require.Error(t, err) + }) + + t.Run("non-finite leaf value is rejected", func(t *testing.T) { + xt := baseTree(0.5, 0) + xt.SplitConditions[1] = negInf + _, err := parseTreeInfo(xt) + require.Error(t, err) + }) +} + func TestReadModelMeta(t *testing.T) { tests := []struct { name string @@ -580,7 +706,7 @@ func TestReadTreesNumParallelTree(t *testing.T) { DefaultLeft: []int{0}, LeftChildren: []int{-1}, RightChildren: []int{-1}, - SplitConditions: []float64{0}, + SplitConditions: []xgbFloat{0}, SplitIndices: []int{0}, } tree.TreeParam.NumNodes = "1" diff --git a/gen/templates.go b/gen/templates.go index f0bc477..8cb6e7e 100644 --- a/gen/templates.go +++ b/gen/templates.go @@ -90,6 +90,10 @@ type decisionNodeParams struct { // is true when the feature value is *not* in the node's category set (and so // routes left). It mirrors the numeric "*data[i] < threshold" predicate. CategoryTest string + // MissingOnly marks a numeric node whose -Infinity threshold routes on + // missingness alone, so the branch tests only "data[i] == nil" rather than + // the threshold comparison. See codegenMissingnessSplit. + MissingOnly bool } func (r *renderer) executeDecisionNode( @@ -97,6 +101,7 @@ func (r *renderer) executeDecisionNode( level int, left, right string, + missingOnly bool, ) (string, error) { var buf bytes.Buffer err := r.template.ExecuteTemplate( @@ -108,6 +113,7 @@ func (r *renderer) executeDecisionNode( nodeData: tree.data, Right: right, CategoryTest: categoryTest(tree.data), + MissingOnly: missingOnly, }, ) if err != nil { diff --git a/gen/templates/decision_node.tmpl b/gen/templates/decision_node.tmpl index a5b49bd..743b148 100644 --- a/gen/templates/decision_node.tmpl +++ b/gen/templates/decision_node.tmpl @@ -1,4 +1,6 @@ -{{if .DefaultLeft -}} +{{if .MissingOnly -}} +{{indent .Level}}if data[{{.SplitIndex}}] == nil { +{{else if .DefaultLeft -}} {{indent .Level}}if data[{{.SplitIndex}}] == nil || {{if .Categorical}}{{.CategoryTest}}{{else}}*data[{{.SplitIndex}}] < {{.SplitCondition}}{{end}} { {{else -}} {{indent .Level}}if data[{{.SplitIndex}}] != nil && {{if .Categorical}}{{.CategoryTest}}{{else}}*data[{{.SplitIndex}}] < {{.SplitCondition}}{{end}} { diff --git a/gen/testdata/neg-inf-split/model.json b/gen/testdata/neg-inf-split/model.json new file mode 100644 index 0000000..3547b9e --- /dev/null +++ b/gen/testdata/neg-inf-split/model.json @@ -0,0 +1,42 @@ +{ + "learner": { + "gradient_booster": { + "model": { + "gbtree_model_param": { + "num_parallel_tree": "1", + "num_trees": "2" + }, + "trees": [ + { + "default_left": [1, 0, 0], + "left_children": [1, -1, -1], + "right_children": [2, -1, -1], + "split_conditions": [-Infinity, 10.0, 20.0], + "split_indices": [0, 0, 0], + "tree_param": { + "num_nodes": "3" + } + }, + { + "default_left": [0, 0, 0], + "left_children": [1, -1, -1], + "right_children": [2, -1, -1], + "split_conditions": [-Infinity, 100.0, 200.0], + "split_indices": [0, 0, 0], + "tree_param": { + "num_nodes": "3" + } + } + ] + } + }, + "learner_model_param": { + "base_score": "0E0", + "num_class": "0", + "num_target": "1" + }, + "objective": { + "name": "reg:squarederror" + } + } +} diff --git a/gen/testdata/neg-inf-split/preds.csv b/gen/testdata/neg-inf-split/preds.csv new file mode 100644 index 0000000..059037c --- /dev/null +++ b/gen/testdata/neg-inf-split/preds.csv @@ -0,0 +1,3 @@ +210 +220 +220 diff --git a/gen/testdata/neg-inf-split/xtest.csv b/gen/testdata/neg-inf-split/xtest.csv new file mode 100644 index 0000000..f0e4d20 --- /dev/null +++ b/gen/testdata/neg-inf-split/xtest.csv @@ -0,0 +1,3 @@ +,1.0 +3.5,1.0 +-99999,1.0 diff --git a/main_test.go b/main_test.go index 75946f8..63b6388 100644 --- a/main_test.go +++ b/main_test.go @@ -23,6 +23,7 @@ func TestGenerateAndRunModels(t *testing.T) { {model: "binary-logitraw"}, {model: "reg-quantileerror"}, {model: "categorical"}, + {model: "neg-inf-split"}, } for _, test := range tests { diff --git a/version.go b/version.go index 8e54e85..740be4e 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package main -const version = "1.1.0" +const version = "1.2.0" From 2c8add0f75daad61d7d389df05a059485bf643c0 Mon Sep 17 00:00:00 2001 From: Noel Dawe Date: Fri, 26 Jun 2026 12:01:02 -0400 Subject: [PATCH 2/3] Reordered the checks in checkSplitCondition so the leaf case is validated first --- gen/parse.go | 28 ++++++++++++++++++++-------- gen/parse_test.go | 18 ++++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/gen/parse.go b/gen/parse.go index 4f46b35..0686d39 100644 --- a/gen/parse.go +++ b/gen/parse.go @@ -381,25 +381,37 @@ func parseTreeInfo(xt xgbTree) (*node, error) { // lowest histogram bin is -inf (NumericBinLowerBound in XGBoost's // src/common/hist_util.h). func checkSplitCondition(id int64, sc float64, categorical, isLeaf bool) error { + // A leaf's value is its output, emitted verbatim as "sum += value" by the + // terminal_node template regardless of any (malformed) categorical marking, + // so it must always be finite. Check this before the categorical exemption + // below, which only applies to a categorical node's dummy threshold. + if isLeaf { + if math.IsInf(sc, 0) || math.IsNaN(sc) { + return fmt.Errorf( + "node %d has a non-finite leaf value (%v); only finite "+ + "values are supported, except a -Infinity threshold on a "+ + "decision node", + id, + sc, + ) + } + return nil + } if categorical { // The value is a dummy for categorical nodes and is never used. return nil } // A -Infinity threshold on a decision node is the supported missingness // split; everything else non-finite is rejected. - if !isLeaf && math.IsInf(sc, -1) { + if math.IsInf(sc, -1) { return nil } if math.IsInf(sc, 0) || math.IsNaN(sc) { - kind := "split threshold" - if isLeaf { - kind = "leaf value" - } return fmt.Errorf( - "node %d has a non-finite %s (%v); only finite values are "+ - "supported, except a -Infinity threshold on a decision node", + "node %d has a non-finite split threshold (%v); only finite "+ + "values are supported, except a -Infinity threshold on a "+ + "decision node", id, - kind, sc, ) } diff --git a/gen/parse_test.go b/gen/parse_test.go index 4f13b81..9d5568f 100644 --- a/gen/parse_test.go +++ b/gen/parse_test.go @@ -480,6 +480,24 @@ func TestParseTreeInfoNonFinite(t *testing.T) { _, err := parseTreeInfo(xt) require.Error(t, err) }) + + // A malformed tree can mark a leaf categorical (split_type 1 with a decoded + // category set). The leaf value is still emitted verbatim by the terminal + // node template, so the categorical dummy-value exemption must not let a + // non-finite leaf value through. + t.Run("non-finite value on a categorical leaf is rejected", func(t *testing.T) { + for _, sc := range []xgbFloat{negInf, posInf, nan} { + xt := baseTree(0.5, 0) + xt.SplitConditions[1] = sc + xt.SplitType = []int{0, 1, 0} + xt.Categories = []int{0} + xt.CategoriesNodes = []int{1} + xt.CategoriesSegments = []int{0} + xt.CategoriesSizes = []int{1} + _, err := parseTreeInfo(xt) + require.Error(t, err) + } + }) } func TestReadModelMeta(t *testing.T) { From 7a63d98f74e57b39be629ecba20f29943a8e450c Mon Sep 17 00:00:00 2001 From: Noel Dawe Date: Fri, 26 Jun 2026 12:10:09 -0400 Subject: [PATCH 3/3] Update inaccurate comment in gen/parse.go --- gen/parse.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gen/parse.go b/gen/parse.go index 0686d39..519ef27 100644 --- a/gen/parse.go +++ b/gen/parse.go @@ -272,10 +272,12 @@ type node struct { type nodeData struct { DefaultLeft int ID int64 - // SplitCondition is kept as float64 through parsing but is narrowed back to - // float32 when rendered into the generated comparison. That is lossless - // because XGBoost stores split conditions as float32 in the first place; the - // float64 here is only the widening from the JSON round-trip. + // SplitCondition is kept as float64 through parsing and rendered as a + // full-precision decimal literal. In the generated comparison it meets + // *data[i] (a float32), so the Go compiler converts the literal back to + // float32 at compile time. That round-trip is lossless because XGBoost + // stores split conditions as float32 in the first place; the float64 here is + // only the widening from the JSON parse. SplitCondition float64 SplitIndex int // Categorical reports whether this is a categorical split. When true,