RFC-0008: GBDT Training#
Status: Implemented
Created: 2025-12-01
Updated: 2026-01-02
Scope: Gradient boosted decision tree training pipeline
Summary#
GBDT training combines objective-driven gradient computation, histogram-based
split finding, and tree growing into an iterative boosting loop. The design
separates orchestration (GBDTTrainer) from tree mechanics (TreeGrower).
Why Histogram-Based Training?#
Traditional exact split finding scans all samples for every feature at every node—O(n × m × depth) per tree. Histogram-based training bins features upfront:
Aspect |
Exact |
Histogram-Based |
|---|---|---|
Split candidates |
All unique values |
Up to 256 bins |
Memory per node |
Full sample refs |
Histogram only |
Parallelism |
Limited |
Easy: each feature independent |
XGBoost, LightGBM, and CatBoost all use histogram-based training.
Layers#
High Level#
Users call GBDTModel::train:
let model = GBDTModel::train(&dataset, Some(&eval_set), config, seed)?;
This constructs GBDTTrainer internally and runs the boosting loop.
Quick Start#
use boosters::{GBDTModel, GBDTConfig, Dataset};
let dataset = Dataset::from_array(features.view(), Some(targets.view()), None)?;
let config = GBDTConfig::default(); // 100 trees, lr=0.3, max_depth=6
let model = GBDTModel::train(&dataset, None, config, 42)?;
let preds = model.predict(&dataset, 4); // 4 threads
Medium Level (Trainer)#
pub struct GBDTTrainer<O: ObjectiveFn, M: MetricFn> {
objective: O,
metric: M,
params: GBDTParams,
}
impl<O, M> GBDTTrainer<O, M> {
pub fn train<W, T>(
&self,
dataset: &BinnedDataset,
targets: T,
weights: W,
eval_set: Option<(&Dataset, T)>,
parallelism: Parallelism,
) -> Result<Forest<ScalarLeaf>, TrainingError>;
}
Boosting loop (conceptual):
Compute gradients from objective:
objective.gradients(preds, targets, grads)Grow tree from gradients:
grower.grow(dataset, grads)Update predictions:
preds += learning_rate * tree_outputsEvaluate and check early stopping
Repeat for
n_treesrounds
Medium Level (Grower)#
pub struct TreeGrower {
params: GrowerParams,
histogram_pool: HistogramPool,
partitioner: RowPartitioner,
tree_builder: MutableTree<ScalarLeaf>,
histogram_builder: HistogramBuilder,
// ... feature metadata, samplers
}
impl TreeGrower {
pub fn grow(
&mut self,
dataset: &BinnedDataset,
gradients: &Gradients,
parallelism: Parallelism,
) -> Tree<ScalarLeaf>;
}
Low Level (Split Finding)#
Split finding evaluates every feature at every bin boundary to find the optimal split point.
Gain Formula#
Split gain uses the XGBoost formula with regularization:
Where:
\(G_L, G_R, G_P\) = gradient sums for left, right, parent
\(H_L, H_R, H_P\) = hessian sums for left, right, parent
\(\lambda\) = L2 regularization (
reg_lambda)\(\gamma\) = minimum gain threshold (
min_split_gain)
Optimization: Parent score is precomputed once per node, reducing from 3 divisions to 2 per candidate split.
Leaf Weight Formula#
Where \(\alpha\) = L1 regularization (reg_alpha). When \(\alpha = 0\), this
simplifies to the Newton step: \(w = -G/(H + \lambda)\).
Splitter API#
pub struct GreedySplitter {
gain_params: GainParams,
max_onehot_cats: u32,
parallelism: Parallelism,
}
impl GreedySplitter {
pub fn find_split(
&self,
histogram: HistogramView<'_>,
parent_stats: GradsTuple,
feature_indices: &[u32],
) -> Option<SplitInfo>;
}
Scan strategies:
Numerical: Bidirectional scan for optimal missing value handling
Categorical one-hot: Each category as singleton left partition
Categorical sorted: Sort by grad/hess ratio, scan partition point
Key Design Decisions#
DD-1: Subtraction Trick#
When parent and one child histogram exist, compute sibling by subtraction:
sibling = parent - child. Reduces histogram builds by ~50%.
Parent (computed)
├── Left (computed) ← Smaller child
└── Right = Parent - Left ← Subtraction
Always compute histogram for smaller child (fewer samples to aggregate).
DD-2: Growth Strategies#
pub enum GrowthStrategy {
DepthWise { max_depth: u32 }, // XGBoost-style: level by level
LeafWise { max_leaves: u32 }, // LightGBM-style: best-gain first
}
Both produce equivalent trees given same hyperparameters; leaf-wise often converges faster but risks overfitting without early stopping.
DD-3: Row Partitioning#
Samples are partitioned into node-specific ranges as tree grows. Benefits:
Gradient gathering is sequential (cache-friendly)
Histogram building accesses contiguous memory
Child counts known for subtraction trick
Partitioner uses double-buffer swap to avoid allocation per split.
DD-4: Ordered Gradients#
Before histogram building, gradients are gathered into contiguous buffers per node, ordered by sample index within that node. This enables vectorized histogram kernels.
DD-5: LRU Histogram Cache#
Large trees may exceed memory if all histograms are kept. HistogramPool uses
LRU eviction, keeping only recently used histograms for the subtraction trick.
DD-6: Multi-Output via Tree Groups#
For K-class classification, we train K trees per round (one per class). Each tree sees class-specific gradients. Trees are grouped in the forest. This matches XGBoost/LightGBM behavior.
Objective and Metric Traits#
pub trait ObjectiveFn: Send + Sync {
fn n_outputs(&self) -> usize;
fn init_predictions(&self, targets: &[f32], out: &mut [f32]);
fn gradients(&self, preds: &[f32], targets: &[f32], grads: &mut [GradsTuple]);
}
pub trait MetricFn: Send + Sync {
fn name(&self) -> &str;
fn score(&self, preds: &[f32], targets: &[f32]) -> f64;
fn higher_is_better(&self) -> bool;
}
Built-in objectives: SquaredError, LogLoss, Softmax.
Built-in metrics: RMSE, MAE, LogLoss, AUC, Accuracy.
Sampling#
Row Sampling#
pub enum RowSamplingParams {
None,
Uniform { subsample: f32 },
GOSS { top_rate: f32, other_rate: f32 }, // Gradient-based
}
GOSS (Gradient-based One-Side Sampling) keeps all high-gradient samples and subsamples low-gradient ones. From LightGBM, improves quality under sampling.
Column Sampling#
pub enum ColSamplingParams {
None,
ByTree { colsample: f32 },
ByLevel { colsample: f32 },
ByNode { colsample: f32 },
}
Parameters#
pub struct GBDTParams {
pub n_trees: u32, // Boosting rounds (default: 100)
pub learning_rate: f32, // Shrinkage (default: 0.3)
pub growth_strategy: GrowthStrategy,
pub gain: GainParams, // Regularization
pub row_sampling: RowSamplingParams,
pub col_sampling: ColSamplingParams,
pub cache_size: usize, // Histogram cache slots
pub early_stopping_rounds: u32,
pub verbosity: Verbosity,
pub seed: u64,
pub linear_leaves: Option<LinearLeafConfig>,
}
DART (Dropout Trees): Not currently implemented. DART adds dropout regularization by randomly dropping trees during training. Deferred to future work.
Early Stopping#
// In boosting loop
for round in 0..n_trees {
// ... train tree ...
if let Some(eval) = &eval_set {
let score = metric.compute(preds, targets);
if early_stopper.should_stop(round, score) {
break; // Stop training, keep best model
}
}
}
Early stopping monitors validation metric and stops when no improvement for
early_stopping_rounds consecutive rounds.
pub struct GainParams {
pub reg_lambda: f32, // L2 regularization
pub reg_alpha: f32, // L1 regularization (pruning)
pub min_child_weight: f32,
pub min_samples_leaf: u32,
pub min_split_gain: f32,
}
Testing Strategy#
Training correctness is validated through:
Category |
Location |
|---|---|
Unit tests |
Inline in |
Integration tests |
|
Quality benchmarks |
|
Reference models |
|
Files#
Path |
Contents |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Row and column samplers |