fix calculate_extra_mem_bytes_needed_for (#48102)
Summary:
This PR fixes a bug in calculate_extra_mem_bytes_needed_for in get_device_to_partitions_mapping
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48102
Reviewed By: gcatron
Differential Revision: D25029059
Pulled By: scottxu0730
fbshipit-source-id: 7447b70e8da96b3dc2c5922cf9b62eb306877317