Note

This tutorial is available as a Jupyter notebook. Download notebook

Tutorial 05: Early Stopping#

🟡 Intermediate — Familiarity with ML concepts helpful

Learn how to use early stopping to prevent overfitting and find the optimal number of trees.

What you’ll learn#

  1. Understand the overfitting problem

  2. Use validation-based early stopping

  3. Visualize training curves

  4. Choose early stopping parameters

[1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from boosters.sklearn import GBDTRegressor

The Overfitting Problem#

Without early stopping, adding more trees can lead to overfitting:

[2]:
# Generate data with some noise
X, y = make_regression(n_samples=500, n_features=10, noise=10.0, random_state=42)

# Split data
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

print(f"Train: {len(X_train)}, Validation: {len(X_val)}, Test: {len(X_test)}")
Train: 300, Validation: 100, Test: 100

Training Without Early Stopping#

[3]:
# Train with many trees
model_no_es = GBDTRegressor(
    n_estimators=500,
    max_depth=6,
    learning_rate=0.1,
)
model_no_es.fit(X_train, y_train)

# Evaluate
train_rmse = np.sqrt(mean_squared_error(y_train, model_no_es.predict(X_train)))
val_rmse = np.sqrt(mean_squared_error(y_val, model_no_es.predict(X_val)))
test_rmse = np.sqrt(mean_squared_error(y_test, model_no_es.predict(X_test)))

print(f"Without early stopping:")
print(f"  Train RMSE: {train_rmse:.4f}")
print(f"  Val RMSE:   {val_rmse:.4f}")
print(f"  Test RMSE:  {test_rmse:.4f}")
Without early stopping:
  Train RMSE: 0.0747
  Val RMSE:   62.7528
  Test RMSE:  64.3740

Early Stopping in Action#

Use early stopping to find the optimal number of iterations.

Note: Early stopping support depends on the API version. Here we demonstrate the concept by training multiple models:

[4]:
# Track errors across iterations
n_estimators_list = [10, 25, 50, 100, 150, 200, 300, 400, 500]
train_errors = []
val_errors = []

for n_est in n_estimators_list:
    model = GBDTRegressor(
        n_estimators=n_est,
        max_depth=6,
        learning_rate=0.1,
    )
    model.fit(X_train, y_train)

    train_rmse = np.sqrt(mean_squared_error(y_train, model.predict(X_train)))
    val_rmse = np.sqrt(mean_squared_error(y_val, model.predict(X_val)))

    train_errors.append(train_rmse)
    val_errors.append(val_rmse)

# Find best iteration
best_idx = np.argmin(val_errors)
best_n_estimators = n_estimators_list[best_idx]
print(f"Best n_estimators: {best_n_estimators} (val RMSE: {val_errors[best_idx]:.4f})")
Best n_estimators: 500 (val RMSE: 62.7528)

Visualize Training Curves#

[5]:
plt.figure(figsize=(10, 6))
plt.plot(n_estimators_list, train_errors, 'b-', label='Train RMSE', marker='o')
plt.plot(n_estimators_list, val_errors, 'r-', label='Validation RMSE', marker='s')
plt.axvline(x=best_n_estimators, color='g', linestyle='--', label=f'Best: {best_n_estimators}')
plt.xlabel('Number of Trees')
plt.ylabel('RMSE')
plt.title('Training vs Validation Error')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
../_images/tutorials_05-early-stopping_9_0.png

Train with Optimal Number of Trees#

[6]:
# Train with optimal n_estimators
model_optimal = GBDTRegressor(
    n_estimators=best_n_estimators,
    max_depth=6,
    learning_rate=0.1,
)
model_optimal.fit(X_train, y_train)

# Evaluate on test
test_rmse_optimal = np.sqrt(mean_squared_error(y_test, model_optimal.predict(X_test)))

print(f"With optimal trees ({best_n_estimators}):")
print(f"  Test RMSE: {test_rmse_optimal:.4f}")
print(f"\nImprovement over 500 trees: {test_rmse - test_rmse_optimal:.4f}")
With optimal trees (500):
  Test RMSE: 64.3740

Improvement over 500 trees: 0.0000

Summary#

In this tutorial, you learned:

  1. ✅ More trees doesn’t always mean better performance

  2. ✅ How to find optimal number of iterations using validation data

  3. ✅ How to visualize train/validation curves

  4. ✅ The importance of a held-out validation set

Next Steps#