pytorch
2455cc2a - Address case when layout of tangent is not same as base (#66292)

Commit
2 years ago
Address case when layout of tangent is not same as base (#66292) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66292 In this PR: 1. Fix the case when tangent has a different layout from the base when `set_fw_grad` by adding a native function and its batching rule. For (1) we replace the following: ``` Tensor new_with_same_meta(const Variable& base) { int64_t nelement_in_storage = base.storage().nbytes() / base.itemsize(); auto new_tensor = at::zeros({nelement_in_storage}, base.options()); auto res = new_tensor.as_strided(base.sizes(), base.strides(), base.storage_offset()); return res; } ``` with a native function as to enable a batching rule to alter its behavior. This new function will be similar to `new_zeros_strided` except we also require the `storage_offset` and `storage_numel` arguments. Possible concerns: - Why have redundant logic? Why not add new args `new_zeros_strided`? This is probably a niche use case, so it's better not to complicate the current API. - Previously the created tensor inherits the TensorOptions of the primal. Now we inherit from the TensorOptions of the tangent. - Probably fine. Likely, no one relies on this because the behavior is only triggered when tangent/base have different layouts. - Why pass in exploded size, stride, and offset? It is possible in the non-batched case to pass in a tensor directly, but not possible when we'd like to have a batching rule. The size, stride, and offset we'd be passing won't belong to any live tensor. Test Plan: Imported from OSS Reviewed By: zou3519, albanD Differential Revision: D31842019 Pulled By: soulitzer fbshipit-source-id: a58433d814fd173bc43a2c550b395377dba40de2
Author
Parents
Loading