julia
73622edd - inference: Add a hook for users to be able to specify custom recursion relations

Commit
4 years ago
inference: Add a hook for users to be able to specify custom recursion relations When infering recursive functions, we try to detect cases where the recursion can be shown to be terminating and in such cases allow significantly more inference among the recursive functions than in cases where we cannot prove termination. The rationale for this is to allow inference of functions that are recursive over structures, while not spending useless (or infinite) compile time chasing down recursions that build up infinitely large types (those get expensive *fast*). Our built-in recursion relation here allows recursion to proceed if argument types are syntactic subsets, if they are decreasing integers, tuples of decreasing length, and a few other cases. It is worth noting that similar considerations come in in significantly more static system, where non-termination is (statically) disallowed. E.g. in Coq, for non-syntactic recursion, a proof needs to be provided that a recursive function does indeed terminate, before one is allowed to define it [1]. More modern languages like Dafny allow the user to specify a predicate over the incoming values that is shown to decrease over subsequent function calls [2]. My motivation comes from Diffractor, where we get call chains like: ∂⃖{1}(sin'', ...) -> ∂⃖{2}(sin', ...) -> ∂⃖{3}(sin, ...) -> ∂⃖{2}(rrule, sin, ...) -> ∂⃖{1}(rrule, sin) In this example, the first two calls are both the same method, as are the last two. Unfortunately, particularly for the first two calls, there isn't really a good way to express the recursion rule here in a way that is generic. Thus, to address these cases, this PR adds a per-method field, similar to a generator that allows packages to provide arbitrary recursion relations that take advantage of the (known) special semantics of those methods to expand the allowed set of recursions. Originally I had hoped to use this hook in place of the existing `type_more_complex` check. However, our code currently requires transitivity of the `type_more_complex` check for soundness of the termination analysis. This runs into problems in the specified use case, because we may have interleaved chains of calls, that are both the same method, but are not actually part of a cycle as such because their ultimate underlying methods are different (in particular this happens when chaining two-Cassette like generated functions). We do not currently express enough about the semantics of these Cassette-like methods in order for inference to reasonably compute whether two instances are part of the same cycle or not (we have `method_for_inference_heuristics` of course, which takes care of one level of this, but does not take care of the nested case). By having this hook, but not requring transitivity, it is legal for the hook to compute whether the ultimate underlying method is the same (by using its knowledge of what the methods actually do) and answering accordingly. In the long run, I would like to bring these Cassette-like capabilities more closely into the compiler, at which point inference itself may have enough information to compute the cycles and we'd be able to get away with requiring transitivity. All that said, this mechanism is quite simple and achieves its goal. I don't think it is particularly pretty and should definitely be considered unstable. I'm not providing any user-facing APIs for this, so those in the know will have to manually poke the methods. I do think a more general language-level framework for proving termination could be useful, particularly as part of more rigurous definitions of when various constant propagation happens, but this is not that, yet. I've tried this in Diffractor and with appropriate definitions of the recursion relation for the relevant functions, Diffractor becomes nicely inferable: ``` julia> using Diffractor: var"'", ∂⃖ julia> Base.return_types(sin''', Tuple{Float64}) 1-element Vector{Any}: Float64 julia> Base.return_types(sin'''', Tuple{Float64}) 1-element Vector{Any}: Float64 julia> Base.return_types(sin''''', Tuple{Float64}) 1-element Vector{Any}: Float64 julia> Base.return_types(sin'''''', Tuple{Float64}) 1-element Vector{Any}: Float64 ``` Diffractor's Phase 1 design goal was to infer fine at 3rd and 4th order - which this meets. The fact that it also infers at higher orders is nice, but inference times also increase to impractical levels for real-world functions, e.g. 5th order above takes a few seconds to infer just `sin` and 6th order takes about 20s or so. Of course that is still better than Zyogte, even at second order, but fixing this properly will be part of Diffractor Phase 2. With this (plus some additional tweaks to constprop heuristics for OpaqueClosure that I'll be putting up separately), we do also generate very nice code: ``` julia> @code_typed sin'''(1.0) CodeInfo( 1 ─ %1 = invoke ChainRules.sincos(_2::Float64)::Tuple{Float64, Float64} │ %2 = Base.getfield(%1, 1)::Float64 │ %3 = Base.getfield(%1, 2)::Float64 │ %4 = Diffractor.getfield(%1, 1)::Float64 │ %5 = Diffractor.getfield(%1, 2)::Float64 │ %6 = Diffractor.getfield(%1, 2)::Float64 │ %7 = Base.mul_float(%6, 1.0)::Float64 │ %8 = Base.mul_float(0.0, %7)::Float64 │ %9 = Base.neg_float(%4)::Float64 │ %10 = Base.mul_float(%9, 1.0)::Float64 │ %11 = Base.mul_float(1.0, %8)::Float64 │ %12 = Base.mul_float(%5, 1.0)::Float64 │ %13 = Base.mul_float(%12, %7)::Float64 │ %14 = Base.mul_float(0.0, %12)::Float64 │ %15 = Base.mul_float(0.0, %13)::Float64 │ %16 = Base.mul_float(%14, 1.0)::Float64 │ %17 = Base.mul_float(%6, %14)::Float64 │ %18 = Base.mul_float(%10, 1.0)::Float64 │ %19 = Base.mul_float(1.0, %10)::Float64 │ %20 = Base.add_float(%17, %18)::Float64 │ %21 = Base.mul_float(0.0, %20)::Float64 │ %22 = Base.mul_float(%21, 1.0)::Float64 │ %23 = Base.mul_float(%6, %21)::Float64 │ %24 = Base.add_float(%22, %16)::Float64 │ %25 = Base.add_float(%23, %19)::Float64 │ %26 = Base.mul_float(0.0, %25)::Float64 │ %27 = Base.add_float(%15, %26)::Float64 │ %28 = Base.add_float(%24, %11)::Float64 │ %29 = Base.add_float(%27, -1.0)::Float64 │ %30 = Base.neg_float(%2)::Float64 │ %31 = Base.mul_float(%3, %29)::Float64 │ %32 = Base.muladd_float(%30, %28, %31)::Float64 └── return %32 ) => Float64 ``` [1] http://adam.chlipala.net/cpdt/html/GeneralRec.html [2] https://rise4fun.com/Dafny/tutorial/Termination
Author
Committer
Parents
Loading