Introduction

When running distributed training on $B$ data points on $N$ machines in theory at least many layers will behave identically to running a batch of size $N * B$ on single machine because most layers process the elements of a batch in parallel. However when dealing with layers that particularly during training do things at a batch-level the results might be different between the single machine and distributed settings. In this post we will see an example of how BatchNorm works when running distributed training with TensorFlow using TPUs and tf.keras.layers. This post assumes you know BatchNormalisation works. Here is one tutorial but there is an endless number of them if you search on Google.

You can use TPUs for free on Colab and you can run this code as a notebook here. If using Colab make sure you select TPU under Hardware accelerator by going to Runtime->Change Runtime Type in the top menu. The purpose of this blog is to demonstrate how distributed batch normalisation works so we won’t go into the details about using TPUs and the code here might not reflect best practices for training with TPUs. Please consult Tensorflow’s guide if you want to learn more about how to use TPUs.

Setting up

import tensorflow as tf
import numpy as np

First some setup code taken from the guide.

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))
INFO:tensorflow:Clearing out eager caches


INFO:tensorflow:Clearing out eager caches


WARNING:tensorflow:TPU system grpc://10.98.27.170:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.


WARNING:tensorflow:TPU system grpc://10.98.27.170:8470 has already been initialized. Reinitializing the TPU can cause previously created variables on TPU to be lost.


INFO:tensorflow:Initializing the TPU system: grpc://10.98.27.170:8470


INFO:tensorflow:Initializing the TPU system: grpc://10.98.27.170:8470


INFO:tensorflow:Finished initializing TPU system.


INFO:tensorflow:Finished initializing TPU system.


All devices:  [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]
tpu_strategy = tf.distribute.TPUStrategy(resolver)
INFO:tensorflow:Found TPU system:


INFO:tensorflow:Found TPU system:


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Cores: 8


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Workers: 1


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Num TPU Cores Per Worker: 8


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)


INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)

Now create a BatchNormExperiment class that lets us get results easily. The experiment class creates a dataset and runs forward passes in the appropriate way depending on whether or not we are running in a distributed fashion.

class BatchNormExperiment(object):
  def __init__(self, data, bn_layer, strategy=None):
    self.bn_layer = bn_layer
    self.strategy = strategy
    if self.strategy is not None:
      self.ds = strategy.experimental_distribute_datasets_from_function(
          lambda _: tf.data.Dataset.from_tensor_slices(data).batch(32)
      )
    else:
      self.ds = tf.data.Dataset.from_tensor_slices(data).batch(32 * 8)

  
  @tf.function
  def apply_bn_dist(self, iterator, training):
    def _f(x, training):
      return self.bn_layer(x, training=training)
    
    result = self.strategy.run(_f, args=(next(iterator), training))
    return tf.concat(result.values, axis=0)

  @tf.function
  def apply_bn(self, iterator, training):
    return self.bn_layer(next(iterator), training=training)

  def get_results(self):
    bn = self.apply_bn if self.strategy is None else self.apply_bn_dist
    iterator = iter(self.ds)
    result = []
    for _ in range(STEPS):
      result.append(bn(iterator, training=True))
    val_result = bn(iterator, training=False)
    return tf.concat(result, axis=0), val_result

A simple function to calculate that max difference between two arrays / tensors since we will be comparing tensors which will be close but not identical

def max_diff(x, y):
  if isinstance(x, tf.Tensor):
    x = x.numpy()
  if isinstance(y, tf.Tensor):
    y = y.numpy()

  return np.abs(np.array(x) - np.array(y)).max()

Now create a fake dataset consisting of 101 256 x 64 random matrices the first 100 of of which will be used to call batchnorm in the training mode and the last 1 in the inference mode. Note that we are not doing any training here. Training and inference refer to how the normalisation is done:

  • In training mode, normalisation is done with the stats obtained from the input batch
  • In inference mode means the normalisation is done with moving stats that are calculated using earlier batches that were run in the training mode
tf.keras.backend.clear_session()
STEPS = 100
data = tf.random.normal([32 * 8 * (STEPS + 1), 64])
bn_layer = tf.keras.layers.BatchNormalization()
bn_layer.build([None, 64])
init = bn_layer.get_weights()

Distributed BatchNorm

This layer normalises the per device batches separately on each device in training mode. However during inference all examples should be treated independently so the same moving stats will need to be used. The layer therefore averages the moving stats across devices. In the source code) the tf.Variable instances for moving mean and variance are defined with these two settings:

  • synchronization=tf.VariableSynchronization.ON_READ - according to the documentation here this means they will be aggregated when read which includes when applying batch normalisation

  • aggregation=tf.VariableAggregation.MEAN - indicates what aggregation should be use, here mean

Let us create a distributed batchnorm layer and initialise it with the same weights as bn_layer

with tpu_strategy.scope():
  bn_dist = tf.keras.layers.BatchNormalization()
  bn_dist.build([None, 64])
  bn_dist.set_weights(init)

Setup the experiments and get results

