pytorch
66939e3b - [acc_tracer] Add test coverage for retracing (#71752)

Commit
2 years ago
[acc_tracer] Add test coverage for retracing (#71752) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71752 Added coverage for reshape specifically which required a fix. The problem for `acc_ops.reshape` as best as I understand: - `torch.reshape` requires the `shape` arg to be a `tuple` of `ints` - If `torch.reshape` is passed a `tuple` where the first element is not an `int`, it throws a TypeError e.g. `TypeError: reshape(): argument 'shape' (position 2) must be tuple of ints, not tuple` - If the `shape` we're reshaping to is an FX Proxy then this type error will be thrown. This happens when the first element of the `shape` tuple is a Proxy because it's input-dependent. - As a workaround we use `tensor.reshape` instead of `torch.reshape`, which doesn't do equivalent type checking for a `tuple` of `ints`. Also remove unnecessary `acc_utils.get_field_from_acc_out_ty()` with cast to `TensorMetadata`. Test Plan: Added test coverage Reviewed By: yinghai Differential Revision: D33760455 fbshipit-source-id: bff5563bf9e3d9e9318901b56211151d2c0e4eb2 (cherry picked from commit d5c1b9732a208dd305a3215920f1ea23e2f327f7)
Author
Committer
Parents
Loading