GEMM/GEMV kernels with Q4_K quantized weights. More...
Go to the source code of this file.
Functions | |
| float | dot_q4_k (const void *w_q4k, const float *x, int K) |
| Compute dot product of Q4_K row with FP32 vector. More... | |
| void | gemm_nt_q4_k (const float *A, const void *B, const float *bias, float *C, int M, int N, int K) |
| void | gemm_q4_k (float *Y, const void *W, const float *X, int M, int N, int K) |
| Auto-dispatch GEMM based on available SIMD. More... | |
| void | gemm_q4_k_backward (float *dX, const void *W, const float *dY, int M, int N, int K) |
| Batched backward pass. More... | |
| void | gemm_q4_k_ref (float *Y, const void *W, const float *X, int M, int N, int K) |
| Matrix-matrix multiply with Q4_K weights (scalar reference) More... | |
| void | gemv_q4_k (float *y, const void *W, const float *x, int M, int K) |
| Auto-dispatch GEMV based on available SIMD. More... | |
| void | gemv_q4_k_backward (float *dX, const void *W, const float *dY, int M, int K) |
| Auto-dispatch backward. More... | |
| void | gemv_q4_k_backward_ref (float *dX, const void *W, const float *dY, int M, int K) |
| Backward pass: compute input gradient (scalar reference) More... | |
| void | gemv_q4_k_ref (float *y, const void *W, const float *x, int M, int K) |
| Matrix-vector multiply with Q4_K weights (scalar reference) More... | |
GEMM/GEMV kernels with Q4_K quantized weights.
After changes: make test && make llamacpp-parity-full
Implements matrix multiplication where:
Key optimization: Fused dequantization - weights are dequantized in registers and immediately used in FMA, never written to memory.
Operations:
Definition in file gemm_kernels_q4k.c.
| float dot_q4_k | ( | const void * | w_q4k, |
| const float * | x, | ||
| int | K | ||
| ) |
Compute dot product of Q4_K row with FP32 vector.
| w_q4k | Q4_K blocks for one row |
| x | FP32 input vector |
| K | Vector length (must be multiple of 256) |
Definition at line 484 of file gemm_kernels_q4k.c.
References gemv_q4_k().
| void gemm_nt_q4_k | ( | const float * | A, |
| const void * | B, | ||
| const float * | bias, | ||
| float * | C, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Definition at line 683 of file gemm_kernels_q4k.c.
References C, and gemm_q4_k().
Referenced by ck_attention_project_head_major_q4_k(), ck_gemm_nt_quant(), ck_layer_forward_rmsnorm_swiglu_decode_q4_k(), ck_mlp_swiglu_forward_q4_k(), ck_qkv_project_head_major_q4_k(), ck_qkv_project_head_major_token_q4_k(), model_decode_token(), model_layer_0_decode(), model_layer_0_prefill(), model_layer_10_decode(), model_layer_10_prefill(), model_layer_11_decode(), model_layer_11_prefill(), model_layer_12_decode(), model_layer_12_prefill(), model_layer_13_decode(), model_layer_13_prefill(), model_layer_14_decode(), model_layer_14_prefill(), model_layer_15_decode(), model_layer_15_prefill(), model_layer_16_decode(), model_layer_16_prefill(), model_layer_17_decode(), model_layer_17_prefill(), model_layer_18_decode(), model_layer_18_prefill(), model_layer_19_decode(), model_layer_19_prefill(), model_layer_1_decode(), model_layer_1_prefill(), model_layer_20_decode(), model_layer_20_prefill(), model_layer_21_decode(), model_layer_21_prefill(), model_layer_22_decode(), model_layer_22_prefill(), model_layer_23_decode(), model_layer_23_prefill(), model_layer_2_decode(), model_layer_2_prefill(), model_layer_3_decode(), model_layer_3_prefill(), model_layer_4_decode(), model_layer_4_prefill(), model_layer_5_decode(), model_layer_5_prefill(), model_layer_6_decode(), model_layer_6_prefill(), model_layer_7_decode(), model_layer_7_prefill(), model_layer_8_decode(), model_layer_8_prefill(), model_layer_9_decode(), model_layer_9_prefill(), qwen2_0_5b_decode_decode_token(), qwen2_0_5b_decode_layer_0_decode(), qwen2_0_5b_decode_layer_0_prefill(), qwen2_0_5b_decode_layer_10_decode(), qwen2_0_5b_decode_layer_10_prefill(), qwen2_0_5b_decode_layer_11_decode(), qwen2_0_5b_decode_layer_11_prefill(), qwen2_0_5b_decode_layer_12_decode(), qwen2_0_5b_decode_layer_12_prefill(), qwen2_0_5b_decode_layer_13_decode(), qwen2_0_5b_decode_layer_13_prefill(), qwen2_0_5b_decode_layer_14_decode(), qwen2_0_5b_decode_layer_14_prefill(), qwen2_0_5b_decode_layer_15_decode(), qwen2_0_5b_decode_layer_15_prefill(), qwen2_0_5b_decode_layer_16_decode(), qwen2_0_5b_decode_layer_16_prefill(), qwen2_0_5b_decode_layer_17_decode(), qwen2_0_5b_decode_layer_17_prefill(), qwen2_0_5b_decode_layer_18_decode(), qwen2_0_5b_decode_layer_18_prefill(), qwen2_0_5b_decode_layer_19_decode(), qwen2_0_5b_decode_layer_19_prefill(), qwen2_0_5b_decode_layer_1_decode(), qwen2_0_5b_decode_layer_1_prefill(), qwen2_0_5b_decode_layer_20_decode(), qwen2_0_5b_decode_layer_20_prefill(), qwen2_0_5b_decode_layer_21_decode(), qwen2_0_5b_decode_layer_21_prefill(), qwen2_0_5b_decode_layer_22_decode(), qwen2_0_5b_decode_layer_22_prefill(), qwen2_0_5b_decode_layer_23_decode(), qwen2_0_5b_decode_layer_23_prefill(), qwen2_0_5b_decode_layer_2_decode(), qwen2_0_5b_decode_layer_2_prefill(), qwen2_0_5b_decode_layer_3_decode(), qwen2_0_5b_decode_layer_3_prefill(), qwen2_0_5b_decode_layer_4_decode(), qwen2_0_5b_decode_layer_4_prefill(), qwen2_0_5b_decode_layer_5_decode(), qwen2_0_5b_decode_layer_5_prefill(), qwen2_0_5b_decode_layer_6_decode(), qwen2_0_5b_decode_layer_6_prefill(), qwen2_0_5b_decode_layer_7_decode(), qwen2_0_5b_decode_layer_7_prefill(), qwen2_0_5b_decode_layer_8_decode(), qwen2_0_5b_decode_layer_8_prefill(), qwen2_0_5b_decode_layer_9_decode(), and qwen2_0_5b_decode_layer_9_prefill().
| void gemm_q4_k | ( | float * | Y, |
| const void * | W, | ||
| const float * | X, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Auto-dispatch GEMM based on available SIMD.
Definition at line 461 of file gemm_kernels_q4k.c.
References gemm_q4_k_ref().
Referenced by gemm_nt_q4_k().
| void gemm_q4_k_backward | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Batched backward pass.
Definition at line 656 of file gemm_kernels_q4k.c.
References gemv_q4_k_backward().
| void gemm_q4_k_ref | ( | float * | Y, |
| const void * | W, | ||
| const float * | X, | ||
| int | M, | ||
| int | N, | ||
| int | K | ||
| ) |
Matrix-matrix multiply with Q4_K weights (scalar reference)
| Y | Output matrix [M x N] |
| W | Weight matrix in Q4_K format [M x K] |
| X | Input matrix [K x N] (column-major for cache efficiency) |
| M | Number of output rows |
| N | Batch size (number of columns) |
| K | Hidden dimension |
Definition at line 316 of file gemm_kernels_q4k.c.
References gemv_q4_k().
Referenced by gemm_q4_k().
| void gemv_q4_k | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch GEMV based on available SIMD.
Definition at line 285 of file gemm_kernels_q4k.c.
References gemv_q4_k_ref().
Referenced by attention_mlp_fused_q4k(), dot_q4_k(), gemm_q4_k_ref(), layer_fused_attn_mlp_qkv_q4k(), and rmsnorm_qkv_q4k_fused().
| void gemv_q4_k_backward | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | K | ||
| ) |
Auto-dispatch backward.
Definition at line 641 of file gemm_kernels_q4k.c.
References gemv_q4_k_backward_ref().
Referenced by gemm_q4_k_backward().
| void gemv_q4_k_backward_ref | ( | float * | dX, |
| const void * | W, | ||
| const float * | dY, | ||
| int | M, | ||
| int | K | ||
| ) |
Backward pass: compute input gradient (scalar reference)
| dX | Output gradient w.r.t. input [K] |
| W | Weight matrix in Q4_K format [M x K] |
| dY | Gradient w.r.t. output [M] |
| M | Number of output rows |
| K | Number of columns (input dimension) |
Definition at line 511 of file gemm_kernels_q4k.c.
References CK_FP16_TO_FP32, block_q4_K::d, block_q4_K::dmin, QK_K, block_q4_K::qs, block_q4_K::scales, and unpack_q4_k_scales().
Referenced by gemv_q4_k_backward().
| void gemv_q4_k_ref | ( | float * | y, |
| const void * | W, | ||
| const float * | x, | ||
| int | M, | ||
| int | K | ||
| ) |
Matrix-vector multiply with Q4_K weights (scalar reference)
| y | Output vector [M] |
| W | Weight matrix in Q4_K format [M x K], stored row-major |
| x | Input vector [K] |
| M | Number of output rows |
| K | Number of columns (must be multiple of 256) |
Definition at line 53 of file gemm_kernels_q4k.c.
References block_q4_K::d, block_q4_K::dmin, GGML_FP16_TO_FP32, QK_K, block_q4_K::qs, block_q4_K::scales, and unpack_q4_k_scales().
Referenced by gemv_q4_k().