← Back to C-Kernel-Engine Docs Doxygen Source Documentation
prefill_fused_gemm.c File Reference

Fused kernels for prefill phase with proper 2D tiling. More...

#include "ckernel_engine.h"
#include "ckernel_quant.h"
#include <math.h>
#include <string.h>
#include <stddef.h>
#include <stdio.h>

Go to the source code of this file.

Macros

#define PREFILL_TILE_M   64
 
#define PREFILL_TILE_N   256
 

Functions

static void add_bias_tile (float *out, const float *bias, int tile_m, int out_dim)
 
static size_t align_up_size (size_t value, size_t align)
 
void fused_mlp_swiglu_prefill (const float *x, const float *W_gate, const float *W_up, const float *W_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
 Fused MLP (Gate + Up + SwiGLU + Down) for prefill. More...
 
void fused_mlp_swiglu_prefill_bias (const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
 Fused MLP for prefill with proper tiling. More...
 
void fused_mlp_swiglu_prefill_w1w2_quant (const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
 Quantized fused MLP for prefill (W1=gate+up, W2=down) More...
 
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size (int aligned_embed_dim, int aligned_intermediate_dim)
 Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant. More...
 
size_t fused_mlp_swiglu_scratch_size (int intermediate)
 Get scratch size for fused MLP. More...
 
static void fused_rmsnorm_gemm_2d_tiled (const float *x, const float *gamma, const float *W, float *output, int seq_len, int hidden, int out_dim, float eps, float *x_norm_scratch)
 Fused RMSNorm + single GEMM with 2D tiling (weight reuse) More...
 
void fused_rmsnorm_qkv_prefill (const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float *scratch)
 Fused RMSNorm + QKV projection for prefill (v3 optimized) More...
 
void fused_rmsnorm_qkv_prefill_head_major (const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
 Fused RMSNorm + QKV projection for prefill (head-major outputs) More...
 
void fused_rmsnorm_qkv_prefill_head_major_quant (const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
 Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations) More...
 
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size (int aligned_embed_dim)
 Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant. More...
 
size_t fused_rmsnorm_qkv_scratch_size (int hidden)
 Get scratch size for fused prefill. More...
 
static void gemm_nt_q8_0_dispatch (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
 
static void gemm_nt_q8_0_mlp_dispatch (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
 
static void gemm_nt_q8_k_mlp_dispatch (const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
 
static void gemm_nt_q8_k_qkv_dispatch (const void *A_q8k, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
 
static void gemm_tile_nt_strided (const float *A, const float *B_tile, float *C, int tile_m, int tile_n, int K, int C_stride)
 GEMM tile with N-dimension tiling (weight reuse) More...
 
static int mlp_q8_0_dtype_supported (CKDataType dt)
 
static int mlp_q8_k_dtype_supported (CKDataType dt)
 
static int qkv_q8_0_dtype_supported (CKDataType dt)
 
static int qkv_q8_k_dtype_supported (CKDataType dt)
 
static void rmsnorm_tile (const float *input, const float *gamma, float *output, int tile_m, int embed_dim, int aligned_embed_dim, float eps)
 Compute RMSNorm for a tile of tokens. More...
 
static float silu_prefill (float x)
 
void unfused_rmsnorm_qkv_prefill (const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *x_norm, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps)
 Unfused version for comparison. More...
 

Detailed Description

Fused kernels for prefill phase with proper 2D tiling.

CK-ENGINE KERNEL RULES:

  1. NO malloc/free - memory via bump allocator, pointers passed in
  2. NO OpenMP - parallelization at orchestrator/codegen layer
  3. API must define: inputs, outputs, workspace, and memory layouts
  4. Pure computation - deterministic, no side effects

After changes: make test && make llamacpp-parity-full

KEY INSIGHT:

Naive M-dimension tiling (token tiles) causes weight reloading:

  • 32 token tiles × 4MB weights = 128MB DRAM reads!

Correct approach: Tile along N (output/weight) dimension OUTER, M (token) dimension INNER. This way:

  • Load weight tile once
  • Process ALL tokens against that weight tile
  • Weight tile stays in cache while streaming through tokens

TILING STRATEGY:

For C[M,N] = RMSNorm(A[M,K]) × B[N,K]^T:

for n_tile in [0, N, TILE_N]: # Outer: weight tiles load B[n_tile:n_tile+TILE_N, :] into L3 for m_tile in [0, M, TILE_M]: # Inner: token tiles x_norm = rmsnorm(A[m_tile]) # x_norm in L2 C[m_tile, n_tile] = x_norm × B_tile # Consumes B from L3

Cache behavior:

  • Weight tile (TILE_N × K × 4 bytes) fits in L3
  • x_norm tile (TILE_M × K × 4 bytes) fits in L2
  • Weights loaded once per tile, reused across all token tiles

Definition in file prefill_fused_gemm.c.

Macro Definition Documentation

◆ PREFILL_TILE_M

#define PREFILL_TILE_M   64

Definition at line 64 of file prefill_fused_gemm.c.

◆ PREFILL_TILE_N

#define PREFILL_TILE_N   256

Definition at line 65 of file prefill_fused_gemm.c.

Function Documentation

◆ add_bias_tile()

static void add_bias_tile ( float *  out,
const float *  bias,
int  tile_m,
int  out_dim 
)
static

Definition at line 310 of file prefill_fused_gemm.c.

314 {
315  if (!out || !bias) {
316  return;
317  }
318  for (int i = 0; i < tile_m; ++i) {
319  float *row = out + (size_t)i * (size_t)out_dim;
320  for (int j = 0; j < out_dim; ++j) {
321  row[j] += bias[j];
322  }
323  }
324 }

Referenced by fused_mlp_swiglu_prefill_bias(), and fused_rmsnorm_qkv_prefill_head_major().

◆ align_up_size()

static size_t align_up_size ( size_t  value,
size_t  align 
)
static

◆ fused_mlp_swiglu_prefill()

void fused_mlp_swiglu_prefill ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
float *  output,
int  seq_len,
int  hidden,
int  intermediate,
float *  scratch 
)

Fused MLP (Gate + Up + SwiGLU + Down) for prefill.

Tiles along token dimension to keep gate/up/hidden in L3 cache.

Parameters
scratchTemporary buffer from fused_mlp_swiglu_scratch_size()

Definition at line 879 of file prefill_fused_gemm.c.

889 {
890  fused_mlp_swiglu_prefill_bias(x, W_gate, W_up, W_down,
891  NULL, NULL, NULL,
892  output, seq_len, hidden, intermediate,
893  scratch);
894 }
void fused_mlp_swiglu_prefill_bias(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP for prefill with proper tiling.

References fused_mlp_swiglu_prefill_bias().

◆ fused_mlp_swiglu_prefill_bias()

void fused_mlp_swiglu_prefill_bias ( const float *  x,
const float *  W_gate,
const float *  W_up,
const float *  W_down,
const float *  B_gate,
const float *  B_up,
const float *  B_down,
float *  output,
int  seq_len,
int  hidden,
int  intermediate,
float *  scratch 
)

Fused MLP for prefill with proper tiling.

Fused MLP (Gate + Up + SwiGLU + Down) for prefill with biases.

Definition at line 746 of file prefill_fused_gemm.c.

759 {
760  /* MLP is more complex because we have:
761  * gate = x @ W_gate
762  * up = x @ W_up
763  * hidden = silu(gate) * up
764  * out = hidden @ W_down
765  *
766  * The intermediate (gate, up, hidden) is large: seq_len × intermediate
767  * For Qwen2-0.5B: 1024 × 4864 × 4 = 19.4MB (way bigger than L3!)
768  *
769  * Strategy: Tile along intermediate dimension for gate/up,
770  * then fuse SwiGLU, then tile down projection.
771  */
772 
773  /* scratch layout:
774  * [gate_tile: TILE_M × TILE_N_INTER]
775  * [up_tile: TILE_M × TILE_N_INTER]
776  */
777  const int TILE_N_INTER = 512; /* Intermediate tile size */
778  float *gate_tile = scratch;
779  float *up_tile = scratch + (size_t)PREFILL_TILE_M * TILE_N_INTER;
780  float *hidden_tile = gate_tile; /* Reuse gate_tile for hidden after SwiGLU */
781 
782  /* For each chunk of intermediate dimension */
783  for (int inter_start = 0; inter_start < intermediate; inter_start += TILE_N_INTER) {
784  int tile_inter = (inter_start + TILE_N_INTER <= intermediate)
785  ? TILE_N_INTER : (intermediate - inter_start);
786 
787  const float *W_gate_tile = W_gate + (size_t)inter_start * hidden;
788  const float *W_up_tile = W_up + (size_t)inter_start * hidden;
789 
790  /* For each chunk of tokens */
791  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
792  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
793  ? PREFILL_TILE_M : (seq_len - m_start);
794 
795  const float *x_tile = x + (size_t)m_start * hidden;
796 
797  /* Compute gate and up projections for this tile */
798  gemm_tile_nt_strided(x_tile, W_gate_tile, gate_tile,
799  tile_m, tile_inter, hidden, tile_inter);
800  gemm_tile_nt_strided(x_tile, W_up_tile, up_tile,
801  tile_m, tile_inter, hidden, tile_inter);
802  if (B_gate) {
803  add_bias_tile(gate_tile, B_gate + inter_start, tile_m, tile_inter);
804  }
805  if (B_up) {
806  add_bias_tile(up_tile, B_up + inter_start, tile_m, tile_inter);
807  }
808 
809  /* Fused SwiGLU: hidden = silu(gate) * up */
810  for (int i = 0; i < tile_m; ++i) {
811  float *g = gate_tile + (size_t)i * tile_inter;
812  float *u = up_tile + (size_t)i * tile_inter;
813  for (int j = 0; j < tile_inter; ++j) {
814  float gv = g[j];
815  float silu = gv / (1.0f + expf(-gv));
816  g[j] = silu * u[j]; /* hidden_tile = gate_tile */
817  }
818  }
819 
820  /* Down projection: accumulate into output
821  * out[m_start:, :] += hidden_tile @ W_down[inter_start:, :]^T
822  */
823  const float *W_down_slice = W_down + (size_t)inter_start; /* Column slice */
824  float *out_tile = output + (size_t)m_start * hidden;
825 
826  /* This is trickier - W_down is [hidden × intermediate]
827  * We have hidden_tile[tile_m × tile_inter]
828  * We want out[tile_m × hidden] += hidden_tile × W_down[:, inter_start:inter_start+tile_inter]^T
829  *
830  * For proper accumulation, need to handle this carefully.
831  * For now, use a simpler approach: accumulate partial results.
832  */
833  for (int i = 0; i < tile_m; ++i) {
834  float *h = hidden_tile + (size_t)i * tile_inter;
835  float *o = out_tile + (size_t)i * hidden;
836 
837  for (int d = 0; d < hidden; ++d) {
838  const float *w_row = W_down + (size_t)d * intermediate + inter_start;
839  float sum = (inter_start == 0)
840  ? (B_down ? B_down[d] : 0.0f)
841  : o[d];
842 
843 #if defined(__AVX512F__)
844  __m512 acc = _mm512_setzero_ps();
845  int j = 0;
846  for (; j + 16 <= tile_inter; j += 16) {
847  __m512 hv = _mm512_loadu_ps(h + j);
848  __m512 wv = _mm512_loadu_ps(w_row + j);
849  acc = _mm512_fmadd_ps(hv, wv, acc);
850  }
851  sum += _mm512_reduce_add_ps(acc);
852  for (; j < tile_inter; ++j) {
853  sum += h[j] * w_row[j];
854  }
855 #elif defined(__AVX__)
856  __m256 acc = _mm256_setzero_ps();
857  int j = 0;
858  for (; j + 8 <= tile_inter; j += 8) {
859  __m256 hv = _mm256_loadu_ps(h + j);
860  __m256 wv = _mm256_loadu_ps(w_row + j);
861  acc = _mm256_add_ps(acc, _mm256_mul_ps(hv, wv));
862  }
863  sum += hsum256_prefill(acc);
864  for (; j < tile_inter; ++j) {
865  sum += h[j] * w_row[j];
866  }
867 #else
868  for (int j = 0; j < tile_inter; ++j) {
869  sum += h[j] * w_row[j];
870  }
871 #endif
872  o[d] = sum;
873  }
874  }
875  }
876  }
877 }
#define PREFILL_TILE_M
static void add_bias_tile(float *out, const float *bias, int tile_m, int out_dim)
static void gemm_tile_nt_strided(const float *A, const float *B_tile, float *C, int tile_m, int tile_n, int K, int C_stride)
GEMM tile with N-dimension tiling (weight reuse)
static void silu(float *x, int n)
Definition: v6_simple.c:159

References add_bias_tile(), gemm_tile_nt_strided(), PREFILL_TILE_M, and silu().

Referenced by fused_mlp_swiglu_prefill().

◆ fused_mlp_swiglu_prefill_w1w2_quant()

void fused_mlp_swiglu_prefill_w1w2_quant ( const float *  x,
const void *  W1,
const float *  B1,
CKDataType  w1_dt,
const void *  W2,
const float *  B2,
CKDataType  w2_dt,
float *  output,
int  seq_len,
int  embed_dim,
int  aligned_embed_dim,
int  intermediate_dim,
int  aligned_intermediate_dim,
void *  scratch 
)

Quantized fused MLP for prefill (W1=gate+up, W2=down)

Uses Q8_0 activations for W1 (Q5_0/Q8_0 weights) and Q8_K activations for W2 (Q4_K/Q6_K weights).

Definition at line 965 of file prefill_fused_gemm.c.

980 {
981  if (!x || !W1 || !W2 || !output || !scratch) {
982  return;
983  }
984  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
985  intermediate_dim <= 0 || aligned_intermediate_dim <= 0) {
986  return;
987  }
988  if (aligned_embed_dim < embed_dim || aligned_intermediate_dim < intermediate_dim) {
989  return;
990  }
991  if ((aligned_embed_dim % 32) != 0 || (aligned_intermediate_dim % 256) != 0) {
992  return;
993  }
994  if (!mlp_q8_0_dtype_supported(w1_dt) || !mlp_q8_k_dtype_supported(w2_dt)) {
995  return;
996  }
997 
998  const int tile_m_max = PREFILL_TILE_M;
999  const int inter = aligned_intermediate_dim;
1000  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
1001  const size_t q8k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_intermediate_dim);
1002  const size_t w1_row_bytes = ck_dtype_row_bytes(w1_dt, (size_t)aligned_embed_dim);
1003 
1004  uint8_t *scratch_bytes = (uint8_t *)scratch;
1005  size_t q8_bytes = (size_t)tile_m_max * q8_row_bytes;
1006  size_t gate_bytes = (size_t)tile_m_max * (size_t)inter * sizeof(float);
1007  size_t up_bytes = gate_bytes;
1008  size_t gate_offset = align_up_size(q8_bytes, 64);
1009  size_t up_offset = gate_offset + align_up_size(gate_bytes, 64);
1010  size_t q8k_offset = up_offset + align_up_size(up_bytes, 64);
1011 
1012  uint8_t *q8_tile = scratch_bytes;
1013  float *gate_tile = (float *)(scratch_bytes + gate_offset);
1014  float *up_tile = (float *)(scratch_bytes + up_offset);
1015  uint8_t *q8k_tile = scratch_bytes + q8k_offset;
1016 
1017  const uint8_t *w1_base = (const uint8_t *)W1;
1018  const uint8_t *w_gate = w1_base;
1019  const uint8_t *w_up = w1_base + (size_t)inter * w1_row_bytes;
1020 
1021  const float *b_gate = B1;
1022  const float *b_up = B1 ? (B1 + (size_t)inter) : NULL;
1023 
1024  for (int m_start = 0; m_start < seq_len; m_start += tile_m_max) {
1025  int tile_m = (m_start + tile_m_max <= seq_len)
1026  ? tile_m_max : (seq_len - m_start);
1027 
1028  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
1029  float *out_tile = output + (size_t)m_start * (size_t)aligned_embed_dim;
1030 
1031  for (int t = 0; t < tile_m; ++t) {
1032  const float *row = x_tile + (size_t)t * (size_t)aligned_embed_dim;
1033  quantize_row_q8_0(row,
1034  q8_tile + (size_t)t * q8_row_bytes,
1035  aligned_embed_dim);
1036  }
1037 
1038  gemm_nt_q8_0_mlp_dispatch(q8_tile, w_gate, b_gate, gate_tile,
1039  tile_m, inter, aligned_embed_dim, w1_dt);
1040  gemm_nt_q8_0_mlp_dispatch(q8_tile, w_up, b_up, up_tile,
1041  tile_m, inter, aligned_embed_dim, w1_dt);
1042 
1043  for (int i = 0; i < tile_m; ++i) {
1044  float *g = gate_tile + (size_t)i * (size_t)inter;
1045  float *u = up_tile + (size_t)i * (size_t)inter;
1046  for (int j = 0; j < inter; ++j) {
1047  g[j] = silu_prefill(g[j]) * u[j];
1048  }
1049  }
1050 
1051  for (int i = 0; i < tile_m; ++i) {
1052  const float *row = gate_tile + (size_t)i * (size_t)inter;
1053  quantize_row_q8_k(row,
1054  q8k_tile + (size_t)i * q8k_row_bytes,
1055  aligned_intermediate_dim);
1056  }
1057 
1058  gemm_nt_q8_k_mlp_dispatch(q8k_tile, W2, B2, out_tile,
1059  tile_m, aligned_embed_dim, aligned_intermediate_dim, w2_dt);
1060  }
1061 }
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q8_K
Definition: ckernel_dtype.h:43
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void quantize_row_q8_k(const float *x, void *y, int k)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
static void gemm_nt_q8_0_mlp_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static size_t align_up_size(size_t value, size_t align)
static int mlp_q8_k_dtype_supported(CKDataType dt)
static float silu_prefill(float x)
static int mlp_q8_0_dtype_supported(CKDataType dt)
static void gemm_nt_q8_k_mlp_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)

References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), gemm_nt_q8_0_mlp_dispatch(), gemm_nt_q8_k_mlp_dispatch(), mlp_q8_0_dtype_supported(), mlp_q8_k_dtype_supported(), PREFILL_TILE_M, quantize_row_q8_0(), quantize_row_q8_k(), and silu_prefill().

Referenced by mega_fused_outproj_mlp_prefill().

◆ fused_mlp_swiglu_prefill_w1w2_quant_scratch_size()

size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size ( int  aligned_embed_dim,
int  aligned_intermediate_dim 
)

Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.

Definition at line 1063 of file prefill_fused_gemm.c.

1065 {
1066  if (aligned_embed_dim <= 0 || aligned_intermediate_dim <= 0) {
1067  return 0;
1068  }
1069  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
1070  const size_t q8k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_intermediate_dim);
1071  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
1072  const size_t gate_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_intermediate_dim * sizeof(float);
1073  const size_t up_bytes = gate_bytes;
1074  const size_t q8k_bytes = (size_t)PREFILL_TILE_M * q8k_row_bytes;
1075 
1076  return align_up_size(q8_bytes, 64) +
1077  align_up_size(gate_bytes, 64) +
1078  align_up_size(up_bytes, 64) +
1079  align_up_size(q8k_bytes, 64);
1080 }

References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), and PREFILL_TILE_M.

Referenced by mega_fused_outproj_mlp_prefill_scratch_size().

◆ fused_mlp_swiglu_scratch_size()

size_t fused_mlp_swiglu_scratch_size ( int  intermediate)

Get scratch size for fused MLP.

Get scratch buffer size for fused_mlp_swiglu_prefill.

Definition at line 899 of file prefill_fused_gemm.c.

899  {
900  const int TILE_N_INTER = 512;
901  /* gate_tile + up_tile */
902  return 2 * (size_t)PREFILL_TILE_M * TILE_N_INTER * sizeof(float);
903 }

References PREFILL_TILE_M.

◆ fused_rmsnorm_gemm_2d_tiled()

static void fused_rmsnorm_gemm_2d_tiled ( const float *  x,
const float *  gamma,
const float *  W,
float *  output,
int  seq_len,
int  hidden,
int  out_dim,
float  eps,
float *  x_norm_scratch 
)
static

Fused RMSNorm + single GEMM with 2D tiling (weight reuse)

Tiles along N (weights) OUTER, M (tokens) INNER. Weight tiles are reused across all token tiles.

Definition at line 332 of file prefill_fused_gemm.c.

342 {
343  /* Outer loop: tile along output dimension (N) - weight tiles */
344  for (int n_start = 0; n_start < out_dim; n_start += PREFILL_TILE_N) {
345  int tile_n = (n_start + PREFILL_TILE_N <= out_dim)
347  : (out_dim - n_start);
348 
349  /* Weight tile pointer - this tile stays in L3 cache */
350  const float *W_tile = W + (size_t)n_start * hidden;
351 
352  /* Inner loop: tile along token dimension (M) */
353  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
354  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
356  : (seq_len - m_start);
357 
358  const float *x_tile = x + (size_t)m_start * hidden;
359  float *out_tile = output + (size_t)m_start * out_dim + n_start;
360 
361  /* Compute RMSNorm for this token tile (only on first weight tile) */
362  if (n_start == 0) {
363  rmsnorm_tile(x_tile, gamma, x_norm_scratch, tile_m, hidden, hidden, eps);
364  } else {
365  /* Recompute x_norm for this tile (we can't cache all of it) */
366  /* TODO: For very large N, consider caching x_norm chunks */
367  rmsnorm_tile(x_tile, gamma, x_norm_scratch, tile_m, hidden, hidden, eps);
368  }
369 
370  /* GEMM: x_norm_tile × W_tile^T → output tile */
371  gemm_tile_nt_strided(x_norm_scratch, W_tile, out_tile,
372  tile_m, tile_n, hidden, out_dim);
373  }
374  }
375 }
#define PREFILL_TILE_N
static void rmsnorm_tile(const float *input, const float *gamma, float *output, int tile_m, int embed_dim, int aligned_embed_dim, float eps)
Compute RMSNorm for a tile of tokens.

