Add support for attention weight in SparseLookup (#26748)
Summary:
Support attention weights input to SparseLookup. In attention sum pooling, if attention weights can be pre-calculated before embedding lookup, they can be passed to SparseLookup and processed by SparseLengthsWeightedSum op. One example is id_score attention sum pooling.
Essentially the net is converted from:
LengthsSum(Mul(Gather(keys, w), att_weight))
to:
SpaseLenghtsWeightedSum(keys, w, att_weight)
It unblocks potential efficiency gain with distributed training.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26748
Test Plan: unit test
Reviewed By: chocjy
Differential Revision: D17553345
Pulled By: wheatkit
fbshipit-source-id: 60cc3c4b0bc1eade5459ac598e85286f3849a412