Fix thread-allocation in `_vec_log_softmax_lastdim` (#85398)
## Problem history
There seems to always have been a bug in `_vec_log_softmax_lastdim `.
In particular, there were two issues with it -
#### Bug 1
Before AVX512 support was added, `CHUNK_SIZE` had been heuristically chosen in `_vec_log_softmax_lastdim`:
`CHUNK_SIZE = (128 / sizeof(scalar_t)) * Vec::size();`
It was `256` for float32, bfloat16, and float16.
When AVX512 support was added, `CHUNK_SIZE` became `512`.
The rationale behind determining `CHUNK_SIZE` has not been described, and seems flawed, since the number of OpenMP threads used currently depends upon it.
#### Bug 2
`grain_size` had been defined as `internal::GRAIN_SIZE / (16 * dim_size * CHUNK_SIZE)`
So, `grain_size` was usually 0, as it was `8 / (dim_size)`, so, it's always replaced by `CHUNK_SIZE`, viz. 256.
Since `256` was always the `grain_size` for `at::parallel_for`, few threads were used in certain cases.
#### Problem caused by bugs
With `outer_size` of say, 700, only 3 threads would have been used with AVX2, irrespective of the value of `dim_size`!
When AVX512 support was added, since `CHUNK_SIZE` became `512`, only 2 threads were used if `outer_dim` was 700.
In the Transformers training example, `log_softmax` was computed on the last dim of a tensor of shape `(700, 23258)`.
AVX512 thus appeared to be quite slower, cloaking the actual issue that even AVX2 performance for the kernel was quite poor due to inefficient work distribution amongst OpenMP threads.
## Solution
Distribute work more efficiently, which would result in higher performance for both AVX2 & AVX512 than now,
and fixes the regression observed with AVX512 (AVX512 kernel would now be faster than its AVX2 counterpart).
## Benchmarks
##### Machine-config:
Intel(R) Xeon(R) Platinum 8371HC CPU (Cooper Lake)
One socket of 26 physical cores was used.
Intel OpenMP & tcmalloc were preloaded.
Example of a command to run benchmark:
`ATEN_CPU_CAPABILITY=avx512 KMP_AFFINITY=granularity=fine,verbose,compact,1,0 KMP_BLOCKTIME=1 KMP_SETTINGS=1 MKL_NUM_THREADS=26 OMP_NUM_THREADS=26 numactl --membind=0 --cpunodebind=0 python3.8 -m pt.softmax_test --test_name LogSoftmax_N1024_seq_len23258_dim1_cpu`
Benchmark | Old implementation time (us) | New implementation time (us) | Speedup ratio (old/new)
-- | -- | -- | --
LogSoftmax_N1024_seq_len23258_dim1_cpu AVX2 | 11069.281 | 2651.186 | 4.17x
LogSoftmax_N1024_seq_len23258_dim1_cpu AVX512 | 18292.928 | 2586.550| 7.07x
LogSoftmax_N700_seq_len23258_dim1_cpu AVX2 | 9611.902 | 1762.833 | 5.452x
LogSoftmax_N700_seq_len23258_dim1_cpu AVX512 | 12168.371 | 1717.824 | 7.08x
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85398
Approved by: https://github.com/jgong5, https://github.com/mingfeima, https://github.com/peterbell10, https://github.com/lezcano