exp_regular = BatchNormExperiment(data, bn_layer, strategy=None)
exp_dist = BatchNormExperiment(data, bn_dist, tpu_strategy)
trn_regular, val_regular = exp_regular.get_results()
trn_dist, val_dist = exp_dist.get_results()

Train results will be different when using regular distributed batchnorm since they are normalised using stats calculated separately in each device.

max_diff(trn_regular, trn_dist)
1.4629191

Val results will also be different since the batch norm stats are first calculated per sub batch and then averaged across the devices. The average of the per device batch mean is the same as the mean across all devices but the average of the per device variance is not the variance across all devices.

max_diff(val_regular, val_dist)
0.09319067

We observe that whilst the moving mean is close to the non distributed one the moving variance is different.

mov_mean_regular, mov_var_regular = bn_layer.get_weights()[-2:]
mov_mean_dist, mov_var_dist = bn_dist.get_weights()[-2:]

print('Moving mean diff', max_diff(mov_mean_regular, mov_mean_dist))
print('Moving var diff', max_diff(mov_var_regular, mov_var_dist))
Moving mean diff 3.1432137e-09
Moving var diff 0.019875884

Note that get weights does the aggregation and returns a single value of the weight whereas bn_dist.moving_mean will return the per device value

mov_var_dist_per_device = bn_dist.moving_variance
len(mov_var_dist_per_device.values)
8

Observe that they the mean of these and mov_var_dist are identical

max_diff(tf.reduce_mean(mov_var_dist_per_device.values, axis=0), mov_var_dist)
0.0
# To avoid errors below delete everything corresponding to the distributed batch norm layer
del exp_dist, bn_dist, trn_dist, val_dist, mov_mean_dist, mov_var_dist, mov_var_dist_per_device 

Synchronised BatchNorm

In synchronised BatchNorm as implemented in tf.keras.layers.experimental.SyncBatchNormalization in the training mode batch stats are aggregated across the devices and the batches are normalised by the resulting value. From the source code we see that that first you find sum(x) and sum(x^2) and the batch_size for each replica

        local_sum = tf.reduce_sum(y, axis=axes, keepdims=True)
        local_squared_sum = tf.reduce_sum(tf.square(y), axis=axes,
                                                keepdims=True)
        batch_size = tf.cast(tf.shape(y)[axes[0]],
                                   tf.float32

Then you aggregate these across replicas using all_reduce.

        y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
        y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
                                               local_squared_sum)
        global_batch_size = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
                                                   batch_size)

At this point each replica has a copy of the global sum, squared sum and global batch size which it can use to normalise its own subset of the data

        axes_vals = [(tf.shape(y))[axes[i]]
                     for i in range(1, len(axes))]
        multiplier = tf.cast(tf.reduce_prod(axes_vals),
                                   tf.float32)
        multiplier = multiplier * global_batch_size

        mean = y_sum / multiplier
        y_squared_mean = y_squared_sum / multiplier
        # var = E(x^2) - E(x)^2
        variance = y_squared_mean - tf.square(mean)

Let us initialise and run a synchronised BatchNorm layer

with tpu_strategy.scope():
  sync_bn_dist = tf.keras.layers.experimental.SyncBatchNormalization()
  sync_bn_dist.build([None, 64])
  sync_bn_dist.set_weights(init)
exp_dist_sync = BatchNormExperiment(data, sync_bn_dist, tpu_strategy)
trn_dist_sync, val_dist_sync = exp_dist_sync.get_results()

Now the results should be the same up to a small epsilon in both training and inference modes assuming that the same data has been used in each case

max_diff(trn_regular, trn_dist_sync)
1.9073486e-06
max_diff(val_regular, val_dist_sync)
7.1525574e-07

Unsurprisingly both the moving stats are close this time

mov_mean_dist_sync, mov_var_dist_sync = sync_bn_dist.get_weights()[-2:]

print('Moving mean diff', max_diff(mov_mean_regular, mov_mean_dist_sync))
print('Moving var diff', max_diff(mov_var_regular, mov_var_dist_sync))
Moving mean diff 4.4237822e-09
Moving var diff 2.9802322e-07
del exp_dist_sync, sync_bn_dist, trn_dist_sync, val_dist_sync, mov_mean_dist_sync, mov_var_dist_sync

Finally we can see that even in non-synchronised distributed batchnorm during inference the same moving stats are applied in each replica since as noted before. We copy the weight from the regular batchnorm layer, bn_layer, where the moving stats have been updated once and then run a non-synchronised distributed batch norm step in inference mode.

with tpu_strategy.scope():
  bn_copy = tf.keras.layers.BatchNormalization()
  bn_copy.build([None, 64])
  bn_copy.set_weights(bn_layer.get_weights())
exp_copy = BatchNormExperiment(data, bn_copy, tpu_strategy)
# Create iterator and advance to last batch which was used to get the "val" results
itr = iter(exp_copy.ds)
for i in range(STEPS):
  _ = next(itr)
val_copy = exp_copy.apply_bn_dist(itr, training=False)

As expected this yields nearly the same values as val_regular.

max_diff(val_regular, val_copy)
4.7683716e-07