remove serial_exec from scatter/gather kernel (#36181)
Summary:
Since the indexed dimension in `scatter/gather` is traversed inside the kernel, all the memory conflicts of writing to the same memory between the threads are actually mutually disjoint.
See [this comment](https://github.com/pytorch/pytorch/issues/33389#issuecomment-590017938) for a graphical explanation. More formal description:
Suppose we deal with 3D tensors and `dim=0`, hence the `scatter_add` operations are
```
self[index[i][j][k]][j][k] += src[i][j][k],
...
self[index[i'][j'][k']][j'][k'] += src[i'][j'][k'],
...
```
Clearly, write/read to the same memory happens if and and only if:
```
index[i][j][k] = index[i'][j'][k'],
j = j',
k = k'.
```
Since the reduction over `dim=0` happens inside the kernel, threads `i` and `i'` partition `dim=1,2`. It means that threads `i` and `i'` receive indices
```
I = {(*, i, k) sent to the thread i},
I' = {(*, i', k') sent to the thread i'},
I intersection with I' = the empty set.
```
This happens:
```
index[i][j][k] = index[i'][j'][k'],
j = j',
k = k',
```
if and only if there exists some thread k which receives indices K and
`(*,j,k),(*,j',k') in K`.
Therefore it is possible to make `scatter_add` parallel and remove `serial_exec` from the `scatter_gather_base_kernel`.
CC v0dro
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36181
Differential Revision: D21716167
Pulled By: ngimel
fbshipit-source-id: 49aee2de43779a1f0b359c22c8589c0702ee68a2