DOC Improve documentation for LayerNorm (#59178)
Summary:
Closes https://github.com/pytorch/pytorch/issues/51455
I think the current implementation is aggregating over the correct dimensions. The shape of `normalized_shape` is only used to determine the dimensions to aggregate over. The actual values of `normalized_shape` are used when `elementwise_affine=True` to initialize the weights and biases.
This PR updates the docstring to clarify how `normalized_shape` is used. Here is a short script comparing the implementations for tensorflow and pytorch:
```python
import torch
import torch.nn as nn
import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization
rng = np.random.RandomState()
x = rng.randn(10, 20, 64, 64).astype(np.float32)
# slightly non-trival
x[:, :10, ...] = x[:, :10, ...] * 10 + 20
x[:, 10:, ...] = x[:, 10:, ...] * 30 - 100
# Tensorflow Layer norm
x_tf = tf.convert_to_tensor(x)
layer_norm_tf = LayerNormalization(axis=[-3, -2, -1], epsilon=1e-5)
output_tf = layer_norm_tf(x_tf)
output_tf_np = output_tf.numpy()
# PyTorch Layer norm
x_torch = torch.as_tensor(x)
layer_norm_torch = nn.LayerNorm([20, 64, 64], elementwise_affine=False)
output_torch = layer_norm_torch(x_torch)
output_torch_np = output_torch.detach().numpy()
# check tensorflow and pytorch
torch.testing.assert_allclose(output_tf_np, output_torch_np)
# manual comutation
manual_output = ((x_torch - x_torch.mean(dim=(-3, -2, -1), keepdims=True)) /
(x_torch.var(dim=(-3, -2, -1), keepdims=True, unbiased=False) + 1e-5).sqrt())
torch.testing.assert_allclose(output_torch, manual_output)
```
To get to the layer normalization as shown here:
<img width="157" alt="Screen Shot 2021-05-29 at 2 13 52 PM" src="https://user-images.githubusercontent.com/5402633/120080691-1e37f100-c088-11eb-9060-4f263e4cd093.png">
One needs to pass in `normalized_shape` with shape `x.dim() - 1` with the size of the channels and all spatial dimensions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59178
Reviewed By: ejguan
Differential Revision: D28931877
Pulled By: jbschlosser
fbshipit-source-id: 193e05205b9085bb190c221428c96d2ca29f2a70