Optimizing Federated Learning on Non-IID Data with Reinforcement Learning

The widespread deployment of machine learning applications in ubiquitous environments has sparked interests in exploiting the vast amount of data stored on mobile devices. To preserve data privacy, Federated Learning has been proposed to learn a shared model by performing distributed training locally on participating devices and aggregating the local models into a global one. However, due to the limited network connectivity of mobile devices, it is not practical for federated learning to perform model updates and aggregation on all participating devices in parallel. Besides, data samples across all devices are usually not independent and identically distributed (IID), posing additional challenges to the convergence and speed of federated learning. In this paper, we propose Favor, an experience-driven control framework that intelligently chooses the client devices to participate in each round of federated learning to counterbalance the bias introduced by non-IID data and to speed up convergence. Through both empirical and mathematical analysis, we observe an implicit connection between the distribution of training data on a device and the model weights trained based on those data, which enables us to profile the data distribution on that device based on its uploaded model weights. We then propose a mechanism based on deep Q-learning that learns to select a subset of devices in each communication round to maximize a reward that encourages the increase of validation accuracy and penalizes the use of more communication rounds. With extensive experiments performed in PyTorch, we show that the number of communication rounds required in federated learning can be reduced by up to 49% on the MNIST dataset, 23% on FashionMNIST, and 42% on CIFAR-10, as compared to the Federated Averaging algorithm.

[1]  Hubert Eichner,et al.  APPLIED FEDERATED LEARNING: IMPROVING GOOGLE KEYBOARD QUERY SUGGESTIONS , 2018, ArXiv.

[2]  Richard S. Sutton,et al.  Reinforcement Learning: An Introduction , 1998, IEEE Trans. Neural Networks.

[3]  Blaise Agüera y Arcas,et al.  Communication-Efficient Learning of Deep Networks from Decentralized Data , 2016, AISTATS.

[4]  Sarvar Patel,et al.  Practical Secure Aggregation for Privacy-Preserving Machine Learning , 2017, IACR Cryptol. ePrint Arch..

[5]  Hubert Eichner,et al.  Towards Federated Learning at Scale: System Design , 2019, SysML.

[6]  Srikanth Kandula,et al.  Resource Management with Deep Reinforcement Learning , 2016, HotNets.

[7]  Takayuki Nishio,et al.  Client Selection for Federated Learning with Heterogeneous Resources in Mobile Edge , 2018, ICC 2019 - 2019 IEEE International Conference on Communications (ICC).

[8]  Yue Zhao,et al.  Federated Learning with Non-IID Data , 2018, ArXiv.

[9]  Alex Graves,et al.  Playing Atari with Deep Reinforcement Learning , 2013, ArXiv.

[10]  Ameet Talwalkar,et al.  Federated Multi-Task Learning , 2017, NIPS.

[11]  William J. Dally,et al.  Deep Gradient Compression: Reducing the Communication Bandwidth for Distributed Training , 2017, ICLR.

[12]  Samy Bengio,et al.  Device Placement Optimization with Reinforcement Learning , 2017, ICML.

[13]  Peter Richtárik,et al.  Federated Learning: Strategies for Improving Communication Efficiency , 2016, ArXiv.

[14]  Silvio Savarese,et al.  Active Learning for Convolutional Neural Networks: A Core-Set Approach , 2017, ICLR.

[15]  Jakub Konecný,et al.  Federated Optimization: Distributed Optimization Beyond the Datacenter , 2015, ArXiv.

[16]  Klaus-Robert Müller,et al.  Robust and Communication-Efficient Federated Learning From Non-i.i.d. Data , 2019, IEEE Transactions on Neural Networks and Learning Systems.

[17]  Mehryar Mohri,et al.  Agnostic Federated Learning , 2019, ICML.

[18]  David Silver,et al.  Deep Reinforcement Learning with Double Q-Learning , 2015, AAAI.