[Pytorch] General broadcast for arithmetic operators (#104718)
Summary:
Currently, broadcast is supported for 4D tensors where, if the batch or channel dimensions are not equal, then the batch and channel of one tensor must both be 1, ie:
```
tensorA NCHW:
5, 2, 3, 3
tensorB NCHW:
1, 1, 3, 3 --> batch=1, channel=1
```
This diff adds broadcast support for 4D tensors where the batch and channel of a tensor are different, ie:
```
tensorA NCHW:
5, 1, 3, 3
tensorB NCHW:
1, 5, 3, 3
```
Broadcast rules:
```
- tensorA.dim()[x] = tensorB.dim()[x]
- tensorA.dim()[x] == 1 || tensorB.dim()[x] == 1
- tensorA.dim()[x] does not exist || tensorB.dim()[x] does not exist
```
Broadcast method:
1. Pass `output`, `input` and `other` tensors to the shader
2. Iterate through the output texture to calculate the value of each texel (no repeating)
3. Mapping NHW positions: use modulo
4. Mapping C position: divide pos.z by ceil(C/4) to map to original tensor range
---
Also some test refactoring to reduce repeated setup code.
Test Plan:
New tests:
Add
```
[ RUN ] VulkanAPITest.add_broadcast5
[ OK ] VulkanAPITest.add_broadcast5 (0 ms)
[ RUN ] VulkanAPITest.add_broadcast6
[ OK ] VulkanAPITest.add_broadcast6 (0 ms)
```
Sub
```
[ RUN ] VulkanAPITest.sub_broadcast5
[ OK ] VulkanAPITest.sub_broadcast5 (0 ms)
[ RUN ] VulkanAPITest.sub_broadcast6
[ OK ] VulkanAPITest.sub_broadcast6 (0 ms)
```
Mul
```
[ RUN ] VulkanAPITest.mul_broadcast5
[ OK ] VulkanAPITest.mul_broadcast5 (1 ms)
[ RUN ] VulkanAPITest.mul_broadcast6
[ OK ] VulkanAPITest.mul_broadcast6 (1 ms)
```
Div
```
[ RUN ] VulkanAPITest.div_broadcast5
[ OK ] VulkanAPITest.div_broadcast5 (1 ms)
[ RUN ] VulkanAPITest.div_broadcast6
[ OK ] VulkanAPITest.div_broadcast6 (2 ms)
```
All tests:
https://www.internalfb.com/phabricator/paste/view/P781794761
Run clang-format on glsl files and Arithmetic.cpp
Differential Revision: D46874508
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104718
Approved by: https://github.com/SS-JIA