Notes on Causal Survival Forest 🛟🌲

In this post, I provide summary notes on the paper “Estimating Heterogeneous Treatment Effects with Right-Censored Data via Causal Survival Forests” by Cui et al. (2023).

Motivation

How to estimate heterogeneous treatment effects with right-censored data?
  1. Heterogeneous treatment effect (HTE) estimation plays a central role in data-driven personalization

  2. Existing methods often can’t handle censored survival outcomes, common in medical/business applications

Causal Survival Forests (CSF)

To address this challenge, the paper proposes causal survival forests (CSF)

1233
  • An adaptation of the causal forest algorithm of Athey et al. (2019)

  • It adjusts for censoring using doubly robust estimating equations developed in the survival analysis literature

Advantages

  • Robust, computationally tractable, and outperforms available baselines in our experiments

  • Good statistical properties – UCAN

Statistical Setting

Assume i.i.d tuples $\{X_i, T_i, C_i, W_i\}$, where

  • $X_i \in \mathcal{X}$ denote covariates
  • $T_i \in \mathbb{R}_{+}$ is the survival time for $i$th unit
  • $C_i \in \mathbb{R}_{+}$ is censoring time (the time at which $i$th unit gets censored)
  • $W_i \in \{0,1\}$ denotes treatment assignment

Using potential outcome framework, posit potential outcomes $\{T_i(1), T_i(0)\}$ s.t. $T_i = T_i(W_i)$, we need to estimate the conditional average treatment effect (CATE)

$$ \tau(x)=\mathbb{E}\left[y(T_i(1)) - y(T_i(0)) \mid X_i=x\right], $$

where $y()$ is the outcome transformation, e.g.

  • $y(T) = T$

  • $y(T) = T \wedge h$ for the restricted mean survival time (RMST); here $h$ is some chosen maximum considered time

  • $y(T) = 1\{T \ge h\}$ for the survival probability

To estimate $\tau(x)$, the main challenge is that $T_i$ is not always observable. We only observe:

  • censored survival time: $U_i = T_i \wedge C_i$

  • non-censoring indictor: $\Delta_i = 1\{T_i \le C_i\}$

Based on Assumption 1 (see later), we define the effective non-censoring indictor as follows:

\begin{align} \Delta_i^h &= 1\{T_i \wedge h \le C_i\} \\ &\overset{(2)}{=} \Delta_i \vee 1\{U_i \ge h\} \end{align}

Note that, for the eqn (2), everything is observed. We can regard an observation with $\Delta_i^h = 1$ as a complete observation.

Assumptions

In order to identify treatment effects, we need to rely on two sets of assumptions.

  • Assumption 2-4 enable us to identify the causal effect of $W_i$ on $T_i$ without censoring

  • Assumption 5-6 is to guarantee that censoring due to $C_i$ does not break identification results

Assumption 1 (Finite Horizon)

$$ y(t) = y(h), \quad \forall \ t \ge h, \ 0<h<\infty $$

Assumption 2 (Potential Outcomes)

$$ \{T_i(1), T_i(0)\} \quad s.t. \quad T_i = T_i(W_i) \quad a.s. $$

Assumption 3 (Ignorability)

$$ \{T_i(1), T_i(0)\} \perp W_i \mid X_i $$

Assumption 4 (Overlap)

Propensity score $e(x) = \mathbb{P}(W_i = 1\mid X_i = x)$ is uniformly bounded away from 0 and 1,

$$ \eta_e \le e(x)\le 1- \eta_e, \quad 0< \eta_e \le \frac{1}{2} $$

Assumption 5 (Ignorable censoring)

Censoring is independent of survival time conditionally on treatment and covariates,

$$ T_i \perp C_i \mid W_i, X_i $$

Assumption 6 (Positivity)

$$ \mathbb{P}(C_i < h | W_i, X_i) \le 1- \eta_c, \quad 0<\eta_c\le1 $$

Causal Forests Without Censoring

How does causal forest work?

Essentially, we are running a “forest”-localized version of Robinson’s regression

$$ \tau(x):=\operatorname{lm}\left(Y_i-\hat{\mu}^{(-i)}\left(X_i\right) \sim W_i-\hat{e}^{(-i)}\left(X_i\right), \text { weights }=\textcolor{blue}{\alpha_i(x)}\right), $$

where $\textcolor{blue}{\alpha_i(x)}$ capture how “similar” a target sample $x$ is to each of the training samples $X_i$

Using notations in previous section, we estimate $\tau(x)$ by solving the following equation,

$$ \sum_{i=1}^n \alpha_i(x) \psi_{\hat{\tau}(x)}^{(c)}\left(X_i, y\left(T_i\right), W_i ; \hat{e}, \hat{m}\right)=0 \tag{cf} $$

where,

  • $$\psi_\tau^{(c)}\left(X_i, y\left(T_i\right), W_i ; \hat{e}, \hat{m}\right)=\left[W_i-\hat{e}\left(X_i\right)\right]\left[y\left(T_i\right)-\hat{m}\left(X_i\right)-\tau\left(W_i-\hat{e}\left(X_i\right)\right)\right]$$ is the orthogonal score function

  • $e(x) = \mathbb{P}(W_i = 1\mid X_i = x)$

  • $m(x) = \mathbb{E}(y(T_i) \mid X_i)$

  • $\hat{e}(X_i)$ and $\hat{m}(X_i)$ are estimates derived via cross-fitting

Adjusting for Censoring via Weighting

In the presence of censoring, the $T_i$ in equation (cf) is no longer observable.

Simply ignoring censoring and building models on with complete observations (i.e. $\Delta_i^h = 1$) would lead to bias.

Simple Censoring Adjustment via IPCW

