← Back to C-Kernel-Engine Docs Doxygen Source Documentation
gemm_head_major_output.c File Reference

Output projection from head-major attention (NO LAYOUT CONVERSION) More...

#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include "ckernel_quant.h"
#include "ckernel_dtype.h"

Go to the source code of this file.

Functions

void ck_gemm_nt_head_major_q5_0 (const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
 Output projection from head-major attention (auto-dispatch) More...
 
void ck_gemm_nt_head_major_q8_0 (const float *attn_out, const void *wo, const float *bias, float *output, int tokens, int embed_dim, int num_heads, int head_dim)
 Output projection from head-major attention (Q8_0 weights) More...
 
void dequant_q5_0_row (const void *src, float *dst, size_t n_elements)
 Dequantize Q5_0 row (multiple blocks) More...
 
void gemv_nt_q5_0_head_major_output (float *output, const float *attn_out, const void *wo, const float *bias, int tokens, int embed_dim, int num_heads, int head_dim)
 Output projection reading head-major attention output (Q5_0 weights) More...
 

Detailed Description

Output projection from head-major attention (NO LAYOUT CONVERSION)

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. NO memcpy for layout - use strided access, not copies
  4. API must define: inputs, outputs, workspace, and memory layouts
  5. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

PROBLEM THIS SOLVES:

The standard mega_fused_attention_prefill has a bottleneck: attn_out [num_heads, tokens, head_dim] (head-major) → flatten_head_major() - 448 memcpy calls for 32 tokens × 14 heads! → token-major buffer → GEMM output projection

This kernel eliminates the flatten by reading head-major data directly with strided access. The output projection computes:

output[t, n] = bias[n] + sum_h wo[n, h*head_dim:(h+1)*head_dim] @ attn_out[h, t, :]

where wo is Q5_0 quantized [embed_dim, embed_dim] and attn_out is head-major.

Expected speedup: 1.5-2x by eliminating 448 small memcpy calls.

Definition in file gemm_head_major_output.c.

Function Documentation

◆ ck_gemm_nt_head_major_q5_0()

void ck_gemm_nt_head_major_q5_0 ( const float *  attn_out,
const void *  wo,
const float *  bias,
float *  output,
int  tokens,
int  embed_dim,
int  num_heads,
int  head_dim 
)

Output projection from head-major attention (auto-dispatch)

This replaces flatten_head_major() + ck_gemm_nt_quant() with a single strided-access kernel that reads head-major attention output directly.

Definition at line 328 of file gemm_head_major_output.c.

336 {
337 #if defined(__AVX__) && defined(__F16C__)
338  gemv_nt_q5_0_head_major_output_avx(output, attn_out, wo, bias,
339  tokens, embed_dim, num_heads, head_dim);
340 #else
341  gemv_nt_q5_0_head_major_output(output, attn_out, wo, bias,
342  tokens, embed_dim, num_heads, head_dim);
343 #endif
344 }
void gemv_nt_q5_0_head_major_output(float *output, const float *attn_out, const void *wo, const float *bias, int tokens, int embed_dim, int num_heads, int head_dim)
Output projection reading head-major attention output (Q5_0 weights)

References gemv_nt_q5_0_head_major_output().

Referenced by mega_fused_attention_prefill().

◆ ck_gemm_nt_head_major_q8_0()

void ck_gemm_nt_head_major_q8_0 ( const float *  attn_out,
const void *  wo,
const float *  bias,
float *  output,
int  tokens,
int  embed_dim,
int  num_heads,
int  head_dim 
)

Output projection from head-major attention (Q8_0 weights)

Definition at line 353 of file gemm_head_major_output.c.

361 {
362  if (!output || !attn_out || !wo) return;
363  if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0) return;
364 
365  const int blocks_per_head = head_dim / QK8_0;
366  const int blocks_per_row = embed_dim / QK8_0;
367  const block_q8_0 *weights = (const block_q8_0 *)wo;
368 
369  const size_t token_stride = head_dim;
370  const size_t head_stride = (size_t)tokens * token_stride;
371 
372  /* Initialize output */
373  if (bias) {
374  for (int t = 0; t < tokens; t++) {
375  float *out_row = output + (size_t)t * embed_dim;
376  for (int n = 0; n < embed_dim; n++) {
377  out_row[n] = bias[n];
378  }
379  }
380  } else {
381  memset(output, 0, (size_t)tokens * embed_dim * sizeof(float));
382  }
383 
384  /* Accumulate from each head */
385  for (int h = 0; h < num_heads; h++) {
386  const float *head_data = attn_out + (size_t)h * head_stride;
387  const int head_offset = h * blocks_per_head;
388 
389  for (int n_block = 0; n_block < blocks_per_head; n_block++) {
390  for (int n = 0; n < embed_dim; n++) {
391  const block_q8_0 *w_row = weights + (size_t)n * blocks_per_row + head_offset + n_block;
392  const float d = CK_FP16_TO_FP32(w_row->d);
393 
394  for (int t = 0; t < tokens; t++) {
395  const float *token_vec = head_data + (size_t)t * token_stride + (size_t)n_block * QK8_0;
396  float sum = 0.0f;
397 
398  for (int j = 0; j < QK8_0; j++) {
399  sum += d * (float)w_row->qs[j] * token_vec[j];
400  }
401 
402  output[(size_t)t * embed_dim + n] += sum;
403  }
404  }
405  }
406  }
407 }
#define CK_FP16_TO_FP32(x)
#define QK8_0
int8_t qs[32]

