← Back to C-Kernel-Engine Docs Doxygen Source Documentation
prefill_fused_gemm.c
Go to the documentation of this file.
1 /**
2  * @file prefill_fused_gemm.c
3  * @brief Fused kernels for prefill phase with proper 2D tiling
4  *
5  * CK-ENGINE KERNEL RULES:
6  * =======================
7  * 1. NO malloc/free - memory via bump allocator, pointers passed in
8  * 2. NO OpenMP - parallelization at orchestrator/codegen layer
9  * 3. API must define: inputs, outputs, workspace, and memory layouts
10  * 4. Pure computation - deterministic, no side effects
11  *
12  * After changes: make test && make llamacpp-parity-full
13  *
14  * KEY INSIGHT:
15  * ------------
16  * Naive M-dimension tiling (token tiles) causes weight reloading:
17  * - 32 token tiles × 4MB weights = 128MB DRAM reads!
18  *
19  * Correct approach: Tile along N (output/weight) dimension OUTER,
20  * M (token) dimension INNER. This way:
21  * - Load weight tile once
22  * - Process ALL tokens against that weight tile
23  * - Weight tile stays in cache while streaming through tokens
24  *
25  * TILING STRATEGY:
26  * ----------------
27  * For C[M,N] = RMSNorm(A[M,K]) × B[N,K]^T:
28  *
29  * for n_tile in [0, N, TILE_N]: # Outer: weight tiles
30  * load B[n_tile:n_tile+TILE_N, :] into L3
31  * for m_tile in [0, M, TILE_M]: # Inner: token tiles
32  * x_norm = rmsnorm(A[m_tile]) # x_norm in L2
33  * C[m_tile, n_tile] = x_norm × B_tile # Consumes B from L3
34  *
35  * Cache behavior:
36  * - Weight tile (TILE_N × K × 4 bytes) fits in L3
37  * - x_norm tile (TILE_M × K × 4 bytes) fits in L2
38  * - Weights loaded once per tile, reused across all token tiles
39  */
40 
41 #include "ckernel_engine.h"
42 #include "ckernel_quant.h"
43 #include <math.h>
44 #include <string.h>
45 #include <stddef.h>
46 #include <stdio.h>
47 
48 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
49 #include <immintrin.h>
50 #endif
51 
52 #ifdef _OPENMP
53 #include <omp.h>
54 #endif
55 
56 /* Tile sizes chosen for your cache hierarchy:
57  * - L2 = 256KB: x_norm tile = TILE_M × hidden × 4
58  * - L3 = 6MB: weight tile = TILE_N × hidden × 4
59  *
60  * For hidden=896:
61  * TILE_M = 64 → 64×896×4 = 224KB (fits L2)
62  * TILE_N = 256 → 256×896×4 = 896KB (fits L3 with room for x_norm)
63  */
64 #define PREFILL_TILE_M 64
65 #define PREFILL_TILE_N 256
66 
67 static size_t align_up_size(size_t value, size_t align) {
68  return (value + align - 1) & ~(align - 1);
69 }
70 
71 /* Helper: horizontal sum for AVX */
72 #if defined(__AVX__) && !defined(__AVX512F__)
73 static inline float hsum256_prefill(__m256 v) {
74  __m128 lo = _mm256_castps256_ps128(v);
75  __m128 hi = _mm256_extractf128_ps(v, 1);
76  __m128 sum128 = _mm_add_ps(lo, hi);
77  sum128 = _mm_hadd_ps(sum128, sum128);
78  sum128 = _mm_hadd_ps(sum128, sum128);
79  return _mm_cvtss_f32(sum128);
80 }
81 #endif
82 
83 /**
84  * @brief Compute RMSNorm for a tile of tokens
85  */
86 static void rmsnorm_tile(const float *input,
87  const float *gamma,
88  float *output,
89  int tile_m,
90  int embed_dim,
91  int aligned_embed_dim,
92  float eps)
93 {
94  for (int t = 0; t < tile_m; ++t) {
95  const float *x = input + (size_t)t * (size_t)aligned_embed_dim;
96  float *y = output + (size_t)t * (size_t)aligned_embed_dim;
97 
98 #if defined(__AVX512F__)
99  __m512 sum_sq_vec = _mm512_setzero_ps();
100  int d = 0;
101  for (; d + 16 <= embed_dim; d += 16) {
102  __m512 xv = _mm512_loadu_ps(&x[d]);
103  sum_sq_vec = _mm512_fmadd_ps(xv, xv, sum_sq_vec);
104  }
105  float sum_sq = _mm512_reduce_add_ps(sum_sq_vec);
106  for (; d < embed_dim; ++d) {
107  sum_sq += x[d] * x[d];
108  }
109 
110  float rstd = 1.0f / sqrtf(sum_sq / (float)embed_dim + eps);
111  __m512 rstd_vec = _mm512_set1_ps(rstd);
112 
113  d = 0;
114  for (; d + 16 <= embed_dim; d += 16) {
115  __m512 xv = _mm512_loadu_ps(&x[d]);
116  __m512 gv = gamma ? _mm512_loadu_ps(&gamma[d]) : _mm512_set1_ps(1.0f);
117  __m512 yv = _mm512_mul_ps(_mm512_mul_ps(xv, rstd_vec), gv);
118  _mm512_storeu_ps(&y[d], yv);
119  }
120  for (; d < embed_dim; ++d) {
121  float g = gamma ? gamma[d] : 1.0f;
122  y[d] = x[d] * rstd * g;
123  }
124 
125 #elif defined(__AVX__)
126  __m256 sum_sq_vec = _mm256_setzero_ps();
127  int d = 0;
128  for (; d + 8 <= embed_dim; d += 8) {
129  __m256 xv = _mm256_loadu_ps(&x[d]);
130  sum_sq_vec = _mm256_add_ps(sum_sq_vec, _mm256_mul_ps(xv, xv));
131  }
132  float sum_sq = hsum256_prefill(sum_sq_vec);
133  for (; d < embed_dim; ++d) {
134  sum_sq += x[d] * x[d];
135  }
136 
137  float rstd = 1.0f / sqrtf(sum_sq / (float)embed_dim + eps);
138  __m256 rstd_vec = _mm256_set1_ps(rstd);
139 
140  d = 0;
141  for (; d + 8 <= embed_dim; d += 8) {
142  __m256 xv = _mm256_loadu_ps(&x[d]);
143  __m256 gv = gamma ? _mm256_loadu_ps(&gamma[d]) : _mm256_set1_ps(1.0f);
144  __m256 yv = _mm256_mul_ps(_mm256_mul_ps(xv, rstd_vec), gv);
145  _mm256_storeu_ps(&y[d], yv);
146  }
147  for (; d < embed_dim; ++d) {
148  float g = gamma ? gamma[d] : 1.0f;
149  y[d] = x[d] * rstd * g;
150  }
151 #else
152  float sum_sq = 0.0f;
153  for (int d = 0; d < embed_dim; ++d) {
154  sum_sq += x[d] * x[d];
155  }
156  float rstd = 1.0f / sqrtf(sum_sq / (float)embed_dim + eps);
157  for (int d = 0; d < embed_dim; ++d) {
158  float g = gamma ? gamma[d] : 1.0f;
159  y[d] = x[d] * rstd * g;
160  }
161 #endif
162 
163  for (int d = embed_dim; d < aligned_embed_dim; ++d) {
164  y[d] = 0.0f;
165  }
166  }
167 }
168 
170  return (dt == CK_DT_Q5_0 || dt == CK_DT_Q8_0);
171 }
172 
174  return (dt == CK_DT_Q4_K || dt == CK_DT_Q6_K);
175 }
176 
177 static void gemm_nt_q8_0_dispatch(const void *A_q8,
178  const void *B,
179  const float *bias,
180  float *C,
181  int M,
182  int N,
183  int K,
184  CKDataType dt)
185 {
186  switch (dt) {
187  case CK_DT_Q5_0:
188  gemm_nt_q5_0_q8_0(A_q8, B, bias, C, M, N, K);
189  break;
190  case CK_DT_Q8_0:
191  gemm_nt_q8_0_q8_0(A_q8, B, bias, C, M, N, K);
192  break;
193  default:
194  break;
195  }
196 }
197 
198 static void gemm_nt_q8_k_qkv_dispatch(const void *A_q8k,
199  const void *B,
200  const float *bias,
201  float *C,
202  int M,
203  int N,
204  int K,
205  CKDataType dt)
206 {
207  switch (dt) {
208  case CK_DT_Q4_K:
209  gemm_nt_q4_k_q8_k(A_q8k, B, bias, C, M, N, K);
210  break;
211  case CK_DT_Q6_K:
212  gemm_nt_q6_k_q8_k(A_q8k, B, bias, C, M, N, K);
213  break;
214  default:
215  break;
216  }
217 }
218 
219 /**
220  * @brief GEMM tile with N-dimension tiling (weight reuse)
221  *
222  * Computes: C[tile_m × tile_n] = A[tile_m × K] × B[tile_n × K]^T
223  * where B_tile is a slice of rows from the weight matrix.
224  *
225  * Uses MKL if available for optimal performance.
226  *
227  * @param A Input tile [tile_m × K]
228  * @param B_tile Weight tile [tile_n × K] (transposed layout)
229  * @param C Output tile [tile_m × tile_n] (column slice of full output)
230  * @param C_stride Stride between rows of C (= full N dimension)
231  */
232 #ifdef USE_MKL
233 #include <mkl.h>
234 #endif
235 
236 static void gemm_tile_nt_strided(const float *A,
237  const float *B_tile,
238  float *C,
239  int tile_m,
240  int tile_n,
241  int K,
242  int C_stride)
243 {
244 #ifdef USE_MKL
245  /* Use MKL SGEMM: C = A × B^T
246  * But MKL expects contiguous output, so we need to handle strided output.
247  * For now, if C_stride == tile_n (contiguous), use MKL directly.
248  * Otherwise, fall back to naive. */
249  if (C_stride == tile_n) {
250  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
251  tile_m, tile_n, K,
252  1.0f, A, K, B_tile, K,
253  0.0f, C, tile_n);
254  return;
255  }
256  /* Strided output - use MKL per row */
257  for (int i = 0; i < tile_m; ++i) {
258  cblas_sgemv(CblasRowMajor, CblasNoTrans,
259  tile_n, K,
260  1.0f, B_tile, K, A + (size_t)i * K, 1,
261  0.0f, C + (size_t)i * C_stride, 1);
262  }
263 #else
264 #ifdef _OPENMP
265 #pragma omp parallel for schedule(static)
266 #endif
267  for (int i = 0; i < tile_m; ++i) {
268  const float *a_row = A + (size_t)i * K;
269  float *c_row = C + (size_t)i * C_stride;
270 
271  for (int j = 0; j < tile_n; ++j) {
272  const float *b_row = B_tile + (size_t)j * K;
273  float sum = 0.0f;
274 
275 #if defined(__AVX512F__)
276  __m512 acc = _mm512_setzero_ps();
277  int k = 0;
278  for (; k + 16 <= K; k += 16) {
279  __m512 av = _mm512_loadu_ps(a_row + k);
280  __m512 bv = _mm512_loadu_ps(b_row + k);
281  acc = _mm512_fmadd_ps(av, bv, acc);
282  }
283  sum = _mm512_reduce_add_ps(acc);
284  for (; k < K; ++k) {
285  sum += a_row[k] * b_row[k];
286  }
287 #elif defined(__AVX__)
288  __m256 acc = _mm256_setzero_ps();
289  int k = 0;
290  for (; k + 8 <= K; k += 8) {
291  __m256 av = _mm256_loadu_ps(a_row + k);
292  __m256 bv = _mm256_loadu_ps(b_row + k);
293  acc = _mm256_add_ps(acc, _mm256_mul_ps(av, bv));
294  }
295  sum = hsum256_prefill(acc);
296  for (; k < K; ++k) {
297  sum += a_row[k] * b_row[k];
298  }
299 #else
300  for (int k = 0; k < K; ++k) {
301  sum += a_row[k] * b_row[k];
302  }
303 #endif
304  c_row[j] = sum;
305  }
306  }
307 #endif
308 }
309 
310 static void add_bias_tile(float *out,
311  const float *bias,
312  int tile_m,
313  int out_dim)
314 {
315  if (!out || !bias) {
316  return;
317  }
318  for (int i = 0; i < tile_m; ++i) {
319  float *row = out + (size_t)i * (size_t)out_dim;
320  for (int j = 0; j < out_dim; ++j) {
321  row[j] += bias[j];
322  }
323  }
324 }
325 
326 /**
327  * @brief Fused RMSNorm + single GEMM with 2D tiling (weight reuse)
328  *
329  * Tiles along N (weights) OUTER, M (tokens) INNER.
330  * Weight tiles are reused across all token tiles.
331  */
333  const float *x, /* [seq_len × hidden] input */
334  const float *gamma, /* [hidden] RMSNorm weights */
335  const float *W, /* [out_dim × hidden] weight matrix (transposed) */
336  float *output, /* [seq_len × out_dim] output */
337  int seq_len,
338  int hidden,
339  int out_dim,
340  float eps,
341  float *x_norm_scratch) /* [TILE_M × hidden] scratch for normalized tile */
342 {
343  /* Outer loop: tile along output dimension (N) - weight tiles */
344  for (int n_start = 0; n_start < out_dim; n_start += PREFILL_TILE_N) {
345  int tile_n = (n_start + PREFILL_TILE_N <= out_dim)
347  : (out_dim - n_start);
348 
349  /* Weight tile pointer - this tile stays in L3 cache */
350  const float *W_tile = W + (size_t)n_start * hidden;
351 
352  /* Inner loop: tile along token dimension (M) */
353  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
354  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
356  : (seq_len - m_start);
357 
358  const float *x_tile = x + (size_t)m_start * hidden;
359  float *out_tile = output + (size_t)m_start * out_dim + n_start;
360 
361  /* Compute RMSNorm for this token tile (only on first weight tile) */
362  if (n_start == 0) {
363  rmsnorm_tile(x_tile, gamma, x_norm_scratch, tile_m, hidden, hidden, eps);
364  } else {
365  /* Recompute x_norm for this tile (we can't cache all of it) */
366  /* TODO: For very large N, consider caching x_norm chunks */
367  rmsnorm_tile(x_tile, gamma, x_norm_scratch, tile_m, hidden, hidden, eps);
368  }
369 
370  /* GEMM: x_norm_tile × W_tile^T → output tile */
371  gemm_tile_nt_strided(x_norm_scratch, W_tile, out_tile,
372  tile_m, tile_n, hidden, out_dim);
373  }
374  }
375 }
376 
377 /**
378  * @brief Fused RMSNorm + QKV projection for prefill (v3 optimized)
379  *
380  * KEY INSIGHT: For Qwen2-0.5B, all QKV weights fit in L3:
381  * Wq (896×896) + Wk (128×896) + Wv (128×896) = 4.1MB < 6MB L3
382  *
383  * So we use M-tiling (tokens) only:
384  * 1. For each token tile:
385  * a. Compute RMSNorm ONCE into scratch (x_norm stays in L2)
386  * b. Do all three GEMMs (Q, K, V) against cached x_norm
387  * c. Weights stay hot in L3 across all token tiles
388  *
389  * This avoids both:
390  * - Large x_norm intermediate buffer (only TILE_M × hidden in L2)
391  * - RMSNorm recomputation (done once per token tile, used 3×)
392  */
394  const float *x,
395  const float *gamma,
396  const float *Wq,
397  const float *Wk,
398  const float *Wv,
399  float *Q,
400  float *K,
401  float *V,
402  int seq_len,
403  int hidden,
404  int q_dim,
405  int kv_dim,
406  float eps,
407  float *scratch)
408 {
409  /* scratch is x_norm tile: [TILE_M × hidden] fits in L2 */
410 
411  /* Process token tiles - weights stay in L3 across all tiles */
412  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
413  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
414  ? PREFILL_TILE_M : (seq_len - m_start);
415 
416  const float *x_tile = x + (size_t)m_start * hidden;
417 
418  /* Step 1: RMSNorm for this token tile (computed ONCE, used 3×) */
419  rmsnorm_tile(x_tile, gamma, scratch, tile_m, hidden, hidden, eps);
420 
421  /* Step 2: Q projection - x_norm is hot in L2, Wq hot in L3 */
422  float *Q_tile = Q + (size_t)m_start * q_dim;
423  gemm_tile_nt_strided(scratch, Wq, Q_tile, tile_m, q_dim, hidden, q_dim);
424 
425  /* Step 3: K projection - x_norm still hot, Wk displaces some Wq */
426  float *K_tile = K + (size_t)m_start * kv_dim;
427  gemm_tile_nt_strided(scratch, Wk, K_tile, tile_m, kv_dim, hidden, kv_dim);
428 
429  /* Step 4: V projection - x_norm still hot, Wv displaces Wk */
430  float *V_tile = V + (size_t)m_start * kv_dim;
431  gemm_tile_nt_strided(scratch, Wv, V_tile, tile_m, kv_dim, hidden, kv_dim);
432  }
433 }
434 
435 /**
436  * @brief Fused RMSNorm + QKV projection for prefill (head-major outputs)
437  *
438  * Q is written as [num_heads, seq_len, aligned_head_dim].
439  * K/V are written with kv_stride_tokens for KV-cache compatibility.
440  */
442  const float *x,
443  const float *gamma,
444  const float *Wq, const float *Bq,
445  const float *Wk, const float *Bk,
446  const float *Wv, const float *Bv,
447  float *Q,
448  float *K,
449  float *V,
450  int seq_len,
451  int embed_dim,
452  int aligned_embed_dim,
453  int num_heads,
454  int num_kv_heads,
455  int head_dim,
456  int aligned_head_dim,
457  int kv_stride_tokens,
458  float eps,
459  float *scratch)
460 {
461  if (!x || !gamma || !Wq || !Wk || !Wv || !Q || !K || !V || !scratch) {
462  return;
463  }
464  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
465  head_dim <= 0 || aligned_head_dim <= 0 ||
466  num_heads <= 0 || num_kv_heads <= 0) {
467  return;
468  }
469  if (kv_stride_tokens < seq_len) {
470  return;
471  }
472 
473  const size_t q_head_stride = (size_t)seq_len * (size_t)aligned_head_dim;
474  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
475  const size_t head_w_stride = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
476 
477  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
478  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
479  ? PREFILL_TILE_M : (seq_len - m_start);
480 
481  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
482  rmsnorm_tile(x_tile, gamma, scratch, tile_m, embed_dim, aligned_embed_dim, eps);
483 
484  for (int h = 0; h < num_heads; ++h) {
485  const float *wq_h = Wq + (size_t)h * head_w_stride;
486  const float *bq_h = Bq ? (Bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
487  float *q_h = Q + (size_t)h * q_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
488 
489  gemm_tile_nt_strided(scratch, wq_h, q_h,
490  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
491  add_bias_tile(q_h, bq_h, tile_m, aligned_head_dim);
492  }
493 
494  for (int h = 0; h < num_kv_heads; ++h) {
495  const float *wk_h = Wk + (size_t)h * head_w_stride;
496  const float *wv_h = Wv + (size_t)h * head_w_stride;
497  const float *bk_h = Bk ? (Bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
498  const float *bv_h = Bv ? (Bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
499  float *k_h = K + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
500  float *v_h = V + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
501 
502  gemm_tile_nt_strided(scratch, wk_h, k_h,
503  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
504  add_bias_tile(k_h, bk_h, tile_m, aligned_head_dim);
505 
506  gemm_tile_nt_strided(scratch, wv_h, v_h,
507  tile_m, aligned_head_dim, aligned_embed_dim, aligned_head_dim);
508  add_bias_tile(v_h, bv_h, tile_m, aligned_head_dim);
509  }
510  }
511 }
512 
513 /**
514  * @brief Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
515  *
516  * Supports Q5_0 or Q8_0 weights with Q8_0 activations.
517  * Writes K/V directly into KV cache layout (kv_stride_tokens).
518  */
520  const float *x,
521  const float *gamma,
522  const void *Wq, const float *Bq, CKDataType wq_dt,
523  const void *Wk, const float *Bk, CKDataType wk_dt,
524  const void *Wv, const float *Bv, CKDataType wv_dt,
525  float *Q,
526  float *K,
527  float *V,
528  int seq_len,
529  int embed_dim,
530  int aligned_embed_dim,
531  int num_heads,
532  int num_kv_heads,
533  int head_dim,
534  int aligned_head_dim,
535  int kv_stride_tokens,
536  float eps,
537  void *scratch)
538 {
539  if (!x || !gamma || !Wq || !Wk || !Wv || !Q || !K || !V || !scratch) {
540  return;
541  }
542  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
543  head_dim <= 0 || aligned_head_dim <= 0 ||
544  num_heads <= 0 || num_kv_heads <= 0) {
545  return;
546  }
547  if (aligned_embed_dim % 32 != 0) {
548  return;
549  }
550  if (kv_stride_tokens < seq_len) {
551  return;
552  }
553  /* Determine quantization path: Q8_0 activations for Q5_0/Q8_0 weights,
554  * Q8_K activations for Q4_K/Q6_K weights. All QKV weights must use
555  * the same quantization family. */
556  int use_q8_k_path = qkv_q8_k_dtype_supported(wq_dt);
557  int use_q8_0_path = qkv_q8_0_dtype_supported(wq_dt);
558 
559  if (!use_q8_k_path && !use_q8_0_path) {
560  /* Unsupported dtype for wq */
561  return;
562  }
563 
564  /* Verify all dtypes are from the same family */
565  if (use_q8_k_path) {
566  if (!qkv_q8_k_dtype_supported(wk_dt) || !qkv_q8_k_dtype_supported(wv_dt)) {
567  return; /* Mixed Q8_K and Q8_0 paths not supported */
568  }
569  } else {
570  if (!qkv_q8_0_dtype_supported(wk_dt) || !qkv_q8_0_dtype_supported(wv_dt)) {
571  return;
572  }
573  }
574 
575  const size_t float_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_embed_dim * sizeof(float);
576  /* Q8_K has larger blocks (256) than Q8_0 (32), so use appropriate size */
577  const CKDataType act_quant_type = use_q8_k_path ? CK_DT_Q8_K : CK_DT_Q8_0;
578  const size_t q8_row_bytes = ck_dtype_row_bytes(act_quant_type, (size_t)aligned_embed_dim);
579  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
580  const size_t q8_offset = align_up_size(float_bytes, 64);
581 
582  float *normed = (float *)scratch;
583  uint8_t *q8_tile = (uint8_t *)scratch + q8_offset;
584  (void)q8_bytes;
585 
586  const size_t q_head_stride = (size_t)seq_len * (size_t)aligned_head_dim;
587  const size_t kv_head_stride = (size_t)kv_stride_tokens * (size_t)aligned_head_dim;
588  const size_t head_w_elems = (size_t)aligned_head_dim * (size_t)aligned_embed_dim;
589  const size_t wq_head_bytes = ck_dtype_row_bytes(wq_dt, head_w_elems);
590  const size_t wk_head_bytes = ck_dtype_row_bytes(wk_dt, head_w_elems);
591  const size_t wv_head_bytes = ck_dtype_row_bytes(wv_dt, head_w_elems);
592 
593  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
594  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
595  ? PREFILL_TILE_M : (seq_len - m_start);
596 
597  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
598  rmsnorm_tile(x_tile, gamma, normed, tile_m, embed_dim, aligned_embed_dim, eps);
599 
600  /* Quantize activations to appropriate format */
601  for (int t = 0; t < tile_m; ++t) {
602  const float *row = normed + (size_t)t * (size_t)aligned_embed_dim;
603  if (use_q8_k_path) {
604  quantize_row_q8_k(row,
605  q8_tile + (size_t)t * q8_row_bytes,
606  aligned_embed_dim);
607  } else {
608  quantize_row_q8_0(row,
609  q8_tile + (size_t)t * q8_row_bytes,
610  aligned_embed_dim);
611  }
612  }
613 
614  for (int h = 0; h < num_heads; ++h) {
615  const uint8_t *wq_h = (const uint8_t *)Wq + (size_t)h * wq_head_bytes;
616  const float *bq_h = Bq ? (Bq + (size_t)h * (size_t)aligned_head_dim) : NULL;
617  float *q_h = Q + (size_t)h * q_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
618 
619  if (use_q8_k_path) {
620  gemm_nt_q8_k_qkv_dispatch(q8_tile, wq_h, bq_h, q_h,
621  tile_m, aligned_head_dim, aligned_embed_dim, wq_dt);
622  } else {
623  gemm_nt_q8_0_dispatch(q8_tile, wq_h, bq_h, q_h,
624  tile_m, aligned_head_dim, aligned_embed_dim, wq_dt);
625  }
626  }
627 
628  for (int h = 0; h < num_kv_heads; ++h) {
629  const uint8_t *wk_h = (const uint8_t *)Wk + (size_t)h * wk_head_bytes;
630  const uint8_t *wv_h = (const uint8_t *)Wv + (size_t)h * wv_head_bytes;
631  const float *bk_h = Bk ? (Bk + (size_t)h * (size_t)aligned_head_dim) : NULL;
632  const float *bv_h = Bv ? (Bv + (size_t)h * (size_t)aligned_head_dim) : NULL;
633  float *k_h = K + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
634  float *v_h = V + (size_t)h * kv_head_stride + (size_t)m_start * (size_t)aligned_head_dim;
635 
636  if (use_q8_k_path) {
637  gemm_nt_q8_k_qkv_dispatch(q8_tile, wk_h, bk_h, k_h,
638  tile_m, aligned_head_dim, aligned_embed_dim, wk_dt);
639  gemm_nt_q8_k_qkv_dispatch(q8_tile, wv_h, bv_h, v_h,
640  tile_m, aligned_head_dim, aligned_embed_dim, wv_dt);
641  } else {
642  gemm_nt_q8_0_dispatch(q8_tile, wk_h, bk_h, k_h,
643  tile_m, aligned_head_dim, aligned_embed_dim, wk_dt);
644  gemm_nt_q8_0_dispatch(q8_tile, wv_h, bv_h, v_h,
645  tile_m, aligned_head_dim, aligned_embed_dim, wv_dt);
646  }
647  }
648  }
649 }
650 
652  if (aligned_embed_dim <= 0) {
653  return 0;
654  }
655  const size_t float_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_embed_dim * sizeof(float);
656  /* Use max of Q8_0 and Q8_K sizes to support both paths */
657  const size_t q8_0_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
658  const size_t q8_k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_embed_dim);
659  const size_t q8_row_bytes = (q8_k_row_bytes > q8_0_row_bytes) ? q8_k_row_bytes : q8_0_row_bytes;
660  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
661  return align_up_size(float_bytes, 64) + q8_bytes;
662 }
663 
664 /**
665  * @brief Unfused version for comparison
666  */
668  const float *x,
669  const float *gamma,
670  const float *Wq,
671  const float *Wk,
672  const float *Wv,
673  float *x_norm,
674  float *Q,
675  float *K,
676  float *V,
677  int seq_len,
678  int hidden,
679  int q_dim,
680  int kv_dim,
681  float eps)
682 {
683  /* Step 1: Full RMSNorm → writes x_norm to memory */
684  rmsnorm_tile(x, gamma, x_norm, seq_len, hidden, hidden, eps);
685 
686  /* Step 2: Separate GEMMs with N-outer tiling for weight reuse */
687  /* Q projection */
688  for (int n_start = 0; n_start < q_dim; n_start += PREFILL_TILE_N) {
689  int tile_n = (n_start + PREFILL_TILE_N <= q_dim)
690  ? PREFILL_TILE_N : (q_dim - n_start);
691  const float *W_tile = Wq + (size_t)n_start * hidden;
692 
693  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
694  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
695  ? PREFILL_TILE_M : (seq_len - m_start);
696  const float *x_tile = x_norm + (size_t)m_start * hidden;
697  float *out_tile = Q + (size_t)m_start * q_dim + n_start;
698  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
699  tile_m, tile_n, hidden, q_dim);
700  }
701  }
702 
703  /* K projection */
704  for (int n_start = 0; n_start < kv_dim; n_start += PREFILL_TILE_N) {
705  int tile_n = (n_start + PREFILL_TILE_N <= kv_dim)
706  ? PREFILL_TILE_N : (kv_dim - n_start);
707  const float *W_tile = Wk + (size_t)n_start * hidden;
708 
709  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
710  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
711  ? PREFILL_TILE_M : (seq_len - m_start);
712  const float *x_tile = x_norm + (size_t)m_start * hidden;
713  float *out_tile = K + (size_t)m_start * kv_dim + n_start;
714  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
715  tile_m, tile_n, hidden, kv_dim);
716  }
717  }
718 
719  /* V projection */
720  for (int n_start = 0; n_start < kv_dim; n_start += PREFILL_TILE_N) {
721  int tile_n = (n_start + PREFILL_TILE_N <= kv_dim)
722  ? PREFILL_TILE_N : (kv_dim - n_start);
723  const float *W_tile = Wv + (size_t)n_start * hidden;
724 
725  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
726  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
727  ? PREFILL_TILE_M : (seq_len - m_start);
728  const float *x_tile = x_norm + (size_t)m_start * hidden;
729  float *out_tile = V + (size_t)m_start * kv_dim + n_start;
730  gemm_tile_nt_strided(x_tile, W_tile, out_tile,
731  tile_m, tile_n, hidden, kv_dim);
732  }
733  }
734 }
735 
736 /**
737  * @brief Get scratch size for fused prefill
738  */
739 size_t fused_rmsnorm_qkv_scratch_size(int hidden) {
740  return (size_t)PREFILL_TILE_M * hidden * sizeof(float);
741 }
742 
743 /**
744  * @brief Fused MLP for prefill with proper tiling
745  */
747  const float *x,
748  const float *W_gate,
749  const float *W_up,
750  const float *W_down,
751  const float *B_gate,
752  const float *B_up,
753  const float *B_down,
754  float *output,
755  int seq_len,
756  int hidden,
757  int intermediate,
758  float *scratch)
759 {
760  /* MLP is more complex because we have:
761  * gate = x @ W_gate
762  * up = x @ W_up
763  * hidden = silu(gate) * up
764  * out = hidden @ W_down
765  *
766  * The intermediate (gate, up, hidden) is large: seq_len × intermediate
767  * For Qwen2-0.5B: 1024 × 4864 × 4 = 19.4MB (way bigger than L3!)
768  *
769  * Strategy: Tile along intermediate dimension for gate/up,
770  * then fuse SwiGLU, then tile down projection.
771  */
772 
773  /* scratch layout:
774  * [gate_tile: TILE_M × TILE_N_INTER]
775  * [up_tile: TILE_M × TILE_N_INTER]
776  */
777  const int TILE_N_INTER = 512; /* Intermediate tile size */
778  float *gate_tile = scratch;
779  float *up_tile = scratch + (size_t)PREFILL_TILE_M * TILE_N_INTER;
780  float *hidden_tile = gate_tile; /* Reuse gate_tile for hidden after SwiGLU */
781 
782  /* For each chunk of intermediate dimension */
783  for (int inter_start = 0; inter_start < intermediate; inter_start += TILE_N_INTER) {
784  int tile_inter = (inter_start + TILE_N_INTER <= intermediate)
785  ? TILE_N_INTER : (intermediate - inter_start);
786 
787  const float *W_gate_tile = W_gate + (size_t)inter_start * hidden;
788  const float *W_up_tile = W_up + (size_t)inter_start * hidden;
789 
790  /* For each chunk of tokens */
791  for (int m_start = 0; m_start < seq_len; m_start += PREFILL_TILE_M) {
792  int tile_m = (m_start + PREFILL_TILE_M <= seq_len)
793  ? PREFILL_TILE_M : (seq_len - m_start);
794 
795  const float *x_tile = x + (size_t)m_start * hidden;
796 
797  /* Compute gate and up projections for this tile */
798  gemm_tile_nt_strided(x_tile, W_gate_tile, gate_tile,
799  tile_m, tile_inter, hidden, tile_inter);
800  gemm_tile_nt_strided(x_tile, W_up_tile, up_tile,
801  tile_m, tile_inter, hidden, tile_inter);
802  if (B_gate) {
803  add_bias_tile(gate_tile, B_gate + inter_start, tile_m, tile_inter);
804  }
805  if (B_up) {
806  add_bias_tile(up_tile, B_up + inter_start, tile_m, tile_inter);
807  }
808 
809  /* Fused SwiGLU: hidden = silu(gate) * up */
810  for (int i = 0; i < tile_m; ++i) {
811  float *g = gate_tile + (size_t)i * tile_inter;
812  float *u = up_tile + (size_t)i * tile_inter;
813  for (int j = 0; j < tile_inter; ++j) {
814  float gv = g[j];
815  float silu = gv / (1.0f + expf(-gv));
816  g[j] = silu * u[j]; /* hidden_tile = gate_tile */
817  }
818  }
819 
820  /* Down projection: accumulate into output
821  * out[m_start:, :] += hidden_tile @ W_down[inter_start:, :]^T
822  */
823  const float *W_down_slice = W_down + (size_t)inter_start; /* Column slice */
824  float *out_tile = output + (size_t)m_start * hidden;
825 
826  /* This is trickier - W_down is [hidden × intermediate]
827  * We have hidden_tile[tile_m × tile_inter]
828  * We want out[tile_m × hidden] += hidden_tile × W_down[:, inter_start:inter_start+tile_inter]^T
829  *
830  * For proper accumulation, need to handle this carefully.
831  * For now, use a simpler approach: accumulate partial results.
832  */
833  for (int i = 0; i < tile_m; ++i) {
834  float *h = hidden_tile + (size_t)i * tile_inter;
835  float *o = out_tile + (size_t)i * hidden;
836 
837  for (int d = 0; d < hidden; ++d) {
838  const float *w_row = W_down + (size_t)d * intermediate + inter_start;
839  float sum = (inter_start == 0)
840  ? (B_down ? B_down[d] : 0.0f)
841  : o[d];
842 
843 #if defined(__AVX512F__)
844  __m512 acc = _mm512_setzero_ps();
845  int j = 0;
846  for (; j + 16 <= tile_inter; j += 16) {
847  __m512 hv = _mm512_loadu_ps(h + j);
848  __m512 wv = _mm512_loadu_ps(w_row + j);
849  acc = _mm512_fmadd_ps(hv, wv, acc);
850  }
851  sum += _mm512_reduce_add_ps(acc);
852  for (; j < tile_inter; ++j) {
853  sum += h[j] * w_row[j];
854  }
855 #elif defined(__AVX__)
856  __m256 acc = _mm256_setzero_ps();
857  int j = 0;
858  for (; j + 8 <= tile_inter; j += 8) {
859  __m256 hv = _mm256_loadu_ps(h + j);
860  __m256 wv = _mm256_loadu_ps(w_row + j);
861  acc = _mm256_add_ps(acc, _mm256_mul_ps(hv, wv));
862  }
863  sum += hsum256_prefill(acc);
864  for (; j < tile_inter; ++j) {
865  sum += h[j] * w_row[j];
866  }
867 #else
868  for (int j = 0; j < tile_inter; ++j) {
869  sum += h[j] * w_row[j];
870  }
871 #endif
872  o[d] = sum;
873  }
874  }
875  }
876  }
877 }
878 
880  const float *x,
881  const float *W_gate,
882  const float *W_up,
883  const float *W_down,
884  float *output,
885  int seq_len,
886  int hidden,
887  int intermediate,
888  float *scratch)
889 {
890  fused_mlp_swiglu_prefill_bias(x, W_gate, W_up, W_down,
891  NULL, NULL, NULL,
892  output, seq_len, hidden, intermediate,
893  scratch);
894 }
895 
896 /**
897  * @brief Get scratch size for fused MLP
898  */
899 size_t fused_mlp_swiglu_scratch_size(int intermediate) {
900  const int TILE_N_INTER = 512;
901  /* gate_tile + up_tile */
902  return 2 * (size_t)PREFILL_TILE_M * TILE_N_INTER * sizeof(float);
903 }
904 
905 static inline float silu_prefill(float x) {
906  return x / (1.0f + expf(-x));
907 }
908 
910  return (dt == CK_DT_Q5_0 || dt == CK_DT_Q8_0);
911 }
912 
914  return (dt == CK_DT_Q4_K || dt == CK_DT_Q6_K);
915 }
916 
917 static void gemm_nt_q8_0_mlp_dispatch(const void *A_q8,
918  const void *B,
919  const float *bias,
920  float *C,
921  int M,
922  int N,
923  int K,
924  CKDataType dt)
925 {
926  switch (dt) {
927  case CK_DT_Q5_0:
928  gemm_nt_q5_0_q8_0(A_q8, B, bias, C, M, N, K);
929  break;
930  case CK_DT_Q8_0:
931  gemm_nt_q8_0_q8_0(A_q8, B, bias, C, M, N, K);
932  break;
933  default:
934  break;
935  }
936 }
937 
938 static void gemm_nt_q8_k_mlp_dispatch(const void *A_q8,
939  const void *B,
940  const float *bias,
941  float *C,
942  int M,
943  int N,
944  int K,
945  CKDataType dt)
946 {
947  switch (dt) {
948  case CK_DT_Q4_K:
949  gemm_nt_q4_k_q8_k(A_q8, B, bias, C, M, N, K);
950  break;
951  case CK_DT_Q6_K:
952  gemm_nt_q6_k_q8_k(A_q8, B, bias, C, M, N, K);
953  break;
954  default:
955  break;
956  }
957 }
958 
959 /**
960  * @brief Quantized fused MLP for prefill (W1=gate+up, W2=down)
961  *
962  * Uses Q8_0 activations for W1 (Q5_0/Q8_0 weights) and Q8_K activations
963  * for W2 (Q4_K/Q6_K weights).
964  */
966  const float *x,
967  const void *W1,
968  const float *B1,
969  CKDataType w1_dt,
970  const void *W2,
971  const float *B2,
972  CKDataType w2_dt,
973  float *output,
974  int seq_len,
975  int embed_dim,
976  int aligned_embed_dim,
977  int intermediate_dim,
978  int aligned_intermediate_dim,
979  void *scratch)
980 {
981  if (!x || !W1 || !W2 || !output || !scratch) {
982  return;
983  }
984  if (seq_len <= 0 || embed_dim <= 0 || aligned_embed_dim <= 0 ||
985  intermediate_dim <= 0 || aligned_intermediate_dim <= 0) {
986  return;
987  }
988  if (aligned_embed_dim < embed_dim || aligned_intermediate_dim < intermediate_dim) {
989  return;
990  }
991  if ((aligned_embed_dim % 32) != 0 || (aligned_intermediate_dim % 256) != 0) {
992  return;
993  }
994  if (!mlp_q8_0_dtype_supported(w1_dt) || !mlp_q8_k_dtype_supported(w2_dt)) {
995  return;
996  }
997 
998  const int tile_m_max = PREFILL_TILE_M;
999  const int inter = aligned_intermediate_dim;
1000  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
1001  const size_t q8k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_intermediate_dim);
1002  const size_t w1_row_bytes = ck_dtype_row_bytes(w1_dt, (size_t)aligned_embed_dim);
1003 
1004  uint8_t *scratch_bytes = (uint8_t *)scratch;
1005  size_t q8_bytes = (size_t)tile_m_max * q8_row_bytes;
1006  size_t gate_bytes = (size_t)tile_m_max * (size_t)inter * sizeof(float);
1007  size_t up_bytes = gate_bytes;
1008  size_t gate_offset = align_up_size(q8_bytes, 64);
1009  size_t up_offset = gate_offset + align_up_size(gate_bytes, 64);
1010  size_t q8k_offset = up_offset + align_up_size(up_bytes, 64);
1011 
1012  uint8_t *q8_tile = scratch_bytes;
1013  float *gate_tile = (float *)(scratch_bytes + gate_offset);
1014  float *up_tile = (float *)(scratch_bytes + up_offset);
1015  uint8_t *q8k_tile = scratch_bytes + q8k_offset;
1016 
1017  const uint8_t *w1_base = (const uint8_t *)W1;
1018  const uint8_t *w_gate = w1_base;
1019  const uint8_t *w_up = w1_base + (size_t)inter * w1_row_bytes;
1020 
1021  const float *b_gate = B1;
1022  const float *b_up = B1 ? (B1 + (size_t)inter) : NULL;
1023 
1024  for (int m_start = 0; m_start < seq_len; m_start += tile_m_max) {
1025  int tile_m = (m_start + tile_m_max <= seq_len)
1026  ? tile_m_max : (seq_len - m_start);
1027 
1028  const float *x_tile = x + (size_t)m_start * (size_t)aligned_embed_dim;
1029  float *out_tile = output + (size_t)m_start * (size_t)aligned_embed_dim;
1030 
1031  for (int t = 0; t < tile_m; ++t) {
1032  const float *row = x_tile + (size_t)t * (size_t)aligned_embed_dim;
1033  quantize_row_q8_0(row,
1034  q8_tile + (size_t)t * q8_row_bytes,
1035  aligned_embed_dim);
1036  }
1037 
1038  gemm_nt_q8_0_mlp_dispatch(q8_tile, w_gate, b_gate, gate_tile,
1039  tile_m, inter, aligned_embed_dim, w1_dt);
1040  gemm_nt_q8_0_mlp_dispatch(q8_tile, w_up, b_up, up_tile,
1041  tile_m, inter, aligned_embed_dim, w1_dt);
1042 
1043  for (int i = 0; i < tile_m; ++i) {
1044  float *g = gate_tile + (size_t)i * (size_t)inter;
1045  float *u = up_tile + (size_t)i * (size_t)inter;
1046  for (int j = 0; j < inter; ++j) {
1047  g[j] = silu_prefill(g[j]) * u[j];
1048  }
1049  }
1050 
1051  for (int i = 0; i < tile_m; ++i) {
1052  const float *row = gate_tile + (size_t)i * (size_t)inter;
1053  quantize_row_q8_k(row,
1054  q8k_tile + (size_t)i * q8k_row_bytes,
1055  aligned_intermediate_dim);
1056  }
1057 
1058  gemm_nt_q8_k_mlp_dispatch(q8k_tile, W2, B2, out_tile,
1059  tile_m, aligned_embed_dim, aligned_intermediate_dim, w2_dt);
1060  }
1061 }
1062 
1064  int aligned_intermediate_dim)
1065 {
1066  if (aligned_embed_dim <= 0 || aligned_intermediate_dim <= 0) {
1067  return 0;
1068  }
1069  const size_t q8_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_0, (size_t)aligned_embed_dim);
1070  const size_t q8k_row_bytes = ck_dtype_row_bytes(CK_DT_Q8_K, (size_t)aligned_intermediate_dim);
1071  const size_t q8_bytes = (size_t)PREFILL_TILE_M * q8_row_bytes;
1072  const size_t gate_bytes = (size_t)PREFILL_TILE_M * (size_t)aligned_intermediate_dim * sizeof(float);
1073  const size_t up_bytes = gate_bytes;
1074  const size_t q8k_bytes = (size_t)PREFILL_TILE_M * q8k_row_bytes;
1075 
1076  return align_up_size(q8_bytes, 64) +
1077  align_up_size(gate_bytes, 64) +
1078  align_up_size(up_bytes, 64) +
1079  align_up_size(q8k_bytes, 64);
1080 }
CKDataType
Supported data types in C-Kernel-Engine.
Definition: ckernel_dtype.h:27
@ CK_DT_Q4_K
Definition: ckernel_dtype.h:40
@ CK_DT_Q8_0
Definition: ckernel_dtype.h:42
@ CK_DT_Q8_K
Definition: ckernel_dtype.h:43
@ CK_DT_Q5_0
Definition: ckernel_dtype.h:44
@ CK_DT_Q6_K
Definition: ckernel_dtype.h:41
static size_t ck_dtype_row_bytes(CKDataType dt, size_t n_elements)
Calculate total bytes for n_elements of given dtype.
void gemm_nt_q4_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
void quantize_row_q8_k(const float *x, void *y, int k)
void quantize_row_q8_0(const float *x, void *y, int k)
Quantize FP32 to Q8_0 format (scalar reference)
void gemm_nt_q8_0_q8_0(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
gemm_nt_q8_0_q8_0 with optional bias (matches header signature)
void gemm_nt_q6_k_q8_k(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K)
NT GEMM: C = A @ B^T where A is Q8_K and B is Q6_K.
Quantization block structures for weight-only quantization.
void gemm_nt_q5_0_q8_0(const void *A_q8, const void *B_q5, const float *bias, float *C, int M, int N, int K)
Batch GEMM with Q5_0 weights and Q8_0 activations for prefill.
static void gemm_nt_q8_0_mlp_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
static void gemm_nt_q8_0_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
size_t fused_rmsnorm_qkv_prefill_head_major_quant_scratch_size(int aligned_embed_dim)
Get scratch buffer size for fused_rmsnorm_qkv_prefill_head_major_quant.
static int qkv_q8_k_dtype_supported(CKDataType dt)
#define PREFILL_TILE_M
static size_t align_up_size(size_t value, size_t align)
static int mlp_q8_k_dtype_supported(CKDataType dt)
void fused_mlp_swiglu_prefill(const float *x, const float *W_gate, const float *W_up, const float *W_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP (Gate + Up + SwiGLU + Down) for prefill.
static float silu_prefill(float x)
static void fused_rmsnorm_gemm_2d_tiled(const float *x, const float *gamma, const float *W, float *output, int seq_len, int hidden, int out_dim, float eps, float *x_norm_scratch)
Fused RMSNorm + single GEMM with 2D tiling (weight reuse)
size_t fused_mlp_swiglu_prefill_w1w2_quant_scratch_size(int aligned_embed_dim, int aligned_intermediate_dim)
Get scratch buffer size for fused_mlp_swiglu_prefill_w1w2_quant.
void fused_rmsnorm_qkv_prefill_head_major(const float *x, const float *gamma, const float *Wq, const float *Bq, const float *Wk, const float *Bk, const float *Wv, const float *Bv, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (head-major outputs)
static void gemm_nt_q8_k_qkv_dispatch(const void *A_q8k, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
#define PREFILL_TILE_N
static void add_bias_tile(float *out, const float *bias, int tile_m, int out_dim)
void fused_rmsnorm_qkv_prefill(const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps, float *scratch)
Fused RMSNorm + QKV projection for prefill (v3 optimized)
void unfused_rmsnorm_qkv_prefill(const float *x, const float *gamma, const float *Wq, const float *Wk, const float *Wv, float *x_norm, float *Q, float *K, float *V, int seq_len, int hidden, int q_dim, int kv_dim, float eps)
Unfused version for comparison.
static int qkv_q8_0_dtype_supported(CKDataType dt)
static int mlp_q8_0_dtype_supported(CKDataType dt)
static void gemm_tile_nt_strided(const float *A, const float *B_tile, float *C, int tile_m, int tile_n, int K, int C_stride)
GEMM tile with N-dimension tiling (weight reuse)
size_t fused_rmsnorm_qkv_scratch_size(int hidden)
Get scratch size for fused prefill.
static void gemm_nt_q8_k_mlp_dispatch(const void *A_q8, const void *B, const float *bias, float *C, int M, int N, int K, CKDataType dt)
void fused_mlp_swiglu_prefill_bias(const float *x, const float *W_gate, const float *W_up, const float *W_down, const float *B_gate, const float *B_up, const float *B_down, float *output, int seq_len, int hidden, int intermediate, float *scratch)
Fused MLP for prefill with proper tiling.
static void rmsnorm_tile(const float *input, const float *gamma, float *output, int tile_m, int embed_dim, int aligned_embed_dim, float eps)
Compute RMSNorm for a tile of tokens.
void fused_mlp_swiglu_prefill_w1w2_quant(const float *x, const void *W1, const float *B1, CKDataType w1_dt, const void *W2, const float *B2, CKDataType w2_dt, float *output, int seq_len, int embed_dim, int aligned_embed_dim, int intermediate_dim, int aligned_intermediate_dim, void *scratch)
Quantized fused MLP for prefill (W1=gate+up, W2=down)
size_t fused_mlp_swiglu_scratch_size(int intermediate)
Get scratch size for fused MLP.
void fused_rmsnorm_qkv_prefill_head_major_quant(const float *x, const float *gamma, const void *Wq, const float *Bq, CKDataType wq_dt, const void *Wk, const float *Bk, CKDataType wk_dt, const void *Wv, const float *Bv, CKDataType wv_dt, float *Q, float *K, float *V, int seq_len, int embed_dim, int aligned_embed_dim, int num_heads, int num_kv_heads, int head_dim, int aligned_head_dim, int kv_stride_tokens, float eps, void *scratch)
Fused RMSNorm + QKV projection for prefill (head-major, Q8 activations)
#define C(color)
Definition: show_config.c:39
static void silu(float *x, int n)
Definition: v6_simple.c:159