Add extra_repr to Linear classes for debugging purpose (#6954)
**Summary**
This PR adds `extra_repr` method to some Linear classes so that
additional info is printed when printing such modules. It is useful for
debugging.
Affected modules:
- LinearLayer
- LinearAllreduce
- LmHeadLinearAllreduce
The `extra_repr` method gives the following info:
- in_features
- out_features
- bias (true or false)
- dtype
**Example**
Print llama-2-7b model on rank 0 after `init_inference` with world size
= 2.
Previously we only got class names of these modules:
```
InferenceEngine(
(module): LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): LinearLayer()
(k_proj): LinearLayer()
(v_proj): LinearLayer()
(o_proj): LinearAllreduce()
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): LinearLayer()
(up_proj): LinearLayer()
(down_proj): LinearAllreduce()
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): LmHeadLinearAllreduce()
)
)
```
Now we get more useful info:
```
InferenceEngine(
(module): LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
(k_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
(v_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
(o_proj): LinearAllreduce(in_features=2048, out_features=4096, bias=False, dtype=torch.bfloat16)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16)
(up_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16)
(down_proj): LinearAllreduce(in_features=5504, out_features=4096, bias=False, dtype=torch.bfloat16)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): LmHeadLinearAllreduce(in_features=2048, out_features=32000, bias=False, dtype=torch.bfloat16)
)
)
```