swift
0641a22b - [AutoDiff] Fix 'autodiff_function_extract' operand ownership kind and memory leaks. (#27199)

Commit
5 years ago
[AutoDiff] Fix 'autodiff_function_extract' operand ownership kind and memory leaks. (#27199) The `autodiff_function_extract` instruction behaves like `tuple_extract`, where it extracts some element from an aggregate. Its operand should have the same ownership kind as that of `tuple_extract`. That is, it should be defined as `CONSTANT_OR_TRIVIAL_OWNERSHIP_INST(Guaranteed, AutoDiffFunctionInst)` in ValueOwnershipKindClassifier. However, this is currently defined wrongly as `FORWARDING_OWNERSHIP_INST(AutoDiffFunctionExtract)`, which caused a bug in the differentiation transform to be uncaught: VJPEmitter and JVPEmitter in the differentiation transform is performing `autodiff_function_extract` on an `@owned` `@differentiable` function, which caused associated functions that are not extracted to be not released: ``` %f = autodiff_function %original %f_vjp = autodiff_function_extract [vjp] %f ... // %f is not released, and not caught by ownership verification! ``` After we fix the operand ownership kind for `autodiff_function_extract`, all these cases are now caught by ownership verification. The reproducer in [TF-795](https://bugs.swift.org/browse/TF-795) and most differentiation tests are failing to compile because ownership verification caught the bug in AD-generated code. The existing AD test suite is serving as good test cases for this ownership error. To fix this, VJPEmitter and JVPEmitter are now changed to emit borrows of `@differentiable` functions and copies of associated functions and property destroying the `@differentiable` function: ``` %f = autodiff_function %original %f_borrowed = begin_borrow %f %f_vjp_extracted = autodiff_function_extract [vjp] %f_borrowed %f_vjp = copy_value %f_vjp_extracted end_borrow %f_borrowed destroy_value %f ``` Fixes [TF-795](https://bugs.swift.org/browse/TF-795).
Author
Parents
  • lib
    • SIL
      • File
        ValueOwnership.cpp
    • SILOptimizer/Mandatory
      • File
        Differentiation.cpp