Support 0-batch size for nn.Linear. (#27211)
Summary:
At the current moment of time nn.Linear (an it's interal functional code), will
fail in THBlas:
RuntimeError: invalid argument 8: lda should be at least max(1, 0), but have 0 at caffe2/aten/src/TH/generic/THBlas.cpp:363
This diff is trying to fix this bug.
As of now I was able to identify 2 possible places where changes needs to be done based on current dispatcher logic:
1. The file touched in this diff
2. caffe2/aten/src/THC/generic/THCTensorMathBlas.cu
At the moment I didn't find a better places comparing to injecting logic to those files:
the only non-generated function for forward pass, this + mm_mat2_backward function family on a backward pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27211
Test Plan: New unit-tests are passing. Code that was failing earlier works. Need to test other backends.
Differential Revision: D17599915
Pulled By: kennyhorror
fbshipit-source-id: 78894ce602d96aac2d6bf8c16a3fab43973e2d53