RFC-0009: GBDT Inference#
Status: Implemented
Created: 2025-12-15
Updated: 2026-01-02
Scope: Tree ensemble prediction pipeline
Summary#
GBDT inference traverses each tree with input features to accumulate leaf values. Block-based processing and optional tree unrolling optimize throughput.
Why Sample-Major for Prediction?#
Training uses feature-major layout (features as columns) for histogram building. Prediction needs sample-major (samples as rows) because:
Access Pattern |
Feature-Major |
Sample-Major |
|---|---|---|
Traversal |
Random access per feature |
Sequential per sample |
Cache behavior |
Poor (jump between features) |
Good (all features together) |
Tree traversal accesses features in split order (unpredictable), so keeping each sample’s features contiguous improves cache hits.
Block-Based Layout Conversion#
Rather than converting the entire dataset upfront (expensive for large data),
we use buffer_samples() to convert feature-major to sample-major in blocks:
// Feature-major input: [n_features, n_samples]
// Sample-major buffer: [block_size, n_features]
let mut buffer = Array2::<f32>::zeros((block_size, n_features));
for block_start in (0..n_samples).step_by(block_size) {
let samples = dataset.buffer_samples(&mut buffer, block_start);
// samples is now sample-major, contiguous in memory
for sample in samples.iter() {
traverse_all_trees(sample);
}
}
This approach:
Converts only what fits in L2 cache (~256KB = 64 samples × 1000 features × 4 bytes)
Amortizes conversion cost across multiple trees
Reuses the buffer across blocks (no repeated allocation)
Works with sparse features (fills zeros appropriately)
Layers#
High Level#
Users call GBDTModel::predict:
let predictions = model.predict(&test_data, n_threads);
This creates a Predictor internally and returns raw predictions.
Medium Level (Predictor)#
pub struct Predictor<'f, T: TreeTraversal<ScalarLeaf>> {
forest: &'f Forest<ScalarLeaf>,
tree_states: Box<[T::TreeState]>,
block_size: usize,
}
impl<'f, T: TreeTraversal<ScalarLeaf>> Predictor<'f, T> {
pub fn new(forest: &'f Forest<ScalarLeaf>) -> Self;
pub fn predict(&self, data: &Dataset, parallelism: Parallelism) -> Array2<f32>;
}
The TreeTraversal trait parameter allows different traversal strategies.
Medium Level (TreeTraversal)#
pub trait TreeTraversal<L: LeafValue>: Clone {
type TreeState: Clone + Send + Sync;
fn build_tree_state(tree: &Tree<L>) -> Self::TreeState;
fn traverse_tree(tree: &Tree<L>, state: &Self::TreeState, sample: ArrayView1<f32>) -> NodeId;
fn traverse_block(tree: &Tree<L>, state: &Self::TreeState, data: &SamplesView, out: &mut [NodeId]);
}
Low Level (Traversal Strategies)#
// Simple: no precomputation, traverse nodes one by one
pub struct StandardTraversal;
// Unrolled: precompute flattened layout for first D levels
pub struct UnrolledTraversal<D: UnrollDepth>;
pub type UnrolledTraversal4 = UnrolledTraversal<Depth4>;
pub type UnrolledTraversal6 = UnrolledTraversal<Depth6>;
Block Processing#
Rather than traverse one tree per sample, we process blocks of samples through all trees:
Block of 64 samples
│
├─► Tree 0 ─► 64 leaf indices ─► accumulate
├─► Tree 1 ─► 64 leaf indices ─► accumulate
└─► Tree N ─► 64 leaf indices ─► accumulate
Benefits:
Tree data (splits, thresholds) stays in L1/L2 cache
Multiple samples amortize cache miss cost
Predictable memory access for leaf accumulation
Default block size: 64 (matches XGBoost).
Unrolled Traversal#
Traditional traversal:
loop {
if is_leaf(node) { return node; }
node = if feature[split_idx] < threshold { left } else { right };
}
Branch misprediction penalty is ~15-20 cycles per miss. Unrolled traversal precomputes a flattened array for the first D levels:
pub struct UnrolledTreeLayout<D: UnrollDepth> {
unrolled: Box<[UnrolledNode]>, // 2^D - 1 nodes
subtree_offsets: Box<[NodeId]>, // Map to original tree
}
For depth-6 unrolling (63 nodes), we do ~6 branchless comparisons, then continue from the subtree root. 2-3× faster for large batches.
Multi-Output Handling#
For K-class classification, the forest contains K trees per round. Output
shape is [n_samples, n_groups]. Trees are assigned to groups round-robin,
and predictions accumulate per-group.
for tree_idx in 0..forest.n_trees() {
let group = forest.tree_group(tree_idx);
for sample in 0..n_samples {
output[[sample, group]] += leaf_value;
}
}
Prediction Outputs#
Raw predictions are logits/scores. The predict() method applies the appropriate
transformation based on the objective:
// predict() applies transformation (sigmoid for binary, softmax for multiclass)
let predictions = model.predict(&data, n_threads);
// predict_raw() returns raw margin scores (no transformation)
let raw = model.predict_raw(&data, n_threads);
Note: There is no predict_proba() method on GBDTModel. The predict() method
already applies probability transformations. Use sklearn wrappers if you need
separate predict() (class labels) and predict_proba() (probabilities).
Parallelization#
With Parallelism::Parallel(n_threads):
// Parallelize over sample blocks
samples.par_chunks(block_size).for_each(|block| {
for tree in forest.trees() {
traverse_block(tree, block, &mut local_output);
}
});
Thread overhead is amortized by block size. For small datasets, sequential is often faster (no thread spawn cost).
Files#
Path |
Contents |
|---|---|
|
|
|
|
|
|
|
Module exports |
Design Decisions#
DD-1: Block size 64. Empirically optimal across CPU architectures. Matches
XGBoost default. Can be customized via with_block_size().
DD-2: Traversal as generic parameter. Compile-time strategy selection enables inlining and specialization. No virtual dispatch overhead.
DD-3: Unroll depth 6. Covers 63 nodes (first 6 levels). Most trees fit entirely or have only deep leaves remaining. Depth 8 (255 nodes) showed minimal additional benefit with larger memory footprint.
DD-4: Dataset reuse for prediction. The same Dataset type is used for
training and prediction. For training, binning converts to BinnedDataset.
For prediction, buffer_samples() provides sample-major blocks on-demand.
Benchmarks#
Inference throughput (Covertype, 54 features, 100 trees, depth 6):
Configuration |
Throughput |
|---|---|
Single-threaded, standard |
~2M samples/sec |
Single-threaded, unrolled |
~3M samples/sec |
8 threads, unrolled |
~15M samples/sec |
Testing Strategy#
Category |
Tests |
|---|---|
Traversal correctness |
Compare standard vs unrolled (same results) |
XGBoost compatibility |
Load XGBoost model, compare predictions |
LightGBM compatibility |
Load LightGBM model, compare predictions |
Edge cases |
Empty forest, single tree, max depth trees |
Parallelism |
Same results with 1, 4, 8 threads |