Distributed Reduction (#18206)
This PR implements distributed reduciton for llama 2. This version
doesn't consider any cases requring re-sharding because we haven't seen
any use cases.
Intutive examples:
- [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] ->
Reduce(axes=[0]) -> [1,4,6]-tensor with spec=RRS[0] and
device_mesh=[0,1]
- [supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1] ->
Reduce(axes=[1]) -> [2,1,6]-tensor with spec=RRS[0] and
device_mesh=[0,1]
- [not supported] [2,4,6]-tensor with spec=RRS[0] and device_mesh=[0,1]
-> Reduce(axes=[2]) -> [2,4,1]-tensor with spec=RRS[0] and
device_mesh=[0,1]
Algorithm:
When the reduced axes are not sharded, each device can call reduction
directly. The output sharding spec will be identical to input sharding
spec. We currently throw when input and output sharding specs are
different.
Review guideline:
- Check 97b8d2f for new op's schema and how new op is registered.
- Read tests in 2450f93 to get faimilar with the behavior of these ops.
- Check the implementation details in 753d9af.