We demonstrate the effectiveness of the VTD in experiments using a synthetic dataset, the MIMIC-III dataset, and the NACC dataset. We show that the VTD reduces confounding bias in ITE estimation from the empirical observation from both experiments. We compared VTD with the following causal inference approaches:
We report the Rooted Mean Square Error (RMSE) between predicted and ground truth outcomes to measure models’ performance on conventional prediction tasks. To evaluate ITE estimation, the most common measurement is the Precision in Estimation of Heterogenous Effect (PEHE) [33], defined as the mean squared error between the ground truth and estimated ITE, i.e.,
However, in real-world datasets, the counterfactual is never observed; thus, we use the influence function - PEHE (IF-PEHE) that approximates the true PEHE by “derivatives” of the PEHE function [34].
3.1 Datasets
3.3.1 The synthetic data
In the problem formulation section above, we introduced that the treatment assignments \({a}_{t}^{\left(i\right)}\)at each time step \(t\) are determined by confounders \({q}_{t}^{\left(i\right)}\), which also include previous hidden confounders \({z}_{t-1}^{\left(i\right)}\), current time-varying covariates \({x}_{t}^{\left(i\right)}\)and static features \({c}^{\left(i\right)}\). The \({x}_{t}^{\left(i\right)}\)and \({z}_{t}^{\left(i\right)}\)are generated for each patient at a given time \(t\) through an autoregressive process, and these generation processes take into account historical information as well as the influence of previous treatment assignments, so we define the following equations to generate covariates \(x\) and hidden confounders \(z\),
$$\begin{array}{rr}{x}_{t,j}^{\left(i\right)}& =\frac{1}{p}\sum _{r=1}^{p} \left({\alpha }_{r,j}{x}_{t-r,j}^{\left(i\right)}+{\beta }_{r}{a}_{t-r}^{\left(i\right)}\right)+{\eta }_{t}\\ {z}_{t,j}^{\left(i\right)}& =\frac{1}{p}\sum _{r=1}^{p} \left({\mu }_{r,j}{z}_{t-r,j}^{\left(i\right)}+{v}_{r}{a}_{t-r}^{\left(i\right)}\right)+{ϵ}_{t} \end{array}$$
16
where \({x}_{t,j}^{\left(i\right)}\) and \({z}_{t,j}^{\left(i\right)}\) denote the \(j\)-th column of \({x}_{t}^{\left(i\right)}\) and \({z}_{t}^{\left(i\right)}\), respectively. For each \(j,\text{w}\text{e} \text{u}\text{s}\text{e} {\alpha }_{r,j},{\mu }_{r,j}\sim \mathcal{N}\left(1-(r/p),(1/p{)}^{2}\right)\) to control the amount of historical information of last \(p\) time stamps incorporated to the current representations; \({\beta }_{r},{v}_{r}\sim\) \(\mathcal{N}\left(0,{0.02}^{2}\right)\) controls the influence of previous treatment assignments; \({\eta }_{t},{ϵ}_{t}\sim \mathcal{N}\left(0,{0.01}^{2}\right)\) are randomly sampled noises. The treatment assignments are generated by creating 1,000 treated samples and 3,000 control samples, with treatments starting at a randomly chosen point for treated samples and all treatments set to 0 for control samples. The confounders \({q}_{t}^{\left(i\right)}\)and outcome \({y}_{T+\tau }^{\left(i\right)}\) at each time stamp \(t\) are generated using the hidden confounders and current covariates as follows,
$$\begin{array}{rr}& {q}_{t}^{\left(i\right)}=\gamma \frac{1}{t}\sum _{r=1}^{t} {z}_{r}^{\left(i\right)}+\left(1-\gamma \right)g\left(\left[{x}_{t}^{\left(i\right)},{c}^{\left(i\right)}\right]\right)\\ & {y}_{T+\tau }^{\left(i\right)}={w}^{\text{\top }}{q}_{T}^{\left(i\right)}+b\end{array}$$
17
where the influence of hidden confounders being controlled by a confounding factor \(\gamma\), and \(w\sim \mathcal{U}(-\text{1,1})\) and \(b\sim \mathcal{N}\left(\text{0,0.1}\right)\) are weights and biases of a linear model. The function \(g(\cdot )\) maps the concatenated feature vectors \(\left[{x}_{t}^{\left(i\right)},{c}^{\left(i\right)}\right]\) into the hidden space. For this study, we used confounding factor \(\gamma\)=0.1, 100 covariates, and 10 time steps when generating the samples.
3.1.2 The MIMIC-III dataset.
Following the similar setting of Bica et al [23], we constructed a dataset based the Medical Information Mart for Intensive Care version III (MIMIC-III) [35]. The MIMIC-III dataset contains more than 61,000 ICU admissions from 2001 to 2012 with recorded patients' demographics and temporal information, including vital signs, lab tests, and treatment decisions. We extracted 11,715 adult sepsis patients fulfilling the sepsis3 criteria [36] as our studied cohort from MIMIC-III.
Here, we obtain 27 time-varying variables (i.e., vital signs: temperature, heart rate, systolic, mean blood pressure (MBP), diastolic blood pressure, respiratory rate, oxygen saturation (SpO2); lab tests: sodium, chloride, magnesium, glucose, blood urea nitrogen, creatinine, urineoutput, glasgow coma scale, white blood cells count, bands, C-Reactive protein, hemoglobin, hematocrit, aniongap, platelets count, partial thromboplastin time, prothrombin time, international normalized ratio, bicarbonate, lactate) and 8 static demographics (i.e., age, gender, race, metastatic cancer, diabetes, height, weight, body mass index) variables. We design two causal inference tasks considering two available treatment assignments: vasopressors and mechanical ventilator (MV). For each treatment option, we separately evaluate its causal effect on the important outcomes of interest. For vasopressors, we adopted MBP as the target outcome; and for mechanical ventilator, we adopted the SpO2 as the outcome. We consider the rest of the variables as the observed covariates.
3.1.3 The NACC dataset
Follow a similar process, we construct the longitudinal data from the National Alzheimer’s Coordinating Center (NACC) Uniform Data Set (UDS) [37]. The NACC-UDS is a database that collects demographic, clinical, diagnostic, and neuropsychological data from 29 Alzheimer's Disease Centers (ADCs) from recruited participants with normal cognition, mild cognitive impairment (MCI), and dementia at baseline and being followed annually, since 2005. We collected data from the NACC-UDS between June 2005 and June 2021 to formulate 2 separate datasets with patients of different baseline conditions, i.e., (1) baseline-1: patients who were diagnosed with MCI and age above 50; and (2) baseline-2: patients with normal cognition and age above 65. We extracted 2,401 and 5,555 patients for baseline-1 and baseline-2 respectively with over 268 variables, and the detailed variables’ information can be found in the Appendix A section. We considered three treatments assignments, i.e., statin, anti-hypertensive, and non-steroidal anti-inflammatory drugs (NSAID) and aim to estimate their effects on reducing the risk of Alzheimer's disease (AD).
3.1.4 Results
Table 1 demonstrates the superior performance of our VTD model in terms of both RMSE and IF-PEHE on the synthetic data. This highlights the ability of the VTD model's variational embedding that can effectively capture the information of hidden confounders within a temporal structure, resulting in a more accurate estimation of ITE. Furthermore, the deep representation-based models exhibit a significant improvement over the baseline G-formula, attributed to their capability to handle complex and high-dimensional data through the utilization of neural networks as the underlying architecture.
Table 1
Performance comparison on the synthetic dataset in terms of RMSE and IF-PEHE.
Model
|
RMSE
|
IF-PEHE
|
G-formula
|
5.46 ± 0.11
|
30.42 ± 4.64
|
DSW
|
2.63 ± 0.05
|
10.28 ± 1.06
|
TSD
|
3.06 ± 0.14
|
23.65 ± 2.23
|
VTD (Ours)
|
2.07 ± 0.12
|
8.31 ± 1.59
|
Table 2 presents the evaluation of VTD's effectiveness in deconfounding by assessing its performance with different strengths (i.e., adjusting γ) of hidden confounders Z. The setting is similar to the previous experiment on synthetic data, and we report the RMSE on the outcome prediction as the performance metric. The results indicate that the proposed VTD outperforms the other baselines and the performance of the VTD increases when the confounding factor\(\gamma\) increases. It should be noted that both baselines and VTD are evaluated on the same data, thus the performance gain is due to VTD's more effective modeling of hidden confounders. The results demonstrate that conditioning on the hidden embedding learned by VTD results in more robust outcome predictions and reduces the bias in ITE estimation.
Table 2
Performance comparison (i.e., RMSE) of models with different confounding factor γ on the synthetic data.
Model
|
\({\gamma }=0\)
|
\({\gamma }=0.2\)
|
\({\gamma }=0.4\)
|
\({\gamma }=0.6\)
|
G-formula
|
3.86 ± 0.12
|
7.41 ± 1.53
|
13.72 ± 3.29
|
16.43 ± 4.24
|
DSW
|
1.95 ± 0.05
|
3.28 ± 0.14
|
7.23 ± 0.35
|
9.17 ± 0.61
|
TSD
|
2.89 ± 0.10
|
5.51 ± 0.21
|
9.79 ± 0.78
|
11.65 ± 1.54
|
VTD (Ours)
|
1.78 ± 0.08
|
2.96 ± 0.16
|
4.16 ± 0.21
|
7.62 ± 0.69
|
We evaluate the performance of the VTD on the benchmark MIMIC-III dataset which a real-world dataset. So we don’t have the knowledge of the true hidden confounders in this dataset. Table 3 demonstrates that the VTD model outperforms both the TSD and the G-formula on all measures and provides better outcome predictions in the " vasopressor-MBP " setting, with similar performance in the " MV-SpO2" setting compared to DSW on the MIMIC-III dataset. This indicates that the VTD, with its time-aware Transformer backbone, can benefit from learning the patterns of irregular elapsed time between consecutive events.
Table 3
Performance comparison on the MIMIC-III dataset.
Model
|
Vasopressor on Mean Blood Pressure
(Vasopressor-MBP)
|
Mechanical Ventilator on SpO2
(MV-SpO2)
|
|
RMSE
|
IF-PEHE
|
RMSE
|
IF-PEHE
|
G-formula
|
12.53 ± 0.27
|
63.35 ± 5.43
|
1.57 ± 0.14
|
53.28 ± 5.21
|
DSW
|
8.55 ± 0.22
|
12.01 ± 2.33
|
1.06 ± 0.11
|
8.68 ± 1.54
|
TSD
|
9.34 ± 0.10
|
57.26 ± 4.71
|
1.23 ± 0.07
|
35.21 ± 4.85
|
VTD (ours)
|
8.36 ± 0.14
|
20.16 ± 2.10
|
1.12 ± 0.13
|
17.25 ± 1.78
|
Table 4 and Table 5 show the performance of four models on baseline-1 and baseline-2, respectively. We see VTD gains more edges on both settings for outcome prediction power. While we have did not observe better performance of IF-PEHE for VTD.
Table 4
Model performance comparison on NACC dataset for baseline-1* setting.
Model
|
AD-Statin
|
AD-Anti-hypertensive
|
AD-NSAID
|
|
RMSE
|
IF-PEHE
|
RMSE
|
IF-PEHE
|
RMSE
|
IF-PEHE
|
G-formula
|
22.16 ± 1.24
|
102.11 ± 12.21
|
21.46 ± 0.98
|
95.87 ± 10.67
|
25.23 ± 1.27
|
96.89 ± 11.54
|
DSW
|
9.43 ± 0.16
|
37.45 ± 2.12
|
11.37 ± 0.11
|
40.28 ± 3.27
|
9.25 ± 0.13
|
39.42 ± 3.54
|
TSD
|
13.72 ± 0.18
|
73.43 ± 3.21
|
13.27 ± 0.17
|
85.27 ± 3.69
|
15.26 ± 0.15
|
65.43 ± 4.78
|
VTD (ours)
|
9.79 ± 0.15
|
42.43 ± 3.25
|
9.42 ± 0.12
|
58.86 ± 4.18
|
7.43 ± 0.11
|
46.59 ± 5.35
|
*baseline-1: patients who were diagnosed with mild cognitive impairment (MCI) and age above 50.
AD: Alzhemeri’s disease (AD)
NSAID: non-steroidal anti-inflammatory drug
|
Table 5
Model performance comparison on NACC dataset for baseline-2* setting.
Model
|
AD-Statin
|
AD-Anti-hypertensive
|
AD-NSAID
|
|
RMSE
|
IF-PEHE
|
RMSE
|
IF-PEHE
|
RMSE
|
IF-PEHE
|
G-formula
|
20.16 ± 1.19
|
84.28 ± 15.28
|
27.16 ± 1.35
|
89.29 ± 11.37
|
19.67 ± 1.14
|
99.17 ± 15.93
|
DSW
|
7.28 ± 0.31
|
43.57 ± 2.15
|
10.46 ± 0.19
|
38.27 ± 2.57
|
7.81 ± 0.16
|
25.48 ± 2.18
|
TSD
|
9.79 ± 0.15
|
64.35 ± 3.67
|
12.67 ± 0.21
|
74.28 ± 2.21
|
9.26 ± 0.13
|
59.37 ± 3.51
|
VTD (ours)
|
7.38 ± 0.24
|
37.25 ± 3.16
|
10.39 ± 0.27
|
40.74 ± 3.29
|
6.98 ± 0.12
|
32.52 ± 3.24
|
*baseline-2: patients with normal cognition and age above 65.
AD: Alzhemeri’s disease (AD)
NSAID: non-steroidal anti-inflammatory drug
|