← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_kernels_q6k.c
Go to the documentation of this file.
1 /**
2  * @file gemm_kernels_q6k.c
3  * @brief GEMM/GEMV kernels with Q6_K quantized weights
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * Implements matrix multiplication where:
15  * - Activations (input): FP32
16  * - Weights: Q6_K (6-bit k-quant, int8 scales)
17  * - Output: FP32
18  *
19  * Q6_K Format (256 weights per block):
20  * - d: FP16 super-block scale
21  * - ql: 128 bytes (low 4 bits of each weight)
22  * - qh: 64 bytes (high 2 bits of each weight)
23  * - scales: 16 int8 sub-block scales
24  */
25 
26 #include <stdint.h>
27 #include <stddef.h>
28 #include "ckernel_quant.h"
29 
30 /* Include SIMD headers based on available extensions */
31 #if defined(__AVX512F__) || defined(__AVX2__) || defined(__AVX__) || defined(__SSE4_1__)
32 #include <immintrin.h>
33 #endif
34 
35 /* ============================================================================
36  * GEMV: y = W @ x (W is Q6_K, x and y are FP32)
37  * ============================================================================ */
38 
39 static float dot_q6_k_ref(const block_q6_K *w,
40  const float *x,
41  int K)
42 {
43  const int blocks_per_row = K / QK_K;
44  float sum = 0.0f;
45 
46  for (int b = 0; b < blocks_per_row; ++b) {
47  const block_q6_K *block = &w[b];
48  const float d = GGML_FP16_TO_FP32(block->d);
49 
50  const uint8_t *ql = block->ql;
51  const uint8_t *qh = block->qh;
52  const int8_t *sc = block->scales;
53  const float *xp = x + (size_t)b * QK_K;
54 
55  for (int n = 0; n < QK_K; n += 128) {
56  for (int l = 0; l < 32; ++l) {
57  const int is = l / 16;
58  const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
59  const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
60  const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
61  const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
62 
63  sum += (d * (float)sc[is + 0] * (float)q1) * xp[l + 0];
64  sum += (d * (float)sc[is + 2] * (float)q2) * xp[l + 32];
65  sum += (d * (float)sc[is + 4] * (float)q3) * xp[l + 64];
66  sum += (d * (float)sc[is + 6] * (float)q4) * xp[l + 96];
67  }
68  xp += 128;
69  ql += 64;
70  qh += 32;
71  sc += 8;
72  }
73  }
74 
75  return sum;
76 }
77 
78 /* ============================================================================
79  * AVX Implementation (256-bit)
80  *
81  * Q6_K is complex: 6-bit weights packed as 4+2 bits, with 16 sub-block scales.
82  * Uses AVX for the final accumulation while keeping scalar dequantization
83  * for simplicity and correctness.
84  * ============================================================================ */
85 
86 #if defined(__AVX__) && !defined(__AVX512F__)
87 static float dot_q6_k_avx(const block_q6_K *w,
88  const float *x,
89  int K)
90 {
91  const int blocks_per_row = K / QK_K;
92 
93  /* Use 4 accumulators for better ILP */
94  __m256 acc0 = _mm256_setzero_ps();
95  __m256 acc1 = _mm256_setzero_ps();
96  __m256 acc2 = _mm256_setzero_ps();
97  __m256 acc3 = _mm256_setzero_ps();
98 
99  for (int b = 0; b < blocks_per_row; ++b) {
100  const block_q6_K *block = &w[b];
101  const float d = GGML_FP16_TO_FP32(block->d);
102 
103  const uint8_t *ql = block->ql;
104  const uint8_t *qh = block->qh;
105  const int8_t *sc = block->scales;
106  const float *xp = x + (size_t)b * QK_K;
107 
108  /* Process 256 weights in 2 iterations of 128 */
109  for (int n = 0; n < QK_K; n += 128) {
110  /* Process 32 elements at a time in groups of 8 for AVX */
111  for (int l = 0; l < 32; l += 8) {
112  const int is = l / 16;
113 
114  /* Dequantize 8 weights for each of the 4 streams */
115  float dq0[8], dq1[8], dq2[8], dq3[8];
116 
117  for (int i = 0; i < 8; i++) {
118  int idx = l + i;
119  const int8_t q1 = (int8_t)((ql[idx + 0] & 0xF) | (((qh[idx] >> 0) & 3) << 4)) - 32;
120  const int8_t q2 = (int8_t)((ql[idx + 32] & 0xF) | (((qh[idx] >> 2) & 3) << 4)) - 32;
121  const int8_t q3 = (int8_t)((ql[idx + 0] >> 4) | (((qh[idx] >> 4) & 3) << 4)) - 32;
122  const int8_t q4 = (int8_t)((ql[idx + 32] >> 4) | (((qh[idx] >> 6) & 3) << 4)) - 32;
123 
124  const int is_i = (l + i) / 16;
125  dq0[i] = d * (float)sc[is_i + 0] * (float)q1;
126  dq1[i] = d * (float)sc[is_i + 2] * (float)q2;
127  dq2[i] = d * (float)sc[is_i + 4] * (float)q3;
128  dq3[i] = d * (float)sc[is_i + 6] * (float)q4;
129  }
130 
131  __m256 vw0 = _mm256_loadu_ps(dq0);
132  __m256 vw1 = _mm256_loadu_ps(dq1);
133  __m256 vw2 = _mm256_loadu_ps(dq2);
134  __m256 vw3 = _mm256_loadu_ps(dq3);
135 
136  __m256 vx0 = _mm256_loadu_ps(&xp[l + 0]);
137  __m256 vx1 = _mm256_loadu_ps(&xp[l + 32]);
138  __m256 vx2 = _mm256_loadu_ps(&xp[l + 64]);
139  __m256 vx3 = _mm256_loadu_ps(&xp[l + 96]);
140 
141  acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(vw0, vx0));
142  acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(vw1, vx1));
143  acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(vw2, vx2));
144  acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(vw3, vx3));
145  }
146  xp += 128;
147  ql += 64;
148  qh += 32;
149  sc += 8;
150  }
151  }
152 
153  /* Combine accumulators */
154  __m256 sum01 = _mm256_add_ps(acc0, acc1);
155  __m256 sum23 = _mm256_add_ps(acc2, acc3);
156  __m256 sum = _mm256_add_ps(sum01, sum23);
157 
158  /* Horizontal sum of 8 floats */
159  __m128 hi = _mm256_extractf128_ps(sum, 1);
160  __m128 lo = _mm256_castps256_ps128(sum);
161  __m128 sum128 = _mm_add_ps(hi, lo);
162  sum128 = _mm_hadd_ps(sum128, sum128);
163  sum128 = _mm_hadd_ps(sum128, sum128);
164 
165  return _mm_cvtss_f32(sum128);
166 }
167 #endif /* __AVX__ && !__AVX512F__ */
168 
169 void gemv_q6_k(float *y,
170  const void *W,
171  const float *x,
172  int M, int K)
173 {
174  if (!y || !W || !x) {
175  return;
176  }
177  if (M <= 0 || K <= 0) {
178  return;
179  }
180  // TEMPORARILY DISABLE NEW AVX KERNELS - USE REFERENCE ONLY
181 
182  const block_q6_K *blocks = (const block_q6_K *)W;
183  const int blocks_per_row = K / QK_K;
184 
185  for (int row = 0; row < M; ++row) {
186  const block_q6_K *w_row = blocks + (size_t)row * (size_t)blocks_per_row;
187 #if defined(__AVX__) && !defined(__AVX512F__)
188  y[row] = dot_q6_k_avx(w_row, x, K);
189 #else
190  y[row] = dot_q6_k_ref(w_row, x, K);
191 #endif
192  }
193 }
194 
195 void gemm_q6_k(float *Y,
196  const void *W,
197  const float *X,
198  int M, int N, int K)
199 {
200  if (!Y || !W || !X) {
201  return;
202  }
203  if (M <= 0 || N <= 0 || K <= 0) {
204  return;
205  }
206 
207  for (int n = 0; n < N; ++n) {
208  gemv_q6_k(&Y[n * M], W, &X[n * K], M, K);
209  }
210 }
211 
212 void gemm_nt_q6_k(const float *A,
213  const void *B,
214  const float *bias,
215  float *C,
216  int M, int N, int K)
217 {
218  if (!A || !B || !C) {
219  return;
220  }
221  if (M <= 0 || N <= 0 || K <= 0) {
222  return;
223  }
224 
225  /* gemm_q6_k produces Y as [batch x M_out] where:
226  * batch = M (tokens)
227  * M_out = N (output channels) */
228  gemm_q6_k(C, B, A, /*M_out=*/N, /*N_batch=*/M, K);
229 
230  if (!bias) {
231  return;
232  }
233 
234  for (int i = 0; i < M; ++i) {
235  float *row = C + (size_t)i * (size_t)N;
236  for (int j = 0; j < N; ++j) {
237  row[j] += bias[j];
238  }
239  }
240 }
241 
242 /* Reference implementation - used as fallback from SSE when K not aligned */
243 void gemm_nt_q6_k_ref(const float *A,
244  const void *B,
245  const float *bias,
246  float *C,
247  int M, int N, int K)
248 {
249  gemm_nt_q6_k(A, B, bias, C, M, N, K);
250 }
Quantization block structures for weight-only quantization.
#define GGML_FP16_TO_FP32
#define QK_K
void gemm_q6_k(float *Y, const void *W, const float *X, int M, int N, int K)
void gemv_q6_k(float *y, const void *W, const float *x, int M, int K)
void gemm_nt_q6_k_ref(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q6_k(const float *A, const void *B, const float *bias, float *C, int M, int N, int K)
static float dot_q6_k_ref(const block_q6_K *w, const float *x, int K)
#define C(color)
Definition: show_config.c:39
uint8_t ql[256/2]
int8_t scales[256/16]
uint8_t qh[256/4]