Batch Norm, Layer Norm, Instance Norm, Group Norm

Batch Normalization 从 2015 年被 Google 提出来之后,又诞生了很多 Normalization 方法,如 Layer Normalization, Instance Normalization, Group Normalization。 这些方法作用、效果各不相同,但却有着统一的内核和本质:计算输入数据在某些维度上的方差和均值,归一化,最后用可学习参数映射归一化后的特征。这可以统一表达为:

\[ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta \]

我们以图像数据为例子,给定输入数据 \(x \in (N, C, H, W)\), 其中 \(N, C, H, W\) 分别为 batch size, 通道数,图像高和宽。

Normalization methods
Normalization methods

如上图所示,BN 计算在 \(N, H, W\) 维度上的均值方差,LN 计算在 \(C, H, W\) 维度上的均值方差,IN 计算在 \(H, W\) 维度上的均值方差,GN 计算在 \(C', H, W\) 维度上的均值方差,其中 \(C'\) 是分组后的通道个数。

计算维度的不同是这些方法的唯一区别。也正是因为计算维度的不同,也导致了不同的效果和特性。

我们可以很轻松的用 PyTorch 实现每个方法的等效版本:

inputs = torch.randn(5, 256, 32, 32)  # (N, C, H, W)
bn = nn.BatchNorm2d(256)  # Weight shape: (C, )

# weight default is 1
bn.weight.data = torch.rand_like(bn.weight)

# compute on (N, H, W)
var, mean = torch.var_mean(inputs, dim=(0, 2, 3), keepdim=True, unbiased=False)  # (1, C, 1, 1)
std = (var + bn.eps).sqrt()  # (1, C, 1, 1)
norm = (inputs - mean) / std  # (N, C, H, W)

print(torch.allclose(
    norm * bn.weight.view(1, 256, 1, 1) + bn.bias.view(1, 256, 1, 1),
    bn(inputs))
)  # True
inputs = torch.randn(5, 256, 32, 32)  # (N, C, H, W)
ins = nn.InstanceNorm2d(256, affine=True)  # Weight shape: (C, )

# weight default is 1
ins.weight.data = torch.rand_like(ins.weight)

# compute on (H, W)
var, mean = torch.var_mean(inputs, dim=(2, 3), keepdim=True, unbiased=False)  # (N, C, 1, 1)
std = (var + ins.eps).sqrt()  # (N, C, 1, 1)
norm = (inputs - mean) / std  # (N, C, H, W)

print(torch.allclose(
    norm * ins.weight.view(1, 256, 1, 1) + ins.bias.view(1, 256, 1, 1),
    ins(inputs))
)  # True
inputs = torch.randn(5, 256, 32, 32)  # (N, C, H, W)
normalized_shape = inputs.shape[1:]  # Normalize on (C, H, W)
ln = nn.LayerNorm(normalized_shape)

# weight default is 1
ln.weight.data = torch.rand_like(ln.weight)

# compute on (C, H, W)
var, mean = torch.var_mean(inputs, dim=(1, 2, 3), keepdim=True, unbiased=False)  # (N, 1, 1, 1)
std = (var + ln.eps).sqrt()  # (N, 1, 1, 1)
norm = (inputs - mean) / std  # (N, C, H, W)

print(torch.allclose(norm * ln.weight + ln.bias, ln(inputs)))  # True
inputs = torch.randn(5, 256, 32, 32)  # (N, C, H, W)
num_groups = 32
bn = nn.GroupNorm(num_groups=num_groups, num_channels=256)  # Weight shape: (C, )

# weight default is 1
bn.weight.data = torch.rand_like(bn.weight)

grouped_inputs = inputs.view(5, num_groups, 256 // num_groups, 32, 32)  # (N, G, C', H, W)

# compute on (C', H, W)
var, mean = torch.var_mean(grouped_inputs, dim=(2, 3, 4), keepdim=True, unbiased=False)  # (N, G, 1, 1, 1)
std = (var + bn.eps).sqrt()  # # (N, G, 1, 1, 1)
norm = (grouped_inputs - mean) / std  # (N, G, C', H, W)

print(torch.allclose(
    norm.view(5, 256, 32, 32) * bn.weight.view(1, 256, 1, 1) + bn.bias.view(1, 256, 1, 1),
    bn(inputs))
)  # True
留言
Login with GitHub OR Google