add qr_backward functionality for wide case (#42216)
Summary:
Unblocks implementation of https://github.com/pytorch/pytorch/issues/27036. Note that this PR ***does not*** fix #{27036}.
Currently QR decomposition only has support for square and tall (a.k.a. skinny) case.
This PR adds functionality for wide A matrix/tensors, includes 3 unit tests for the new case
and restructures the `qr_backward` method to use the same Walther method as a helper.
cc albanD t-vi
I don't have a gpu machine so haven't tested on cuda but everything passes on my local machine in cpu.
The basic idea of the PR is noted in the comments in the `Functions.cpp` file but I'll note here too for clarity:
let <img src="https://render.githubusercontent.com/render/math?math=A_{m,n}"> be a matrix and <img src="https://render.githubusercontent.com/render/math?math=m < n"> then partition <img src="https://render.githubusercontent.com/render/math?math=A_{m, n}"> as <img src="https://render.githubusercontent.com/render/math?math=A_{m,n} = [ X_{m,m} |\ Y_{m, n-m} ]">
and take QR of <img src="https://render.githubusercontent.com/render/math?math=X"> and call that one
<img src="https://render.githubusercontent.com/render/math?math=X=QU"> the <img src="https://render.githubusercontent.com/render/math?math=Q"> here from <img src="https://render.githubusercontent.com/render/math?math=X"> is the same as the <img src="https://render.githubusercontent.com/render/math?math=Q"> from <img src="https://render.githubusercontent.com/render/math?math=QR"> on entire <img src="https://render.githubusercontent.com/render/math?math=A"> matrix. Then transform <img src="https://render.githubusercontent.com/render/math?math=Y"> with the <img src="https://render.githubusercontent.com/render/math?math=Q"> rotation got from <img src="https://render.githubusercontent.com/render/math?math=X"> to get <img src="https://render.githubusercontent.com/render/math?math=V=Q^{T}Y"> now <img src="https://render.githubusercontent.com/render/math?math=R= [U |\ V] "> and similarly for the grads of each piece, e.g. if <img src="https://render.githubusercontent.com/render/math?math=\bar{A}"> is `grad_A` then
<img src="https://render.githubusercontent.com/render/math?math=\bar{A} = [ \bar{X} |\ \bar{Y}]"> and <img src="https://render.githubusercontent.com/render/math?math=\bar{R} = [ \bar{U} |\ \bar{V}]"> and then
<img src="https://render.githubusercontent.com/render/math?math=\bar{Y} = Q\bar{V}"> and
<img src="https://render.githubusercontent.com/render/math?math=\bar{V}"> is the `narrow()` of `grad_R`.
<img src="https://render.githubusercontent.com/render/math?math=\bar{X}"> is calculated very similar to the original Walther formula (exactly the same in the tall and square cases) but is slightly modified here for wide case matrices.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42216
Reviewed By: glaringlee
Differential Revision: D23373118
Pulled By: albanD
fbshipit-source-id: 3702ba7e7e23923868c02cdb7e10a96036052344