[BLOOM] Clean modeling code (#18344)
* Cleanup some code
* Improve signatures
* Try to reduce the number of reshape/copies
* I don't think we actually need the layer_num scaling trick
* No need for duplication
* Try to fix beam_search
* Fix beam search
* Removing layer num normalization seems to be breaking
* Not sure self.layer_number normalization actually matters
* Try and be backward compatible
* Try to fix beam_search
* Revert attempt to be backward compatible
* Improve documentation on past_key_values format
* Optimize the device allocation in case of hidden_states in multiple devices
* No need to manually cast the values to a specific device
* Rename with long version of variables
* Improve type hinting
* Add comment that explains that some methods return views
* Actually i think the attention casting only makes sense when we use torch.float16
* We don't actually need layer_number to be passed anymore
* Fix FX test
* Bypass torch.baddbmm
* Apply suggestions from code review
* Add comment about support for torchScript v1.11
* fix ONNX support for bloom (#18456)
Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com>
Co-authored-by: Nouamane Tazi <nouamane98@gmail.com>