References gemm_tile_nt_strided(), PREFILL_TILE_M, PREFILL_TILE_N, and rmsnorm_tile().

◆ fused_rmsnorm_qkv_prefill()

void fused_rmsnorm_qkv_prefill ( const float *  x,
const float *  gamma,
const float *  Wq,
const float *  Wk,
const float *  Wv,
float *  Q,
float *  K,
float *  V,
int  seq_len,
int  hidden,
int  q_dim,
int  kv_dim,
float  eps,
float *  scratch 
)

Fused RMSNorm + QKV projection for prefill (v3 optimized)

Fused RMSNorm + QKV projection for prefill.

KEY INSIGHT: For Qwen2-0.5B, all QKV weights fit in L3: Wq (896×896) + Wk (128×896) + Wv (128×896) = 4.1MB < 6MB L3

So we use M-tiling (tokens) only:

  1. For each token tile: a. Compute RMSNorm ONCE into scratch (x_norm stays in L2) b. Do all three GEMMs (Q, K, V) against cached x_norm c. Weights stay hot in L3 across all token tiles

This avoids both:

  • Large x_norm intermediate buffer (only TILE_M × hidden in L2)
  • RMSNorm recomputation (done once per token tile, used 3×)

Definition at line 393 of file prefill_fused_gemm.c.

