onnxruntime
ce8796dd - GatherBlockQuantized supports zero points and 8 bits for uint8 dtype (#25214)

Commit
164 days ago
GatherBlockQuantized supports zero points and 8 bits for uint8 dtype (#25214) Add support for unit8 GatherBlockQuantized for the following two areas: * Allow zero points. * Add bits attribute and support bits=8. Major change is to update shape inference; and update unit tests to cover these. Note that only CPU implementation, and CUDA implementation will be added later in another PR. ### Motivation and Context Previously, zero points are not supported when dtype is uint8. Only 4 bit quantization without zero points were supported. This change is to share weights of lm_head with 8 bit quantization between GatherBlockQuantized and MatMulNBits. For example, when K is multiple of `block_size`, typical input and output shapes are like the following: * data has shape (N, K) for 8 bits, or (N, K / 2) for 4 bits. * scales has shape (N, k_blocks), where k_blocks = (K / block_size). * zero_points has shape (N, k_blocks) for 8 bits, (N, (k_blocks + 1) / 2) for 4 bits. * output will have shape (..., K), where ... is the shape of `indices`.
Author
Parents
Loading