pytorch
cc940f35 - ns for fx: change dtype cast from once per N node to once per node

Commit
4 years ago
ns for fx: change dtype cast from once per N node to once per node Summary: This PR ensures that when we do a dtype cast for a shadow module, we insert N dtype casts for N nodes, instead of combining N nodes into a single dtype cast. An example where this occurs is `cat([x, y], dim=0)` ``` // original graph [x, y] -> cat_b -> output // shadow graph with a single dtype cast, before this PR dtype_cast -> cat_a_shadow -> output_a_shadow / [x, y] -> cat_b -> output_b // shadow graph with multiple dtype casts, after this PR [dtype_cast_x, dtype_cast_y] -> cat_a_shadow -> output_a_shadow / [x, y] -> cat_b -> output_b ``` The reason things worked before this PR is because `torch.dequantize` can take either a single tensor or a list of tensors. We are changing this to make an upcoming addition of input loggers easier. Test Plan: ``` python test/test_quantization.py TestFXNumericSuiteCoreAPIs.test_prepare_model_with_stubs_multiple_dtype_casts ``` Imported from OSS Differential Revision: D26931226 Reviewed By: hx89 Pulled By: vkuzo fbshipit-source-id: e9c7d4c7942e0f59c952094d2e446b1e2c838396
Author
Parents
Loading