408 {
409  /* scratch is x_norm tile: [TILE_M × hidden] fits in L2 */
410 
411  /* Process token tiles - weights stay in L3 across all tiles */
412  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
413  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
414  ? PREFILL_TILE_M : (seq_len - m_start);
415 
416  const float *x_tile = x + (size_t)m_start * hidden;
417 
418  /* Step 1: RMSNorm for this token tile (computed ONCE, used 3×) */
419  rmsnorm_tile(x_tile, gamma, scratch, tile_m, hidden, hidden, eps);
420 
421  /* Step 2: Q projection - x_norm is hot in L2, Wq hot in L3 */
422  float *Q_tile = Q + (size_t)m_start * q_dim;
423  gemm_tile_nt_strided(scratch, Wq, Q_tile, tile_m, q_dim, hidden, q_dim);
424 
425  /* Step 3: K projection - x_norm still hot, Wk displaces some Wq */
426  float *K_tile = K + (size_t)m_start * kv_dim;
427  gemm_tile_nt_strided(scratch, Wk, K_tile, tile_m, kv_dim, hidden, kv_dim);
428 
429  /* Step 4: V projection - x_norm still hot, Wv displaces Wk */
430  float *V_tile = V + (size_t)m_start * kv_dim;
431  gemm_tile_nt_strided(scratch, Wv, V_tile, tile_m, kv_dim, hidden, kv_dim);
432  }
433 }