References CK_FP16_TO_FP32, block_q8_0::d, QK8_0, and block_q8_0::qs.

Referenced by mega_fused_attention_prefill().

◆ dequant_q5_0_row()

void dequant_q5_0_row ( const void *  src,
float *  dst,
size_t  n_elements 
)

Dequantize Q5_0 row (multiple blocks)

Definition at line 196 of file dequant_kernels.c.

197 {
198  const block_q5_0 *blocks = (const block_q5_0 *)src;
199  const size_t n_blocks = n_elements / QK5_0;
200 
201  for (size_t b = 0; b < n_blocks; b++) {
202  dequant_q5_0_block(&blocks[b], &dst[b * QK5_0]);
203  }
204 }
#define QK5_0
Definition: ckernel_quant.h:67
void dequant_q5_0_block(const block_q5_0 *block, float *output)
Dequantize a single Q5_0 block to FP32.

◆ gemv_nt_q5_0_head_major_output()

void gemv_nt_q5_0_head_major_output ( float *  output,
const float *  attn_out,
const void *  wo,
const float *  bias,
int  tokens,
int  embed_dim,
int  num_heads,
int  head_dim 
)

Output projection reading head-major attention output (Q5_0 weights)

Parameters
outputOutput [tokens, embed_dim] (token-major, written contiguously)
attn_outAttention output [num_heads, tokens, head_dim] (head-major, strided)
woOutput weights in Q5_0 format [embed_dim, embed_dim]
biasOptional bias [embed_dim]
tokensNumber of tokens
embed_dimOutput embedding dimension
num_headsNumber of attention heads
head_dimHead dimension (must be multiple of 32 for Q5_0)

Definition at line 62 of file gemm_head_major_output.c.

70 {
71  if (!output || !attn_out || !wo) return;
72  if (tokens <= 0 || embed_dim <= 0 || num_heads <= 0 || head_dim <= 0) return;
73 
74  const int blocks_per_head = head_dim / QK5_0;
75  const int blocks_per_row = embed_dim / QK5_0;
76  const block_q5_0 *weights = (const block_q5_0 *)wo;
77 
78  /* Strides for head-major layout */
79  const size_t token_stride = head_dim; /* attn_out[h][t] offset */
80  const size_t head_stride = (size_t)tokens * token_stride; /* attn_out[h] offset */
81 
82  /* Initialize output with bias (if provided) */
83  if (bias) {
84  for (int t = 0; t < tokens; t++) {
85  float *out_row = output + (size_t)t * embed_dim;
86  for (int n = 0; n < embed_dim; n++) {
87  out_row[n] = bias[n];
88  }
89  }
90  } else {
91  memset(output, 0, (size_t)tokens * embed_dim * sizeof(float));
92  }
93 
94  /* Accumulate contributions from each head */
95  for (int h = 0; h < num_heads; h++) {
96  const float *head_data = attn_out + (size_t)h * head_stride;
97 
98  /* For each output row (n) corresponding to this head's slice */
99  const int head_offset = h * blocks_per_head;
100 
101  for (int n_block = 0; n_block < blocks_per_head; n_block++) {
102  for (int n = 0; n < embed_dim; n++) {
103  const block_q5_0 *w_row = weights + (size_t)n * blocks_per_row + head_offset + n_block;
104  const float d = CK_FP16_TO_FP32(w_row->d);
105 
106  /* Get high bits */
107  uint32_t qh;
108  memcpy(&qh, w_row->qh, sizeof(qh));
109 
110  /* Accumulate for all tokens at once (better cache reuse) */
111  for (int t = 0; t < tokens; t++) {
112  const float *token_vec = head_data + (size_t)t * token_stride + (size_t)n_block * QK5_0;
113  float sum = 0.0f;
114 
115  /* Q5_0 dot product for this block */
116  for (int j = 0; j < QK5_0 / 2; j++) {
117  const uint8_t packed = w_row->qs[j];
118  const int lo = (packed & 0x0F);
119  const int hi = (packed >> 4);
120  const int xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
121  const int xh_1 = ((qh >> (j + 12))) & 0x10;
122  const int q0 = (lo | xh_0) - 16;
123  const int q1 = (hi | xh_1) - 16;
124 
125  sum += d * (float)q0 * token_vec[j];
126  sum += d * (float)q1 * token_vec[j + 16];
127  }
128 
129  output[(size_t)t * embed_dim + n] += sum;
130  }
131  }
132  }
133  }
134 }
ck_half d
Definition: ckernel_quant.h:70
uint8_t qh[4]
Definition: ckernel_quant.h:71
uint8_t qs[32/2]
Definition: ckernel_quant.h:72

References CK_FP16_TO_FP32, block_q5_0::d, block_q5_0::qh, QK5_0, and block_q5_0::qs.

Referenced by ck_gemm_nt_head_major_q5_0().