Batch Normalization

Batch Normalization (BN) normalizes the pre-activations (or post-activations) of each layer to have zero mean and unit variance within a mini-batch, then applies a learned affine transformation. It accelerates training, reduces sensitivity to initialization, and acts as a mild regularizer.

Forward Pass

For a mini-batch $\mathcal{B} = {z_1, \ldots, z_B}$ of pre-activations at one layer (treating each feature independently):

Step 1: Batch statistics

\[\mu_\mathcal{B} = \frac{1}{B}\sum_{i=1}^B z_i, \quad \sigma^2_\mathcal{B} = \frac{1}{B}\sum_{i=1}^B (z_i - \mu_\mathcal{B})^2\]

Step 2: Normalize

\[\hat{z}_i = \frac{z_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}}\]

where $\epsilon \approx 10^{-5}$ prevents division by zero.

Step 3: Scale and shift (affine transform)

\[y_i = \gamma \hat{z}_i + \beta\]

$\gamma$ (scale) and $\beta$ (shift) are learned parameters per feature. They allow the network to undo the normalization if that is optimal.

Inference (Running Statistics)

During training, maintain exponential moving averages of batch statistics:

\[\mu_\text{run} \leftarrow \alpha \mu_\text{run} + (1-\alpha)\mu_\mathcal{B}\] \[\sigma^2_\text{run} \leftarrow \alpha \sigma^2_\text{run} + (1-\alpha)\sigma^2_\mathcal{B}\]

At inference: use $\mu_\text{run}$ and $\sigma^2_\text{run}$ (fixed) instead of batch statistics. This decouples inference from batch size.

Why It Works

Original motivation (Ioffe & Szegedy 2015): reduces internal covariate shift: the distribution of layer inputs changes as parameters of previous layers update, slowing convergence. BN re-centers and re-scales to stabilize this.

Current understanding (more nuanced): empirical and theoretical work suggests the primary benefit is smoothing the loss landscape: BN makes the gradient magnitude more predictable and reduces sensitivity to learning rate, enabling larger learning rates and faster convergence.

Gradient flow: BN’s normalization ensures that gradients in the backward pass are also re-scaled, helping mitigate vanishing/exploding gradients.

Regularization Effect

The batch statistics $\mu_\mathcal{B}, \sigma^2_\mathcal{B}$ depend on the other samples in the mini-batch, introducing stochasticity similar to dropout. This noise:

  • Prevents over-reliance on exact activation values.
  • Often reduces the need for dropout in convolutional layers.

BN’s regularization effect diminishes with larger batch sizes (less noise in statistics).

Where to Apply BN

Standard placement: before the activation function: Linear → BN → ReLU

Post-activation: used in some architectures: Linear → ReLU → BN

Pre-norm vs. Post-norm in Transformers:

  • Post-norm (original Transformer): BN/LN applied after residual addition. Less stable for deep models.
  • Pre-norm (GPT-2, modern LLMs): LN applied before each sublayer. More stable.

BN and Batch Size

BN statistics are unreliable for small batch sizes ($B < 8$). Variance estimate is noisy; statistics may not represent the population well.

Alternatives for small batches:

Normalization Statistics Computed Over Notes
Batch Norm (BN) Batch dimension (per feature) Needs $B \geq 16$ ideally
Layer Norm (LN) Feature dimension (per sample) Used in Transformers; batch-size independent
Instance Norm (IN) Spatial dims per sample per channel Used in style transfer
Group Norm (GN) Groups of channels per sample Good for object detection (small batches)

Layer Normalization

Most common alternative; used in all Transformer-based models (BERT, GPT, LLaMA):

\[\text{LN}(\mathbf{z}) = \gamma \frac{\mathbf{z} - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\]

where $\mu, \sigma^2$ are computed across the feature dimension for a single sample.

Advantages over BN:

  • No dependence on batch size or batch composition.
  • Works identically at training and inference.
  • Well-suited for variable-length sequences.

Disadvantages: less effective than BN for CNNs on image tasks.

RMSNorm

Simplified layer normalization used in LLaMA, Mistral, and other modern LLMs. Skips the mean subtraction (re-centering):

\[\text{RMSNorm}(\mathbf{z}) = \frac{\mathbf{z}}{\sqrt{\frac{1}{d}\sum_j z_j^2 + \epsilon}} \cdot \gamma\]

Empirically equivalent to LN in language models; $\approx 7\%$ faster due to fewer operations.

Practical Considerations

  • Always use BN (or LN) in deep networks; reduces need for careful initialization.
  • For CNNs: apply BN after each convolutional layer, before activation.
  • For Transformers: use Layer Norm (pre-norm for stability).
  • Use Group Norm for detection/segmentation models with small batch sizes.
  • BN has 4 parameters per feature: $\gamma$, $\beta$ (learned), $\mu_\text{run}$, $\sigma^2_\text{run}$ (running, not trained by backprop).
  • Set model.eval() to freeze running statistics during inference; forgetting this is a common bug.