[NNC] cacheAccesses transform (cache_reads + cache_writes) (#45869)
Summary:
Adds a new transform to the NNC compiler, which adds support for buffer access caching. All accesses within a provided scope are redirected to a cache which is initialized or written back as necessary at the boundaries of that scope. For TVM fans, this is essentially a combination of cache_reads and cache_writes. E.g. it can do this kind of thing:
Before:
```
for (int i = 0; i < 64; i++) {
for (int j = 0; j < 64; j++) {
A[i, j] = i * j;
}
}
for (int i_1 = 0; i_1 < 20; i_1++) {
for (int j_1 = 0; j_1 < 10; j_1++) {
B[i_1, j_1] = (A(i_1 + 30, j_1 + 40)) + (A(i_1 + 31, j_1 + 41));
}
```
After `cacheAccesses(A->buf(), "A_local", j_loop);`
```
for (int i = 0; i < 64; i++) {
for (int j = 0; j < 64; j++) {
A[i, j] = i * j;
}
}
for (int i_1 = 0; i_1 < 20; i_1++) {
for (int i_2 = 0; i_2 < 2; i_2++) {
for (int j_1 = 0; j_1 < 11; j_1++) {
A_local[i_2, j_1] = A[(i_2 + i_1) + 30, j_1 + 40];
}
}
for (int j_2 = 0; j_2 < 10; j_2++) {
B[i_1, j_2] = (A_local[1, j_2 + 1]) + (A_local[0, j_2]);
}
}
```
Or this reduction:
```
for (int l1 = 0; l1 < 4; l1++) {
sum[l1] = 0.f;
for (int n1_1 = 0; n1_1 < 3; n1_1++) {
for (int m1_1 = 0; m1_1 < 2; m1_1++) {
sum[l1] = (sum[l1]) + (scale[(6 * l1 + 2 * n1_1) + m1_1]);
}
}
}
```
After `l.cacheAccesses(d->buf(), "d_local", n_loop);`:
```
for (int l1 = 0; l1 < 4; l1++) {
Allocate(d_local, float, {1});
sum[l1] = 0.f;
d_local[0] = 0.f;
for (int n1_1 = 0; n1_1 < 3; n1_1++) {
for (int m1_1 = 0; m1_1 < 2; m1_1++) {
d_local[0] = (d_local[0]) + (scale[(6 * l1 + 2 * n1_1) + m1_1]);
}
}
sum[l1] = (sum[l1]) + (d_local[0]);
Free(d_local);
}
```
I had originally planned to write `cacheReads` and `cacheWrites` wrappers so we could use them just like their TVM cousins, but they just ended up being big masses of checking that reads or writes weren't present. Didn't feel too useful so I removed them, but let me know.
This is based on bounds inference and inherits a few bugs present in that functionality, which I will address in a followup.
While working on this I realized that it overlaps heavily with `computeAt`: which is really just `cacheReads` + `computeInline`. I'm considering refactoring computeAt to be a wrapper around those two transforms. ZolotukhinM opinions on this?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45869
Reviewed By: mruberry
Differential Revision: D24195276
Pulled By: nickgg
fbshipit-source-id: 36a58ae265f346903187ebc4923637b628048155