[functorch] Specialization cache (pytorch/functorch#99)
Benchmark:
https://gist.github.com/zou3519/f7691a94f8570b27cccc8e16fc8ed13b
It doesn't look like this adds a lot of overhead. If the overhead
becomes a problem we can move this into C++.
The cache works by specializing on shape/stride/dtype/device of the
input tensor and any concrete values.
NB: the concrete value cache means that if an integer arg to the
function changes, we will recompile. In the future, when we add
static_argnums, we should change the cache to "not specialize on
specific integers".
Test Plan:
- run tests