← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_head_major_output.c
Go to the documentation of this file.
1 /**
2  * @file gemm_head_major_output.c
3  * @brief Output projection from head-major attention (NO LAYOUT CONVERSION)
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. NO memcpy for layout - use strided access, not copies
10  * 4. API must define: inputs, outputs, workspace, and memory layouts
11  * 5. Pure computation - deterministic, no side effects
12  *
13  * After changes: make test && make llamacpp-parity-full
14  *
15  * PROBLEM THIS SOLVES:
16  * ====================
17  * The standard mega_fused_attention_prefill has a bottleneck:
18  * attn_out [num_heads, tokens, head_dim] (head-major)
19  * → flatten_head_major() - 448 memcpy calls for 32 tokens × 14 heads!
20  * → token-major buffer
21  * → GEMM output projection
22  *
23  * This kernel eliminates the flatten by reading head-major data directly with
24  * strided access. The output projection computes:
25  *
26  * output[t, n] = bias[n] + sum_h wo[n, h*head_dim:(h+1)*head_dim] @ attn_out[h, t, :]
27  *
28  * where wo is Q5_0 quantized [embed_dim, embed_dim] and attn_out is head-major.
29  *
30  * Expected speedup: 1.5-2x by eliminating 448 small memcpy calls.
31  */
32 
33 #include <stdint.h>
34 #include <stddef.h>
35 #include <string.h>
36 #include "ckernel_quant.h"
37 #include "ckernel_dtype.h"
38 
39 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
40 #include <immintrin.h>
41 #endif
42 
43 /* Forward declaration from dequant_kernels.c */
44 void dequant_q5_0_row(const void *src, float *dst, size_t n_elements);
45 
46 /* ============================================================================
47  * Scalar reference: Output projection from head-major attention
48  * ============================================================================ */
49 
50 /**
51  * @brief Output projection reading head-major attention output (Q5_0 weights)
52  *
53  * @param output Output [tokens, embed_dim] (token-major, written contiguously)
54  * @param attn_out Attention output [num_heads, tokens, head_dim] (head-major, strided)
55  * @param wo Output weights in Q5_0 format [embed_dim, embed_dim]
56  * @param bias Optional bias [embed_dim]
57  * @param tokens Number of tokens
58  * @param embed_dim Output embedding dimension
59  * @param num_heads Number of attention heads
60  * @param head_dim Head dimension (must be multiple of 32 for Q5_0)
61  */
62 void gemv_nt_q5_0_head_major_output(float *output,
63  const float *attn_out,
64  const void *wo,
65  const float *bias,
66  int tokens,
67  int embed_dim,
68  int num_heads,
69  int head_dim)
70 {
71  if (!output || !attn_out || !wo) return;
72  if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0) return;
73 
74  const int blocks_per_head = head_dim / QK5_0;
75  const int blocks_per_row = embed_dim / QK5_0;
76  const block_q5_0 *weights = (const block_q5_0 *)wo;
77 
78  /* Strides for head-major layout */
79  const size_t token_stride = head_dim; /* attn_out[h][t] offset */
80  const size_t head_stride = (size_t)tokens * token_stride; /* attn_out[h] offset */
81 
82  /* Initialize output with bias (if provided) */
83  if (bias) {
84  for (int t = 0; t < tokens; t++) {
85  float *out_row = output + (size_t)t * embed_dim;
86  for (int n = 0; n < embed_dim; n++) {
87  out_row[n] = bias[n];
88  }
89  }
90  } else {
91  memset(output, 0, (size_t)tokens * embed_dim * sizeof(float));
92  }
93 
94  /* Accumulate contributions from each head */
95  for (int h = 0; h < num_heads; h++) {
96  const float *head_data = attn_out + (size_t)h * head_stride;
97 
98  /* For each output row (n) corresponding to this head's slice */
99  const int head_offset = h * blocks_per_head;
100 
101  for (int n_block = 0; n_block < blocks_per_head; n_block++) {
102  for (int n = 0; n < embed_dim; n++) {
103  const block_q5_0 *w_row = weights + (size_t)n * blocks_per_row + head_offset + n_block;
104  const float d = CK_FP16_TO_FP32(w_row->d);
105 
106  /* Get high bits */
107  uint32_t qh;
108  memcpy(&qh, w_row->qh, sizeof(qh));
109 
110  /* Accumulate for all tokens at once (better cache reuse) */
111  for (int t = 0; t < tokens; t++) {
112  const float *token_vec = head_data + (size_t)t * token_stride + (size_t)n_block * QK5_0;
113  float sum = 0.0f;
114 
115  /* Q5_0 dot product for this block */
116  for (int j = 0; j < QK5_0 / 2; j++) {
117  const uint8_t packed = w_row->qs[j];
118  const int lo = (packed & 0x0F);
119  const int hi = (packed >> 4);
120  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
121  const int xh_1 = ((qh >> (j + 12))) & 0x10;
122  const int q0 = (lo | xh_0) - 16;
123  const int q1 = (hi | xh_1) - 16;
124 
125  sum += d * (float)q0 * token_vec[j];
126  sum += d * (float)q1 * token_vec[j + 16];
127  }
128 
129  output[(size_t)t * embed_dim + n] += sum;
130  }
131  }
132  }
133  }
134 }
135 
136 /* ============================================================================
137  * Vectorized version with AVX (8 floats at a time)
138  * ============================================================================ */
139 
140 #if defined(__AVX__) && defined(__F16C__)
141 #include <immintrin.h>
142 
143 /**
144  * @brief Optimized version with AVX SIMD
145  *
146  * Key optimizations:
147  * 1. Process 8 output rows at a time using AVX
148  * 2. Accumulate across heads for better cache utilization
149  * 3. Use FMAC for multiply-accumulate
150  */
151 void gemv_nt_q5_0_head_major_output_avx(float *output,
152  const float *attn_out,
153  const void *wo,
154  const float *bias,
155  int tokens,
156  int embed_dim,
157  int num_heads,
158  int head_dim)
159 {
160  if (!output || !attn_out || !wo) return;
161  if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0) return;
162 
163  const int blocks_per_head = head_dim / QK5_0;
164  const int blocks_per_row = embed_dim / QK5_0;
165  const block_q5_0 *weights = (const block_q5_0 *)wo;
166 
167  const size_t token_stride = head_dim;
168  const size_t head_stride = (size_t)tokens * token_stride;
169 
170  /* Initialize output */
171  if (bias) {
172  for (int t = 0; t < tokens; t++) {
173  float *out_row = output + (size_t)t * embed_dim;
174  for (int n = 0; n < embed_dim; n++) {
175  out_row[n] = bias[n];
176  }
177  }
178  } else {
179  memset(output, 0, (size_t)tokens * embed_dim * sizeof(float));
180  }
181 
182  /* Process heads sequentially, accumulating into output */
183  for (int h = 0; h < num_heads; h++) {
184  const float *head_data = attn_out + (size_t)h * head_stride;
185  const int head_offset = h * blocks_per_head;
186 
187  /* Process output rows in chunks of 8 for AVX */
188  int n = 0;
189  for (; n + 7 < embed_dim; n += 8) {
190  /* Process 8 output rows at once */
191  for (int n_block = 0; n_block < blocks_per_head; n_block++) {
192  const size_t w_offset = (size_t)(n + head_offset + n_block) * blocks_per_row + n_block;
193 
194  __m256 acc0 = _mm256_setzero_ps();
195  __m256 acc1 = _mm256_setzero_ps();
196  __m256 acc2 = _mm256_setzero_ps();
197  __m256 acc3 = _mm256_setzero_ps();
198  __m256 acc4 = _mm256_setzero_ps();
199  __m256 acc5 = _mm256_setzero_ps();
200  __m256 acc6 = _mm256_setzero_ps();
201  __m256 acc7 = _mm256_setzero_ps();
202 
203  /* For each token */
204  for (int t = 0; t < tokens; t++) {
205  const float *token_vec = head_data + (size_t)t * token_stride + (size_t)n_block * QK5_0;
206 
207  /* Load 8 weight blocks */
208  const block_q5_0 *w0 = weights + w_offset;
209  const block_q5_0 *w1 = w0 + blocks_per_row;
210  const block_q5_0 *w2 = w1 + blocks_per_row;
211  const block_q5_0 *w3 = w2 + blocks_per_row;
212  const block_q5_0 *w4 = w3 + blocks_per_row;
213  const block_q5_0 *w5 = w4 + blocks_per_row;
214  const block_q5_0 *w6 = w5 + blocks_per_row;
215  const block_q5_0 *w7 = w6 + blocks_per_row;
216 
217  const float d0 = CK_FP16_TO_FP32(w0->d);
218  const float d1 = CK_FP16_TO_FP32(w1->d);
219  const float d2 = CK_FP16_TO_FP32(w2->d);
220  const float d3 = CK_FP16_TO_FP32(w3->d);
221  const float d4 = CK_FP16_TO_FP32(w4->d);
222  const float d5 = CK_FP16_TO_FP32(w5->d);
223  const float d6 = CK_FP16_TO_FP32(w6->d);
224  const float d7 = CK_FP16_TO_FP32(w7->d);
225 
226  /* Dot products for each output row */
227  for (int j = 0; j < 16; j++) {
228  const uint8_t p0 = w0->qs[j];
229  const uint8_t p1 = w1->qs[j];
230  const uint8_t p2 = w2->qs[j];
231  const uint8_t p3 = w3->qs[j];
232  const uint8_t p4 = w4->qs[j];
233  const uint8_t p5 = w5->qs[j];
234  const uint8_t p6 = w6->qs[j];
235  const uint8_t p7 = w7->qs[j];
236 
237  const float tv0 = token_vec[j];
238  const float tv1 = token_vec[j + 16];
239 
240  /* Extract low nibbles */
241  const int lo0 = (p0 & 0x0F) - 8;
242  const int lo1 = (p1 & 0x0F) - 8;
243  const int lo2 = (p2 & 0x0F) - 8;
244  const int lo3 = (p3 & 0x0F) - 8;
245  const int lo4 = (p4 & 0x0F) - 8;
246  const int lo5 = (p5 & 0x0F) - 8;
247  const int lo6 = (p6 & 0x0F) - 8;
248  const int lo7 = (p7 & 0x0F) - 8;
249 
250  __m256 xv = _mm256_set1_ps(tv0);
251  __m256 qw = _mm256_setr_ps(lo0, lo1, lo2, lo3, lo4, lo5, lo6, lo7);
252  __m256 vw = _mm256_setr_ps(d0, d1, d2, d3, d4, d5, d6, d7);
253  acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(_mm256_mul_ps(qw, vw), xv));
254 
255  /* Extract high nibbles */
256  const int hi0 = (p0 >> 4) - 8;
257  const int hi1 = (p1 >> 4) - 8;
258  const int hi2 = (p2 >> 4) - 8;
259  const int hi3 = (p3 >> 4) - 8;
260  const int hi4 = (p4 >> 4) - 8;
261  const int hi5 = (p5 >> 4) - 8;
262  const int hi6 = (p6 >> 4) - 8;
263  const int hi7 = (p7 >> 4) - 8;
264 
265  xv = _mm256_set1_ps(tv1);
266  qw = _mm256_setr_ps(hi0, hi1, hi2, hi3, hi4, hi5, hi6, hi7);
267  vw = _mm256_setr_ps(d0, d1, d2, d3, d4, d5, d6, d7);
268  acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(_mm256_mul_ps(qw, vw), xv));
269  }
270 
271  /* Combine low and high accumulators */
272  __m256 total = _mm256_add_ps(acc0, acc1);
273 
274  /* Store to output */
275  float *out_row = output + (size_t)t * embed_dim + n;
276  __m256 out_val = _mm256_loadu_ps(out_row);
277  out_val = _mm256_add_ps(out_val, total);
278  _mm256_storeu_ps(out_row, out_val);
279  }
280  }
281  }
282 
283  /* Handle remaining output rows with scalar */
284  for (; n < embed_dim; n++) {
285  for (int n_block = 0; n_block < blocks_per_head; n_block++) {
286  const block_q5_0 *w_row = weights + (size_t)(n + head_offset + n_block) * blocks_per_row + n_block;
287  const float d = CK_FP16_TO_FP32(w_row->d);
288 
289  uint32_t qh;
290  memcpy(&qh, w_row->qh, sizeof(qh));
291 
292  for (int t = 0; t < tokens; t++) {
293  const float *token_vec = head_data + (size_t)t * token_stride + (size_t)n_block * QK5_0;
294  float sum = 0.0f;
295 
296  for (int j = 0; j < QK5_0 / 2; j++) {
297  const uint8_t packed = w_row->qs[j];
298  const int lo = (packed & 0x0F);
299  const int hi = (packed >> 4);
300  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
301  const int xh_1 = ((qh >> (j + 12))) & 0x10;
302  const int q0 = (lo | xh_0) - 16;
303  const int q1 = (hi | xh_1) - 16;
304 
305  sum += d * (float)q0 * token_vec[j];
306  sum += d * (float)q1 * token_vec[j + 16];
307  }
308 
309  output[(size_t)t * embed_dim + n] += sum;
310  }
311  }
312  }
313  }
314 }
315 
316 #endif /* __AVX__ */
317 
318 /* ============================================================================
319  * Generic dispatch
320  * ============================================================================ */
321 
322 /**
323  * @brief Output projection from head-major attention (auto-dispatch)
324  *
325  * This replaces flatten_head_major() + ck_gemm_nt_quant() with a single
326  * strided-access kernel that reads head-major attention output directly.
327  */
328 void ck_gemm_nt_head_major_q5_0(const float *attn_out, /* [num_heads, tokens, head_dim] */
329  const void *wo,
330  const float *bias,
331  float *output, /* [tokens, embed_dim] */
332  int tokens,
333  int embed_dim,
334  int num_heads,
335  int head_dim)
336 {
337 #if defined(__AVX__) && defined(__F16C__)
338  gemv_nt_q5_0_head_major_output_avx(output, attn_out, wo, bias,
339  tokens, embed_dim, num_heads, head_dim);
340 #else
341  gemv_nt_q5_0_head_major_output(output, attn_out, wo, bias,
342  tokens, embed_dim, num_heads, head_dim);
343 #endif
344 }
345 
346 /* ============================================================================
347  * Q8_0 variant (for V weights which are often Q8_0)
348  * ============================================================================ */
349 
350 /**
351  * @brief Output projection from head-major attention (Q8_0 weights)
352  */
353 void ck_gemm_nt_head_major_q8_0(const float *attn_out,
354  const void *wo,
355  const float *bias,
356  float *output,
357  int tokens,
358  int embed_dim,
359  int num_heads,
360  int head_dim)
361 {
362  if (!output || !attn_out || !wo) return;
363  if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0) return;
364 
365  const int blocks_per_head = head_dim / QK8_0;
366  const int blocks_per_row = embed_dim / QK8_0;
367  const block_q8_0 *weights = (const block_q8_0 *)wo;
368 
369  const size_t token_stride = head_dim;
370  const size_t head_stride = (size_t)tokens * token_stride;
371 
372  /* Initialize output */
373  if (bias) {
374  for (int t = 0; t < tokens; t++) {
375  float *out_row = output + (size_t)t * embed_dim;
376  for (int n = 0; n < embed_dim; n++) {
377  out_row[n] = bias[n];
378  }
379  }
380  } else {
381  memset(output, 0, (size_t)tokens * embed_dim * sizeof(float));
382  }
383 
384  /* Accumulate from each head */
385  for (int h = 0; h < num_heads; h++) {
386  const float *head_data = attn_out + (size_t)h * head_stride;
387  const int head_offset = h * blocks_per_head;
388 
389  for (int n_block = 0; n_block < blocks_per_head; n_block++) {
390  for (int n = 0; n < embed_dim; n++) {
391  const block_q8_0 *w_row = weights + (size_t)n * blocks_per_row + head_offset + n_block;
392  const float d = CK_FP16_TO_FP32(w_row->d);
393 
394  for (int t = 0; t < tokens; t++) {
395  const float *token_vec = head_data + (size_t)t * token_stride + (size_t)n_block * QK8_0;
396  float sum = 0.0f;
397 
398  for (int j = 0; j < QK8_0; j++) {
399  sum += d * (float)w_row->qs[j] * token_vec[j];
400  }
401 
402  output[(size_t)t * embed_dim + n] += sum;
403  }
404  }
405  }
406  }
407 }
Quantization block structures for weight-only quantization.
#define QK5_0
Definition: ckernel_quant.h:67
#define CK_FP16_TO_FP32(x)
#define QK8_0
void ck_gemm_nt_head_major_q8_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (Q8_0 weights)
void ck_gemm_nt_head_major_q5_0(const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection from head-major attention (auto-dispatch)
void dequant_q5_0_row(const void *src, float *dst, size_t n_elements)
Dequantize Q5_0 row (multiple blocks)
void gemv_nt_q5_0_head_major_output(float *output, const float *attn_out, const void *wo, const float *bias, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection reading head-major attention output (Q5_0 weights)
ck_half d
Definition: ckernel_quant.h:70
uint8_t qh[4]
Definition: ckernel_quant.h:71
uint8_t qs[32/2]
Definition: ckernel_quant.h:72
int8_t qs[32]