Overcoming Forgetting in Federated Learning on Non-IID Data

Edgify’s Collaborated Method for Distributed Learning, to be Fully Released at NeurIPS this Year!

Image for post
Image for post

There is a growing interest today in training deep learning models on the edge. Algorithms such as Federated Averaging [1] (FedAvg) allow training on devices with high network latency by performing many local gradient steps before communicating their weights. However, the very nature of this setting is such that there is no control over the way the data is distributed on the devices.

consider, for instance, a smart checkout scale at a supermarket that has a camera mounted on it and some processing power. You want each scale to collect images of fruits and vegetables being scaled and to collectively train a Neural Network on the scales to recognize these fruits and vegetables. Such an unconstrained environment would almost always mean that not all edge devices (in this case, scales) will have data from all the classes (in this case, fruits and vegetables). This is commonly referred to as a Non-IID data distribution.

Training with FedAvg on Non-IID data, would lead the locally trained models to “forget” the classes for which they have little or no data.

In a recent paper [2], which will appear at NuerIPS’s Federated Learning for Data Privacy and Confidentiality workshop, we present Federated Curvature (FedCurv), an algorithm for training with Federated Learning on non-IID data. In this paper, we build on ideas from Lifelong Learning to prevent knowledge forgetting in Federated Learning.

In Lifelong Learning, the challenge is to learn task A, and continue on to learn task Busing the same model, but without “forgetting” task A, i.e. without severely hurting the performance on that task. Or in general, to learn tasks A1, A2 … in sequence without forgetting previously-learnt tasks for which samples are not presented anymore.

In the paper — Elastic Weight Consolidation (EWC) [3] the authors propose an algorithm for sequentially training a model on new tasks without forgetting old ones.

The idea behind EWC is to prevent forgetting by identifying the coordinates in the network parameters that are the most informative for a learnt task A, and then, while task B is being learned, penalize the learner for changing these parameters. The basic assumption is that deep neural networks are over-parameterized enough, so that there are good chances of finding an optimal solution *B to task B in the neighborhood of previously learned *A. They depict the idea with the following diagram:

Image for post
Image for post

In order to “select” the parameters that are important for the previous task, the authors use the diagonal of the Fisher information matrix. This is a matrix whose size is the same as the model parameter tensor, and each entry’s value correlates with the matching model parameter’s “importance”.

The authors enforce the penalty by adding a term to the optimization objective, forcing model parameters that have high Fisher information for task A, to preserve their value while learning task B. This is depicted by the following objective:

Image for post
Image for post

This loss adjustment can be extended to multiple tasks by having the penalty term be a sum on all previous tasks.

For Federated Learning, we adapt the EWC algorithm from a sequential algorithm to a parallel one. In this scenario, we keep communicating and averaging the local models, just like in FedAvg, but we also add the EWC penalty for forcing each local model to preserve the knowledge of all other devices. During communication, each device sends its model and the model’s Fisher information matrix diagonal. Mathematically, we get:

Image for post
Image for post

This way, we enable training on local data, without forgetting the knowledge gained from data of other devices (such as other classes).

At first glance, the number of new terms added to the loss would seem to grow linearly with the number of edge devices. However, as we show in [2], by simple arithmetic manipulations we can keep to a constant number of terms which depend on the sum of Fisher information matrices, making the loss function scalable since the number of terms is not dependent on the number of edge devices. This also means that while each edge needs to send the model and its Fisher information matrix diagonal to the central point, the central points only need to send the aggregation of the individual models and their Fisher information matrix diagonals to each edge. Note that FedCurv only sends local gradient-related aggregated information (aggregated on local data) to the central point. In terms of privacy, it is not significantly different from the classical FedAvg algorithm.

Image for post
Image for post

We tested FedCurv on a set of 96 edge devices. We used MNIST for the experiment, and divided the data so that every device has images from exactly 2 class (which no other device sees). We compared FedCurv to FedAvg and FedProx [4] (the central existing solution, whose description is beyond the scope of this blog).

Since the main benefit of our algorithm is that it allows less frequent communications, we expect that as the number of local epochs E between consecutive communication rounds increases the advantages of using FedCurv will become more apparent, i.e. FedCurv will need less iterations to reach a desired accuracy.

The results, presented in table 1, show that for 50 local epochs FedCurv achieved 90% accuracy three times faster than FedAvg. Figures 1 and 2 show that both FedProx and FedCurv are doing well at the beginning of the training process. However, while FedCurv provides enough flexibility that allows for reaching high accuracy at the end of the process, the stiffness of the parameters in FedProx comes at the expense of accuracy.

Image for post
Image for post
Figure 1 — Learning Curves, E=50
Image for post
Image for post
Figure 2 — Learning Curves, E=10

We presented the problem of non-i.i.d data in Federated Learning. We showed how this is related to the problem forgetting in Life-Long Learning and presented FedCurv, a novel approach to train for training in this case. We showed that FedCurv can be implemented efficiently without a substantial increase in bandwidth.

Make sure to follow Edgify for updates from the NeurIPS conference this week

Image for post
Image for post

Edgify.ai has been researching distributed edge training for four years. We are building a platform (framework) that enables the training and deployment of machine learning models directly on edge devices, such as smartphones, IoT devices, connected cars, healthcare equipment, smart dishwashers and more. We are committed to revolutionising the privacy, information security, latency and costs associated with AI.

[1] H Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, et al. Communication-efficient learning of deep networks from decentralized data.arXiv preprint arXiv:1602.05629, 2016.

[2] Neta Shoam, Tomer Avidor, Aviv Keren, Nadav Israel, Daniel Benditkis, Liron Mor-Yosef, Itai Zaitek. Overcoming Forgetting in Federated Learning on Non-IID Data. arXiv preprint arXiv:1910.07796

[3] James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, et al. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences 114(13):3521–3526, 2017.

[4] Anit Kumar Sahu, Tian Li, Maziar Sanjabi, Manzil Zaheer, Ameet Talwalkar, Virginia Smith. Federated Optimization for Heterogeneous Network. arXiv preprint arXiv:1812.06127, 2018.

A foundational shift in the world of AI training. Deep Learning and Machine Learning training directly on edge devices.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store