scikit-survival 0.21.0 released

Today marks the release of scikit-survival 0.21.0. This release features some exciting new features and significant performance improvements:

  • Pointwise confidence intervals for the Kaplan-Meier estimator.
  • Early stopping in GradientBoostingSurvivalAnalysis.
  • Improved performance of fitting SurvivalTree and RandomSurvivalForest.
  • Reduced memory footprint of concordance_index_censored.

Pointwise Confidence Intervals for the Kaplan-Meier Estimator

kaplan_meier_estimator() can now estimate pointwise confidence intervals by specifying the conf_type parameter.

import matplotlib.pyplot as plt
from sksurv.datasets import load_veterans_lung_cancer
from sksurv.nonparametric import kaplan_meier_estimator

_, y = load_veterans_lung_cancer()

time, survival_prob, conf_int = kaplan_meier_estimator(
    y["Status"], y["Survival_in_days"], conf_type="log-log"
)
plt.step(time, survival_prob, where="post")
plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
plt.ylim(0, 1)
plt.ylabel("est. probability of survival $\hat{S}(t)$")
plt.xlabel("time $t$")
Kaplan-Meier curve with pointwise confidence intervals.

Kaplan-Meier curve with pointwise confidence intervals.

Early Stopping in GradientBoostingSurvivalAnalysis

Early stopping enables us to determine when the model is sufficiently complex. This is usually done by continuously evaluating the model on held-out data. For GradientBoostingSurvivalAnalysis, the easiest way to achieve this is by setting n_iter_no_change and optionally validation_fraction (defaults to 0.1).

from sksurv.datasets import load_whas500
from sksurv.ensemble import GradientBoostingSurvivalAnalysis

X, y = load_whas500()

model = GradientBoostingSurvivalAnalysis(
    n_estimators=1000, max_depth=2, subsample=0.8, n_iter_no_change=10, random_state=0,
)

model.fit(X, y)
print(model.n_estimators_)

In this example, model.n_estimators_ indicates that fitting stopped after 73 iterations, instead of the maximum 1000 iterations.

Alternatively, one can provide a custom callback function to the fit method. If the callback returns True, training is stopped.

model = GradientBoostingSurvivalAnalysis(
    n_estimators=1000, max_depth=2, subsample=0.8, random_state=0,
)

def early_stopping_monitor(iteration, model, args):
    """Stop training if there was no improvement in the last 10 iterations"""
    start = max(0, iteration - 10)
    end = iteration + 1
    oob_improvement = model.oob_improvement_[start:end]
    return all(oob_improvement < 0)

model.fit(X, y, monitor=early_stopping_monitor)
print(model.n_estimators_)

In the example above, early stopping is determined by checking the last 10 entries of the oob_improvement_ attribute. It contains the improvement in loss on the out-of-bag samples relative to the previous iteration. This requires setting subsample to a value smaller 1, here 0.8. Using this approach, training stopped after 114 iterations.

Improved Performance of SurvivalTree and RandomSurvivalForest

Another exciting feature of scikit-survival 0.21.0 is due to a re-write of the training routine of SurvivalTree. This results in roughly 3x faster training times.

Runtime comparison of fitting SurvivalTree.

Runtime comparison of fitting SurvivalTree.

The plot above compares the time required to fit a single SurvivalTree on data with 25 features and varying number of samples. The performance difference becomes notable for data with 1000 samples and above.

Note that this improvement also speeds-up fitting RandomSurvivalForest and ExtraSurvivalTrees.

Improved concordance index

Another performance improvement is due to Christine Poerschke who significantly reduced the memory footprint of concordance_index_censored(). With scikit-survival 0.21.0, memory usage scales linear, instead of quadratic, in the number of samples, making performance evaluation on large datasets much more manageable.

For a full list of changes in scikit-survival 0.21.0, please see the release notes.

Install

Pre-built conda packages are available for Linux, macOS (Intel), and Windows, either

via pip:

pip install scikit-survival

or via conda

 conda install -c sebp scikit-survival
Avatar
Sebastian Pölsterl
AI Researcher

My research interests include machine learning for time-to-event analysis, causal inference and biomedical applications.

Related