Fix dense Embedding to work with double backward (#9078)
Summary:
Fixes : #6469
1. `ATen/native/native_functions.yml` had [dispatch](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/native_functions.yaml#L451-L455) variants for for `embedding_dense_backward` , however `embedding_backward` explicitly made [call](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/Embedding.cpp#L35-L45) to it, thus leading to error.
2. In case of CUDA type tensor, the function crashed used to crash on dereferencing of indices's data [pointer](https://github.com/pytorch/pytorch/blob/03e7953a98875c0164cb8e2c19b45800e85f4347/aten/src/ATen/native/Embedding.cpp#L93).
Both have been solved and checked against (on CUDA and CPU)
1. As mentioned in the issue
```
import torch
class Test(torch.nn.Module):
def __init__(self):
super(Test,self).__init__()
self.embd = torch.nn.Embedding(1000, 100)
self.dense = torch.nn.Linear(100, 1)
def forward(self, inp):
inp = self.embd(inp)
return self.dense(inp)
test = Test()
inp = torch.tensor([0,1,2,1,1])
out = test(inp)
raw_loss = out.mean(dim=0)
loss_grad = torch.autograd.grad(outputs=raw_loss,
inputs=list(test.parameters()),
retain_graph=True, create_graph=True, only_inputs=True)
norm = sum([param.norm()**2 for param in loss_grad])
loss = raw_loss + norm
loss.backward(retain_graph=True)
print(test.embd.weight.grad)
```
2. Test Script
```
import torch
import time
start = time.time()
l = [1,1]*100
input = torch.tensor([[1,0],[1,0]],device='cpu')
embedding_matrix = torch.tensor([[1.0,3.0],[2.0,4]],requires_grad=True,device='cpu')
sq = embedding_matrix * embedding_matrix
emb = torch.nn.functional.embedding(input, sq,scale_grad_by_freq=False)
print('Embedding Matrix')
print(embedding_matrix)
print('-----------------')
sum_ = emb.sum()#prod.sum()
loss_grad, = torch.autograd.grad(outputs=sum_,inputs=embedding_matrix,create_graph=True)
print('Gradient')
print(loss_grad)
print('-----------------')
sum2_ = sum_ + loss_grad.sum()
print(sum2_)
sum2_.backward()
print(embedding_matrix.grad)
print(time.time() - start)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9078
Reviewed By: ezyang
Differential Revision: D14691901
Pulled By: soumith
fbshipit-source-id: 78e2612ba39080be564c876311671eb5a0119a0f