Normalization for Better Generalization and Faster Training
Different types of Normalization layers ( Batch Norm, Layernorm)
Batch Normalization
Training Deep Neural Networks is complicated by the fact that the distribution of each layer’s inputs changes during training, as the parameters of the previous layers change. This slows down the training by requiring lower learning rates and careful parameter initialization and makes it notoriously hard to train models with saturating nonlinearities. so to overcome this, we can do a normalization after some layers as below.
It calculates the batch means, std, and using those, normalizes the data then creates running mean and std which will be used in inference. One intuition about why BatchNorm works is that it removes the internal covariance shift. You can check that in the below video.
Another intuition:
Batch Normalization normalizes the activations in the intermediate layers. BN primarily enables training with a larger learning rate which is cause for faster convergence and better generalization.
Larger batch size training may converge to sharp minima. If we converge to sharp minima, generalization capacity may decrease. so noise in the SGD has an important role in regularizing the NN. Similarly, Higher learning rate will bias the network towards wider minima so it will give the better generalization. But, training with a higher learning rate may cause an explosion in the updates.
If we compare the gradients between with batch normalization and without batch normalization, without batch norm network gradients are larger and heavier tailed as shown below so we can train with larger learning rates with BN.
You can check the figure below from a paper, which compares the BN in CV and NLP. The differences between running mean/Variance and batch mean/variance exhibit very high variance with extreme outliers in Transformers.
import tensorflow as tf
input_layer = tf.keras.Input(shape=(6,))
bn_layer = tf.keras.layers.BatchNormalization()
bn_layer_out = bn_layer(input_layer)
print('Number of weights is', len(bn_layer.get_weights()))
If we have n
features as input to the BN layer, the weight matrix we have to learn is of size (4, n)
, i.e. n
features for each beta_initializer, gamma_initializer, moving_mean_initializer, moving_variance_initializer.
Please read Tensorflow documentation to know more about Training mode, inference mode of the BN layer. It is very important to take care of the mode in BN layer.
Layer Normalization
Unlike Batch normalization, it normalized horizontally i.e. it normalizes each data point. so $\mu$, $\sigma$ not depend on the batch. layer normalization does not have to use "running mean" and "running variance".
It gives the better results because of the gradinets with respect to $\mu$, $\sigma$ in Layer Normalization. Derivative of $\mu$ re-centers network gradients to zero. Derivative of $\sigma$ reduces variance of network gradient, which can be seen a kind of re-scaling.
center
, scale
parameters in Tensorflow
.
import tensorflow as tf
input_layer = tf.keras.Input(shape=(6))
norm_layer = tf.keras.layers.LayerNormalization(scale=False, center=False)
norm_layer_out = norm_layer(input_layer)
print('Number of weights is', len(norm_layer.get_weights()))
import tensorflow as tf
input_layer = tf.keras.Input(shape=(10,),batch_size=1)
norm_layer = tf.keras.layers.LayerNormalization(scale=True, center=True)
norm_layer_out = norm_layer(input_layer)
print('Number of weights is', len(norm_layer.get_weights()))