Define the conditional survival function as $$ S_w^C(s \mid x)=\mathbb{P}\left[C_i \geq s \mid W_i=w, X_i=x\right] $$ We have, $$ \mathbb{P}\left[\Delta_i^h=1 \mid X_i, W_i, T_i\right]=S_{W_i}^C\left(T_i \wedge h \mid X_i\right) $$

  • the LHS is the conditional probability of observing a complete observations (i.e. $\Delta_i^h = 1$)

  • the RHS is the conditional probability that censoring time is greater than survival time

  • Does the above $\mathbb{P}(\Delta_i^h = 1 \mid \cdots)$ look like propensity score function?

The main idea of IPCW estimation is to only consider complete cases, but up-weight all complete observations by $1/S_{W_i}^C\left(T_i \wedge h \mid X_i\right)$ to compensate for censoring.

As a result, IPCW estimators succeed in eliminating censoring bias.

With IPCW, we estimate $\tau(x)$ by solving the following equation,

$$ \sum_{\left\{i: \Delta_i^h=1\right\}} \frac{\alpha_i(x)}{\hat{S}_{W_i}^C\left(T_i \wedge h \mid X_i\right)} \psi_{\hat{\tau}(x)}^{(c)}\left(X_i, y\left(T_i\right), W_i ; \hat{e}, \hat{m}\right)=0 \tag{IPCW} $$

Let’s compare the equation (cf) v.s (IPCW),

  • For eqn (cf), we sum over all observations; for eqn (IPCW), we only sum over complete observations
  • For eqn (IPCW), we add $\frac{1}{1/S_{W_i}^C\left(T_i \wedge h \mid X_i\right)}$ as a part of weight

A Doubly Robust Correction

Two limitations of IPCW approach:

  1. Only use complete observations; throw away all observations with $\Delta_i^h = 0$, and this may hurt efficiency
  2. IPCW-type methods are generally not robust to estimation errors; Neyman orthogonality condition does not hold (Chernozhukov et al. 2018)

CSF Method

CSF method does not rely on IPCW. Instead, it relies on a more robust approach to making estimating equations robust to censoring.

Recall the simplest case (without censoring), we have, $$ \sum_{i=1}^n \alpha_i(x) \psi_{\hat{\tau}(x)}^{(c)}\left(X_i, y\left(T_i\right), W_i ; \hat{e}, \hat{m}\right)=0 \tag{cf} $$ Now, we estimate the $\tau(x)$ by solving the following equation,

$$ \sum_{i=1}^n \alpha_i(x) \psi_{\hat{\tau}(x)}\left(X_i, y\left(U_i\right), U_i \wedge h, W_i, \Delta_i^h ; \hat{e}, \hat{m}, \hat{\lambda}_w^C, \hat{S}_w^C, \hat{Q}_w\right)=0 $$ , where the score function is,

$$ \begin{aligned} \psi_\tau( & \left.X_i, y\left(U_i\right), U_i \wedge h, W_i, \Delta_i^h ; \hat{e}, \hat{m}, \hat{\lambda}_w^C, \hat{S}_w^C, \hat{Q}_w\right) \\ = & \left(\frac{\hat{Q}_{W_i}\left(U_i \wedge h \mid X_i\right)+\Delta_i^h\left[y\left(U_i\right)-\hat{Q}_{W_i}\left(U_i \wedge h \mid X_i\right)\right]-\hat{m}\left(X_i\right)-\tau\left(W_i-\hat{e}\left(X_i\right)\right)}{\hat{S}_{W_i}^C\left(U_i \wedge h \mid X_i\right)}\right. \\ & \left.-\int_0^{U_i \wedge h} \frac{\hat{\lambda}_{W_i}^C\left(s \mid X_i\right)}{\hat{S}_{W_i}^C\left(s \mid X_i\right)}\left[\hat{Q}_{W_i}\left(s \mid X_i\right)-\hat{m}\left(X_i\right)-\tau\left(W_i-\hat{e}\left(X_i\right)\right)\right] d s\right)\left(W_i-\hat{e}\left(X_i\right)\right) \end{aligned} $$

the conditional expectation of the transformed survival time,

$$ Q_w(s \mid x)=\mathbb{E}\left[y\left(T_i\right) \mid X_i=x, W_i=w, T_i \wedge h>s\right] $$ and the conditional hazard function, $$ \lambda_w^{\mathrm{C}}(s \mid x)=-\frac{d}{d s} \log S_w^{\mathrm{C}}(s \mid x) $$ $\hat{Q}_w(s \mid x), \hat{S}_w^C(s \mid x)$ and $\hat{\lambda}_w^C(s \mid x)$ are cross-fit nuisance parameter estimates.

References

Athey, Susan, Julie Tibshirani, and Stefan Wager. 2019. “Generalized Random Forests.” The Annals of Statistics 47 (2): 1148–78. https://doi.org/10.1214/18-AOS1709.

Chernozhukov, Victor, Denis Chetverikov, Mert Demirer, Esther Duflo, Christian Hansen, Whitney Newey, and James Robins. 2018. “Double/Debiased Machine Learning for Treatment and Structural Parameters.” The Econometrics Journal 21 (1): C1–68. https://doi.org/10.1111/ectj.12097.

Cui, Yifan, Michael R Kosorok, Erik Sverdrup, Stefan Wager, and Ruoqing Zhu. 2023. “Estimating Heterogeneous Treatment Effects with Right-Censored Data via Causal Survival Forests.” Journal of the Royal Statistical Society Series B: Statistical Methodology 85 (2): 179–211. https://doi.org/10.1093/jrsssb/qkac001.

Chen Xing
Chen Xing
Founder & Data Scientist

Enjoy Life & Enjoy Work!

Related