GEMM/GEMV kernels with Q4_1 quantized weights. More...
Go to the source code of this file.
Functions | |
| float | dot_q4_1 (const void *w_q4_1, const float *x, int K) |
| void | gemm_nt_q4_1 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| GEMM with transposed Q4_1 weights: C = A @ B^T. More... | |
| void | gemm_q4_1 (float *Y, const void *W, const float *X, int M, int N, int K) |
| Matrix-matrix multiply with Q4_1 weights. More... | |
| void | gemm_q4_1_backward (float *dX, const void *W, const float *dY, int M, int N, int K) |
| Batched backward pass. More... | |
| void | gemv_q4_1 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV. More... | |
| void | gemv_q4_1_backward (float *dX, const void *W, const float *dY, int M, int K) |
| Auto-dispatch backward. More... | |
| void | gemv_q4_1_backward_ref (float *dX, const void *W, const float *dY, int M, int K) |
| Backward pass: compute input gradient. More... | |
| void | gemv_q4_1_ref (float *y, const void *W, const float *x, int M, int K) |
| Matrix-vector multiply with Q4_1 weights (scalar reference) More... | |
GEMM/GEMV kernels with Q4_1 quantized weights.
After changes: make test && make llamacpp-parity-full
Q4_1 Format:
Dequantization: w = d * q + m where q is the 4-bit unsigned value (0-15)
Operations: Forward: Y = W @ X (W is Q4_1, X and Y are FP32) Backward: dX = W^T @ dY (gradient w.r.t. input)
Definition in file gemm_kernels_q4_1.c.
| float dot_q4_1 | ( | const void * | w_q4_1, |
| const float * | x, | ||
| int | K | ||
| ) |
Definition at line 299 of file gemm_kernels_q4_1.c.
References gemv_q4_1().
| void gemm_nt_q4_1 | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
GEMM with transposed Q4_1 weights: C = A @ B^T.
| A | Input activations [M x K], row-major FP32 |
| B | Weight matrix in Q4_1 format [N x K], row-major quantized |
| bias | Optional bias [N], NULL if not used |
| C | Output [M x N], row-major FP32 |
| M | Batch size (number of tokens) |
| N | Output dimension |
| K | Input dimension |
Definition at line 256 of file gemm_kernels_q4_1.c.
References C, CK_FP16_TO_FP32, block_q4_1::d, block_q4_1::m, QK4_1, and block_q4_1::qs.
Referenced by ck_gemm_nt_quant().
| void gemm_q4_1 | ( | float * | Y, |
| const void * | W, | ||
| const float * | X, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Matrix-matrix multiply with Q4_1 weights.
Definition at line 158 of file gemm_kernels_q4_1.c.
References gemv_q4_1().
| void gemm_q4_1_backward | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Batched backward pass.
Definition at line 231 of file gemm_kernels_q4_1.c.
References gemv_q4_1_backward().
| void gemv_q4_1 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV.
Definition at line 139 of file gemm_kernels_q4_1.c.
References gemv_q4_1_ref().
Referenced by dot_q4_1(), and gemm_q4_1().
| void gemv_q4_1_backward | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch backward.
Definition at line 220 of file gemm_kernels_q4_1.c.
References gemv_q4_1_backward_ref().
Referenced by gemm_q4_1_backward().
| void gemv_q4_1_backward_ref | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | K | ||
| ) |
Backward pass: compute input gradient.
| dX | Output gradient w.r.t. input [K] |
| W | Weight matrix in Q4_1 format [M x K] |
| dY | Gradient w.r.t. output [M] |
| M | Number of output rows |
| K | Number of columns (input dimension) |
Definition at line 181 of file gemm_kernels_q4_1.c.
References CK_FP16_TO_FP32, block_q4_1::d, block_q4_1::m, QK4_1, and block_q4_1::qs.
Referenced by gemv_q4_1_backward().
| void gemv_q4_1_ref | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Matrix-vector multiply with Q4_1 weights (scalar reference)
| y | Output vector [M] |
| W | Weight matrix in Q4_1 format [M x K] |
| x | Input vector [K] |
| M | Number of output rows |
| K | Number of columns (must be multiple of 32) |
Definition at line 50 of file gemm_kernels_q4_1.c.
References CK_FP16_TO_FP32, block_q4_1::d, block_q4_1::m, QK4_1, and block_q4_1::qs.
Referenced by gemv_q4_1().