Allow None to pass through for vmap (#65565)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65565
Does jax allow this?
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D31236258
Pulled By: soulitzer
fbshipit-source-id: 80460b355fc32ecbba8151e1f3179f076a927f9d