swift
6335b991 - [AutoDiff] Update `TypeBase::getAutoDiffAssociatedVectorSpace` to handle function types. (#22166)

Commit
6 years ago
[AutoDiff] Update `TypeBase::getAutoDiffAssociatedVectorSpace` to handle function types. (#22166) Formally, when a type `T where T : Differentiable` gets abstracted as a function `(X...) -> T` for any `X...`, the differentiability of the abstracted type depends entirely on the differentiability of `T`. Since structural types cannot conform to protocols yet in Swift, we need to handle in AD-associated type calculation the same way we handle tuples. The type calculation rules are better described as code, in imaginary syntax where parameterized extensions, variadic generic parameters, and protocol conformances for structural types are supported. ```swift extension<T..., U> ((T...) -> U) : Differentiable where U : Differentiable { public typealias TangentVector = (T...) -> U.TangentVector public typealias CotangentVector = (T...) -> U.CotangentVector public func moved(along direction: TangentVector) -> (T...) -> U { return { (x...) in self(x...).moved(along: direction(x...)) } } public func tangentVector(from cotangent: CotangentVector) -> TangentVector { return { (x...) in self(x...).tangentVector(from: cotangent(x...)) } } } ``` This is a crucial step towards the correct typing of curried differentiable functions, which helps us differentiate through curry thunks for methods. ```swift func curry<T : Differentiable, U : Differentiable>( f: @autodiff (T, U) -> V ) -> @autodiff (T) -> @autodiff (U) -> V { return { x in { y in f(x, y) } } } ``` Partially resolves [SR-9448](https://bugs.swift.org/browse/SR-9448), which needs this patch to be able to calculate the associate vector space of a curried function.
Author
Parents
Loading