RFC-0013: Explainability#
Status: Implemented
Created: 2025-12-20
Updated: 2026-01-02
Scope: Feature importance and SHAP values
Summary#
Explainability features help users understand model behavior: which features matter (importance) and how they affect individual predictions (SHAP).
Why Explainability?#
Need |
Solution |
|---|---|
“Which features matter?” |
Feature importance |
“Why this prediction?” |
SHAP values |
“Model debugging” |
Both |
“Regulatory compliance” |
Both |
Feature Importance#
ImportanceType Enum#
pub enum ImportanceType {
/// Number of times each feature is used in splits.
Split, // Default
/// Total gain from splits using each feature.
Gain, // Requires gain storage
/// Average gain per split (Gain / Split count).
AverageGain, // Requires gain storage
/// Total cover (hessian sum) at nodes splitting on each feature.
Cover, // Requires cover storage
/// Average cover per split (Cover / Split count).
AverageCover, // Requires cover storage
}
FeatureImportance Container#
The FeatureImportance struct wraps raw importance values with utility methods:
impl FeatureImportance {
pub fn values(&self) -> &[f32]; // Raw values
pub fn normalized(&self) -> Self; // Sum to 1.0
pub fn sorted_indices(&self) -> Vec<usize>; // Descending order
pub fn top_k(&self, k: usize) -> Vec<(usize, Option<String>, f32)>;
pub fn get_by_name(&self, name: &str) -> Option<f32>;
}
GBDTModel API#
impl GBDTModel {
pub fn feature_importance(
&self,
importance_type: ImportanceType,
) -> Result<FeatureImportance, ExplainError>;
}
Returns ExplainError::MissingNodeStats if gain/cover types are requested
but the model lacks those statistics.
SHAP Values#
Shapley Value Formula#
ShapValues Container#
SHAP output is stored in a 3D array:
pub struct ShapValues(Array3<f32>); // [n_samples, n_features + 1, n_outputs]
impl ShapValues {
pub fn get(&self, sample: usize, feature: usize, output: usize) -> f32;
pub fn base_value(&self, sample: usize, output: usize) -> f32;
pub fn feature_shap(&self, sample_idx: usize, output: usize) -> Vec<f32>;
pub fn verify(&self, predictions: &[f32], tolerance: f32) -> bool;
}
The verify() method checks the sum property: sum(SHAP) + base ≈ prediction.
TreeSHAP#
The TreeExplainer computes exact SHAP values for tree ensembles in
polynomial time using the algorithm from Lundberg et al. (2020).
API#
pub struct TreeExplainer<'a> {
forest: &'a Forest<ScalarLeaf>,
base_value: f64,
block_size: usize,
max_depth: usize,
}
impl<'a> TreeExplainer<'a> {
pub fn new(forest: &'a Forest<ScalarLeaf>) -> Result<Self, ExplainError>;
pub fn base_value(&self) -> f64;
pub fn shap_values(&self, dataset: &Dataset, parallelism: Parallelism) -> ShapValues;
}
Requires cover statistics — returns ExplainError::MissingNodeStats if trees
don’t have covers. Models trained with booste-rs automatically have covers.
Algorithm#
TreeSHAP tracks a path state through the tree:
PathState: Tracks features seen, zero fractions, one fractions
Recursive traversal: At each node, extend path and recurse both branches
At leaves: Compute contributions for all features in path
Unwound sum: Weight contributions by path weights
fn tree_shap(&self, tree: &Tree, sample: ArrayView1, path: &mut PathState, node: u32) {
if tree.is_leaf(node) {
self.compute_contributions(output, path, leaf_value);
return;
}
// Hot path (sample's direction)
path.extend(feature, zero_fraction, one_fraction);
self.tree_shap(tree, sample, path, hot_child);
path.unwind();
// Cold path (other direction)
path.extend(feature, cold_zero_fraction, 0.0);
self.tree_shap(tree, sample, path, cold_child);
path.unwind();
}
TreeSHAP Complexity#
Per sample: O(T × L × D²) where T=trees, L=leaves, D=depth
Parallelization: Over samples via
Parallelism::ParallelMemory: O(D) path state per thread
GBDTModel.shap_values#
impl GBDTModel {
pub fn shap_values(&self, data: &Dataset) -> Result<ShapValues, ExplainError>;
}
Linear SHAP#
For LinearModel, SHAP values have a closed-form solution:
Where \(w_i\) is the feature weight and \(\mu_i\) is the background mean.
LinearExplainer API#
pub struct LinearExplainer<'a> {
model: &'a LinearModel,
feature_means: Vec<f64>,
}
impl<'a> LinearExplainer<'a> {
pub fn new(model: &'a LinearModel, feature_means: Vec<f64>) -> Result<Self, ExplainError>;
pub fn with_zero_means(model: &'a LinearModel) -> Self;
pub fn base_value(&self, output: usize) -> f64;
pub fn shap_values(&self, dataset: &Dataset) -> ShapValues;
}
Base Value#
The expected prediction given the background distribution:
Linear SHAP Complexity#
O(n_features × n_samples) — linear in data size
No additional statistics required
File Structure#
Path |
Contents |
|---|---|
|
Module exports |
|
Feature importance types and computation |
|
SHAP submodule |
|
ShapValues container |
|
PathState for TreeSHAP |
|
TreeSHAP implementation |
|
Linear SHAP implementation |
Design Decisions#
DD-1: Separate from inference. Explainability is optional and adds computation. Keep it in dedicated module, not integrated into prediction.
DD-2: f64 for intermediate, f32 for output. TreeSHAP uses f64 internally for numerical stability, but stores results as f32 for consistency with predictions and memory efficiency.
DD-3: Base value stored per sample. The base value is stored in
ShapValues at index n_features for each sample/output, enabling
the verify() method to check correctness.
DD-4: Tree-path SHAP (not interventional). We implement the faster tree-path SHAP variant. Interventional SHAP is more theoretically sound but O(2^d) expensive.
DD-5: Require covers for TreeSHAP. Rather than estimating covers, we require them to be present. Models trained with booste-rs automatically store covers during training.
DD-6: Block-based sample buffering. TreeExplainer uses the same block buffering approach as prediction (buffer_samples) for cache efficiency when converting feature-major to sample-major layout.
Error Handling#
pub enum ExplainError {
/// Node statistics (gains/covers) are required but not available.
MissingNodeStats(&'static str),
/// Feature statistics (means) are required but not available.
MissingFeatureStats,
/// Empty model (no trees).
EmptyModel,
}
Testing Strategy#
Category |
Tests |
|---|---|
Sum property |
|
Splitting feature |
Non-zero contribution for features that split |
Linear exactness |
Linear SHAP matches closed-form |
Multi-output |
Correct shape for multiclass |
Missing stats |
Proper error when covers unavailable |
Importance types |
All five types compute correctly |
Future Work#
SHAP interaction values: Feature × feature interactions
Approximate SHAP: Sampling-based for faster large-model computation
Examples and API polish: Add end-to-end Python examples and stabilize the public surface (without changing the underlying algorithms)