Coffee Chat Brewing AI Knowledge

eng kor

[Paper] Self-Supervised Transformer for Sparse and Irregularly Sampled Multivariate Clinical Time-Series (ACM 2022)

Tipirneni, Sindhu, and Chandan K. Reddy. β€œSelf-supervised transformer for sparse and irregularly sampled multivariate clinical time-series.” ACM Transactions on Knowledge Discovery from Data (TKDD) 16.6 (2022): 1-17.

Paper Link

Points

Self-supervised Transformer for Time-Series (STraTS) model

  • Using observation triplets as time-series components: avoids the problems faced by aggregation and imputation methods for sparse and sporadic multivariate time-series

  • Continuous Value Embedding: encodes continuous time and variable values without the need for discretization

  • Transformer-based model: learns contextual triplet embeddings

  • Time series forecasting as a proxy task: leverages unlabeled data to learn better generalized representations

Background

Problems

  • Multivariate time-series data are frequently observed in critical care settings and are typically characterized by sparsity (missing information) and irregular time intervals.
  • Existing approaches, such as aggregation or imputation of values, suppress the fine-grained information and add undesirable noise/overhead into the model.
  • The problem of limited availability of labeled data is easily observed in healthcare applications.

The clinical domain portrays a unique set of challenges:

  • Missingness and Sparsity: Not all the variables are observed for every patient. Also, the time-series matrices are very sparse.
  • Irregular time intervals and Sporadicity: Not all clinical variables are measured at regular time intervals. The measurements may occur sporadically in time depending.
  • Limited labeled data: expensive and even more limited for specific tasks.

Existing methods

  • Aggregation: could suppress important fine-grained information
  • Imputation/Interpolation: not reasonable as not considering the domain knowledge about each variable

Method

Self-supervised Transformer for Time-Series (STraTS)

Embeddings

Triplet Emgeddings = Feature embedding + Value embedding + Time embedding \(T=\{(t_i, j_i, u_i)\}^n_{i=1}\\ e_i=e_i^f+e_i^v+e_i^t\) Continuous Value Embedding (CVE)

For continuous values of feature values and times

A one-to-many Feed-Forward Network \(FFN(x) = U tanh(Wx+b)\)

  • Feature embeddings $e_i^f$: obtained from a simpole lookup table

  • Value embeddings $e_i^v$ and Time embeddings$e_i^t$: through CVE

Demographics Embedding

the prediction models performed better when demographics were processed separately. \(e^𝑑 = π‘‘π‘Žπ‘›β„Ž(W^𝑑_2π‘‘π‘Žπ‘›β„Ž(W^𝑑_1d + b^𝑑_1) + b^𝑑_2) ∈ R^d\) where the hidden layer has a dimension of 2d

Self-Supervision

Pre-training Tasks: Both masking and forecasting as pretext tasks for providing self-supervision

The forecasitng improved the results on target tasks

The loss is: \(L_{ss}=\frac{1}{|N'|}\sum_{k=1}^{N'}\sum_{j=1}^{|F|}m_j^k\Big(\tilde{z}_j^k-z_j^k\Big)^2\)

Interpretability

I-STraTS: an interpretable version of STraTS

  • The output can be expressed using a linear combination of components that are derived from individual features

Differences with STraTS

  • Combine the initial triplet embeddings in Fusion Self-attention module
  • Directly use the raw demographics vector as the demographics embedding
\[\tilde{y}=sigmoid\Big(\sum_{j=1}^{D}{\bold{w}_0[j]d[j]+\sum_{i=1}^{n}\sum_{j=1}^{d}\alpha_i\bold{w}_o[j+D]\bold{e}_i[j]+b_o}\Big)\]

Experiments

Target Task: Prediction of in-hospital mortality

Datasets: 2 EHR datasets; MIMIC-III and PhysioNet Challenge 2012

  • MIMIC-III: 46,000 patients
  • PhysioNet-2012: 11,988 patients

Baselines: Gated Recurrent Unit (GRU), Temporal Convolutional Network (TCN), Simply Attend and DIagnose (SaND), GRU with trainable Decays (GRU-D), Interpolation-prediction Network (InterpNet), Set Functions for Time Series (SeFT)

  • Used 2 dense layers for demographics encoding
  • Concatenated it to the time-series representation before the last dense layer

Metrics

  • ROC-AUC: Area under ROC curve
  • PR-AUC: Area under precision-recall curve
  • min(Re, Pr): the max of β€˜min of recall and precision’ across all thresholds

Prediction Performance

  • Trained each model using 10 different random samplings of 50% labeled data from the train and validation sets
  • STraTS uses the entire labeled data and additional unlabeled data if avaliable
  • STraTS achieves the best performance
  • GRU showed better performance than interpolation-based models (GRU-D, InterpNet) on the MIMIC-III dataset, which was not expected

Generalizability test of models

Lower propotions of labeled data can be observed in real-world when there are several right-censord samples.

  • STraTS has an advantage compared to others in scarce labeled data settings, which can be attributed to self-supervision

Ablation Study

Compared STraTS and I-STraTS with and without self-supervision: β€˜ss+’ and β€˜ss-β€˜ indicate each case

  • I-STraTS showed slightly worse performance as constrained its representations
  • Adding self-supervision improves performance of both models
  • I-STraTS(ss+) outperforms STraTS(ss-): self-supervision can compensate the performance which could get lower by introducing interpretability

Interpretability

How I-STraTS explains its predictions

A case study: a 85 yrs old female patient from MIMIC-III

  • expired on the 6th day after ICU admission
  • had 380 measurements corresponding to 58 time-series variables

The model predicts the probability of her in-hospital mortality as 0.94 using only the data collected the first day

  • Average contribution score: the average score along with the range, for multiple observations, or value, for only one observation
  • The top 5 variables are the most important factors in predicting she β€˜s at high risk of mortality that the model observed

&rarr Can be helpful to identify high-risk patients and also understand the contributing factors and make better diagnoses, especially at the early stages of treatment