GEMM/GEMV kernels with Q4_0 quantized weights. More...
Go to the source code of this file.
Functions | |
| float | dot_q4_0 (const void *w_q4_0, const float *x, int K) |
| void | gemm_nt_q4_0 (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias. More... | |
| void | gemm_q4_0 (float *Y, const void *W, const float *X, int M, int N, int K) |
| Matrix-matrix multiply with Q4_0 weights. More... | |
| void | gemm_q4_0_backward (float *dX, const void *W, const float *dY, int M, int N, int K) |
| Batched backward pass. More... | |
| void | gemv_q4_0 (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV. More... | |
| void | gemv_q4_0_backward (float *dX, const void *W, const float *dY, int M, int K) |
| Auto-dispatch backward. More... | |
| void | gemv_q4_0_backward_ref (float *dX, const void *W, const float *dY, int M, int K) |
| Backward pass: compute input gradient. More... | |
| void | gemv_q4_0_ref (float *y, const void *W, const float *x, int M, int K) |
| Matrix-vector multiply with Q4_0 weights (scalar reference) More... | |
GEMM/GEMV kernels with Q4_0 quantized weights.
After changes: make test && make llamacpp-parity-full
Q4_0 Format:
Operations: Forward: Y = W @ X (W is Q4_0, X and Y are FP32) Backward: dX = W^T @ dY (gradient w.r.t. input)
Note: Weight gradients are not computed for quantized weights. For fine-tuning, use LoRA adapters which maintain FP32 gradients separately.
Definition in file gemm_kernels_q4_0.c.
| float dot_q4_0 | ( | const void * | w_q4_0, |
| const float * | x, | ||
| int | K | ||
| ) |
Definition at line 347 of file gemm_kernels_q4_0.c.
References gemv_q4_0().
| void gemm_nt_q4_0 | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Matrix-matrix multiply: C[M,N] = A[M,K] @ B[N,K]^T + bias.
| A | Input matrix [M x K], row-major FP32 |
| B | Weight matrix in Q4_0 format, [N x K] stored row-major |
| bias | Optional bias [N], NULL if not used |
| C | Output [M x N], row-major FP32 |
| M | Batch size (number of tokens) |
| N | Output dimension (number of rows in B) |
| K | Input dimension |
Definition at line 176 of file gemm_kernels_q4_0.c.
References C, CK_FP16_TO_FP32, block_q4_0::d, QK4_0, and block_q4_0::qs.
Referenced by ck_gemm_nt_quant().
| void gemm_q4_0 | ( | float * | Y, |
| const void * | W, | ||
| const float * | X, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Matrix-matrix multiply with Q4_0 weights.
Definition at line 151 of file gemm_kernels_q4_0.c.
References gemv_q4_0().
| void gemm_q4_0_backward | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Batched backward pass.
Definition at line 333 of file gemm_kernels_q4_0.c.
References gemv_q4_0_backward().
| void gemv_q4_0 | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV.
Definition at line 132 of file gemm_kernels_q4_0.c.
References gemv_q4_0_ref().
Referenced by dot_q4_0(), and gemm_q4_0().
| void gemv_q4_0_backward | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch backward.
Definition at line 318 of file gemm_kernels_q4_0.c.
References gemv_q4_0_backward_ref().
Referenced by gemm_q4_0_backward().
| void gemv_q4_0_backward_ref | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | K | ||
| ) |
Backward pass: compute input gradient.
| dX | Output gradient w.r.t. input [K] |
| W | Weight matrix in Q4_0 format [M x K] |
| dY | Gradient w.r.t. output [M] |
| M | Number of output rows |
| K | Number of columns (input dimension) |
Definition at line 230 of file gemm_kernels_q4_0.c.
References CK_FP16_TO_FP32, block_q4_0::d, QK4_0, and block_q4_0::qs.
Referenced by gemv_q4_0_backward().
| void gemv_q4_0_ref | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Matrix-vector multiply with Q4_0 weights (scalar reference)
| y | Output vector [M] |
| W | Weight matrix in Q4_0 format [M x K] |
| x | Input vector [K] |
| M | Number of output rows |
| K | Number of columns (must be multiple of 32) |
Definition at line 49 of file gemm_kernels_q4_0.c.
References CK_FP16_TO_FP32, block_q4_0::d, QK4_0, and block_q4_0::qs.
Referenced by gemv_q4_0().