Passing LinearPackedParamBase Capsule as a saved_data to backward stage (#96269)
Summary:
Initial implementation was unpacking for original weight in custom furward function which will double weight tensor in memory 2x bigger.
Hence we better unpack weight in backward function.
store Capsule object in saved_data storage and unpack in backward function.
Detail :
https://github.com/pytorch/pytorch/pull/94432#discussion_r1126669178
Test Plan: buck2 run //scripts/kwanghoon/pytorch:torch_playground - [D43809980](https://www.internalfb.com/diff/D43809980)
You can plug and play with above script.
Differential Revision: D43895790
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96269
Approved by: https://github.com/kimishpatel