Fix bad use of channels last kernel in sync batch norm backward (#64100)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/64039
There are two distinct problems here.
1. If `grad_output` is channels last but not input, then input would be read as-if it were channels last. So reading the wrong values.
2. `use_channels_last_kernels` doesn't guarunte that `suggest_memory_format` will actually return channels last, so use `empty_like` instead so the strides always match.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64100
Reviewed By: mruberry
Differential Revision: D30622127
Pulled By: ngimel
fbshipit-source-id: e28cc57215596817f1432fcdd6c49d69acfedcf2