Rethinking modeling Alzheimer's disease progression from a multi-task learning perspective with deep recurrent neural network

Alzheimer's disease (AD) is a severe neurodegenerative disorder that usually starts slowly and progressively worsens. Predicting the progression of Alzheimer's disease with longitudinal analysis on the time series data has recently received increasing attention. However, training an accurate progression model for brain network faces two major challenges: missing features, and the small sample size during the follow-up study. According to our analysis on the AD progression task, we thoroughly analyze the correlation among the multiple predictive tasks of AD progression at multiple time points. Thus, we propose a multi-task learning framework that can adaptively impute missing values and predict future progression over time from a subject's historical measurements. Progression is measured in terms of MRI volumetric measurements, trajectories of a cognitive score and clinical status. To this end, we propose a new perspective of predicting the AD progression with a multi-task learning paradigm. In our multi-task learning paradigm, we hypothesize that the inherent correlations exist among: (i). the prediction tasks of clinical diagnosis, cognition and ventricular volume at each time point; (ii). the tasks of imputation and prediction; and (iii). the prediction tasks at multiple future time points. According to our findings of the task correlation, we develop an end-to-end deep multi-task learning method to jointly improve the performance of assigning missing value and prediction. We design a balanced multi-task dynamic weight optimization. With in-depth analysis and empirical evidence on Alzheimer's Disease Neuroimaging Initiative (ADNI), we show the benefits and flexibility of the proposed multi-task learning model, especially for the prediction at the M60 time point. The proposed approach achieves 5.6%, 5.7%, 4.0% and 11.8% improvement with respect to mAUC, BCA and MAE (ADAS-Cog13 and Ventricles), respectively.

[1]  M. Jorge Cardoso,et al.  Training recurrent neural networks robust to incomplete data: Application to Alzheimer’s disease progression modeling , 2019, Medical Image Anal..

[2]  Andrew J. Davison,et al.  End-To-End Multi-Task Learning With Attention , 2018, 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR).

[3]  Hao Sheng,et al.  Multi-Task Time Series Forecasting With Shared Attention , 2020, 2020 International Conference on Data Mining Workshops (ICDMW).

[4]  Dinggang Shen,et al.  Joint Classification and Regression via Deep Multi-Task Multi-Channel Learning for Alzheimer's Disease Diagnosis , 2019, IEEE Transactions on Biomedical Engineering.

[5]  D. Salmon,et al.  Alzheimer’s Disease: Past, Present, and Future , 2017, Journal of the International Neuropsychological Society.

[6]  Jiashi Feng,et al.  Predicting Alzheimer’s disease progression using deep recurrent neural networks✩ , 2019, NeuroImage.

[7]  Daoqiang Zhang,et al.  Temporally Constrained Group Sparse Learning for Longitudinal Data Analysis in Alzheimer's Disease , 2017, IEEE Transactions on Biomedical Engineering.

[8]  Daoqiang Zhang,et al.  Multi‐task exclusive relationship learning for alzheimer’s disease progression prediction with longitudinal data , 2019, Medical Image Anal..

[9]  Johan H. C. Reiber,et al.  MMSE scores correlate with local ventricular enlargement in the spectrum from cognitively normal to Alzheimer disease , 2008, NeuroImage.

[10]  Neil P. Oxtoby,et al.  Imaging plus X: multimodal models of neurodegenerative disease , 2017, Current opinion in neurology.

[11]  Xiangyu Wang,et al.  Ensemble of 3D densely connected convolutional network for diagnosis of mild cognitive impairment and Alzheimer's disease , 2019, Neurocomputing.

[12]  Deng Cai,et al.  What to Do Next: Modeling User Behaviors by Time-LSTM , 2017, IJCAI.

[13]  Daoqiang Zhang,et al.  Temporally-Constrained Group Sparse Learning for Longitudinal Data Analysis , 2012, MICCAI.

[14]  C. Jack,et al.  Preclinical Alzheimer's disease: Definition, natural history, and diagnostic criteria , 2016, Alzheimer's & Dementia.

[15]  Danni Cheng,et al.  Classification of MR brain images by combination of multi-CNNs for AD diagnosis , 2017, International Conference on Digital Image Processing.

[16]  Jiayu Zhou,et al.  A multi-task learning formulation for predicting disease progression , 2011, KDD.

[17]  Sterling C. Johnson,et al.  Predicting Alzheimer’s disease progression using multi-modal deep learning approach , 2019, Scientific Reports.

[18]  Heikki Huttunen,et al.  Machine learning framework for early MRI-based Alzheimer's conversion prediction in MCI subjects , 2015, NeuroImage.

[19]  Tingyan Wang,et al.  Predictive Modeling of the Progression of Alzheimer’s Disease with Recurrent Neural Networks , 2018, Scientific Reports.

[20]  Hamido Fujita,et al.  CMC: A consensus multi-view clustering model for predicting Alzheimer's disease progression , 2020, Comput. Methods Programs Biomed..

[21]  Heung-Il Suk,et al.  Deep recurrent model for individualized prediction of Alzheimer’s disease progression , 2020, NeuroImage.

[22]  Nick C Fox,et al.  The Alzheimer's disease neuroimaging initiative (ADNI): MRI methods , 2008, Journal of magnetic resonance imaging : JMRI.

[23]  Heung-Il Suk,et al.  Unified Modeling of Imputation, Forecasting, and Prediction for AD Progression , 2019, MICCAI.

[24]  Daoqiang Zhang,et al.  Dual Attention Multi-Instance Deep Learning for Alzheimer’s Disease Diagnosis With Structural MRI , 2021, IEEE Transactions on Medical Imaging.

[25]  Anders Søgaard,et al.  Deep multi-task learning with low level tasks supervised at lower layers , 2016, ACL.

[26]  Jiayu Zhou,et al.  Modeling disease progression via multi-task learning , 2013, NeuroImage.

[27]  Xuanjing Huang,et al.  Meta Multi-Task Learning for Sequence Modeling , 2018, AAAI.

[28]  Mirsad Hadzikadic,et al.  Predicting Neural Deterioration in Patients with Alzheimer’s Disease Using a Convolutional Neural Network , 2020, 2020 IEEE International Conference on Bioinformatics and Biomedicine (BIBM).

[29]  David C. Kale,et al.  Modeling Missing Data in Clinical Time Series with RNNs , 2016 .

[30]  Kyung Sup Kwak,et al.  Multimodal multitask deep learning model for Alzheimer's disease progression detection based on time series data , 2020, Neurocomputing.

[31]  M. Mallar Chakravarty,et al.  Modeling and prediction of clinical symptom trajectories in Alzheimer’s disease using longitudinal data , 2018, PLoS Comput. Biol..

[32]  Massimo Filippi,et al.  Automated classification of Alzheimer's disease and mild cognitive impairment using a single MRI and deep neural networks , 2018, NeuroImage: Clinical.

[33]  Khan M. Iftekharuddin,et al.  Deep learning of texture and structural features for multiclass Alzheimer's disease classification , 2017, 2017 International Joint Conference on Neural Networks (IJCNN).