[Paper] Self-Supervised Transformer for Sparse and Irregularly Sampled Multivariate Clinical Time-Series (ACM 2022)
15 Sep 2023 #bio #ehr #transformer
Tipirneni, Sindhu, and Chandan K. Reddy. βSelf-supervised transformer for sparse and irregularly sampled multivariate clinical time-series.β ACM Transactions on Knowledge Discovery from Data (TKDD) 16.6 (2022): 1-17.
Points
Self-supervised Transformer for Time-Series (STraTS) model
-
Using observation triplets as time-series components: avoids the problems faced by aggregation and imputation methods for sparse and sporadic multivariate time-series
-
Continuous Value Embedding: encodes continuous time and variable values without the need for discretization
-
Transformer-based model: learns contextual triplet embeddings
-
Time series forecasting as a proxy task: leverages unlabeled data to learn better generalized representations
Background
Problems
- Multivariate time-series data are frequently observed in critical care settings and are typically characterized by sparsity (missing information) and irregular time intervals.
- Existing approaches, such as aggregation or imputation of values, suppress the fine-grained information and add undesirable noise/overhead into the model.
- The problem of limited availability of labeled data is easily observed in healthcare applications.
The clinical domain portrays a unique set of challenges:
- Missingness and Sparsity: Not all the variables are observed for every patient. Also, the time-series matrices are very sparse.
- Irregular time intervals and Sporadicity: Not all clinical variables are measured at regular time intervals. The measurements may occur sporadically in time depending.
- Limited labeled data: expensive and even more limited for specific tasks.
Existing methods
- Aggregation: could suppress important fine-grained information
- Imputation/Interpolation: not reasonable as not considering the domain knowledge about each variable
Method
Self-supervised Transformer for Time-Series (STraTS)
Embeddings
Triplet Emgeddings = Feature embedding + Value embedding + Time embedding \(T=\{(t_i, j_i, u_i)\}^n_{i=1}\\ e_i=e_i^f+e_i^v+e_i^t\) Continuous Value Embedding (CVE)
For continuous values of feature values and times
A one-to-many Feed-Forward Network \(FFN(x) = U tanh(Wx+b)\)
-
Feature embeddings $e_i^f$: obtained from a simpole lookup table
-
Value embeddings $e_i^v$ and Time embeddings$e_i^t$: through CVE
Demographics Embedding
the prediction models performed better when demographics were processed separately. \(e^π = π‘ππβ(W^π_2π‘ππβ(W^π_1d + b^π_1) + b^π_2) β R^d\) where the hidden layer has a dimension of 2d
Self-Supervision
Pre-training Tasks: Both masking and forecasting as pretext tasks for providing self-supervision
The forecasitng improved the results on target tasks
The loss is: \(L_{ss}=\frac{1}{|N'|}\sum_{k=1}^{N'}\sum_{j=1}^{|F|}m_j^k\Big(\tilde{z}_j^k-z_j^k\Big)^2\)
Interpretability
I-STraTS: an interpretable version of STraTS
- The output can be expressed using a linear combination of components that are derived from individual features
Differences with STraTS
- Combine the initial triplet embeddings in Fusion Self-attention module
- Directly use the raw demographics vector as the demographics embedding
Experiments
Target Task: Prediction of in-hospital mortality
Datasets: 2 EHR datasets; MIMIC-III and PhysioNet Challenge 2012
- MIMIC-III: 46,000 patients
- PhysioNet-2012: 11,988 patients
Baselines: Gated Recurrent Unit (GRU), Temporal Convolutional Network (TCN), Simply Attend and DIagnose (SaND), GRU with trainable Decays (GRU-D), Interpolation-prediction Network (InterpNet), Set Functions for Time Series (SeFT)
- Used 2 dense layers for demographics encoding
- Concatenated it to the time-series representation before the last dense layer
Metrics
- ROC-AUC: Area under ROC curve
- PR-AUC: Area under precision-recall curve
- min(Re, Pr): the max of βmin of recall and precisionβ across all thresholds
Prediction Performance
- Trained each model using 10 different random samplings of 50% labeled data from the train and validation sets
- STraTS uses the entire labeled data and additional unlabeled data if avaliable
- STraTS achieves the best performance
- GRU showed better performance than interpolation-based models (GRU-D, InterpNet) on the MIMIC-III dataset, which was not expected
Generalizability test of models
Lower propotions of labeled data can be observed in real-world when there are several right-censord samples.
- STraTS has an advantage compared to others in scarce labeled data settings, which can be attributed to self-supervision
Ablation Study
Compared STraTS and I-STraTS with and without self-supervision: βss+β and βss-β indicate each case
- I-STraTS showed slightly worse performance as constrained its representations
- Adding self-supervision improves performance of both models
- I-STraTS(ss+) outperforms STraTS(ss-): self-supervision can compensate the performance which could get lower by introducing interpretability
Interpretability
How I-STraTS explains its predictions
A case study: a 85 yrs old female patient from MIMIC-III
- expired on the 6th day after ICU admission
- had 380 measurements corresponding to 58 time-series variables
The model predicts the probability of her in-hospital mortality as 0.94 using only the data collected the first day
- Average contribution score: the average score along with the range, for multiple observations, or value, for only one observation
- The top 5 variables are the most important factors in predicting she βs at high risk of mortality that the model observed
&rarr Can be helpful to identify high-risk patients and also understand the contributing factors and make better diagnoses, especially at the early stages of treatment