pytorch
fdcb203e - Identify weights and bias by argument position in aten call (#29147)

Commit
5 years ago
Identify weights and bias by argument position in aten call (#29147) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29147 Previously we use a vector of weight and bias to record the values of weight/bias and we assume we'll get them by GetAttr nodes, then we propagate these values through the function calls However, it doesn't work if we also do some transformations on these values right now, we'll need to mark all the values that's produced by weight/bias as weight/bias, e.g. ``` %w = GetAttr[name="weight"](%conv) %wt = aten::transpose(%w) %r = aten::conv2d(..., %wt, ...) ``` we'll mark both %w and %wt as weight. This is a bit over compilicated to support this. Alternatively, we can identify weights by argument positions, e.g. for call %r = aten::conv2d(..., %w, ...), we know the argument 1 is weight, argument 2 is bias. Test Plan: test_jit.py Imported from OSS Differential Revision: D18362839 fbshipit-source-id: afbf07f48bab8d01c5be1c882561a0255730a6b9
Author
Parents
Loading