xla
d7f680c0 - [Functionalization] Fix SPMD tests (#4494)

Commit
2 years ago
[Functionalization] Fix SPMD tests (#4494) Summary: This pull request tries to fix three SPMD tests that fails with Functionalization. 1. test_model_weight_metrics: Before Functionalization, model.weight is an at::tensor stored in the XLATensor, but it's an IRValue now. Therefore, the test no longer holds true. 2. test_optimizer_step_with_sharding/test_inplace_add_with_sharding: we propagate the sharding spec via _propagate_xla_data which is an operator that help us during in-place op functionalized pass. Test Plan: CI.
Author
Committer
Parents
Loading