Address case when layout of tangent is not same as base (#66292)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66292
In this PR:
1. Fix the case when tangent has a different layout from the base when `set_fw_grad` by adding a native function and its batching rule.
For (1) we replace the following:
```
Tensor new_with_same_meta(const Variable& base) {
int64_t nelement_in_storage = base.storage().nbytes() / base.itemsize();
auto new_tensor = at::zeros({nelement_in_storage}, base.options());
auto res = new_tensor.as_strided(base.sizes(), base.strides(), base.storage_offset());
return res;
}
```
with a native function as to enable a batching rule to alter its behavior.
This new function will be similar to `new_zeros_strided` except we also require the `storage_offset` and `storage_numel` arguments.
Possible concerns:
- Why have redundant logic? Why not add new args `new_zeros_strided`? This is probably a niche use case, so it's better not to complicate the current API.
- Previously the created tensor inherits the TensorOptions of the primal. Now we inherit from the TensorOptions of the tangent.
- Probably fine. Likely, no one relies on this because the behavior is only triggered when tangent/base have different layouts.
- Why pass in exploded size, stride, and offset? It is possible in the non-batched case to pass in a tensor directly, but not possible when we'd like to have a batching rule. The size, stride, and offset we'd be passing won't belong to any live tensor.
Test Plan: Imported from OSS
Reviewed By: zou3519, albanD
Differential Revision: D31842019
Pulled By: soulitzer
fbshipit-source-id: a58433d814fd173bc43a2c550b395377dba40de2