Weighted decay with frequency (count-based) (#60382)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60382
Instead of setting weight_decay w uniformly for all ids, for each row i in the sparse embedding table, the actual weight_decay `w_i` becomes `w*freq_i` where `freq_i = halflife/counter_i \in [\log(2), halflife]`. Counter is from `rowwise_counter` with definition `counter_i = 1 + \exp(-iter_{\delta}*\rho)*counter_i`.
Test Plan:
buck test //caffe2/caffe2/python/operator_test:adagrad_test -- test_row_wise_sparse_adagrad
buck test caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test_weight_decay
Reviewed By: 0x10cxR1
Differential Revision: D25581030
fbshipit-source-id: 54b3831b20516c76c559b13d8deb809e2ee3b446