References gemm_tile_nt_strided(), PREFILL_TILE_M, and rmsnorm_tile().

◆ fused_rmsnorm_qkv_prefill_head_major()

void fused_rmsnorm_qkv_prefill_head_major ( const float *  x,
const float *  gamma,
const float *  Wq,
const float *  Bq,
const float *  Wk,
const float *  Bk,
const float *  Wv,
const float *  Bv,
float *  Q,
float *  K,
float *  V,
int  seq_len,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens,
float  eps,
float *  scratch 
)

Fused RMSNorm + QKV projection for prefill (head-major outputs)

Q is written as [num_heads, seq_len, aligned_head_dim]. K/V are written with kv_stride_tokens for KV-cache compatibility.

Definition at line 441 of file prefill_fused_gemm.c.

460 {
461  if (!x || !gamma || !Wq || !Wk || !Wv || !Q || !K || !V || !scratch) {
462  return;
463  }
464  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
465  head_dim <= 0 || aligned_head_dim <= 0 ||
466  num_heads <= 0 || num_kv_heads <= 0) {
467  return;
468  }
469  if (kv_stride_tokens < seq_len) {
470  return;
471  }
472 
473  const size_t q_head_stride = (size_t)seq_len * (size_t)aligned_head_dim;
474  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
475  const size_t head_w_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
476 
477  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
478  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
479  ? PREFILL_TILE_M : (seq_len - m_start);
480 
481  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
482  rmsnorm_tile(x_tile, gamma, scratch, tile_m, embed_dim, aligned_embed_dim, eps);
483 
484  for (int h = 0; h < num_heads; ++h) {
485  const float *wq_h = Wq + (size_t)h * head_w_stride;
486  const float *bq_h = Bq ? (Bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
487  float *q_h = Q + (size_t)h * q_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
488 
489  gemm_tile_nt_strided(scratch, wq_h, q_h,
490  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
491  add_bias_tile(q_h, bq_h, tile_m, aligned_head_dim);
492  }
493 
494  for (int h = 0; h < num_kv_heads; ++h) {
495  const float *wk_h = Wk + (size_t)h * head_w_stride;
496  const float *wv_h = Wv + (size_t)h * head_w_stride;
497  const float *bk_h = Bk ? (Bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
498  const float *bv_h = Bv ? (Bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
499  float *k_h = K + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
500  float *v_h = V + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
501 
502  gemm_tile_nt_strided(scratch, wk_h, k_h,
503  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
504  add_bias_tile(k_h, bk_h, tile_m, aligned_head_dim);
505 
506  gemm_tile_nt_strided(scratch, wv_h, v_h,
507  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
508  add_bias_tile(v_h, bv_h, tile_m, aligned_head_dim);
509  }
510  }
511 }

References add_bias_tile(), gemm_tile_nt_strided(), PREFILL_TILE_M, and rmsnorm_tile().

Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().

◆ fused_rmsnorm_qkv_prefill_head_major_quant()

void fused_rmsnorm_qkv_prefill_head_major_quant ( const float *  x,
const float *  gamma,
const void *  Wq,
const float *  Bq,
CKDataType  wq_dt,
const void *  Wk,
const float *  Bk,
CKDataType  wk_dt,
const void *  Wv,
const float *  Bv,
CKDataType  wv_dt,
float *  Q,
float *  K,
float *  V,
int  seq_len,
int  embed_dim,
int  aligned_embed_dim,
int  num_heads,
int  num_kv_heads,
int  head_dim,
int  aligned_head_dim,
int  kv_stride_tokens,
float  eps,
void *  scratch 
)

Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)

Supports Q5_0 or Q8_0 weights with Q8_0 activations. Writes K/V directly into KV cache layout (kv_stride_tokens).

Definition at line 519 of file prefill_fused_gemm.c.

538 {
539  if (!x || !gamma || !Wq || !Wk || !Wv || !Q || !K || !V || !scratch) {
540  return;
541  }
542  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
543  head_dim <= 0 || aligned_head_dim <= 0 ||
544  num_heads <= 0 || num_kv_heads <= 0) {
545  return;
546  }
547  if (aligned_embed_dim % 32 != 0) {
548  return;
549  }
550  if (kv_stride_tokens < seq_len) {
551  return;
552  }
553  /* Determine quantization path: Q8_0 activations for Q5_0/Q8_0 weights,
554  * Q8_K activations for Q4_K/Q6_K weights. All QKV weights must use
555  * the same quantization family. */
556  int use_q8_k_path = qkv_q8_k_dtype_supported(wq_dt);
557  int use_q8_0_path = qkv_q8_0_dtype_supported(wq_dt);
558 
559  if (!use_q8_k_path && !use_q8_0_path) {
560  /* Unsupported dtype for wq */
561  return;
562  }
563 
564  /* Verify all dtypes are from the same family */
565  if (use_q8_k_path) {
566  if (!qkv_q8_k_dtype_supported(wk_dt) || !qkv_q8_k_dtype_supported(wv_dt)) {
567  return; /* Mixed Q8_K and Q8_0 paths not supported */
568  }
569  } else {
570  if (!qkv_q8_0_dtype_supported(wk_dt) || !qkv_q8_0_dtype_supported(wv_dt)) {
571  return;
572  }
573  }
574 
575  const size_t float_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_embed_dim * sizeof(float);
576  /* Q8_K has larger blocks (256) than Q8_0 (32), so use appropriate size */
577  const CKDataType act_quant_type = use_q8_k_path ? CK_DT_Q8_K : CK_DT_Q8_0;
578  const size_t q8_row_bytes = ck_dtype_row_bytes(act_quant_type, (size_t)aligned_embed_dim);
579  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
580  const size_t q8_offset = align_up_size(float_bytes, 64);
581 
582  float *normed = (float *)scratch;
583  uint8_t *q8_tile = (uint8_t *)scratch + q8_offset;
584  (void)q8_bytes;
585 
586  const size_t q_head_stride = (size_t)seq_len * (size_t)aligned_head_dim;
587  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
588  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
589  const size_t wq_head_bytes = ck_dtype_row_bytes(wq_dt, head_w_elems);
590  const size_t wk_head_bytes = ck_dtype_row_bytes(wk_dt, head_w_elems);
591  const size_t wv_head_bytes = ck_dtype_row_bytes(wv_dt, head_w_elems);
592 
593  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
594  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
595  ? PREFILL_TILE_M : (seq_len - m_start);
596 
597  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
598  rmsnorm_tile(x_tile, gamma, normed, tile_m, embed_dim, aligned_embed_dim, eps);
599 
600  /* Quantize activations to appropriate format */
601  for (int t = 0; t < tile_m; ++t) {
602  const float *row = normed + (size_t)t * (size_t)aligned_embed_dim;
603  if (use_q8_k_path) {
604  quantize_row_q8_k(row,
605  q8_tile + (size_t)t * q8_row_bytes,
606  aligned_embed_dim);
607  } else {
608  quantize_row_q8_0(row,
609  q8_tile + (size_t)t * q8_row_bytes,
610  aligned_embed_dim);
611  }
612  }
613 
614  for (int h = 0; h < num_heads; ++h) {
615  const uint8_t *wq_h = (const uint8_t *)Wq + (size_t)h * wq_head_bytes;
616  const float *bq_h = Bq ? (Bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
617  float *q_h = Q + (size_t)h * q_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
618 
619  if (use_q8_k_path) {
620  gemm_nt_q8_k_qkv_dispatch(q8_tile, wq_h, bq_h, q_h,
621  tile_m, aligned_head_dim, aligned_embed_dim, wq_dt);
622  } else {
623  gemm_nt_q8_0_dispatch(q8_tile, wq_h, bq_h, q_h,
624  tile_m, aligned_head_dim, aligned_embed_dim, wq_dt);
625  }
626  }
627 
628  for (int h = 0; h < num_kv_heads; ++h) {
629  const uint8_t *wk_h = (const uint8_t *)Wk + (size_t)h * wk_head_bytes;
630  const uint8_t *wv_h = (const uint8_t *)Wv + (size_t)h * wv_head_bytes;
631  const float *bk_h = Bk ? (Bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
632  const float *bv_h = Bv ? (Bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
633  float *k_h = K + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
634  float *v_h = V + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
635 
636  if (use_q8_k_path) {
637  gemm_nt_q8_k_qkv_dispatch(q8_tile, wk_h, bk_h, k_h,
638  tile_m, aligned_head_dim, aligned_embed_dim, wk_dt);
639  gemm_nt_q8_k_qkv_dispatch(q8_tile, wv_h, bv_h, v_h,
640  tile_m, aligned_head_dim, aligned_embed_dim, wv_dt);
641  } else {
642  gemm_nt_q8_0_dispatch(q8_tile, wk_h, bk_h, k_h,
643  tile_m, aligned_head_dim, aligned_embed_dim, wk_dt);
644  gemm_nt_q8_0_dispatch(q8_tile, wv_h, bv_h, v_h,
645  tile_m, aligned_head_dim, aligned_embed_dim, wv_dt);
646  }
647  }
648  }
649 }
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
static void gemm_nt_q8_0_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static int qkv_q8_k_dtype_supported(CKDataType dt)
static void gemm_nt_q8_k_qkv_dispatch(const void *A_q8k, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static int qkv_q8_0_dtype_supported(CKDataType dt)

References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), gemm_nt_q8_0_dispatch(), gemm_nt_q8_k_qkv_dispatch(), PREFILL_TILE_M, qkv_q8_0_dtype_supported(), qkv_q8_k_dtype_supported(), quantize_row_q8_0(), quantize_row_q8_k(), and rmsnorm_tile().

Referenced by mega_fused_attention_prefill(), and mega_fused_attention_prefill_q8_0().

◆ fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size()

size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size ( int  aligned_embed_dim)

Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.

Definition at line 651 of file prefill_fused_gemm.c.

651  {
652  if (aligned_embed_dim <= 0) {
653  return 0;
654  }
655  const size_t float_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_embed_dim * sizeof(float);
656  /* Use max of Q8_0 and Q8_K sizes to support both paths */
657  const size_t q8_0_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
658  const size_t q8_k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_embed_dim);
659  const size_t q8_row_bytes = (q8_k_row_bytes > q8_0_row_bytes) ? q8_k_row_bytes : q8_0_row_bytes;
660  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
661  return align_up_size(float_bytes, 64) + q8_bytes;
662 }

References align_up_size(), CK_DT_Q8_0, CK_DT_Q8_K, ck_dtype_row_bytes(), and PREFILL_TILE_M.

Referenced by mega_fused_attention_prefill(), mega_fused_attention_prefill_q8_0(), mega_fused_attention_prefill_q8_0_scratch_size(), and mega_fused_attention_prefill_scratch_size().

◆ fused_rmsnorm_qkv_scratch_size()

size_t fused_rmsnorm_qkv_scratch_size ( int  hidden)

Get scratch size for fused prefill.

Get scratch buffer size for fused_rmsnorm_qkv_prefill.

Definition at line 739 of file prefill_fused_gemm.c.

739  {
740  return (size_t)PREFILL_TILE_M * hidden * sizeof(float);
741 }

References PREFILL_TILE_M.

◆ gemm_nt_q8_0_dispatch()

static void gemm_nt_q8_0_dispatch ( const void *  A_q8,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K,
CKDataType  dt 
)
static

Definition at line 177 of file prefill_fused_gemm.c.

185 {
186  switch (dt) {
187  case CK_DT_Q5_0:
188  gemm_nt_q5_0_q8_0(A_q8, B, bias, C, M, N, K);
189  break;
190  case CK_DT_Q8_0:
191  gemm_nt_q8_0_q8_0(A_q8, B, bias, C, M, N, K);
192  break;
193  default:
194  break;
195  }
196 }
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
#define C(color)
Definition: show_config.c:39

References C, CK_DT_Q5_0, CK_DT_Q8_0, gemm_nt_q5_0_q8_0(), and gemm_nt_q8_0_q8_0().

Referenced by fused_rmsnorm_qkv_prefill_head_major_quant().

◆ gemm_nt_q8_0_mlp_dispatch()

static void gemm_nt_q8_0_mlp_dispatch ( const void *  A_q8,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K,
CKDataType  dt 
)
static

Definition at line 917 of file prefill_fused_gemm.c.

925 {
926  switch (dt) {
927  case CK_DT_Q5_0:
928  gemm_nt_q5_0_q8_0(A_q8, B, bias, C, M, N, K);
929  break;
930  case CK_DT_Q8_0:
931  gemm_nt_q8_0_q8_0(A_q8, B, bias, C, M, N, K);
932  break;
933  default:
934  break;
935  }
936 }

References C, CK_DT_Q5_0, CK_DT_Q8_0, gemm_nt_q5_0_q8_0(), and gemm_nt_q8_0_q8_0().

Referenced by fused_mlp_swiglu_prefill_w1w2_quant().

◆ gemm_nt_q8_k_mlp_dispatch()

static void gemm_nt_q8_k_mlp_dispatch ( const void *  A_q8,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K,
CKDataType  dt 
)
static

Definition at line 938 of file prefill_fused_gemm.c.

946 {
947  switch (dt) {
948  case CK_DT_Q4_K:
949  gemm_nt_q4_k_q8_k(A_q8, B, bias, C, M, N, K);
950  break;
951  case CK_DT_Q6_K:
952  gemm_nt_q6_k_q8_k(A_q8, B, bias, C, M, N, K);
953  break;
954  default:
955  break;
956  }
957 }
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.

References C, CK_DT_Q4_K, CK_DT_Q6_K, gemm_nt_q4_k_q8_k(), and gemm_nt_q6_k_q8_k().

Referenced by fused_mlp_swiglu_prefill_w1w2_quant().

◆ gemm_nt_q8_k_qkv_dispatch()

static void gemm_nt_q8_k_qkv_dispatch ( const void *  A_q8k,
const void *  B,
const float *  bias,
float *  C,
int  M,
int  N,
int  K,
CKDataType  dt 
)
static

Definition at line 198 of file prefill_fused_gemm.c.

206 {
207  switch (dt) {
208  case CK_DT_Q4_K:
209  gemm_nt_q4_k_q8_k(A_q8k, B, bias, C, M, N, K);
210  break;
211  case CK_DT_Q6_K:
212  gemm_nt_q6_k_q8_k(A_q8k, B, bias, C, M, N, K);
213  break;
214  default:
215  break;
216  }
217 }

References C, CK_DT_Q4_K, CK_DT_Q6_K, gemm_nt_q4_k_q8_k(), and gemm_nt_q6_k_q8_k().

Referenced by fused_rmsnorm_qkv_prefill_head_major_quant().

◆ gemm_tile_nt_strided()

static void gemm_tile_nt_strided ( const float *  A,
const float *  B_tile,
float *  C,
int  tile_m,
int  tile_n,
int  K,
int  C_stride 
)
static

GEMM tile with N-dimension tiling (weight reuse)

Computes: C[tile_m × tile_n] = A[tile_m × K] × B[tile_n × K]^T where B_tile is a slice of rows from the weight matrix.

Uses MKL if available for optimal performance.

Parameters
AInput tile [tile_m × K]
B_tileWeight tile [tile_n × K] (transposed layout)
COutput tile [tile_m × tile_n] (column slice of full output)
C_strideStride between rows of C (= full N dimension)

Definition at line 236 of file prefill_fused_gemm.c.

243 {
244 #ifdef USE_MKL
245  /* Use MKL SGEMM: C = A × B^T
246  * But MKL expects contiguous output, so we need to handle strided output.
247  * For now, if C_stride == tile_n (contiguous), use MKL directly.
248  * Otherwise, fall back to naive. */
249  if (C_stride == tile_n) {
250  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
251  tile_m, tile_n, K,
252  1.0f, A, K, B_tile, K,
253  0.0f, C, tile_n);
254  return;
255  }
256  /* Strided output - use MKL per row */
257  for (int i = 0; i < tile_m; ++i) {
258  cblas_sgemv(CblasRowMajor, CblasNoTrans,
259  tile_n, K,
260  1.0f, B_tile, K, A + (size_t)i * K, 1,
261  0.0f, C + (size_t)i * C_stride, 1);
262  }
263 #else
264 #ifdef _OPENMP
265 #pragma omp parallel for schedule(static)
266 #endif
267  for (int i = 0; i < tile_m; ++i) {
268  const float *a_row = A + (size_t)i * K;
269  float *c_row = C + (size_t)i * C_stride;
270 
271  for (int j = 0; j < tile_n; ++j) {
272  const float *b_row = B_tile + (size_t)j * K;
273  float sum = 0.0f;
274 
275 #if defined(__AVX512F__)
276  __m512 acc = _mm512_setzero_ps();
277  int k = 0;
278  for (; k + 16 <= K; k += 16) {
279  __m512 av = _mm512_loadu_ps(a_row + k);
280  __m512 bv = _mm512_loadu_ps(b_row + k);
281  acc = _mm512_fmadd_ps(av, bv, acc);
282  }
283  sum = _mm512_reduce_add_ps(acc);
284  for (; k < K; ++k) {
285  sum += a_row[k] * b_row[k];
286  }
287 #elif defined(__AVX__)
288  __m256 acc = _mm256_setzero_ps();
289  int k = 0;
290  for (; k + 8 <= K; k += 8) {
291  __m256 av = _mm256_loadu_ps(a_row + k);
292  __m256 bv = _mm256_loadu_ps(b_row + k);
293  acc = _mm256_add_ps(acc, _mm256_mul_ps(av, bv));
294  }
295  sum = hsum256_prefill(acc);
296  for (; k < K; ++k) {
297  sum += a_row[k] * b_row[k];
298  }
299 #else
300  for (int k = 0; k < K; ++k) {
301  sum += a_row[k] * b_row[k];
302  }
303 #endif
304  c_row[j] = sum;
305  }
306  }
307 #endif
308 }

References C.

Referenced by fused_mlp_swiglu_prefill_bias(), fused_rmsnorm_gemm_2d_tiled(), fused_rmsnorm_qkv_prefill(), fused_rmsnorm_qkv_prefill_head_major(), and unfused_rmsnorm_qkv_prefill().

◆ mlp_q8_0_dtype_supported()

static int mlp_q8_0_dtype_supported ( CKDataType  dt)
static

Definition at line 909 of file prefill_fused_gemm.c.

909  {
910  return (dt == CK_DT_Q5_0 || dt == CK_DT_Q8_0);
911 }

References CK_DT_Q5_0, and CK_DT_Q8_0.

Referenced by fused_mlp_swiglu_prefill_w1w2_quant().

◆ mlp_q8_k_dtype_supported()

static int mlp_q8_k_dtype_supported ( CKDataType  dt)
static

Definition at line 913 of file prefill_fused_gemm.c.

913  {
914  return (dt == CK_DT_Q4_K || dt == CK_DT_Q6_K);
915 }

References CK_DT_Q4_K, and CK_DT_Q6_K.

Referenced by fused_mlp_swiglu_prefill_w1w2_quant().

◆ qkv_q8_0_dtype_supported()

static int qkv_q8_0_dtype_supported ( CKDataType  dt)
static

Definition at line 169 of file prefill_fused_gemm.c.

169  {
170  return (dt == CK_DT_Q5_0 || dt == CK_DT_Q8_0);
171 }

References CK_DT_Q5_0, and CK_DT_Q8_0.

Referenced by fused_rmsnorm_qkv_prefill_head_major_quant().

◆ qkv_q8_k_dtype_supported()

static int qkv_q8_k_dtype_supported ( CKDataType  dt)
static

Definition at line 173 of file prefill_fused_gemm.c.

173  {
174  return (dt == CK_DT_Q4_K || dt == CK_DT_Q6_K);
175 }

References CK_DT_Q4_K, and CK_DT_Q6_K.

Referenced by fused_rmsnorm_qkv_prefill_head_major_quant().

◆ rmsnorm_tile()

static void rmsnorm_tile ( const float *  input,
const float *  gamma,
float *  output,
int  tile_m,
int  embed_dim,
int  aligned_embed_dim,
float  eps 
)
static

Compute RMSNorm for a tile of tokens.

Definition at line 86 of file prefill_fused_gemm.c.

93 {
94  for (int t = 0; t < tile_m; ++t) {
95  const float *x = input + (size_t)t * (size_t)aligned_embed_dim;
96  float *y = output + (size_t)t * (size_t)aligned_embed_dim;
97 
98 #if defined(__AVX512F__)
99  __m512 sum_sq_vec = _mm512_setzero_ps();
100  int d = 0;
101  for (; d + 16 <= embed_dim; d += 16) {
102  __m512 xv = _mm512_loadu_ps(&x[d]);
103  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
104  }
105  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
106  for (; d < embed_dim; ++d) {
107  sum_sq += x[d] * x[d];
108  }
109 
110  float rstd = 1.0f / sqrtf(sum_sq / (float)embed_dim + eps);
111  __m512 rstd_vec = _mm512_set1_ps(rstd);
112 
113  d = 0;
114  for (; d + 16 <= embed_dim; d += 16) {
115  __m512 xv = _mm512_loadu_ps(&x[d]);
116  __m512 gv = gamma ? _mm512_loadu_ps(&gamma[d]) : _mm512_set1_ps(1.0f);
117  __m512 yv = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
118  _mm512_storeu_ps(&y[d], yv);
119  }
120  for (; d < embed_dim; ++d) {
121  float g = gamma ? gamma[d] : 1.0f;
122  y[d] = x[d] * rstd * g;
123  }
124 
125 #elif defined(__AVX__)
126  __m256 sum_sq_vec = _mm256_setzero_ps();
127  int d = 0;
128  for (; d + 8 <= embed_dim; d += 8) {
129  __m256 xv = _mm256_loadu_ps(&x[d]);
130  sum_sq_vec = _mm256_add_ps(sum_sq_vec, _mm256_mul_ps(xv, xv));
131  }
132  float sum_sq = hsum256_prefill(sum_sq_vec);
133  for (; d < embed_dim; ++d) {
134  sum_sq += x[d] * x[d];
135  }
136 
137  float rstd = 1.0f / sqrtf(sum_sq / (float)embed_dim + eps);
138  __m256 rstd_vec = _mm256_set1_ps(rstd);
139 
140  d = 0;
141  for (; d + 8 <= embed_dim; d += 8) {
142  __m256 xv = _mm256_loadu_ps(&x[d]);
143  __m256 gv = gamma ? _mm256_loadu_ps(&gamma[d]) : _mm256_set1_ps(1.0f);
144  __m256 yv = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
145  _mm256_storeu_ps(&y[d], yv);
146  }
147  for (; d < embed_dim; ++d) {
148  float g = gamma ? gamma[d] : 1.0f;
149  y[d] = x[d] * rstd * g;
150  }
151 #else
152  float sum_sq = 0.0f;
153  for (int d = 0; d < embed_dim; ++d) {
154  sum_sq += x[d] * x[d];
155  }
156  float rstd = 1.0f / sqrtf(sum_sq / (float)embed_dim + eps);
157  for (int d = 0; d < embed_dim; ++d) {
158  float g = gamma ? gamma[d] : 1.0f;
159  y[d] = x[d] * rstd * g;
160  }
161 #endif
162 
163  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
164  y[d] = 0.0f;
165  }
166  }
167 }

Referenced by fused_rmsnorm_gemm_2d_tiled(), fused_rmsnorm_qkv_prefill(), fused_rmsnorm_qkv_prefill_head_major(), fused_rmsnorm_qkv_prefill_head_major_quant(), and unfused_rmsnorm_qkv_prefill().

◆ silu_prefill()

static float silu_prefill ( float  x)
inlinestatic

Definition at line 905 of file prefill_fused_gemm.c.

905  {
906  return x / (1.0f + expf(-x));
907 }

Referenced by fused_mlp_swiglu_prefill_w1w2_quant().

◆ unfused_rmsnorm_qkv_prefill()

void unfused_rmsnorm_qkv_prefill ( const float *  x,
const float *  gamma,
const float *  Wq,
const float *  Wk,
const float *  Wv,
float *  x_norm,
float *  Q,
float *  K,
float *  V,
int  seq_len,
int  hidden,
int  q_dim,
int  kv_dim,
float  eps 
)

Unfused version for comparison.

Unfused version for benchmarking comparison.

Definition at line 667 of file prefill_fused_gemm.c.

682 {
683  /* Step 1: Full RMSNorm → writes x_norm to memory */
684  rmsnorm_tile(x, gamma, x_norm, seq_len, hidden, hidden, eps);
685 
686  /* Step 2: Separate GEMMs with N-outer tiling for weight reuse */
687  /* Q projection */
688  for (int n_start = 0; n_start < q_dim; n_start += PREFILL_TILE_N) {
689  int tile_n = (n_start + PREFILL_TILE_N <= q_dim)
690  ? PREFILL_TILE_N : (q_dim - n_start);
691  const float *W_tile = Wq + (size_t)n_start * hidden;
692 
693  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
694  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
695  ? PREFILL_TILE_M : (seq_len - m_start);
696  const float *x_tile = x_norm + (size_t)m_start * hidden;
697  float *out_tile = Q + (size_t)m_start * q_dim + n_start;
698  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
699  tile_m, tile_n, hidden, q_dim);
700  }
701  }
702 
703  /* K projection */
704  for (int n_start = 0; n_start < kv_dim; n_start += PREFILL_TILE_N) {
705  int tile_n = (n_start + PREFILL_TILE_N <= kv_dim)
706  ? PREFILL_TILE_N : (kv_dim - n_start);
707  const float *W_tile = Wk + (size_t)n_start * hidden;
708 
709  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
710  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
711  ? PREFILL_TILE_M : (seq_len - m_start);
712  const float *x_tile = x_norm + (size_t)m_start * hidden;
713  float *out_tile = K + (size_t)m_start * kv_dim + n_start;
714  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
715  tile_m, tile_n, hidden, kv_dim);
716  }
717  }
718 
719  /* V projection */
720  for (int n_start = 0; n_start < kv_dim; n_start += PREFILL_TILE_N) {
721  int tile_n = (n_start + PREFILL_TILE_N <= kv_dim)
722  ? PREFILL_TILE_N : (kv_dim - n_start);
723  const float *W_tile = Wv + (size_t)n_start * hidden;
724 
725  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
726  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
727  ? PREFILL_TILE_M : (seq_len - m_start);
728  const float *x_tile = x_norm + (size_t)m_start * hidden;
729  float *out_tile = V + (size_t)m_start * kv_dim + n_start;
730  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
731  tile_m, tile_n, hidden, kv_dim);
732  }
733  }
734 }

References gemm_tile_nt_strided(), PREFILL_TILE_M, PREFILL_TILE_N, and rmsnorm_tile().