Make nn functions configurable for different scalar types (#20729)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20729
Currently there is no way to specify what scalar types each nn function will support.
This change will allow to specify supported scalar types for each function/backward function and device. By default each function will support Float, Double, Half.
If you want to scpecify any extra supported scalar types, other then default, you will need to change nn.yalm:
- name: _some_func(Tensor self)
cname: SomeFunction
CPU:
forward_scalar_types: ['Float', 'Double', 'Long']
backward_scalar_types: ['Float', 'Double']
Differential Revision: D15423752
fbshipit-source-id: b3c157316d6e629bc39c1b377a3b23c71b1656cf