Add should_traverse_fn to torch.fx.node.map_aggregate (#81510)
Adds an optional callback that checks if map_aggregate should continue recursive traversal. The main motivation is to not traverse torch.Size which is tuple
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81510
Approved by: https://github.com/SherlockNoMad, https://github.com/jamesr66a