pytorch
89b371bc - [quant] Add support for 2D indices for quantized embedding operators (#47766)

Commit
4 years ago
[quant] Add support for 2D indices for quantized embedding operators (#47766) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47766 The operator now supports accepting 2D indices as inputs. For embedding operators, we set the default offsets in the op since the FBGEMM kernel expects it to be set Output shape depends on the shape if the indices. For embedding_bag operator, if indices is 2D (B, N) then offsets should be set to None by user. In this case the input is interpreted as B bags each of fixed length N. Output shape is still 2-D in this case. Test Plan: python test/test_quantization.py TestQuantizedEmbeddingOps.test_embedding_bag_2d_indices python test/test_quantization.py TestQuantizedEmbeddingOps.test_embedding_2d_indices Imported from OSS Reviewed By: jerryzh168 Differential Revision: D24895048 fbshipit-source-id: 2020910e1d85ed8673eedee2e504611ba260d801
Author
Parents
Loading