inference: Add a hook for users to be able to specify custom recursion relations
When infering recursive functions, we try to detect cases where
the recursion can be shown to be terminating and in such cases
allow significantly more inference among the recursive functions
than in cases where we cannot prove termination. The rationale for
this is to allow inference of functions that are recursive over
structures, while not spending useless (or infinite) compile time
chasing down recursions that build up infinitely large types (those
get expensive *fast*). Our built-in recursion relation here allows
recursion to proceed if argument types are syntactic subsets, if
they are decreasing integers, tuples of decreasing length, and
a few other cases. It is worth noting that similar considerations
come in in significantly more static system, where non-termination
is (statically) disallowed. E.g. in Coq, for non-syntactic recursion,
a proof needs to be provided that a recursive function does indeed
terminate, before one is allowed to define it [1]. More modern languages
like Dafny allow the user to specify a predicate over the incoming
values that is shown to decrease over subsequent function calls [2].
My motivation comes from Diffractor, where we get call chains like:
∂⃖{1}(sin'', ...) -> ∂⃖{2}(sin', ...) -> ∂⃖{3}(sin, ...) -> ∂⃖{2}(rrule, sin, ...) ->
∂⃖{1}(rrule, sin)
In this example, the first two calls are both the same method, as are the
last two. Unfortunately, particularly for the first two calls, there isn't
really a good way to express the recursion rule here in a way that is
generic.
Thus, to address these cases, this PR adds a per-method field, similar
to a generator that allows packages to provide arbitrary recursion
relations that take advantage of the (known) special semantics of those
methods to expand the allowed set of recursions.
Originally I had hoped to use this hook in place of the existing
`type_more_complex` check. However, our code currently requires
transitivity of the `type_more_complex` check for soundness of
the termination analysis.
This runs into problems in the specified use case, because we may
have interleaved chains of calls, that are both the same method,
but are not actually part of a cycle as such because their ultimate
underlying methods are different (in particular this happens when
chaining two-Cassette like generated functions). We do not
currently express enough about the semantics of these Cassette-like
methods in order for inference to reasonably compute whether
two instances are part of the same cycle or not (we have
`method_for_inference_heuristics` of course, which takes
care of one level of this, but does not take care of the nested
case).
By having this hook, but not requring transitivity, it is legal
for the hook to compute whether the ultimate underlying method
is the same (by using its knowledge of what the methods actually
do) and answering accordingly. In the long run, I would like to
bring these Cassette-like capabilities more closely into the
compiler, at which point inference itself may have enough
information to compute the cycles and we'd be able to get away
with requiring transitivity.
All that said, this mechanism is quite simple and achieves its
goal. I don't think it is particularly pretty and should
definitely be considered unstable. I'm not providing any
user-facing APIs for this, so those in the know will have to
manually poke the methods. I do think a more general
language-level framework for proving termination could be
useful, particularly as part of more rigurous definitions
of when various constant propagation happens, but this is
not that, yet.
I've tried this in Diffractor and with appropriate definitions
of the recursion relation for the relevant functions, Diffractor
becomes nicely inferable:
```
julia> using Diffractor: var"'", ∂⃖
julia> Base.return_types(sin''', Tuple{Float64})
1-element Vector{Any}:
Float64
julia> Base.return_types(sin'''', Tuple{Float64})
1-element Vector{Any}:
Float64
julia> Base.return_types(sin''''', Tuple{Float64})
1-element Vector{Any}:
Float64
julia> Base.return_types(sin'''''', Tuple{Float64})
1-element Vector{Any}:
Float64
```
Diffractor's Phase 1 design goal was to infer fine at 3rd and
4th order - which this meets. The fact that it also infers
at higher orders is nice, but inference times also increase
to impractical levels for real-world functions, e.g. 5th
order above takes a few seconds to infer just `sin` and
6th order takes about 20s or so. Of course that is still
better than Zyogte, even at second order, but fixing this
properly will be part of Diffractor Phase 2.
With this (plus some additional tweaks to constprop
heuristics for OpaqueClosure that I'll be putting up
separately), we do also generate very nice code:
```
julia> @code_typed sin'''(1.0)
CodeInfo(
1 ─ %1 = invoke ChainRules.sincos(_2::Float64)::Tuple{Float64, Float64}
│ %2 = Base.getfield(%1, 1)::Float64
│ %3 = Base.getfield(%1, 2)::Float64
│ %4 = Diffractor.getfield(%1, 1)::Float64
│ %5 = Diffractor.getfield(%1, 2)::Float64
│ %6 = Diffractor.getfield(%1, 2)::Float64
│ %7 = Base.mul_float(%6, 1.0)::Float64
│ %8 = Base.mul_float(0.0, %7)::Float64
│ %9 = Base.neg_float(%4)::Float64
│ %10 = Base.mul_float(%9, 1.0)::Float64
│ %11 = Base.mul_float(1.0, %8)::Float64
│ %12 = Base.mul_float(%5, 1.0)::Float64
│ %13 = Base.mul_float(%12, %7)::Float64
│ %14 = Base.mul_float(0.0, %12)::Float64
│ %15 = Base.mul_float(0.0, %13)::Float64
│ %16 = Base.mul_float(%14, 1.0)::Float64
│ %17 = Base.mul_float(%6, %14)::Float64
│ %18 = Base.mul_float(%10, 1.0)::Float64
│ %19 = Base.mul_float(1.0, %10)::Float64
│ %20 = Base.add_float(%17, %18)::Float64
│ %21 = Base.mul_float(0.0, %20)::Float64
│ %22 = Base.mul_float(%21, 1.0)::Float64
│ %23 = Base.mul_float(%6, %21)::Float64
│ %24 = Base.add_float(%22, %16)::Float64
│ %25 = Base.add_float(%23, %19)::Float64
│ %26 = Base.mul_float(0.0, %25)::Float64
│ %27 = Base.add_float(%15, %26)::Float64
│ %28 = Base.add_float(%24, %11)::Float64
│ %29 = Base.add_float(%27, -1.0)::Float64
│ %30 = Base.neg_float(%2)::Float64
│ %31 = Base.mul_float(%3, %29)::Float64
│ %32 = Base.muladd_float(%30, %28, %31)::Float64
└── return %32
) => Float64
```
[1] http://adam.chlipala.net/cpdt/html/GeneralRec.html
[2] https://rise4fun.com/Dafny/tutorial/Termination