Vectorized CPU code implementing left shift operator. (#88607)
This PR adds vectorized implementation for CPU version of left shift operator.
All of the tests run by `pytest test/test_ops.py -vk left_shift` pass.
Here are some additional details:
<details>
<summary>
Benchmarking script (writen by Philip, with small tweaks by Mario) comparing left shifts with multiplications - on par now
</summary>
```python
import torch
from torch import Tensor
from torch.utils.benchmark import Timer, Compare
from itertools import product
from functools import partial
# These functions exist, because torch.jit.script does not support `torch.iinfo`
def _num_value_bits(dtype):
if dtype == torch.uint8:
return 8
else: # torch.int32
return 31
def _max_value(dtype):
if dtype == torch.uint8:
return 255
else: # torch.int32
return 2147483647
def bitshift(image, dtype):
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
def mul(image, dtype):
input_max = float(_max_value(image.dtype))
output_max = float(_max_value(dtype))
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
size = 256
image = torch.randint(0, 256, (3, size, size), dtype=torch.uint8)
dtype = torch.int32
def gen_inputs():
devices = ("cpu",)
fns = (mul, bitshift)
threads = (1,)
for device, fn, threads in product(devices, fns, threads):
yield f"Bitshift {device} {image.dtype}", str(tuple(image.shape)), threads, fn, image, dtype
def benchmark(label, sub_label, threads, f, *args, **kwargs):
return Timer("f(*args, **kwargs)",
globals=locals(),
label=label,
description=f.__name__,
sub_label=sub_label,
num_threads=threads).blocked_autorange()
results = []
for args in gen_inputs():
results.append(benchmark(*args))
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
```
</details>
<details>
<summary>
Test script exercising large number of combinations of left shift operands that I've used for further testing (validates results through comparing with results generated by NumPy)
</summary>
```python
import numpy as np
import torch
# Testing shifting of non-negative numbers only, but will test all
# possible RHS shift values for given type. For int8 and int16, we'll
# test shifting all of non-negative values represntable by type. For
# the rest of data types, we'll test shifting some random numbers in
# the corresponding range.
def _create_inputs(dtype):
info = torch.iinfo(dtype)
if dtype == torch.int8 or dtype == torch.int16:
ntests = info.max + 1
x = torch.arange(info.max + 1, dtype=dtype, device="cpu", requires_grad=False)
else:
ntests = 100000
x = torch.randint(info.max + 1 if dtype != torch.int64 else info.max, (ntests,), dtype=dtype, device="cpu", requires_grad=False)
y = torch.tensor(range(info.bits), dtype=dtype, device="cpu", requires_grad=False)
xy = torch.cartesian_prod(x, y)
return (xy[:, 0], xy[:, 1])
torch.manual_seed(0)
# Perform testing for each datatype supported, and compare results
# with ones generated by numpy.
for dtype in (torch.int8, torch.int16, torch.int32, torch.int64):
(x, y) = _create_inputs(dtype)
z = x << y
xnp = x.numpy()
ynp = y.numpy()
znp = z.numpy()
assert((znp == (xnp << ynp)).all())
```
</details>
<details>
<summary>
Benchmarking script running the left shift operator on tensors of different length (and varying number of bits to shift)
</summary>
```python
import torch
import pickle
import itertools
from torch.utils.benchmark import Timer, Compare
torch.manual_seed(0)
# Edit this part if needed.
lengths = [1024, 4096, 16384, 65536]
rhss = [1, 2, 7, 8, 15, 16, 31, 32, 63, 64]
benchmark_name = "lshift"
label = ""
dtypes = [torch.int8, torch.int16, torch.int32, torch.int64]
results = []
# Create an argument pair for testing. Argument are tensors of given
# datatype and length, LHS for each shift operation is a random
# number, and RHS is given value that is same for all of them.
def _make_args(dtype, length, rhs):
info = torch.iinfo(dtype)
imax = info.max
return (torch.randint(info.max, (length,), dtype=dtype, device="cpu", requires_grad=False),
rhs * torch.ones((length,), dtype=dtype, device="cpu", requires_grad=False))
# Run shift operation for vectors of given lenghts and for given
# number of bits to be shifted, and remember timings.
for dtype, length, rhs in itertools.product(dtypes, lengths, rhss):
x, y = _make_args(dtype, length, rhs)
timer = Timer("x << y",
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"dtype={dtype},length={length}",
num_threads=1)
results.append(timer.blocked_autorange())
# Gather results.
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
# Print results.
with open("{}.pickle".format(label), "wb") as f:
pickle.dump(results, f)
```
</details>
<details>
<summary>
Results of running above benchmarking script - results manually merged for runs of viable/strict (labeled "master" in the table below) and my branch (labeled "mybranch" in the table below)
</summary>
```
[------------------- lshift -------------------------------]
| master | mybranch
1 threads: ------------------------------------------------
dtype=torch.int8,length=1024 | 3 | 3
dtype=torch.int8,length=4096 | 5 | 3
dtype=torch.int8,length=16384 | 14 | 5
dtype=torch.int8,length=65536 | 51 | 15
dtype=torch.int16,length=1024 | 3 | 3
dtype=torch.int16,length=4096 | 4 | 3
dtype=torch.int16,length=16384 | 11 | 5
dtype=torch.int16,length=65536 | 39 | 13
dtype=torch.int32,length=1024 | 3 | 2
dtype=torch.int32,length=4096 | 4 | 3
dtype=torch.int32,length=16384 | 10 | 4
dtype=torch.int32,length=65536 | 35 | 12
dtype=torch.int64,length=1024 | 3 | 3
dtype=torch.int64,length=4096 | 4 | 3
dtype=torch.int64,length=16384 | 11 | 6
dtype=torch.int64,length=65536 | 36 | 20
Times are in microseconds (us).
```
</details>
All of the testing/benchmarking was conducted on qpu3, that supports AVX2 only. For basic validation of AVX-512 update of left shift implementation for 8-bit operands (that is the only one that is non-trivial in AVX-512 case), [Compiler Explorer](https://godbolt.org/) is used, with GCC trunk and `-mavx512f -mavx512bw` flags added. Here are further details:
<details>
<summary>
C program used for basic validation of AVX-512 vectorized version for 8-bit operands
</summary>
```
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <immintrin.h>
static void print_m512i_int8(const __m512i* x)
{
int8_t val[64];
memcpy(val, x, sizeof(val));
for (int i = 0; i < 64; ++i) {
if (i > 0)
printf(", ");
printf("%d", (int)val[i]);
}
printf("\n");
}
int main()
{
__m512i a = _mm512_set_epi8(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1);
__m512i b = _mm512_set_epi8(7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6,
5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4,
3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0);
// ------- Copied code from vec512_int.h
// Mask used to set upper 8 bits of each 16-bit value to 0, and keep
// lower 8 bits.
__m512i mask = _mm512_set_epi16(0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff);
// Convert 8-bit operands from lower lanes to 16-bit values, and
// perform vectorized shift. Make sure that upper 8 bits of 16-bit
// results are all 0.
__m256i a_lo_8 = _mm512_extracti64x4_epi64(a, 0);
__m256i b_lo_8 = _mm512_extracti64x4_epi64(b, 0);
__m512i a_lo_16 = _mm512_cvtepi8_epi16(a_lo_8);
__m512i b_lo_16 = _mm512_cvtepi8_epi16(b_lo_8);
__m512i c_lo_16 = _mm512_and_si512(_mm512_sllv_epi16(a_lo_16, b_lo_16), mask);
// Convert 8-bit operands from upper lanes to 16-bit values, and
// perform vectorized shift. Make sure that upper 8 bits of 16-bit
// results are all 0.
__m256i a_hi_8 = _mm512_extracti64x4_epi64(a, 1);
__m256i b_hi_8 = _mm512_extracti64x4_epi64(b, 1);
__m512i a_hi_16 = _mm512_cvtepi8_epi16(a_hi_8);
__m512i b_hi_16 = _mm512_cvtepi8_epi16(b_hi_8);
__m512i c_hi_16 = _mm512_and_si512(_mm512_sllv_epi16(a_hi_16, b_hi_16), mask);
// Cast 16-bit results back into 8-bit values and merge them
// together (using unsigned saturation with higher 8 bits set to 0
// above ensures that results are correct). Values are merged per
// lanes, so this is not yet the final result.
__m512i c_perm = _mm512_packus_epi16(c_lo_16, c_hi_16);
// Permute values so that final result is produced.
__m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
__m512i c = _mm512_permutexvar_epi64(idx, c_perm);
// ------- End copied
print_m512i_int8(&c);
// Expected output: 1(x8), 2(x8), 4(x8), 8(x8), 16(x8), 32(x8), 64(x8), 128(x8), -128(x8)
return 0;
}
```
</details>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88607
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/peterbell10