pytorch
d88bc38b - [functorch] fix batching rule for dropout (#92975)

Commit
2 years ago
[functorch] fix batching rule for dropout (#92975) Fixes https://github.com/pytorch/pytorch/issues/92283 The repro now works: ```python import torch import torch.func import torch.nn as nn x = torch.randn(3, device='cuda') y = torch.randn(1, 3, device='cuda') def fn(x, y): # previously output of dropout used to be incorrect [B, 3] (B=1) and thus `mean(1)` used to fail # post the fix output of dropout is [B, 1, 3] and `mean(1)` works. return x + nn.functional.dropout(y, 0.3).mean(1) o = torch.func.vmap(fn, in_dims=(0, None), randomness='different')(x, y) ``` **NOTE**: `native_dropout_batching_rule(const Tensor& tensor, double p, c10::optional<bool> train)` was called only for CUDA tensor. Hence this issue only affected CUDA tensors and not CPU tensors Ref: https://github.com/pytorch/pytorch/blob/a6ac922eabee8fce7a48dedac81e82ac8cfe9a45/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp#L251-L258 Pull Request resolved: https://github.com/pytorch/pytorch/pull/92975 Approved by: https://github.com/Chillee, https://github.com/Skylion007
Author
Committer
Parents
Loading