Distributed Training on Edge Devices. Large Batch vs. Federated Learning

(1/3) An Edgify Research Team Publication

Image for post
Image for post
  1. Costs or constraints on transmitting all that data for training purposes.
Image for post
Image for post
Image for post
Image for post

The Basic Idea and Its Challenges

Our goal is to train a high-quality model while the training data remains distributed over a large number of edge devices. This is done as follows:

  1. The distribution of data among those endpoints (where an uneven, non IID fashion presents a deep challenge to such approaches)
  2. Network bandwidth limitations (where not all approaches lend themselves to compression schemes equally)
Image for post
Image for post
  1. Synchronised Algorithms — a simple implementation of the distribution framework. It requires that each device sends a model or model gradients (or model update) back to the server during each round.
  2. The Machine Learning algorithmic framework is that of Deep Learning (other methods also lend themselves to such distribution schemes).
  3. Furthermore, the Deep Learning training uses a standard optimizer — Stochastic Gradient Descent (SGD).

Large Batch

In classical, non-distributed batch training, a batch of samples is used in order to perform a stochastic gradient descent optimization step. With large batches allowed for (which is not always simple or even feasible), it is rather straightforward for this process to be a distributed one, running over many server GPUs or, in our case, edge devices:

  1. Download the current shared model.
  2. Run a single pass of forward prediction and backward error propagation.
  3. Compress the gradients.
  4. Upload the update to a central server.
    — — — — — — — — server phase — — — — — — — —
  5. Decompress the incoming gradients and compute their average
  6. Perform a single step of SGD to update the shared model.
  7. Go to step 2 (edge training phase).

Federated Learning

In Large Batch, in every round, each device performs a single forward-backward pass, and immediately communicates the gradient. In Federated Learning, in contrast, in every round, each edge device performs some independent training on its local data (that is, without communicating with the other devices), for several iterations. Only then does it send its model-update to the server to be aggregated with those of the other devices, for the sake of an improved shared model (from which this process can then repeat).

  1. Download the current shared model.
  2. Run a number of SGD iterations (based on forward and backward passes upon the local data).
  3. Compress the updated model weights (optional).
  4. Upload the update to the central server.
    — — — — — — — — server phase — — — — — — — —
  5. Decompress the incoming weights of the machines’ models if needed.
  6. Compute their average.
  7. Update the shared model.
  8. Go to step 2 (edge training phase).

The Two Approaches in Comparison

Image for post
Image for post
Image for post
Image for post
Federated learning has a hard time handling batches that are unevenly distributed

A Preliminary Empirical Comparison

We have set out to benchmark the two approaches, pinning them against each other and against classical, single-server training. Our basic comparative philosophy was to run each methodology according to its own fitting or “standard” parameters, to the extent that this is possible. Large Batch requires comparably larger learning rates, for example, and so simply using the same parameters across all methodologies wouldn’t do. On the other hand, for the by-epoch comparison to make any kind of sense, we had to keep some uniformity. Importantly, as using momentum didn’t fit our optimization of the Federated Learning run, we had to avoid using momentum for the centralised and Large Batch training as well.

Image for post
Image for post
Table 1: The parameters of the different methods
Image for post
Image for post
Figure 1: A graph of the three compared approaches — Large Batch, Federated Learning, and the classic single-server SGD training.
Image for post
Image for post
Table 2: Number of communication rounds required during the training in order to reach an accuracy of 80% in the experiment as above. For Federated Learning, synched once every epoch, this is simply the number of epochs for 80% accuracy. For Large Batch, this is the number of epochs for 80% accuracy, times the number of (3072) batches that go into a single epoch (of 50,000 samples).

Conclusion

Distributing the training among many edge devices provides groundbreaking advantages. It also means, however, that communication cost now becomes an important factor, which has to be taken into account and managed, somehow. The number of rounds is but one fundamental aspect of it. The other side of the coin is the amount of data that has to be sent each time. This is the topic of our next post.

Image for post
Image for post
From left to right: Nadav, Itay, Aviv, Neta, Daniel, Tomer, Liron
Image for post
Image for post

A foundational shift in the world of AI training. Deep Learning and Machine Learning training directly on edge devices.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store